Skip to content

Commit

Permalink
[V1] VLM - enable processor cache by default (#11305)
Browse files Browse the repository at this point in the history
Signed-off-by: Alexander Matveev <[email protected]>
  • Loading branch information
alexm-neuralmagic authored Dec 18, 2024
1 parent ca5f54a commit fdea8ec
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 48 deletions.
50 changes: 25 additions & 25 deletions examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def run_aria(question: str, modality: str):
tokenizer_mode="slow",
trust_remote_code=True,
dtype="bfloat16",
mm_cache_preprocessor=args.mm_cache_preprocessor)
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)

prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>\n{question}"
"<|im_end|>\n<|im_start|>assistant\n")
Expand All @@ -45,7 +45,7 @@ def run_blip2(question: str, modality: str):
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
prompt = f"Question: {question} Answer:"
llm = LLM(model="Salesforce/blip2-opt-2.7b",
mm_cache_preprocessor=args.mm_cache_preprocessor)
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompt, stop_token_ids

Expand All @@ -57,7 +57,7 @@ def run_chameleon(question: str, modality: str):
prompt = f"{question}<image>"
llm = LLM(model="facebook/chameleon-7b",
max_model_len=4096,
mm_cache_preprocessor=args.mm_cache_preprocessor)
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompt, stop_token_ids

Expand All @@ -70,7 +70,7 @@ def run_fuyu(question: str, modality: str):
llm = LLM(model="adept/fuyu-8b",
max_model_len=2048,
max_num_seqs=2,
mm_cache_preprocessor=args.mm_cache_preprocessor)
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompt, stop_token_ids

Expand All @@ -85,7 +85,7 @@ def run_glm4v(question: str, modality: str):
max_num_seqs=2,
trust_remote_code=True,
enforce_eager=True,
mm_cache_preprocessor=args.mm_cache_preprocessor)
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
prompt = question
stop_token_ids = [151329, 151336, 151338]
return llm, prompt, stop_token_ids
Expand All @@ -101,7 +101,7 @@ def run_h2ovl(question: str, modality: str):
model=model_name,
trust_remote_code=True,
max_model_len=8192,
mm_cache_preprocessor=args.mm_cache_preprocessor,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)

tokenizer = AutoTokenizer.from_pretrained(model_name,
Expand Down Expand Up @@ -134,7 +134,7 @@ def run_idefics3(question: str, modality: str):
"longest_edge": 3 * 364
},
},
mm_cache_preprocessor=args.mm_cache_preprocessor,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
prompt = (
f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
Expand All @@ -153,7 +153,7 @@ def run_internvl(question: str, modality: str):
model=model_name,
trust_remote_code=True,
max_model_len=4096,
mm_cache_preprocessor=args.mm_cache_preprocessor,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)

tokenizer = AutoTokenizer.from_pretrained(model_name,
Expand All @@ -180,7 +180,7 @@ def run_llava(question: str, modality: str):

llm = LLM(model="llava-hf/llava-1.5-7b-hf",
max_model_len=4096,
mm_cache_preprocessor=args.mm_cache_preprocessor)
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompt, stop_token_ids

Expand All @@ -192,7 +192,7 @@ def run_llava_next(question: str, modality: str):
prompt = f"[INST] <image>\n{question} [/INST]"
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf",
max_model_len=8192,
mm_cache_preprocessor=args.mm_cache_preprocessor)
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompt, stop_token_ids

Expand All @@ -205,7 +205,7 @@ def run_llava_next_video(question: str, modality: str):
prompt = f"USER: <video>\n{question} ASSISTANT:"
llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf",
max_model_len=8192,
mm_cache_preprocessor=args.mm_cache_preprocessor)
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompt, stop_token_ids

Expand All @@ -223,7 +223,7 @@ def run_llava_onevision(question: str, modality: str):

llm = LLM(model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
max_model_len=16384,
mm_cache_preprocessor=args.mm_cache_preprocessor)
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompt, stop_token_ids

Expand All @@ -239,7 +239,7 @@ def run_mantis(question: str, modality: str):
model="TIGER-Lab/Mantis-8B-siglip-llama3",
max_model_len=4096,
hf_overrides={"architectures": ["MantisForConditionalGeneration"]},
mm_cache_preprocessor=args.mm_cache_preprocessor,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
stop_token_ids = [128009]
return llm, prompt, stop_token_ids
Expand All @@ -266,7 +266,7 @@ def run_minicpmv(question: str, modality: str):
max_model_len=4096,
max_num_seqs=2,
trust_remote_code=True,
mm_cache_preprocessor=args.mm_cache_preprocessor,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
# NOTE The stop_token_ids are different for various versions of MiniCPM-V
# 2.0
Expand Down Expand Up @@ -305,7 +305,7 @@ def run_mllama(question: str, modality: str):
max_model_len=4096,
max_num_seqs=16,
enforce_eager=True,
mm_cache_preprocessor=args.mm_cache_preprocessor,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)

