Skip to content

Commit

Permalink
[V1] Support VLMs with fine-grained scheduling (vllm-project#9871)
Browse files Browse the repository at this point in the history
Signed-off-by: Woosuk Kwon <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
  • Loading branch information
WoosukKwon and ywang96 authored Nov 13, 2024
1 parent 0d4ea3f commit bbd3e86
Show file tree
Hide file tree
Showing 12 changed files with 542 additions and 96 deletions.
11 changes: 9 additions & 2 deletions vllm/model_executor/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,11 @@ def forward(
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
inputs_embeds = self.wte(input_ids)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
else:
Expand Down Expand Up @@ -263,16 +265,21 @@ def __init__(
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.wte(input_ids)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states

def compute_logits(
Expand Down
7 changes: 6 additions & 1 deletion vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,16 +538,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
normalize=False,
softmax=False)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
attn_metadata, intermediate_tensors,
inputs_embeds)
return model_output

def compute_logits(
Expand Down
46 changes: 28 additions & 18 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import NestedTensors
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of

Expand Down Expand Up @@ -448,13 +449,33 @@ def _process_image_input(self,
image_features = self._process_image_pixels(image_input)
return self.multi_modal_projector(image_features)

def process_mm_inputs(self, **kwargs):
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings

def get_input_embeddings(
self,
input_ids: torch.Tensor,
vision_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if vision_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
return inputs_embeds

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for LLaVA-1.5.
Expand Down Expand Up @@ -494,24 +515,13 @@ def forward(
"""
if intermediate_tensors is not None:
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)

inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
else:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)

# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
# for `torch.compile` integration
input_ids = None
elif inputs_embeds is None:
vision_embeddings = self.process_mm_inputs(**kwargs)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None

hidden_states = self.language_model.model(input_ids,
positions,
Expand Down
7 changes: 6 additions & 1 deletion vllm/model_executor/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,16 +360,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states

def compute_logits(
Expand Down
63 changes: 40 additions & 23 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import NestedTensors, PlaceholderRange
from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.utils import is_list_of
Expand Down Expand Up @@ -500,23 +501,29 @@ def input_processor_for_phi3v(ctx: InputContext,

# TODO: Move this to utils or integrate with clip.
new_token_ids: List[int] = []
placeholder_ranges: List[PlaceholderRange] = []
placeholder_idx = 0
while merged_token_ids:
token_id = merged_token_ids.pop(0)
if token_id == _IMAGE_TOKEN_ID:
new_token_ids.extend(
repeat_and_pad_token(
_IMAGE_TOKEN_ID,
repeat_count=image_feature_size[placeholder_idx],
))
replacement_ids = repeat_and_pad_token(
_IMAGE_TOKEN_ID,
repeat_count=image_feature_size[placeholder_idx],
)
placeholder_ranges.append({
"offset": len(new_token_ids),
"length": len(replacement_ids)
})
new_token_ids.extend(replacement_ids)
placeholder_idx += 1
else:
new_token_ids.append(token_id)

# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": placeholder_ranges})


@MULTIMODAL_REGISTRY.register_image_input_mapper()
Expand Down Expand Up @@ -669,32 +676,42 @@ def _process_image_input(

return image_embeds

def process_mm_inputs(self, **kwargs):
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings

def get_input_embeddings(
self,
input_ids: torch.Tensor,
vision_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
if vision_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.image_token_id)
return inputs_embeds

def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object):
if intermediate_tensors is not None:
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)

if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.embed_tokens(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.image_token_id)
else:
inputs_embeds = self.language_model.model.embed_tokens(
input_ids)

# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
# for `torch.compile` integration
input_ids = None
elif inputs_embeds is None:
vision_embeddings = self.process_mm_inputs(**kwargs)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None

hidden_states = self.language_model.model(input_ids,
positions,
Expand Down
7 changes: 6 additions & 1 deletion vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,16 +441,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states

def compute_logits(
Expand Down
48 changes: 48 additions & 0 deletions vllm/v1/core/encoder_cache_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import Dict, List, Set, Tuple

from vllm.v1.request import Request


class EncoderCacheManager:

def __init__(self, cache_size: int):
self.cache_size = cache_size
self.num_free_slots = cache_size
# req_id -> cached input ids
self.cached: Dict[str, Set[int]] = {}
# List of [req_id, input_id]
self.freed: List[Tuple[str, int]] = []

def has_cache(self, request: Request, input_id: int) -> bool:
req_id = request.request_id
return req_id in self.cached and input_id in self.cached[req_id]

def can_allocate(self, request: Request, input_id: int) -> bool:
num_tokens = request.get_num_encoder_tokens(input_id)
return num_tokens <= self.num_free_slots

def allocate(self, request: Request, input_id: int) -> None:
req_id = request.request_id
if req_id not in self.cached:
self.cached[req_id] = set()
self.cached[req_id].add(input_id)
self.num_free_slots -= request.get_num_encoder_tokens(input_id)

def get_cached_input_ids(self, request: Request) -> Set[int]:
return self.cached.get(request.request_id, set())

def free(self, request: Request, input_id: int) -> None:
req_id = request.request_id
if req_id not in self.cached:
return

self.cached[req_id].discard(input_id)
if len(self.cached[req_id]) == 0:
del self.cached[req_id]
self.num_free_slots += request.get_num_encoder_tokens(input_id)
self.freed.append((req_id, input_id))

def get_freed_ids(self) -> List[Tuple[str, int]]:
freed = self.freed
self.freed = []
return freed
Loading

0 comments on commit bbd3e86

Please sign in to comment.