From f04328d5c5c3ab5fa5d747e2a157d4df219f3987 Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Mon, 25 Nov 2024 15:51:35 +0000 Subject: [PATCH 01/10] allow mm_mapper execution in the frontend process --- examples/offline_inference_vision_language.py | 9 +++++++- vllm/config.py | 7 +++++- vllm/engine/arg_utils.py | 8 +++++++ vllm/entrypoints/llm.py | 5 ++++ vllm/inputs/data.py | 23 ++++++++++++++++++- vllm/v1/engine/__init__.py | 4 +++- vllm/v1/engine/core.py | 9 ++++---- vllm/v1/engine/processor.py | 18 ++++++++++++++- vllm/v1/request.py | 9 ++++++-- 9 files changed, 81 insertions(+), 11 deletions(-) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index f08f22eec164a..3ff29fe16a7d4 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -5,6 +5,7 @@ For most models, the prompt format should follow corresponding examples on HuggingFace model repository. """ +import time from transformers import AutoTokenizer from vllm import LLM, SamplingParams @@ -23,7 +24,10 @@ def run_llava(question: str, modality: str): prompt = f"USER: \n{question}\nASSISTANT:" - llm = LLM(model="llava-hf/llava-1.5-7b-hf", max_model_len=4096) + llm = LLM(model="llava-hf/llava-1.5-7b-hf", + max_model_len=4096) + #mm_disable_frontend_processor=True) + stop_token_ids = None return llm, prompt, stop_token_ids @@ -514,7 +518,10 @@ def main(args): }, } for _ in range(args.num_prompts)] + start_time = time.time() outputs = llm.generate(inputs, sampling_params=sampling_params) + elapsed_time = time.time() - start_time + print("generate time = {}".format(elapsed_time)) for o in outputs: generated_text = o.outputs[0].text diff --git a/vllm/config.py b/vllm/config.py index 5f50d65ec87e1..2fe677026e016 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -131,6 +131,8 @@ class ModelConfig: HuggingFace config. mm_processor_kwargs: Arguments to be forwarded to the model's processor for multi-modal data, e.g., image processor. + mm_disable_frontend_processor: Disables multi-modal HF preprocessor/mapper + execution in the frontend process (not recommended) override_neuron_config: Initialize non default neuron config or override default neuron config that are specific to Neuron devices, this argument will be used to configure the neuron config that @@ -169,6 +171,7 @@ def __init__( config_format: ConfigFormat = ConfigFormat.AUTO, hf_overrides: Optional[HfOverrides] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None, + mm_disable_frontend_processor: bool = False, override_neuron_config: Optional[Dict[str, Any]] = None, override_pooler_config: Optional["PoolerConfig"] = None) -> None: self.model = model @@ -235,6 +238,7 @@ def __init__( self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.use_async_output_proc = use_async_output_proc self.mm_processor_kwargs = mm_processor_kwargs + self.mm_disable_frontend_processor = mm_disable_frontend_processor # Set enforce_eager to False if the value is unset. if self.enforce_eager is None: @@ -2525,7 +2529,8 @@ def __str__(self): "decoding_config=%r, observability_config=%r, " "seed=%d, served_model_name=%s, " "num_scheduler_steps=%d, enable_prefix_caching=%s, " - "use_async_output_proc=%s, mm_processor_kwargs=%s") % \ + "use_async_output_proc=%s, mm_processor_kwargs=%s, " + "mm_disable_frontend_processor=%s") % \ (self.model_config.model, self.speculative_config, self.model_config.tokenizer, self.model_config.skip_tokenizer_init, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 4aa0eebd976c9..30d92031eac7d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -143,6 +143,7 @@ class EngineArgs: tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None limit_mm_per_prompt: Optional[Mapping[str, int]] = None mm_processor_kwargs: Optional[Dict[str, Any]] = None + mm_disable_frontend_processor: bool = False enable_lora: bool = False enable_lora_bias: bool = False max_loras: int = 1 @@ -592,6 +593,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help=('Overrides for the multimodal input mapping/processing, ' 'e.g., image processor. For example: {"num_crops": 4}.')) + parser.add_argument( + '--mm-disable-frontend-processor', + action='store_true', + default=EngineArgs.mm_disable_frontend_processor, + help="Disable multi-modal frontend processing (not recommended)") + # LoRA related configs parser.add_argument('--enable-lora', action='store_true', @@ -963,6 +970,7 @@ def create_model_config(self) -> ModelConfig: use_async_output_proc=not self.disable_async_output_proc, config_format=self.config_format, mm_processor_kwargs=self.mm_processor_kwargs, + mm_disable_frontend_processor=self.mm_disable_frontend_processor, override_neuron_config=self.override_neuron_config, override_pooler_config=self.override_pooler_config, ) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index a25c401b4ea10..28749c95d89d4 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -168,6 +168,7 @@ def __init__( disable_async_output_proc: bool = False, hf_overrides: Optional[HfOverrides] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None, + mm_disable_frontend_processor: bool = False, # After positional args are removed, move this right below `model` task: TaskOption = "auto", override_pooler_config: Optional[PoolerConfig] = None, @@ -219,6 +220,7 @@ def __init__( disable_async_output_proc=disable_async_output_proc, hf_overrides=hf_overrides, mm_processor_kwargs=mm_processor_kwargs, + mm_disable_frontend_processor=mm_disable_frontend_processor, override_pooler_config=override_pooler_config, compilation_config=compilation_config_instance, **kwargs, @@ -549,6 +551,7 @@ def chat( continue_final_message: bool = False, tools: Optional[List[Dict[str, Any]]] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None, + mm_disable_frontend_processor: bool = False ) -> List[RequestOutput]: """ Generate responses for a chat conversation. @@ -590,6 +593,8 @@ def chat( ``True`` if ``add_generation_prompt`` is also ``True``. mm_processor_kwargs: Multimodal processor kwarg overrides for this chat request. Only used for offline requests. + mm_disable_frontend_processor: Disable multi-modal frontend + processing (not recommended) Returns: A list of ``RequestOutput`` objects containing the generated diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index fb7dbbebd7b90..448e909765d80 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -7,7 +7,7 @@ from typing_extensions import NotRequired, TypedDict, TypeVar, assert_never if TYPE_CHECKING: - from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict + from vllm.multimodal import MultiModalDataDict, MultiModalKwargs, MultiModalPlaceholderDict from vllm.multimodal.inputs import MultiModalInputsV2 @@ -150,6 +150,12 @@ class TokenInputs(TypedDict): if the model supports it. """ + multi_modal_inputs: NotRequired["MultiModalKwargs"] + """ + Optional multi-modal inputs to pass to the model, + if the model supports it. + """ + multi_modal_placeholders: NotRequired["MultiModalPlaceholderDict"] """ Placeholder ranges for the multi-modal data. @@ -169,6 +175,7 @@ def token_inputs( token_type_ids: Optional[List[int]] = None, prompt: Optional[str] = None, multi_modal_data: Optional["MultiModalDataDict"] = None, + multi_modal_inputs: Optional["MultiModalKwargs"] = None, multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None, ) -> TokenInputs: @@ -181,6 +188,8 @@ def token_inputs( inputs["token_type_ids"] = token_type_ids if multi_modal_data is not None: inputs["multi_modal_data"] = multi_modal_data + if multi_modal_inputs is not None: + inputs["multi_modal_inputs"] = multi_modal_inputs if multi_modal_placeholders is not None: inputs["multi_modal_placeholders"] = multi_modal_placeholders if mm_processor_kwargs is not None: @@ -273,6 +282,18 @@ def multi_modal_data(self) -> "MultiModalDataDict": assert_never(inputs) + @cached_property + def multi_modal_inputs(self) -> "MultiModalKwargs": + inputs = self.inputs + + if inputs["type"] == "token": + return inputs.get("multi_modal_inputs", {}) + + if inputs["type"] == "multimodal": + return inputs.get("mm_inputs", {}) + + assert_never(inputs) + @cached_property def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict": inputs = self.inputs diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 967124fd850ea..19cdc1581ab9e 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -5,7 +5,8 @@ import msgspec from vllm.lora.request import LoRARequest -from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict +from vllm.multimodal import (MultiModalDataDict, MultiModalKwargs, + MultiModalPlaceholderDict) from vllm.sampling_params import RequestOutputKind, SamplingParams @@ -36,6 +37,7 @@ class EngineCoreRequest: prompt: Optional[str] prompt_token_ids: List[int] mm_data: Optional[MultiModalDataDict] + mm_inputs: Optional[List[MultiModalKwargs]] mm_placeholders: Optional[MultiModalPlaceholderDict] mm_processor_kwargs: Optional[Dict[str, Any]] sampling_params: SamplingParams diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 34f99dd30ef2e..044ae96b05aeb 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -15,6 +15,7 @@ from vllm.config import CacheConfig, VllmConfig from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext +from vllm.multimodal import MultiModalDataDict, MultiModalKwargs from vllm.v1.core.scheduler import Scheduler from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, EngineCoreProfile, EngineCoreRequest, @@ -84,14 +85,14 @@ def _initialize_kv_caches(self, def add_request(self, request: EngineCoreRequest): """Add request to the scheduler.""" - req = Request.from_engine_core_request(request) - # FIXME(woosuk): The input mapping (e.g., PIL images to tensors) may - # take 10-50 ms, which can cause a spike in the latency. We should - # consider moving this to a separate thread. + + # Apply multi-modal mapper (if necessary) if req.mm_data: + assert req.mm_inputs is None or req.mm_inputs == [] req.mm_inputs = self.mm_input_mapper.process_inputs( req.mm_data, req.mm_processor_kwargs) + self.scheduler.add_request(req) def abort_requests(self, request_ids: List[str]): diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 5c1577190c75a..069ac5bb4d06d 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -14,6 +14,7 @@ from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest +from vllm.v1.engine.mm_input_mapper import MMInputMapper class Processor: @@ -39,6 +40,9 @@ def __init__( self.input_processor = input_registry.create_input_processor( model_config) + # Multi-modal (huggingface) input mapper + self.mm_input_mapper = MMInputMapper(model_config) + # TODO: run in an ThreadpoolExecutor or BackgroundProcess. # This ideally should releases the GIL, so we should not block the # asyncio loop while this is running. @@ -96,6 +100,17 @@ def process_inputs( sampling_params.update_from_generation_config( self.generation_config_fields, eos_token_id) + # Process multi-modal data via (huggingface) preprocessor + # here in the frontend process (if enabled) + mm_data = decoder_inputs.multi_modal_data + mm_inputs = None + if (not self.model_config.mm_disable_frontend_processor + and mm_data is not None): + mm_inputs = self.mm_input_mapper.process_inputs( + decoder_inputs.multi_modal_data, + decoder_inputs.mm_processor_kwargs) + mm_data = None + # Make Request for Detokenizer. detokenizer_request = DetokenizerRequest( request_id, @@ -113,7 +128,8 @@ def process_inputs( request_id, decoder_inputs.prompt, decoder_inputs.prompt_token_ids, - decoder_inputs.multi_modal_data, + mm_data, + mm_inputs, decoder_inputs.multi_modal_placeholders, decoder_inputs.mm_processor_kwargs, sampling_params, diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 51fb4003e5fe0..5bcc98bcd87ad 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -55,7 +55,11 @@ def __init__( else: self.mm_positions = [] # Output of the mm input mapper (e.g., image tensors). - self.mm_inputs: List[MultiModalKwargs] = [] + # (May be already provided by the frontend) + if self.inputs.multi_modal_inputs: + self.mm_inputs = self.inputs.multi_modal_inputs + else: + self.mm_inputs: List[MultiModalKwargs] = [] @classmethod def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": @@ -65,6 +69,7 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": prompt_token_ids=request.prompt_token_ids, prompt=request.prompt, multi_modal_data=request.mm_data, + multi_modal_inputs=request.mm_inputs, multi_modal_placeholders=request.mm_placeholders, mm_processor_kwargs=request.mm_processor_kwargs, ), @@ -110,7 +115,7 @@ def get_finished_reason(self) -> Union[str, None]: return RequestStatus.get_finished_reason(self.status) def has_encoder_inputs(self) -> bool: - return len(self.mm_data) > 0 + return len(self.mm_data) > 0 or len(self.mm_inputs) > 0 @property def num_encoder_inputs(self) -> int: From 37f22c344824a0732b38525060afe2dad39fd41e Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Mon, 2 Dec 2024 14:38:40 +0000 Subject: [PATCH 02/10] remove disable arg --- examples/offline_inference_vision_language.py | 1 - vllm/config.py | 7 +------ vllm/engine/arg_utils.py | 8 -------- vllm/entrypoints/llm.py | 7 +------ vllm/v1/engine/core.py | 8 +++----- vllm/v1/engine/processor.py | 3 +-- 6 files changed, 6 insertions(+), 28 deletions(-) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 3ff29fe16a7d4..359f8b965461d 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -26,7 +26,6 @@ def run_llava(question: str, modality: str): llm = LLM(model="llava-hf/llava-1.5-7b-hf", max_model_len=4096) - #mm_disable_frontend_processor=True) stop_token_ids = None return llm, prompt, stop_token_ids diff --git a/vllm/config.py b/vllm/config.py index 2fe677026e016..5f50d65ec87e1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -131,8 +131,6 @@ class ModelConfig: HuggingFace config. mm_processor_kwargs: Arguments to be forwarded to the model's processor for multi-modal data, e.g., image processor. - mm_disable_frontend_processor: Disables multi-modal HF preprocessor/mapper - execution in the frontend process (not recommended) override_neuron_config: Initialize non default neuron config or override default neuron config that are specific to Neuron devices, this argument will be used to configure the neuron config that @@ -171,7 +169,6 @@ def __init__( config_format: ConfigFormat = ConfigFormat.AUTO, hf_overrides: Optional[HfOverrides] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None, - mm_disable_frontend_processor: bool = False, override_neuron_config: Optional[Dict[str, Any]] = None, override_pooler_config: Optional["PoolerConfig"] = None) -> None: self.model = model @@ -238,7 +235,6 @@ def __init__( self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.use_async_output_proc = use_async_output_proc self.mm_processor_kwargs = mm_processor_kwargs - self.mm_disable_frontend_processor = mm_disable_frontend_processor # Set enforce_eager to False if the value is unset. if self.enforce_eager is None: @@ -2529,8 +2525,7 @@ def __str__(self): "decoding_config=%r, observability_config=%r, " "seed=%d, served_model_name=%s, " "num_scheduler_steps=%d, enable_prefix_caching=%s, " - "use_async_output_proc=%s, mm_processor_kwargs=%s, " - "mm_disable_frontend_processor=%s") % \ + "use_async_output_proc=%s, mm_processor_kwargs=%s") % \ (self.model_config.model, self.speculative_config, self.model_config.tokenizer, self.model_config.skip_tokenizer_init, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 30d92031eac7d..4aa0eebd976c9 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -143,7 +143,6 @@ class EngineArgs: tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None limit_mm_per_prompt: Optional[Mapping[str, int]] = None mm_processor_kwargs: Optional[Dict[str, Any]] = None - mm_disable_frontend_processor: bool = False enable_lora: bool = False enable_lora_bias: bool = False max_loras: int = 1 @@ -593,12 +592,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help=('Overrides for the multimodal input mapping/processing, ' 'e.g., image processor. For example: {"num_crops": 4}.')) - parser.add_argument( - '--mm-disable-frontend-processor', - action='store_true', - default=EngineArgs.mm_disable_frontend_processor, - help="Disable multi-modal frontend processing (not recommended)") - # LoRA related configs parser.add_argument('--enable-lora', action='store_true', @@ -970,7 +963,6 @@ def create_model_config(self) -> ModelConfig: use_async_output_proc=not self.disable_async_output_proc, config_format=self.config_format, mm_processor_kwargs=self.mm_processor_kwargs, - mm_disable_frontend_processor=self.mm_disable_frontend_processor, override_neuron_config=self.override_neuron_config, override_pooler_config=self.override_pooler_config, ) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 28749c95d89d4..1fc4451df1297 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -168,7 +168,6 @@ def __init__( disable_async_output_proc: bool = False, hf_overrides: Optional[HfOverrides] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None, - mm_disable_frontend_processor: bool = False, # After positional args are removed, move this right below `model` task: TaskOption = "auto", override_pooler_config: Optional[PoolerConfig] = None, @@ -220,7 +219,6 @@ def __init__( disable_async_output_proc=disable_async_output_proc, hf_overrides=hf_overrides, mm_processor_kwargs=mm_processor_kwargs, - mm_disable_frontend_processor=mm_disable_frontend_processor, override_pooler_config=override_pooler_config, compilation_config=compilation_config_instance, **kwargs, @@ -551,7 +549,6 @@ def chat( continue_final_message: bool = False, tools: Optional[List[Dict[str, Any]]] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None, - mm_disable_frontend_processor: bool = False ) -> List[RequestOutput]: """ Generate responses for a chat conversation. @@ -593,9 +590,7 @@ def chat( ``True`` if ``add_generation_prompt`` is also ``True``. mm_processor_kwargs: Multimodal processor kwarg overrides for this chat request. Only used for offline requests. - mm_disable_frontend_processor: Disable multi-modal frontend - processing (not recommended) - + Returns: A list of ``RequestOutput`` objects containing the generated responses in the same order as the input messages. diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 044ae96b05aeb..fcf89d66f4821 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -87,11 +87,9 @@ def add_request(self, request: EngineCoreRequest): """Add request to the scheduler.""" req = Request.from_engine_core_request(request) - # Apply multi-modal mapper (if necessary) - if req.mm_data: - assert req.mm_inputs is None or req.mm_inputs == [] - req.mm_inputs = self.mm_input_mapper.process_inputs( - req.mm_data, req.mm_processor_kwargs) + # Sanity check to verify that the multi-modal preprocessor + # ran in the frontend P0 process + assert req.mm_data is None or req.mm_data == {} self.scheduler.add_request(req) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 069ac5bb4d06d..30e4e17c360b2 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -104,8 +104,7 @@ def process_inputs( # here in the frontend process (if enabled) mm_data = decoder_inputs.multi_modal_data mm_inputs = None - if (not self.model_config.mm_disable_frontend_processor - and mm_data is not None): + if mm_data is not None: mm_inputs = self.mm_input_mapper.process_inputs( decoder_inputs.multi_modal_data, decoder_inputs.mm_processor_kwargs) From 7b0d9c4406329f42d968b71b1797d4f39c96f968 Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Mon, 2 Dec 2024 15:07:09 +0000 Subject: [PATCH 03/10] Nick's comment --- vllm/v1/engine/__init__.py | 2 -- vllm/v1/engine/core.py | 5 ----- vllm/v1/engine/processor.py | 15 ++++----------- vllm/v1/request.py | 11 ++++------- 4 files changed, 8 insertions(+), 25 deletions(-) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 19cdc1581ab9e..8019a1f82b7db 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -36,10 +36,8 @@ class EngineCoreRequest: # always be tokenized? prompt: Optional[str] prompt_token_ids: List[int] - mm_data: Optional[MultiModalDataDict] mm_inputs: Optional[List[MultiModalKwargs]] mm_placeholders: Optional[MultiModalPlaceholderDict] - mm_processor_kwargs: Optional[Dict[str, Any]] sampling_params: SamplingParams eos_token_id: Optional[int] arrival_time: float diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index fcf89d66f4821..0345db2248140 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -86,11 +86,6 @@ def _initialize_kv_caches(self, def add_request(self, request: EngineCoreRequest): """Add request to the scheduler.""" req = Request.from_engine_core_request(request) - - # Sanity check to verify that the multi-modal preprocessor - # ran in the frontend P0 process - assert req.mm_data is None or req.mm_data == {} - self.scheduler.add_request(req) def abort_requests(self, request_ids: List[str]): diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 30e4e17c360b2..53e723686407b 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -100,15 +100,10 @@ def process_inputs( sampling_params.update_from_generation_config( self.generation_config_fields, eos_token_id) - # Process multi-modal data via (huggingface) preprocessor - # here in the frontend process (if enabled) - mm_data = decoder_inputs.multi_modal_data - mm_inputs = None - if mm_data is not None: - mm_inputs = self.mm_input_mapper.process_inputs( - decoder_inputs.multi_modal_data, - decoder_inputs.mm_processor_kwargs) - mm_data = None + # Preprocess multi-modal data + mm_inputs = self.mm_input_mapper.process_inputs( + decoder_inputs.multi_modal_data, decoder_inputs.mm_processor_kwargs + ) if decoder_inputs.multi_modal_data is not None else None # Make Request for Detokenizer. detokenizer_request = DetokenizerRequest( @@ -127,10 +122,8 @@ def process_inputs( request_id, decoder_inputs.prompt, decoder_inputs.prompt_token_ids, - mm_data, mm_inputs, decoder_inputs.multi_modal_placeholders, - decoder_inputs.mm_processor_kwargs, sampling_params, eos_token_id, arrival_time, diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 5bcc98bcd87ad..50806f910bb4b 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -45,17 +45,14 @@ def __init__( self._all_token_ids: List[int] = self.prompt_token_ids.copy() self.num_computed_tokens = 0 - # Raw multimodal data before the mm input mapper (e.g., PIL images). - self.mm_data = self.inputs.multi_modal_data - self.mm_processor_kwargs = self.inputs.mm_processor_kwargs mm_positions = self.inputs.multi_modal_placeholders if mm_positions: # FIXME(woosuk): Support other modalities. self.mm_positions = mm_positions.get("image", []) else: self.mm_positions = [] + # Output of the mm input mapper (e.g., image tensors). - # (May be already provided by the frontend) if self.inputs.multi_modal_inputs: self.mm_inputs = self.inputs.multi_modal_inputs else: @@ -68,10 +65,10 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": inputs=token_inputs( prompt_token_ids=request.prompt_token_ids, prompt=request.prompt, - multi_modal_data=request.mm_data, + multi_modal_data=None, multi_modal_inputs=request.mm_inputs, multi_modal_placeholders=request.mm_placeholders, - mm_processor_kwargs=request.mm_processor_kwargs, + mm_processor_kwargs=None, ), sampling_params=request.sampling_params, eos_token_id=request.eos_token_id, @@ -115,7 +112,7 @@ def get_finished_reason(self) -> Union[str, None]: return RequestStatus.get_finished_reason(self.status) def has_encoder_inputs(self) -> bool: - return len(self.mm_data) > 0 or len(self.mm_inputs) > 0 + return len(self.mm_inputs) > 0 @property def num_encoder_inputs(self) -> int: From 59e6495b8b37ca23542d118b02b31614d655bd9c Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Mon, 2 Dec 2024 15:30:41 +0000 Subject: [PATCH 04/10] format --- examples/offline_inference_vision_language.py | 4 ++-- vllm/inputs/data.py | 7 ++++--- vllm/v1/engine/__init__.py | 5 ++--- vllm/v1/engine/core.py | 1 - 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 359f8b965461d..e2afe64d2021b 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -6,6 +6,7 @@ on HuggingFace model repository. """ import time + from transformers import AutoTokenizer from vllm import LLM, SamplingParams @@ -24,8 +25,7 @@ def run_llava(question: str, modality: str): prompt = f"USER: \n{question}\nASSISTANT:" - llm = LLM(model="llava-hf/llava-1.5-7b-hf", - max_model_len=4096) + llm = LLM(model="llava-hf/llava-1.5-7b-hf", max_model_len=4096) stop_token_ids = None return llm, prompt, stop_token_ids diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 448e909765d80..554d96c0989d1 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -7,7 +7,8 @@ from typing_extensions import NotRequired, TypedDict, TypeVar, assert_never if TYPE_CHECKING: - from vllm.multimodal import MultiModalDataDict, MultiModalKwargs, MultiModalPlaceholderDict + from vllm.multimodal import (MultiModalDataDict, MultiModalKwargs, + MultiModalPlaceholderDict) from vllm.multimodal.inputs import MultiModalInputsV2 @@ -283,14 +284,14 @@ def multi_modal_data(self) -> "MultiModalDataDict": assert_never(inputs) @cached_property - def multi_modal_inputs(self) -> "MultiModalKwargs": + def multi_modal_inputs(self) -> Union[Dict, "MultiModalKwargs"]: inputs = self.inputs if inputs["type"] == "token": return inputs.get("multi_modal_inputs", {}) if inputs["type"] == "multimodal": - return inputs.get("mm_inputs", {}) + return inputs.get("mm_kwargs", {}) assert_never(inputs) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 8019a1f82b7db..3cf0e610ae7af 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -1,12 +1,11 @@ import enum from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Union +from typing import List, Optional, Union import msgspec from vllm.lora.request import LoRARequest -from vllm.multimodal import (MultiModalDataDict, MultiModalKwargs, - MultiModalPlaceholderDict) +from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict from vllm.sampling_params import RequestOutputKind, SamplingParams diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 0345db2248140..397a33eed3896 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -15,7 +15,6 @@ from vllm.config import CacheConfig, VllmConfig from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext -from vllm.multimodal import MultiModalDataDict, MultiModalKwargs from vllm.v1.core.scheduler import Scheduler from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, EngineCoreProfile, EngineCoreRequest, From 259fbf217cce5f3cf4e173f0fa1ffa8bb0a01e6d Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Mon, 2 Dec 2024 17:58:40 +0000 Subject: [PATCH 05/10] Roger's comment --- vllm/v1/engine/processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 53e723686407b..02d825df2b4d9 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -103,7 +103,7 @@ def process_inputs( # Preprocess multi-modal data mm_inputs = self.mm_input_mapper.process_inputs( decoder_inputs.multi_modal_data, decoder_inputs.mm_processor_kwargs - ) if decoder_inputs.multi_modal_data is not None else None + ) if not decoder_inputs.multi_modal_data else None # Make Request for Detokenizer. detokenizer_request = DetokenizerRequest( From 593ae17fba1bdaca7997b08ac1dae50e8584a852 Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Mon, 2 Dec 2024 18:07:01 +0000 Subject: [PATCH 06/10] fix --- vllm/v1/engine/processor.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 02d825df2b4d9..7a1ea2530abda 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -102,8 +102,9 @@ def process_inputs( # Preprocess multi-modal data mm_inputs = self.mm_input_mapper.process_inputs( - decoder_inputs.multi_modal_data, decoder_inputs.mm_processor_kwargs - ) if not decoder_inputs.multi_modal_data else None + decoder_inputs.multi_modal_data, + decoder_inputs.mm_processor_kwargs) if len( + decoder_inputs.multi_modal_data) > 0 else None # Make Request for Detokenizer. detokenizer_request = DetokenizerRequest( From 4e730ab6108ddead8cc26351402dabff15af91dc Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Mon, 2 Dec 2024 16:16:29 -0500 Subject: [PATCH 07/10] Revert offline_inference_vision_language.py --- examples/offline_inference_vision_language.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index e2afe64d2021b..f08f22eec164a 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -5,8 +5,6 @@ For most models, the prompt format should follow corresponding examples on HuggingFace model repository. """ -import time - from transformers import AutoTokenizer from vllm import LLM, SamplingParams @@ -26,7 +24,6 @@ def run_llava(question: str, modality: str): prompt = f"USER: \n{question}\nASSISTANT:" llm = LLM(model="llava-hf/llava-1.5-7b-hf", max_model_len=4096) - stop_token_ids = None return llm, prompt, stop_token_ids @@ -517,10 +514,7 @@ def main(args): }, } for _ in range(args.num_prompts)] - start_time = time.time() outputs = llm.generate(inputs, sampling_params=sampling_params) - elapsed_time = time.time() - start_time - print("generate time = {}".format(elapsed_time)) for o in outputs: generated_text = o.outputs[0].text From 22e89142322dc68472f2f15bca0ed677c76280ef Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Mon, 2 Dec 2024 16:17:42 -0800 Subject: [PATCH 08/10] format Signed-off-by: Roger Wang --- vllm/entrypoints/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 1fc4451df1297..a25c401b4ea10 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -590,7 +590,7 @@ def chat( ``True`` if ``add_generation_prompt`` is also ``True``. mm_processor_kwargs: Multimodal processor kwarg overrides for this chat request. Only used for offline requests. - + Returns: A list of ``RequestOutput`` objects containing the generated responses in the same order as the input messages. From 7953df33359a19070fd84e73a715e39e8cc9031d Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Mon, 2 Dec 2024 19:00:07 -0800 Subject: [PATCH 09/10] trigger new CI build Signed-off-by: Roger Wang --- vllm/v1/request.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 50806f910bb4b..6bc1e4d5c769f 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -51,7 +51,6 @@ def __init__( self.mm_positions = mm_positions.get("image", []) else: self.mm_positions = [] - # Output of the mm input mapper (e.g., image tensors). if self.inputs.multi_modal_inputs: self.mm_inputs = self.inputs.multi_modal_inputs From 382fc0be9e168fe8ae47176ba54fbdc126f36940 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Mon, 2 Dec 2024 23:52:35 -0800 Subject: [PATCH 10/10] fix test Signed-off-by: Roger Wang --- tests/v1/engine/test_engine_core.py | 3 +-- tests/v1/engine/test_engine_core_client.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index bd11ff1877064..fef44ac29c41f 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -27,9 +27,8 @@ def make_request() -> EngineCoreRequest: request_id=uuid.uuid4(), prompt=PROMPT, prompt_token_ids=PROMPT_TOKENS, - mm_data=None, + mm_inputs=None, mm_placeholders=None, - mm_processor_kwargs=None, sampling_params=SamplingParams(), eos_token_id=None, arrival_time=time.time(), diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 582192196aaf9..4e003a25e91d2 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -29,9 +29,8 @@ def make_request(params: SamplingParams) -> EngineCoreRequest: request_id=str(uuid.uuid4()), prompt=PROMPT, prompt_token_ids=PROMPT_TOKENS, - mm_data=None, + mm_inputs=None, mm_placeholders=None, - mm_processor_kwargs=None, sampling_params=params, eos_token_id=None, arrival_time=time.time(),