prompt = f"<|image|><|begin_of_text|>{question}"
Expand All @@ -323,7 +323,7 @@ def run_molmo(question, modality):
model=model_name,
trust_remote_code=True,
dtype="bfloat16",
mm_cache_preprocessor=args.mm_cache_preprocessor,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)

prompt = question
Expand All @@ -343,7 +343,7 @@ def run_nvlm_d(question: str, modality: str):
trust_remote_code=True,
max_model_len=4096,
tensor_parallel_size=4,
mm_cache_preprocessor=args.mm_cache_preprocessor,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)

tokenizer = AutoTokenizer.from_pretrained(model_name,
Expand All @@ -363,7 +363,7 @@ def run_paligemma(question: str, modality: str):
# PaliGemma has special prompt format for VQA
prompt = "caption en"
llm = LLM(model="google/paligemma-3b-mix-224",
mm_cache_preprocessor=args.mm_cache_preprocessor)
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompt, stop_token_ids

Expand All @@ -375,7 +375,7 @@ def run_paligemma2(question: str, modality: str):
# PaliGemma 2 has special prompt format for VQA
prompt = "caption en"
llm = LLM(model="google/paligemma2-3b-ft-docci-448",
mm_cache_preprocessor=args.mm_cache_preprocessor)
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
stop_token_ids = None
return llm, prompt, stop_token_ids

Expand Down Expand Up @@ -405,7 +405,7 @@ def run_phi3v(question: str, modality: str):
max_num_seqs=2,
# Note - mm_processor_kwargs can also be passed to generate/chat calls
mm_processor_kwargs={"num_crops": 16},
mm_cache_preprocessor=args.mm_cache_preprocessor,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
stop_token_ids = None
return llm, prompt, stop_token_ids
Expand All @@ -420,7 +420,7 @@ def run_pixtral_hf(question: str, modality: str):
llm = LLM(
model=model_name,
max_model_len=8192,
mm_cache_preprocessor=args.mm_cache_preprocessor,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)

prompt = f"<s>[INST]{question}\n[IMG][/INST]"
Expand All @@ -437,7 +437,7 @@ def run_qwen_vl(question: str, modality: str):
trust_remote_code=True,
max_model_len=1024,
max_num_seqs=2,
mm_cache_preprocessor=args.mm_cache_preprocessor,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)

prompt = f"{question}Picture 1: <img></img>\n"
Expand All @@ -460,7 +460,7 @@ def run_qwen2_vl(question: str, modality: str):
"min_pixels": 28 * 28,
"max_pixels": 1280 * 28 * 28,
},
mm_cache_preprocessor=args.mm_cache_preprocessor,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)

prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
Expand Down Expand Up @@ -651,9 +651,9 @@ def main(args):
' (if enabled)')

parser.add_argument(
'--mm-cache-preprocessor',
'--disable-mm-preprocessor-cache',
action='store_true',
help='If True, enable caching of multi-modal preprocessor/mapper.')
help='If True, disables caching of multi-modal preprocessor/mapper.')

parser.add_argument(
'--time-generate',
Expand Down
11 changes: 5 additions & 6 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,8 @@ class ModelConfig:
HuggingFace config.
mm_processor_kwargs: Arguments to be forwarded to the model's processor
for multi-modal data, e.g., image processor.
mm_cache_preprocessor: If true, then enables caching of the multi-modal
preprocessor/mapper. Otherwise, the mapper executes each time, and
for better performance consider enabling frontend process.
disable_mm_preprocessor_cache: If true, then disables caching of the
multi-modal preprocessor/mapper. (not recommended)
override_neuron_config: Initialize non default neuron config or
override default neuron config that are specific to Neuron devices,
this argument will be used to configure the neuron config that
Expand Down Expand Up @@ -216,7 +215,7 @@ def __init__(self,
config_format: ConfigFormat = ConfigFormat.AUTO,
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
mm_cache_preprocessor: bool = False,
disable_mm_preprocessor_cache: bool = False,
override_neuron_config: Optional[Dict[str, Any]] = None,
override_pooler_config: Optional["PoolerConfig"] = None,
logits_processor_pattern: Optional[str] = None) -> None:
Expand Down Expand Up @@ -286,7 +285,7 @@ def __init__(self,
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.use_async_output_proc = use_async_output_proc
self.mm_processor_kwargs = mm_processor_kwargs
self.mm_cache_preprocessor = mm_cache_preprocessor
self.disable_mm_preprocessor_cache = disable_mm_preprocessor_cache

# Set enforce_eager to False if the value is unset.
if self.enforce_eager is None:
Expand Down Expand Up @@ -3155,7 +3154,7 @@ def __str__(self):
f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa
f"use_async_output_proc={self.model_config.use_async_output_proc}, "
f"mm_cache_preprocessor={self.model_config.mm_cache_preprocessor!r}, " # noqa
f"disable_mm_preprocessor_cache={self.model_config.disable_mm_preprocessor_cache!r}, " # noqa
f"mm_processor_kwargs={self.model_config.mm_processor_kwargs}, "
f"pooler_config={self.model_config.pooler_config!r}, "
f"compilation_config={self.compilation_config!r}")
Expand Down
11 changes: 5 additions & 6 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class EngineArgs:
tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None
mm_cache_preprocessor: bool = False
disable_mm_preprocessor_cache: bool = False
enable_lora: bool = False
enable_lora_bias: bool = False
max_loras: int = 1
Expand Down Expand Up @@ -606,11 +606,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help=('Overrides for the multimodal input mapping/processing, '
'e.g., image processor. For example: {"num_crops": 4}.'))
parser.add_argument(
'--mm-cache-preprocessor',
'--disable-mm-preprocessor-cache',
action='store_true',
help='If true, then enables caching of the multi-modal '
'preprocessor/mapper. Otherwise, the mapper executes each time'
', and for better performance consider enabling frontend process.')
help='If true, then disables caching of the multi-modal '
'preprocessor/mapper. (not recommended)')

# LoRA related configs
parser.add_argument('--enable-lora',
Expand Down Expand Up @@ -983,7 +982,7 @@ def create_model_config(self) -> ModelConfig:
use_async_output_proc=not self.disable_async_output_proc,
config_format=self.config_format,
mm_processor_kwargs=self.mm_processor_kwargs,
mm_cache_preprocessor=self.mm_cache_preprocessor,
disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
override_neuron_config=self.override_neuron_config,
override_pooler_config=self.override_pooler_config,
logits_processor_pattern=self.logits_processor_pattern)
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def generate_block_hash_extra_keys(
raise ValueError(
"The number of multi-modal positions and hashes must match. This "
"is likely because you do not enable MM preprocessor hashing. "
"Please set mm_cache_preprocessor=True.")
"Please set disable_mm_preprocessor_cache=False.")

# Note that we assume mm_positions is sorted by offset.
# We do not need to check all mm inputs if the start token index is out of
Expand Down
20 changes: 17 additions & 3 deletions vllm/v1/engine/mm_input_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
self.mm_registry.init_mm_limits_per_prompt(model_config)

# Init cache
self.use_cache = model_config.mm_cache_preprocessor
self.use_cache = not model_config.disable_mm_preprocessor_cache
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)

# DEBUG: Set to None to disable
Expand Down Expand Up @@ -119,7 +119,7 @@ def process_inputs(
class MMInputMapperServer:

def __init__(self, model_config):
self.use_cache = model_config.mm_cache_preprocessor
self.use_cache = not model_config.disable_mm_preprocessor_cache
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)

def process_inputs(
Expand Down Expand Up @@ -151,12 +151,26 @@ class MMHasher:
def __init__(self):
pass

def hash(self, prompt: PromptType) -> Optional[List[str]]:
def hash_mm_data(
self,
mm_data: Optional[MultiModalDataDict]) -> Optional[List[str]]:
if mm_data is None:
return None

image_inputs = mm_data['image']

return self.hash_images(image_inputs)

def hash_prompt(self, prompt: PromptType) -> Optional[List[str]]:
if "multi_modal_data" not in prompt:
return None

mm_data = prompt["multi_modal_data"]
image_inputs = mm_data["image"]

return self.hash_images(image_inputs)

def hash_images(self, image_inputs) -> Optional[List[str]]:
if not isinstance(image_inputs, list):
image_inputs = [image_inputs]
assert len(image_inputs) > 0
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
self.mm_input_mapper_client = MMInputMapperClient(model_config)

# Multi-modal hasher (for images)
self.use_hash = model_config.mm_cache_preprocessor or \
self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \
cache_config.enable_prefix_caching
self.mm_hasher = MMHasher()

Expand Down Expand Up @@ -80,7 +80,7 @@ def process_inputs(
# Compute MM hashes (if enabled)
mm_hashes = None
if self.use_hash:
mm_hashes = self.mm_hasher.hash(prompt)
mm_hashes = self.mm_hasher.hash_prompt(prompt)

# Process inputs.
preprocessed_inputs = self.input_preprocessor.preprocess(
Expand Down
Loading

0 comments on commit fdea8ec

Please sign in to comment.