Skip to content

Commit

Permalink
[V1][VLM] Proper memory profiling for image language models (#11210)
Browse files Browse the repository at this point in the history
Signed-off-by: Roger Wang <[email protected]>
Co-authored-by: ywang96 <[email protected]>
  • Loading branch information
ywang96 and ywang96 authored Dec 17, 2024
1 parent 66d4b16 commit 59c9b6e
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 13 deletions.
8 changes: 8 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1280,6 +1280,14 @@ class SchedulerConfig:

is_multimodal_model: bool = False

# FIXME(woosuk & ywang96): Below are placeholder values. We need to
# calculate the actual values from the configurations.
# Multimodal encoder run compute budget, only used in V1
max_num_encoder_input_tokens = 16384

# Multimodal encoder cache size, only used in V1
encoder_cache_size = 16384

# Whether to perform preemption by swapping or
# recomputation. If not specified, we determine the mode as follows:
# We use recomputation by default since it incurs lower overhead than
Expand Down
5 changes: 5 additions & 0 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,11 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
# Do not split, return as tensor of shape [1, fs, hs]
return image_embeds.unsqueeze(0)

# If the last split index is the last index in image_tokens, we
# ignore it to avoid empty split tensor
if split_indices[-1] == len(image_tokens):
split_indices = split_indices[:-1]

image_embeds = image_embeds.tensor_split(split_indices.cpu())
return image_embeds

Expand Down
23 changes: 20 additions & 3 deletions vllm/multimodal/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,23 @@ def register_max_image_tokens(
"""
return self.register_max_multimodal_tokens("image", max_mm_tokens)

def get_max_tokens_per_item_by_modality(
self,
model_config: "ModelConfig",
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality
for profiling the memory usage of a model.
Note:
This is currently directly used only in V1.
"""

return {
key: plugin.get_max_multimodal_tokens(model_config)
for key, plugin in self._plugins.items()
}

def get_max_tokens_by_modality(
self,
model_config: "ModelConfig",
Expand All @@ -216,9 +233,9 @@ def get_max_tokens_by_modality(
limits_per_plugin = self._limits_by_model[model_config]

return {
key: (limits_per_plugin[key] *
plugin.get_max_multimodal_tokens(model_config))
for key, plugin in self._plugins.items()
key: limits_per_plugin[key] * max_tokens_per_mm_item
for key, max_tokens_per_mm_item in
self.get_max_tokens_per_item_by_modality(model_config).items()
}

def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
Expand Down
7 changes: 3 additions & 4 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,13 @@ def __init__(
# NOTE(woosuk): Here, "encoder" includes the vision encoder (and
# projector if needed). Currently, we assume that the encoder also
# has the Transformer architecture (e.g., ViT).
# FIXME(woosuk): Below are placeholder values. We need to calculate the
# actual values from the configurations.
self.max_num_encoder_input_tokens = 16384
self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens #noqa: E501
# NOTE(woosuk): For the models without encoder (e.g., text-only models),
# the encoder cache will not be initialized and used, regardless of
# the cache size. This is because the memory space for the encoder cache
# is preallocated in the profiling run.
self.encoder_cache_manager = EncoderCacheManager(cache_size=16384)
self.encoder_cache_manager = EncoderCacheManager(
cache_size=self.scheduler_config.encoder_cache_size)

def schedule(self) -> "SchedulerOutput":
# NOTE(woosuk) on the scheduling algorithm:
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/engine/mm_input_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def cache_hit_ratio(self, steps):
logger.debug("MMInputMapper: cache_hit_ratio = %.2f ",
self.mm_cache_hits / self.mm_cache_total)

# TODO: Support modalities beyond image.
def process_inputs(
self,
mm_data: MultiModalDataDict,
Expand Down
67 changes: 61 additions & 6 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@
from vllm.config import CompilationLevel, VllmConfig
from vllm.distributed.parallel_state import graph_capture
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.sampling_params import SamplingType
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType, cdiv, is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata)
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
Expand All @@ -35,7 +36,6 @@ def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
input_registry: InputRegistry = INPUT_REGISTRY,
):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
Expand Down Expand Up @@ -77,7 +77,12 @@ def __init__(
self.hidden_size = model_config.get_hidden_size()

# Multi-modal data support
self.input_registry = input_registry
self.input_registry = INPUT_REGISTRY
self.mm_registry = MULTIMODAL_REGISTRY
# NOTE: mm_input_mapper is only used for memory profiling.
self.mm_input_mapper = MMInputMapperClient(self.model_config)
self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens # noqa: E501
self.encoder_cache_size = self.scheduler_config.encoder_cache_size

# Lazy initialization
# self.model: nn.Module # Set after load_model
Expand Down Expand Up @@ -599,8 +604,6 @@ def _dummy_run(
return hidden_states

def profile_run(self) -> None:
# TODO(woosuk): Profile the max memory usage of the encoder and
# the encoder cache.
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value `None`.
# the `dtype` argument does not matter, and we use `float32` as
Expand All @@ -612,6 +615,57 @@ def profile_run(self) -> None:
torch.tensor([], dtype=torch.float32, device=self.device)
for _ in range(self.num_attn_layers)
]

# Profile with multimodal encoder & encoder cache.
# TODO (ywang96): generalize this beyond image modality since
# mm_input_mapper only supports image inputs.
if self.is_multimodal_model:

# Create dummy batch of multimodal inputs.
dummy_request_data = self.input_registry.dummy_data_for_profiling(
model_config=self.model_config,
seq_len=self.max_num_tokens,
mm_registry=self.mm_registry,
)
dummy_mm_data = dummy_request_data.multi_modal_data
dummy_mm_kwargs, _ = self.mm_input_mapper.process_inputs(
mm_data=dummy_mm_data,
mm_hashes=None,
mm_processor_kwargs=None,
precomputed_mm_inputs=None)

# NOTE: Currently model is profiled with a single non-text
# modality even when it supports multiple.
max_tokens_per_mm_item = max(
self.mm_registry.get_max_tokens_per_item_by_modality(
self.model_config).values())

max_num_mm_items = min(
self.max_num_encoder_input_tokens,
self.encoder_cache_size) // max_tokens_per_mm_item

# Dummy data definition in V0 may contain multiple multimodal items
# (e.g, multiple images) for a single request, therefore here we
# always replicate first item by max_num_mm_items times since in V1
# they are scheduled to be processed separately.
batched_dummy_mm_inputs = MultiModalKwargs.batch(
[dummy_mm_kwargs[0]] * max_num_mm_items)
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
batched_dummy_mm_inputs, device=self.device)

# Run multimodal encoder.
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
**batched_dummy_mm_inputs)
assert len(dummy_encoder_outputs) == max_num_mm_items, (
"Expected dimension 0 of encoder outputs to match the number "
f"of multimodal data items: {max_num_mm_items}, got "
f"{len(dummy_encoder_outputs)=} instead. This is most likely "
"due to the 'get_multimodal_embeddings' method of the model "
"not implemented correctly.")

# Cache the dummy encoder outputs.
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))

# Trigger compilation for general shape.
hidden_states = self._dummy_run(self.model, self.max_num_tokens,
dummy_kv_caches)
Expand All @@ -620,6 +674,7 @@ def profile_run(self) -> None:
# TODO(woosuk): Consider the memory usage of the sampler.
torch.cuda.synchronize()
del hidden_states, logits
self.encoder_cache.clear()
gc.collect()

def capture_model(self) -> None:
Expand Down

0 comments on commit 59c9b6e

Please sign in to comment.