Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Initial support of multimodal models for V1 re-arch #10699

Merged
merged 31 commits into from
Dec 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
246d75b
internvl
ywang96 Nov 27, 2024
2a081bb
fix token id
ywang96 Nov 27, 2024
e4d6bb2
Merge branch 'vllm-project:main' into v1-initial
ywang96 Nov 28, 2024
94d66cc
Pixtral
ywang96 Nov 30, 2024
79f24c6
use special ids
ywang96 Nov 30, 2024
7a88433
comment
ywang96 Nov 30, 2024
af1dbab
cleanup for pixtral
ywang96 Nov 30, 2024
39dd4f2
Merge branch 'vllm-project:main' into v1-initial
ywang96 Nov 30, 2024
6d0df5a
qwen2vl
ywang96 Dec 1, 2024
124b0c1
Merge branch 'vllm-project:main' into v1-initial
ywang96 Dec 2, 2024
8c4da46
molmo
ywang96 Dec 2, 2024
3e3a346
minor changes on interfaces
ywang96 Dec 2, 2024
1c50613
typo
ywang96 Dec 2, 2024
6d8ddff
pad
ywang96 Dec 2, 2024
7ddf7d9
Merge branch 'vllm-project:main' into v1-initial
ywang96 Dec 2, 2024
f1fa769
remove print
ywang96 Dec 3, 2024
ee8e0ae
Merge branch 'vllm-project:main' into v1-initial
ywang96 Dec 3, 2024
319e689
Merge branch 'vllm-project:main' into v1-initial
ywang96 Dec 4, 2024
77256d9
change check order
ywang96 Dec 4, 2024
bdd8da6
Merge branch 'main' into v1-initial
ywang96 Dec 5, 2024
e32efd5
Merge branch 'vllm-project:main' into v1-initial
ywang96 Dec 6, 2024
0176b7b
molmo
ywang96 Dec 6, 2024
69f4e5f
fix launch args
ywang96 Dec 6, 2024
8b7e746
fix qwen2-vl
ywang96 Dec 6, 2024
bb15b01
typing
ywang96 Dec 6, 2024
610e662
add documentation
ywang96 Dec 6, 2024
2b5fdd7
minor fix
ywang96 Dec 6, 2024
a5a38dd
typehint
ywang96 Dec 6, 2024
fbf9cd0
Merge branch 'main' into v1-initial
ywang96 Dec 7, 2024
8d1d80e
iterate
ywang96 Dec 8, 2024
4a79255
revert changes in qwen2vl
ywang96 Dec 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,9 +1053,12 @@ def create_engine_config(self,
# long context (> 32K) models. This is to avoid OOM errors in the
# initial memory profiling phase.

# Chunked prefill is currently disabled for multimodal models by
# default.
if use_long_context and not model_config.is_multimodal_model:
# For multimodal models, chunked prefill is disabled by default in
# V0, but enabled by design in V1
if model_config.is_multimodal_model:
self.enable_chunked_prefill = bool(envs.VLLM_USE_V1)

elif use_long_context:
is_gpu = device_config.device_type == "cuda"
use_sliding_window = (model_config.get_sliding_window()
is not None)
Expand Down Expand Up @@ -1244,12 +1247,9 @@ def _override_v1_engine_config(self, engine_config: VllmConfig) -> None:
Override the EngineConfig's configs based on the usage context for V1.
"""
assert envs.VLLM_USE_V1, "V1 is not enabled"
# TODO (ywang96): Enable APC by default when VLM supports it.
if engine_config.model_config.is_multimodal_model:
logger.warning(
"Prefix caching is currently not supported for multimodal "
"models and has been disabled.")
engine_config.cache_config.enable_prefix_caching = False
# TODO (ywang96): Enable APC by default when VLM supports it.
assert not engine_config.cache_config.enable_prefix_caching


@dataclass
Expand Down
5 changes: 5 additions & 0 deletions vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[T]:
"""
Returns multimodal embeddings generated from multimodal kwargs
to be merged with text embeddings.

The output embeddings must be one of the following formats:
- A list or tuple of 2D tensors, where each tensor corresponds to
each input image.
- A single 3D tensor, with the batch dimension grouping the 2D tensors.
ywang96 marked this conversation as resolved.
Show resolved Hide resolved
"""
...

Expand Down
68 changes: 57 additions & 11 deletions vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
InternVisionPatchModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
Expand All @@ -52,12 +52,18 @@ class InternVLImagePixelInputs(TypedDict):
Shape:
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
"""
patches_per_image: List[int]
"""
List of number of total patches for each image in the batch.
"""


class InternVLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
data: NestedTensors
"""
A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
or a list of tensors of shape `(total_image_feature_size, hidden_size)`

`hidden_size` must match the hidden size of language model backbone.
"""
Expand Down Expand Up @@ -349,10 +355,32 @@ def input_processor(
new_prompt = self._expand_image_prompt(prompt, image_feature_sizes,
num_patches)
new_prompt_token_ids = tokenizer.encode(new_prompt)
img_context_token_id = tokenizer.encode(self.img_context_token,
add_special_tokens=False)
assert len(img_context_token_id) == 1, \
(f"Invalid image token '{self.img_context_token}': A valid image "
f"token encodes to a single token ID, got {img_context_token_id}.")
img_context_token_id = img_context_token_id[0]

# Get precise tracking of placeholder positions
token_idx = image_idx = 0
placeholder_ranges = []
while token_idx < len(new_prompt_token_ids):
if new_prompt_token_ids[token_idx] == img_context_token_id:
curr_image_featue_size = image_feature_sizes[image_idx]
placeholder_ranges.append(
PlaceholderRange(offset=token_idx,
length=curr_image_featue_size))
image_idx += 1
token_idx += curr_image_featue_size
else:
token_idx += 1

return token_inputs(prompt=prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=multi_modal_data)
return token_inputs(
prompt=prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": placeholder_ranges})

def input_mapper(
self,
Expand Down Expand Up @@ -614,26 +642,46 @@ def _parse_and_validate_image_input(
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")

patches_per_image = []
for request_pixel_values in pixel_values:
for image_pixel_values in request_pixel_values:
patches_per_image.append(image_pixel_values.shape[0])
# We need to flatten (B, N, P) to (B*N*P),
# so we call flatten_bn twice.
return InternVLImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
flatten_bn(flatten_bn(pixel_values), concat=True)),
)
patches_per_image=patches_per_image)

raise AssertionError("This line should be unreachable.")

def _process_image_input(
self,
image_input: InternVLImageInputs,
) -> torch.Tensor:
) -> Tuple[torch.Tensor]:
if image_input["type"] == "image_embeds":
return image_input["data"]

assert self.vision_model is not None

image_embeds = self.extract_feature(image_input["data"])

patches_per_image = image_input["patches_per_image"]
if len(patches_per_image) == 1:
image_embeds = image_embeds.unsqueeze(0)
return image_embeds

# NOTE: Image embeddings are split into separate tensors for each image
# by the size of each embedding.
feature_size = image_embeds.shape[1]
image_embeds = image_embeds.view(-1,
self.config.text_config.hidden_size)
image_feature_sizes = [
num_patches * feature_size for num_patches in patches_per_image
]
image_embeds = image_embeds.split(image_feature_sizes)
return image_embeds

def _set_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -696,13 +744,11 @@ def forward(
"inputs_embeds": inputs_embeds,
}

# Only required if the model is mono-architecture
if self.visual_token_mask is not None:
# overwrite visual_token_mask and img_context_token_id back to None,
# so that this doesn't need to depend on encoder output
forward_kwargs.update(
{"visual_token_mask": self.visual_token_mask})
self.visual_token_mask = None
self.img_context_token_id = None

hidden_states = self.language_model.model(**forward_kwargs)
return hidden_states
Expand Down
72 changes: 63 additions & 9 deletions vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
Expand All @@ -46,12 +46,16 @@
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
maybe_prefix, merge_multimodal_embeddings)

# TODO: hard-coded for now. Consider making it configurable.
VIT_LAYERS = [-2, -9]
NUM_PREFIX_TOKENS = 1
ADDITIONAL_VOCAB_SIZE = 128
DEFAULT_IMAGE_PATCH_TOKEN_ID = 152066
DEFAULT_IM_START_TOKEN_ID = 152067
DEFAULT_IM_END_TOKEN_ID = 152064
DEFAULT_IM_COL_TOKEN_ID = 152065


class MolmoImageInputs(TypedDict):
Expand All @@ -75,6 +79,11 @@ class MolmoImageInputs(TypedDict):
`(batch_size, num_crops, num_patch)`
"""

image_start_end: Tuple[int, int]
"""Starting and ending index of placeholder
tokens
"""


@dataclass
class VisionBackboneConfig:
Expand Down Expand Up @@ -918,6 +927,8 @@ def image_input_mapper_for_molmo(
ctx: InputContext,
data: object,
):
if isinstance(data, list):
data = data[0]
return MultiModalKwargs(data)


Expand Down Expand Up @@ -967,7 +978,22 @@ def dummy_data_for_molmo(ctx: InputContext, seq_len: int,
if "image_masks" in out:
dummy_imgdata["image_masks"] = out["image_masks"]
dummy_imgdata["seq_len"] = torch.tensor(seq_len, dtype=torch.long)
return DummyData(dummy_seqdata, {"image": dummy_imgdata})
size = 0
offset = -1
for i in range(len(token_ids)):
if token_ids[i] in (DEFAULT_IMAGE_PATCH_TOKEN_ID,
DEFAULT_IM_START_TOKEN_ID, DEFAULT_IM_END_TOKEN_ID,
DEFAULT_IM_COL_TOKEN_ID):
if offset < 0:
offset = i
size += 1
dummy_imgdata["image_start_end"] = (offset, offset + size)
return DummyData(seq_data=dummy_seqdata,
multi_modal_data={"image": dummy_imgdata},
multi_modal_placeholders={
"image":
[PlaceholderRange(offset=offset, length=size)]
})


def pad_images(
Expand Down Expand Up @@ -1055,19 +1081,34 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
if image_masks is not None:
image_data["image_masks"] = image_masks

image_data["seq_len"] = torch.tensor(len(out["input_ids"]),
new_prompt_token_ids = out["input_ids"].tolist()
image_data["seq_len"] = torch.tensor(len(new_prompt_token_ids),
dtype=torch.long)

multi_modal_data = dict(image=image_data)
size = 0
offset = -1
for i in range(len(new_prompt_token_ids)):
if new_prompt_token_ids[i] in (DEFAULT_IMAGE_PATCH_TOKEN_ID,
DEFAULT_IM_START_TOKEN_ID,
DEFAULT_IM_END_TOKEN_ID,
DEFAULT_IM_COL_TOKEN_ID):
if offset < 0:
offset = i
size += 1
image_data["image_start_end"] = (offset, offset + size)

prompt = inputs.get("prompt")
if prompt is None:
prompt = tokenizer.decode(out["input_ids"])
prompt = tokenizer.decode(new_prompt_token_ids)

return token_inputs(
prompt_token_ids=out["input_ids"],
prompt_token_ids=new_prompt_token_ids,
prompt=prompt,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={
"image": [PlaceholderRange(offset=offset, length=size)]
},
)


Expand Down Expand Up @@ -1113,6 +1154,7 @@ def _parse_and_validate_image_input(
) -> Optional[MolmoImageInputs]:
images = kwargs.pop("images", None)
image_masks = kwargs.pop("image_masks", None)
image_start_end = kwargs.pop("image_start_end", None)
if images is None:
return None

Expand All @@ -1130,6 +1172,7 @@ def _parse_and_validate_image_input(
image_input_idx=image_input_idx,
seq_len=seq_len,
image_masks=image_masks,
image_start_end=image_start_end,
)

def _process_image_input(
Expand Down Expand Up @@ -1178,9 +1221,16 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:

# Note: In this original implementation from AI2, the final
# vision_embeddings will be always be the same length
# of input embedddings, which is not very efficient.
# TODO(ywang96): see if this can be optimized.
# of input embeddings.
vision_embeddings = torch.einsum('nd,nm->md', image_features, mat)

# Split by the sizes of the input sequences. For each full embedding,
# extract the actual vision embeddings to be merged.
vision_embeddings = list(vision_embeddings.split(seq_len.tolist()))
for i in range(len(vision_embeddings)):
start, end = image_input['image_start_end'][i]
vision_embeddings[i] = vision_embeddings[i][start:end]

return vision_embeddings

def get_input_embeddings(
Expand All @@ -1190,7 +1240,11 @@ def get_input_embeddings(
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = inputs_embeds + multimodal_embeddings
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, [
DEFAULT_IMAGE_PATCH_TOKEN_ID, DEFAULT_IM_START_TOKEN_ID,
DEFAULT_IM_END_TOKEN_ID, DEFAULT_IM_COL_TOKEN_ID
])
return inputs_embeds

def forward(
Expand Down
Loading
Loading