From 59c9b6ebeba79b2d744eec86734a7e13b03dcab7 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Mon, 16 Dec 2024 22:10:57 -0800 Subject: [PATCH] [V1][VLM] Proper memory profiling for image language models (#11210) Signed-off-by: Roger Wang Co-authored-by: ywang96 --- vllm/config.py | 8 ++++ vllm/model_executor/models/pixtral.py | 5 ++ vllm/multimodal/registry.py | 23 +++++++-- vllm/v1/core/scheduler.py | 7 ++- vllm/v1/engine/mm_input_mapper.py | 1 + vllm/v1/worker/gpu_model_runner.py | 67 ++++++++++++++++++++++++--- 6 files changed, 98 insertions(+), 13 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 9cfd08024ea7b..9ecd3e72afa9f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1280,6 +1280,14 @@ class SchedulerConfig: is_multimodal_model: bool = False + # FIXME(woosuk & ywang96): Below are placeholder values. We need to + # calculate the actual values from the configurations. + # Multimodal encoder run compute budget, only used in V1 + max_num_encoder_input_tokens = 16384 + + # Multimodal encoder cache size, only used in V1 + encoder_cache_size = 16384 + # Whether to perform preemption by swapping or # recomputation. If not specified, we determine the mode as follows: # We use recomputation by default since it incurs lower overhead than diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 161d6b41bfa5f..f05ea195e043d 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -245,6 +245,11 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: # Do not split, return as tensor of shape [1, fs, hs] return image_embeds.unsqueeze(0) + # If the last split index is the last index in image_tokens, we + # ignore it to avoid empty split tensor + if split_indices[-1] == len(image_tokens): + split_indices = split_indices[:-1] + image_embeds = image_embeds.tensor_split(split_indices.cpu()) return image_embeds diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 03f8814a95356..6cd79d414c978 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -200,6 +200,23 @@ def register_max_image_tokens( """ return self.register_max_multimodal_tokens("image", max_mm_tokens) + def get_max_tokens_per_item_by_modality( + self, + model_config: "ModelConfig", + ) -> Mapping[str, int]: + """ + Get the maximum number of tokens per data item from each modality + for profiling the memory usage of a model. + + Note: + This is currently directly used only in V1. + """ + + return { + key: plugin.get_max_multimodal_tokens(model_config) + for key, plugin in self._plugins.items() + } + def get_max_tokens_by_modality( self, model_config: "ModelConfig", @@ -216,9 +233,9 @@ def get_max_tokens_by_modality( limits_per_plugin = self._limits_by_model[model_config] return { - key: (limits_per_plugin[key] * - plugin.get_max_multimodal_tokens(model_config)) - for key, plugin in self._plugins.items() + key: limits_per_plugin[key] * max_tokens_per_mm_item + for key, max_tokens_per_mm_item in + self.get_max_tokens_per_item_by_modality(model_config).items() } def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index f76364f64033d..178532e477dae 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -73,14 +73,13 @@ def __init__( # NOTE(woosuk): Here, "encoder" includes the vision encoder (and # projector if needed). Currently, we assume that the encoder also # has the Transformer architecture (e.g., ViT). - # FIXME(woosuk): Below are placeholder values. We need to calculate the - # actual values from the configurations. - self.max_num_encoder_input_tokens = 16384 + self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens #noqa: E501 # NOTE(woosuk): For the models without encoder (e.g., text-only models), # the encoder cache will not be initialized and used, regardless of # the cache size. This is because the memory space for the encoder cache # is preallocated in the profiling run. - self.encoder_cache_manager = EncoderCacheManager(cache_size=16384) + self.encoder_cache_manager = EncoderCacheManager( + cache_size=self.scheduler_config.encoder_cache_size) def schedule(self) -> "SchedulerOutput": # NOTE(woosuk) on the scheduling algorithm: diff --git a/vllm/v1/engine/mm_input_mapper.py b/vllm/v1/engine/mm_input_mapper.py index cca27c2218af7..6cdeba6f3f71e 100644 --- a/vllm/v1/engine/mm_input_mapper.py +++ b/vllm/v1/engine/mm_input_mapper.py @@ -54,6 +54,7 @@ def cache_hit_ratio(self, steps): logger.debug("MMInputMapper: cache_hit_ratio = %.2f ", self.mm_cache_hits / self.mm_cache_total) + # TODO: Support modalities beyond image. def process_inputs( self, mm_data: MultiModalDataDict, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 67166fb05085c..c6fab5f05fcb3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -10,15 +10,16 @@ from vllm.config import CompilationLevel, VllmConfig from vllm.distributed.parallel_state import graph_capture from vllm.forward_context import set_forward_context -from vllm.inputs import INPUT_REGISTRY, InputRegistry +from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model -from vllm.multimodal import MultiModalKwargs +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.sampling_params import SamplingType from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, LayerBlockType, cdiv, is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, FlashAttentionMetadata) +from vllm.v1.engine.mm_input_mapper import MMInputMapperClient from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -35,7 +36,6 @@ def __init__( self, vllm_config: VllmConfig, device: torch.device, - input_registry: InputRegistry = INPUT_REGISTRY, ): self.vllm_config = vllm_config self.model_config = vllm_config.model_config @@ -77,7 +77,12 @@ def __init__( self.hidden_size = model_config.get_hidden_size() # Multi-modal data support - self.input_registry = input_registry + self.input_registry = INPUT_REGISTRY + self.mm_registry = MULTIMODAL_REGISTRY + # NOTE: mm_input_mapper is only used for memory profiling. + self.mm_input_mapper = MMInputMapperClient(self.model_config) + self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens # noqa: E501 + self.encoder_cache_size = self.scheduler_config.encoder_cache_size # Lazy initialization # self.model: nn.Module # Set after load_model @@ -599,8 +604,6 @@ def _dummy_run( return hidden_states def profile_run(self) -> None: - # TODO(woosuk): Profile the max memory usage of the encoder and - # the encoder cache. # use an empty tensor instead of `None`` to force Dynamo to pass # it by reference, rather by specializing on the value `None`. # the `dtype` argument does not matter, and we use `float32` as @@ -612,6 +615,57 @@ def profile_run(self) -> None: torch.tensor([], dtype=torch.float32, device=self.device) for _ in range(self.num_attn_layers) ] + + # Profile with multimodal encoder & encoder cache. + # TODO (ywang96): generalize this beyond image modality since + # mm_input_mapper only supports image inputs. + if self.is_multimodal_model: + + # Create dummy batch of multimodal inputs. + dummy_request_data = self.input_registry.dummy_data_for_profiling( + model_config=self.model_config, + seq_len=self.max_num_tokens, + mm_registry=self.mm_registry, + ) + dummy_mm_data = dummy_request_data.multi_modal_data + dummy_mm_kwargs, _ = self.mm_input_mapper.process_inputs( + mm_data=dummy_mm_data, + mm_hashes=None, + mm_processor_kwargs=None, + precomputed_mm_inputs=None) + + # NOTE: Currently model is profiled with a single non-text + # modality even when it supports multiple. + max_tokens_per_mm_item = max( + self.mm_registry.get_max_tokens_per_item_by_modality( + self.model_config).values()) + + max_num_mm_items = min( + self.max_num_encoder_input_tokens, + self.encoder_cache_size) // max_tokens_per_mm_item + + # Dummy data definition in V0 may contain multiple multimodal items + # (e.g, multiple images) for a single request, therefore here we + # always replicate first item by max_num_mm_items times since in V1 + # they are scheduled to be processed separately. + batched_dummy_mm_inputs = MultiModalKwargs.batch( + [dummy_mm_kwargs[0]] * max_num_mm_items) + batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs( + batched_dummy_mm_inputs, device=self.device) + + # Run multimodal encoder. + dummy_encoder_outputs = self.model.get_multimodal_embeddings( + **batched_dummy_mm_inputs) + assert len(dummy_encoder_outputs) == max_num_mm_items, ( + "Expected dimension 0 of encoder outputs to match the number " + f"of multimodal data items: {max_num_mm_items}, got " + f"{len(dummy_encoder_outputs)=} instead. This is most likely " + "due to the 'get_multimodal_embeddings' method of the model " + "not implemented correctly.") + + # Cache the dummy encoder outputs. + self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) + # Trigger compilation for general shape. hidden_states = self._dummy_run(self.model, self.max_num_tokens, dummy_kv_caches) @@ -620,6 +674,7 @@ def profile_run(self) -> None: # TODO(woosuk): Consider the memory usage of the sampler. torch.cuda.synchronize() del hidden_states, logits + self.encoder_cache.clear() gc.collect() def capture_model(self) -> None: