Skip to content

Commit

Permalink
[V1] VLM - Run the mm_mapper preprocessor in the frontend process (vl…
Browse files Browse the repository at this point in the history
…lm-project#10640)

Signed-off-by: Roger Wang <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
  • Loading branch information
3 people authored Dec 3, 2024
1 parent f6084f6 commit 3bc94ca
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 25 deletions.
3 changes: 1 addition & 2 deletions tests/v1/engine/test_engine_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@ def make_request() -> EngineCoreRequest:
request_id=uuid.uuid4(),
prompt=PROMPT,
prompt_token_ids=PROMPT_TOKENS,
mm_data=None,
mm_inputs=None,
mm_placeholders=None,
mm_processor_kwargs=None,
sampling_params=SamplingParams(),
eos_token_id=None,
arrival_time=time.time(),
Expand Down
3 changes: 1 addition & 2 deletions tests/v1/engine/test_engine_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@ def make_request(params: SamplingParams) -> EngineCoreRequest:
request_id=str(uuid.uuid4()),
prompt=PROMPT,
prompt_token_ids=PROMPT_TOKENS,
mm_data=None,
mm_inputs=None,
mm_placeholders=None,
mm_processor_kwargs=None,
sampling_params=params,
eos_token_id=None,
arrival_time=time.time(),
Expand Down
24 changes: 23 additions & 1 deletion vllm/inputs/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from typing_extensions import NotRequired, TypedDict, TypeVar, assert_never

if TYPE_CHECKING:
from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict
from vllm.multimodal import (MultiModalDataDict, MultiModalKwargs,
MultiModalPlaceholderDict)
from vllm.multimodal.inputs import MultiModalInputsV2


Expand Down Expand Up @@ -150,6 +151,12 @@ class TokenInputs(TypedDict):
if the model supports it.
"""

multi_modal_inputs: NotRequired["MultiModalKwargs"]
"""
Optional multi-modal inputs to pass to the model,
if the model supports it.
"""

multi_modal_placeholders: NotRequired["MultiModalPlaceholderDict"]
"""
Placeholder ranges for the multi-modal data.
Expand All @@ -169,6 +176,7 @@ def token_inputs(
token_type_ids: Optional[List[int]] = None,
prompt: Optional[str] = None,
multi_modal_data: Optional["MultiModalDataDict"] = None,
multi_modal_inputs: Optional["MultiModalKwargs"] = None,
multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> TokenInputs:
Expand All @@ -181,6 +189,8 @@ def token_inputs(
inputs["token_type_ids"] = token_type_ids
if multi_modal_data is not None:
inputs["multi_modal_data"] = multi_modal_data
if multi_modal_inputs is not None:
inputs["multi_modal_inputs"] = multi_modal_inputs
if multi_modal_placeholders is not None:
inputs["multi_modal_placeholders"] = multi_modal_placeholders
if mm_processor_kwargs is not None:
Expand Down Expand Up @@ -273,6 +283,18 @@ def multi_modal_data(self) -> "MultiModalDataDict":

assert_never(inputs)

@cached_property
def multi_modal_inputs(self) -> Union[Dict, "MultiModalKwargs"]:
inputs = self.inputs

if inputs["type"] == "token":
return inputs.get("multi_modal_inputs", {})

if inputs["type"] == "multimodal":
return inputs.get("mm_kwargs", {})

assert_never(inputs)

@cached_property
def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
inputs = self.inputs
Expand Down
7 changes: 3 additions & 4 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import enum
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from typing import List, Optional, Union

import msgspec

from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict
from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict
from vllm.sampling_params import RequestOutputKind, SamplingParams


Expand Down Expand Up @@ -35,9 +35,8 @@ class EngineCoreRequest:
# always be tokenized?
prompt: Optional[str]
prompt_token_ids: List[int]
mm_data: Optional[MultiModalDataDict]
mm_inputs: Optional[List[MultiModalKwargs]]
mm_placeholders: Optional[MultiModalPlaceholderDict]
mm_processor_kwargs: Optional[Dict[str, Any]]
sampling_params: SamplingParams
eos_token_id: Optional[int]
arrival_time: float
Expand Down
7 changes: 0 additions & 7 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,7 @@ def _initialize_kv_caches(self,

def add_request(self, request: EngineCoreRequest):
"""Add request to the scheduler."""

req = Request.from_engine_core_request(request)
# FIXME(woosuk): The input mapping (e.g., PIL images to tensors) may
# take 10-50 ms, which can cause a spike in the latency. We should
# consider moving this to a separate thread.
if req.mm_data:
req.mm_inputs = self.mm_input_mapper.process_inputs(
req.mm_data, req.mm_processor_kwargs)
self.scheduler.add_request(req)

def abort_requests(self, request_ids: List[str]):
Expand Down
13 changes: 11 additions & 2 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest
from vllm.v1.engine.mm_input_mapper import MMInputMapper


class Processor:
Expand All @@ -39,6 +40,9 @@ def __init__(
self.input_processor = input_registry.create_input_processor(
model_config)

# Multi-modal (huggingface) input mapper
self.mm_input_mapper = MMInputMapper(model_config)

# TODO: run in an ThreadpoolExecutor or BackgroundProcess.
# This ideally should releases the GIL, so we should not block the
# asyncio loop while this is running.
Expand Down Expand Up @@ -96,6 +100,12 @@ def process_inputs(
sampling_params.update_from_generation_config(
self.generation_config_fields, eos_token_id)

# Preprocess multi-modal data
mm_inputs = self.mm_input_mapper.process_inputs(
decoder_inputs.multi_modal_data,
decoder_inputs.mm_processor_kwargs) if len(
decoder_inputs.multi_modal_data) > 0 else None

# Make Request for Detokenizer.
detokenizer_request = DetokenizerRequest(
request_id,
Expand All @@ -113,9 +123,8 @@ def process_inputs(
request_id,
decoder_inputs.prompt,
decoder_inputs.prompt_token_ids,
decoder_inputs.multi_modal_data,
mm_inputs,
decoder_inputs.multi_modal_placeholders,
decoder_inputs.mm_processor_kwargs,
sampling_params,
eos_token_id,
arrival_time,
Expand Down
15 changes: 8 additions & 7 deletions vllm/v1/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,17 @@ def __init__(
self._all_token_ids: List[int] = self.prompt_token_ids.copy()
self.num_computed_tokens = 0

# Raw multimodal data before the mm input mapper (e.g., PIL images).
self.mm_data = self.inputs.multi_modal_data
self.mm_processor_kwargs = self.inputs.mm_processor_kwargs
mm_positions = self.inputs.multi_modal_placeholders
if mm_positions:
# FIXME(woosuk): Support other modalities.
self.mm_positions = mm_positions.get("image", [])
else:
self.mm_positions = []
# Output of the mm input mapper (e.g., image tensors).
self.mm_inputs: List[MultiModalKwargs] = []
if self.inputs.multi_modal_inputs:
self.mm_inputs = self.inputs.multi_modal_inputs
else:
self.mm_inputs: List[MultiModalKwargs] = []

@classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
Expand All @@ -64,9 +64,10 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
inputs=token_inputs(
prompt_token_ids=request.prompt_token_ids,
prompt=request.prompt,
multi_modal_data=request.mm_data,
multi_modal_data=None,
multi_modal_inputs=request.mm_inputs,
multi_modal_placeholders=request.mm_placeholders,
mm_processor_kwargs=request.mm_processor_kwargs,
mm_processor_kwargs=None,
),
sampling_params=request.sampling_params,
eos_token_id=request.eos_token_id,
Expand Down Expand Up @@ -110,7 +111,7 @@ def get_finished_reason(self) -> Union[str, None]:
return RequestStatus.get_finished_reason(self.status)

def has_encoder_inputs(self) -> bool:
return len(self.mm_data) > 0
return len(self.mm_inputs) > 0

@property
def num_encoder_inputs(self) -> int:
Expand Down

0 comments on commit 3bc94ca

Please sign in to comment.