Skip to content

Commit

Permalink
[V1] VLM hashing and mapper caching
Browse files Browse the repository at this point in the history
  • Loading branch information
alexm-neuralmagic committed Dec 3, 2024
1 parent 3bc94ca commit db28436
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 14 deletions.
4 changes: 4 additions & 0 deletions examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,11 @@ def main(args):
},
} for _ in range(args.num_prompts)]

import time
start_time = time.time()
outputs = llm.generate(inputs, sampling_params=sampling_params)
elapsed_time = time.time() - start_time
print("-- generate time = {}".format(elapsed_time))

for o in outputs:
generated_text = o.outputs[0].text
Expand Down
10 changes: 8 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ class ModelConfig:
HuggingFace config.
mm_processor_kwargs: Arguments to be forwarded to the model's processor
for multi-modal data, e.g., image processor.
mm_cache_preprocessor: If True, enable caching of multi-modal
preprocessor/mapper.
override_neuron_config: Initialize non default neuron config or
override default neuron config that are specific to Neuron devices,
this argument will be used to configure the neuron config that
Expand Down Expand Up @@ -169,6 +171,7 @@ def __init__(
config_format: ConfigFormat = ConfigFormat.AUTO,
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
mm_cache_preprocessor: bool = False,
override_neuron_config: Optional[Dict[str, Any]] = None,
override_pooler_config: Optional["PoolerConfig"] = None) -> None:
self.model = model
Expand Down Expand Up @@ -235,6 +238,7 @@ def __init__(
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.use_async_output_proc = use_async_output_proc
self.mm_processor_kwargs = mm_processor_kwargs
self.mm_cache_preprocessor = mm_cache_preprocessor

# Set enforce_eager to False if the value is unset.
if self.enforce_eager is None:
Expand Down Expand Up @@ -2593,7 +2597,8 @@ def __str__(self):
"decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s, "
"num_scheduler_steps=%d, enable_prefix_caching=%s, "
"use_async_output_proc=%s, mm_processor_kwargs=%s") % \
"use_async_output_proc=%s, mm_processor_kwargs=%s, "
"mm_cache_preprocessor=%s") % \
(self.model_config.model, self.speculative_config,
self.model_config.tokenizer,
self.model_config.skip_tokenizer_init,
Expand All @@ -2619,7 +2624,8 @@ def __str__(self):
self.scheduler_config.num_scheduler_steps,
self.cache_config.enable_prefix_caching,
self.model_config.use_async_output_proc,
self.model_config.mm_processor_kwargs)
self.model_config.mm_processor_kwargs,
self.model_config.mm_cache_preprocessor)


_current_vllm_config: Optional[VllmConfig] = None
Expand Down
2 changes: 2 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ class EngineArgs:
tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None
mm_cache_preprocessor: bool = False
enable_lora: bool = False
enable_lora_bias: bool = False
max_loras: int = 1
Expand Down Expand Up @@ -964,6 +965,7 @@ def create_model_config(self) -> ModelConfig:
use_async_output_proc=not self.disable_async_output_proc,
config_format=self.config_format,
mm_processor_kwargs=self.mm_processor_kwargs,
mm_cache_preprocessor=self.mm_cache_preprocessor,
override_neuron_config=self.override_neuron_config,
override_pooler_config=self.override_pooler_config,
)
Expand Down
4 changes: 0 additions & 4 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreProfile, EngineCoreRequest,
EngineCoreRequestType)
from vllm.v1.engine.mm_input_mapper import MMInputMapper
from vllm.v1.executor.gpu_executor import GPUExecutor
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import PickleEncoder
Expand Down Expand Up @@ -55,9 +54,6 @@ def __init__(
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks

# Set up multimodal input mapper (e.g., convert PIL images to tensors).
self.mm_input_mapper = MMInputMapper(vllm_config.model_config)

# Setup scheduler.
self.scheduler = Scheduler(vllm_config.scheduler_config,
vllm_config.cache_config,
Expand Down
85 changes: 79 additions & 6 deletions vllm/v1/engine/mm_input_mapper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import time
import PIL

from blake3 import blake3
from typing import Any, Dict, List, Optional

from vllm.config import ModelConfig
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalKwargs, MultiModalRegistry)
from vllm.v1.utils import LRUDictCache


class MMInputMapper:
Expand All @@ -11,29 +16,97 @@ def __init__(
self,
model_config: ModelConfig,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
mm_cache_size: int = 128,
):
self.mm_registry = mm_registry
self.multi_modal_input_mapper = mm_registry.create_input_mapper(
model_config)
self.mm_registry.init_mm_limits_per_prompt(model_config)

self.mm_cache = LRUDictCache(size=mm_cache_size)
self.mm_cache_hits = 0
self.mm_cache_misses = 0

# Set to None to disable (TODO: Disable!)
self.mm_debug_cache_hit_ratio_steps = 32

def cache_hit_ratio(self, steps) -> float:
total_steps = self.mm_cache_hits + self.mm_cache_misses

if total_steps > 0 and total_steps % steps == 0:
print("[debug] MMInputMapper: cache_hit_ratio = {}".format(
self.mm_cache_hits / total_steps))

def process_inputs(
self,
mm_data: MultiModalDataDict,
mm_hash: Optional[List[str]],
mm_processor_kwargs: Optional[Dict[str, Any]],
) -> List[MultiModalKwargs]:
image_inputs = mm_data["image"]
if not isinstance(image_inputs, list):
image_inputs = [image_inputs]

use_hash = mm_hash is not None
if use_hash:
assert len(image_inputs) == len(mm_hash) # Sanity

# Process each image input separately so that later we can schedule
# them in a fine-grained manner.
# Utilize caching (if enabled)
mm_inputs: List[MultiModalKwargs] = []
num_images = len(image_inputs)
for i in range(num_images):
mm_input = self.multi_modal_input_mapper(
{"image": [image_inputs[i]]},
mm_processor_kwargs=mm_processor_kwargs,
)
for i in range(len(image_inputs)):
if self.mm_debug_cache_hit_ratio_steps is not None:
self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps)

mm_input = self.mm_cache.get(mm_hash[i]) if use_hash else None
if mm_input is None:
self.mm_cache_misses += 1
mm_input = self.multi_modal_input_mapper(
{"image": [image_inputs[i]]},
mm_processor_kwargs=mm_processor_kwargs,
)
if use_hash:
self.mm_cache.put(mm_hash[i], mm_input)
else:
self.mm_cache_hits += 1

mm_inputs.append(mm_input)

return mm_inputs


class MMHasher:

def __init__(self):
pass

def hash(self, mm_data: MultiModalDataDict) -> List[str]:
image_inputs = mm_data["image"]
if not isinstance(image_inputs, list):
image_inputs = [image_inputs]

ret = []
for image in image_inputs:
assert isinstance(image, PIL.Image.Image)

# FIXME(alexm): Remove debug

# print(" type(data) = {}, data = {}".format(type(image), image))

# Convert image to bytes
start_time = time.time()
bytes = image.tobytes()
elapsed_time = time.time() - start_time
# print(" tobytes time = {}".format(elapsed_time))

# Hash image bytes
start_time = time.time()
hasher = blake3()
hasher.update(bytes)
ret.append(hasher.hexdigest())
elapsed_time = time.time() - start_time

Check failure on line 108 in vllm/v1/engine/mm_input_mapper.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F841)

vllm/v1/engine/mm_input_mapper.py:108:13: F841 Local variable `elapsed_time` is assigned to but never used
# print(" hash time = {}".format(elapsed_time))
# print(" hash val = {}".format(ret[-1]))

return ret
10 changes: 8 additions & 2 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest
from vllm.v1.engine.mm_input_mapper import MMInputMapper
from vllm.v1.engine.mm_input_mapper import MMInputMapper, MMHasher


class Processor:
Expand Down Expand Up @@ -43,6 +43,10 @@ def __init__(
# Multi-modal (huggingface) input mapper
self.mm_input_mapper = MMInputMapper(model_config)

# Multi-modal hasher (for images)
self.mm_hasher = MMHasher(
) if model_config.mm_cache_preprocessor else None

# TODO: run in an ThreadpoolExecutor or BackgroundProcess.
# This ideally should releases the GIL, so we should not block the
# asyncio loop while this is running.
Expand Down Expand Up @@ -101,8 +105,10 @@ def process_inputs(
self.generation_config_fields, eos_token_id)

# Preprocess multi-modal data
mm_hash = self.mm_hasher.hash(decoder_inputs.multi_modal_data
) if self.mm_hasher is not None else None
mm_inputs = self.mm_input_mapper.process_inputs(
decoder_inputs.multi_modal_data,
decoder_inputs.multi_modal_data, mm_hash,
decoder_inputs.mm_processor_kwargs) if len(
decoder_inputs.multi_modal_data) > 0 else None

Expand Down
21 changes: 21 additions & 0 deletions vllm/v1/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Generic, List, TypeVar, overload
from collections import OrderedDict

T = TypeVar("T")

Expand Down Expand Up @@ -62,3 +63,23 @@ def __contains__(self, item):

def __len__(self):
return len(self._x)


class LRUDictCache:

def __init__(self, size: int):
self.cache = OrderedDict()
self.size = size

def get(self, key):
if key not in self.cache:
return None

self.cache.move_to_end(key)
return self.cache[key]

def put(self, key, value):
self.cache[key] = value
self.cache.move_to_end(key)
if len(self.cache) > self.size:
self.cache.popitem(last=False)

0 comments on commit db28436

Please sign in to comment.