From 7379b3d4b2e0b85de43e7c5145ff26c8200aac8a Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Thu, 19 Dec 2024 08:27:22 -0800 Subject: [PATCH 1/6] [V1] Fix multimodal profiling for `Molmo` (#11325) Signed-off-by: ywang96 Co-authored-by: ywang96 --- vllm/model_executor/models/molmo.py | 5 +++++ vllm/v1/engine/mm_input_mapper.py | 19 +++++++++++++++++-- vllm/v1/engine/processor.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 2 +- 4 files changed, 24 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index a328b5a2aeea7..9f744b6918818 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -928,7 +928,11 @@ def image_input_mapper_for_molmo( data: object, ): if isinstance(data, list): + assert len(data) == 1, "Molmo supports only one image per prompt." data = data[0] + + # Remove unused dummy PIL image + data.pop('raw_mm_data', None) return MultiModalKwargs(data) @@ -974,6 +978,7 @@ def dummy_data_for_molmo(ctx: InputContext, seq_len: int, dummy_imgdata = { "images": out["images"], "image_input_idx": out["image_input_idx"], + "raw_mm_data": dummy_image, } if "image_masks" in out: dummy_imgdata["image_masks"] = out["image_masks"] diff --git a/vllm/v1/engine/mm_input_mapper.py b/vllm/v1/engine/mm_input_mapper.py index bba71c29cc108..cb97f743b1d52 100644 --- a/vllm/v1/engine/mm_input_mapper.py +++ b/vllm/v1/engine/mm_input_mapper.py @@ -151,17 +151,31 @@ class MMHasher: def __init__(self): pass - def hash_mm_data( + def hash_dummy_mm_data( self, mm_data: Optional[MultiModalDataDict]) -> Optional[List[str]]: + """Hash user-defined dummy multimodal data used for profiling.""" + if mm_data is None: return None image_inputs = mm_data['image'] + # This is a temporary workaround for models (e.g, Molmo) that + # process multimodal data in the input processor (therefore + # image_inputs is MultiModalKwargs instead of raw input format). + # `raw_mm_data` with the original input format is expected + # in this case. + if isinstance(image_inputs, dict): + assert "raw_mm_data" in image_inputs and isinstance( + image_inputs["raw_mm_data"], PIL.Image.Image) + image_inputs = image_inputs.pop("raw_mm_data") + return self.hash_images(image_inputs) - def hash_prompt(self, prompt: PromptType) -> Optional[List[str]]: + def hash_prompt_mm_data(self, prompt: PromptType) -> Optional[List[str]]: + """Hash multimodal data in the user input prompt if they exist.""" + if "multi_modal_data" not in prompt: return None @@ -171,6 +185,7 @@ def hash_prompt(self, prompt: PromptType) -> Optional[List[str]]: return self.hash_images(image_inputs) def hash_images(self, image_inputs) -> Optional[List[str]]: + """Hash PIL image objects to strings.""" if not isinstance(image_inputs, list): image_inputs = [image_inputs] assert len(image_inputs) > 0 diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index ffcaa158d252d..6ee8732bc902c 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -79,7 +79,7 @@ def process_inputs( # Compute MM hashes (if enabled) mm_hashes = None if self.use_hash: - mm_hashes = self.mm_hasher.hash_prompt(prompt) + mm_hashes = self.mm_hasher.hash_prompt_mm_data(prompt) # Process inputs. preprocessed_inputs = self.input_preprocessor.preprocess( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8ec4a252d5925..cb89246db0cc9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -638,7 +638,7 @@ def profile_run(self) -> None: # Compute MM hashes (if enabled) mm_hashes = None if self.use_hash: - mm_hashes = self.mm_hasher.hash_mm_data(dummy_mm_data) + mm_hashes = self.mm_hasher.hash_dummy_mm_data(dummy_mm_data) dummy_mm_kwargs = self.mm_input_mapper_client.process_inputs( mm_data=dummy_mm_data, From e24113a8fe5de5b96459d1f8509d1b48fd7ceebe Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Fri, 20 Dec 2024 00:28:00 +0800 Subject: [PATCH 2/6] [Model] Refactor Qwen2-VL to use merged multimodal processor (#11258) Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: DarkLight1337 Co-authored-by: Cyrus Leung Co-authored-by: DarkLight1337 --- examples/offline_inference_vision_language.py | 8 +- .../mm_processor_kwargs/test_qwen2_vl.py | 192 ++---- vllm/model_executor/models/qwen2_audio.py | 4 +- vllm/model_executor/models/qwen2_vl.py | 580 ++++++------------ vllm/multimodal/processing.py | 20 +- 5 files changed, 277 insertions(+), 527 deletions(-) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 64c7b93f4a71b..d5a71862656e7 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -447,7 +447,6 @@ def run_qwen_vl(question: str, modality: str): # Qwen2-VL def run_qwen2_vl(question: str, modality: str): - assert modality == "image" model_name = "Qwen/Qwen2-VL-7B-Instruct" @@ -463,8 +462,13 @@ def run_qwen2_vl(question: str, modality: str): disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, ) + if modality == "image": + placeholder = "<|image_pad|>" + elif modality == "video": + placeholder = "<|video_pad|>" + prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" - "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>" + f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>" f"{question}<|im_end|>\n" "<|im_start|>assistant\n") stop_token_ids = None diff --git a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py index 7e2bea130583e..cd8954ffc48c2 100644 --- a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py +++ b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py @@ -1,12 +1,9 @@ from typing import Any, Dict, Tuple import pytest -import torch -from PIL.Image import Image from transformers import AutoTokenizer -from vllm.inputs import InputContext, token_inputs -from vllm.multimodal import MultiModalRegistry +from vllm.inputs import InputContext, InputProcessingContext from .....conftest import _ImageAssets from ....utils import build_model_context @@ -20,22 +17,9 @@ # NOTE: Qwen2VL supports multiple input modalities, so it registers multiple # input mappers. @pytest.fixture() -def image_input_mapper_for_qwen2_vl(): - from vllm.model_executor.models.qwen2_vl import ( - image_input_mapper_for_qwen2_vl) - return image_input_mapper_for_qwen2_vl - - -@pytest.fixture() -def input_processor_for_qwen2_vl(): - from vllm.model_executor.models.qwen2_vl import ( - input_processor_for_qwen2_vl) - return input_processor_for_qwen2_vl - - -@pytest.fixture() -def qwen2_vl_context() -> InputContext: - return build_model_context(model_name=MODEL) +def processor_for_qwen2_vl(): + from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalProcessor + return Qwen2VLMultiModalProcessor @pytest.fixture() @@ -45,12 +29,6 @@ def get_max_qwen2_vl_image_tokens(): return get_max_qwen2_vl_image_tokens -@pytest.fixture() -def dummy_data_for_qwen2_vl(): - from vllm.model_executor.models.qwen2_vl import dummy_data_for_qwen2_vl - return dummy_data_for_qwen2_vl - - @pytest.mark.parametrize("mm_processor_kwargs,expected_max_tokens", [ ({}, 1225), ({ @@ -58,110 +36,70 @@ def dummy_data_for_qwen2_vl(): MAX_PIXELS: 512**2 }, 324), ]) -def test_qwen2_vl_max_image_tokens(get_max_qwen2_vl_image_tokens, - qwen2_vl_context: InputContext, - mm_processor_kwargs: Dict[str, Any], - expected_max_tokens: int): +@pytest.mark.parametrize("model", [MODEL]) +def test_qwen2_vl_max_image_tokens( + get_max_qwen2_vl_image_tokens, + model: str, + mm_processor_kwargs: Dict[str, Any], + expected_max_tokens: int, +): """Ensure that the max token calc handles min/max pixels properly.""" - actual_max_tokens = get_max_qwen2_vl_image_tokens(qwen2_vl_context, - **mm_processor_kwargs) - assert actual_max_tokens == expected_max_tokens - - -@pytest.mark.parametrize("mm_processor_kwargs,token_count,img_size", [ - [{}, 1225, (980, 980)], - [{ - MIN_PIXELS: 64**2, - MAX_PIXELS: 512**2 - }, 324, (504, 504)], -]) -def test_qwen2_vl_dummy_data(dummy_data_for_qwen2_vl, - qwen2_vl_context: InputContext, - mm_processor_kwargs: Dict[str, Any], - token_count: int, img_size: Tuple[int, int]): - """Ensure that the dummy data handles min/max pixels properly.""" - seq_len = 3000 - hf_config = qwen2_vl_context.get_hf_config() - image_token_id = hf_config.image_token_id - - # NOTE: video value is required, but isn't actually used - # when making the dummy data except for error handling currently - dummy_data = dummy_data_for_qwen2_vl( - ctx=qwen2_vl_context, - seq_len=seq_len, - mm_counts={ - "image": 1, - "video": 0 - }, - **mm_processor_kwargs, + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + mm_processor_kwargs=None, ) - seq_data = dummy_data.seq_data - mm_data = dummy_data.multi_modal_data - - # Ensure we have the right number of placeholders for min/max pixel values - assert seq_data.get_token_ids().count(image_token_id) == token_count - # Ensure the images were resized correctly - image = mm_data["image"] - assert isinstance(image, Image) - assert image.size == img_size + actual_max_tokens = get_max_qwen2_vl_image_tokens( + InputContext(ctx.model_config), **mm_processor_kwargs) + assert actual_max_tokens == expected_max_tokens -@pytest.mark.parametrize("mm_processor_kwargs,num_placeholders", [ - ({}, 1426), - ({ - MIN_PIXELS: 64**2, - MAX_PIXELS: 512**2 - }, 330), -]) -def test_input_processor(input_processor_for_qwen2_vl, - qwen2_vl_context: InputContext, - image_assets: _ImageAssets, num_placeholders: int, - mm_processor_kwargs: Dict[str, Any]): - """Ensure that the image processor handles min/max pixels properly.""" - tokenizer = AutoTokenizer.from_pretrained(MODEL) - prompt = "<|vision_start|><|image_pad|><|vision_end|>" - - image = image_assets[0].pil_image - hf_config = qwen2_vl_context.get_hf_config() - image_token_id = hf_config.image_token_id - - inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt), - prompt=prompt, - multi_modal_data={"image": [image]}) - - processed_inputs = input_processor_for_qwen2_vl(qwen2_vl_context, inputs, - **mm_processor_kwargs) - assert processed_inputs["prompt_token_ids"].count( - image_token_id) == num_placeholders - assert len(processed_inputs["multi_modal_data"]["image"]) == 1 - - -@pytest.mark.parametrize("mm_processor_kwargs,pixels_shape", [ - ({}, [5704, 1176]), - ({ - MIN_PIXELS: 64**2, - MAX_PIXELS: 512**2 - }, [1320, 1176]), -]) -def test_image_mapper_override(qwen2_vl_context: InputContext, - image_assets: _ImageAssets, - mm_processor_kwargs: Dict[str, Any], - pixels_shape: Tuple[int, int]): - """Ensure that the image mapper handles min/max pixels properly.""" - mm_registry = MultiModalRegistry() - mm_registry.init_mm_limits_per_prompt(qwen2_vl_context.model_config) - - image = image_assets[0].pil_image - - mapped_output = mm_registry.map_input( - qwen2_vl_context.model_config, - {"image": image}, - mm_processor_kwargs=mm_processor_kwargs, +@pytest.mark.parametrize( + "mm_processor_kwargs, expected_toks_per_img, expected_pixels_shape", [ + ({}, 1426, (5704, 1176)), + ({ + MIN_PIXELS: 64**2, + MAX_PIXELS: 512**2 + }, 330, (1320, 1176)), + ]) +@pytest.mark.parametrize("model", [MODEL]) +@pytest.mark.parametrize("num_imgs", [1, 2]) +def test_processor_override( + processor_for_qwen2_vl, + image_assets: _ImageAssets, + model: str, + mm_processor_kwargs: Dict[str, Any], + expected_toks_per_img: int, + expected_pixels_shape: Tuple[int, int], + num_imgs: int, +): + """Ensure Qwen2VLMultiModalProcessor handles min/max pixels properly.""" + # Same as the previous test - don't initialize mm_processor_kwargs + # in this test and assume that the kwargs will be correctly expanded by + # the partial when calling the custom input processor. + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + mm_processor_kwargs=None, ) - - # Dimension 0 of pixel values should match the product of image_grid_thw - actual_pixels_shape = mapped_output["pixel_values"].shape - assert list(actual_pixels_shape) == pixels_shape - assert actual_pixels_shape[0] == torch.prod( - mapped_output["image_grid_thw"]) + tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) + ctx = InputProcessingContext(ctx.model_config, tokenizer) + # Build the image str / prompt based on the number of images we pass + prompt = "<|vision_start|><|image_pad|><|vision_end|>" * num_imgs + images = [image_assets[0].pil_image] * num_imgs + + mm_data = {"image": images} + + processor = processor_for_qwen2_vl(ctx) + processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) + + # Ensure we have the right number of placeholders per num_crops size + hf_processor = processor._get_hf_processor(**mm_processor_kwargs) + image_token_id = tokenizer.convert_tokens_to_ids(hf_processor.image_token) + img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) + pixel_shape = processed_inputs["mm_kwargs"]["pixel_values"].shape + + assert img_tok_count == expected_toks_per_img * num_imgs + assert pixel_shape[0] == expected_pixels_shape[0] * num_imgs + assert pixel_shape[1] == expected_pixels_shape[1] diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 07e29b71c2ed4..6259166a7fc57 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -164,7 +164,9 @@ def _get_dummy_mm_inputs( self, mm_counts: Mapping[str, int], ) -> ProcessorInputs: - audio_len = get_max_qwen2_audio_audio_tokens(self.ctx) + feature_extractor = self._get_feature_extractor() + sampling_rate = feature_extractor.sampling_rate + audio_len = feature_extractor.chunk_length * sampling_rate audio_count = mm_counts["audio"] audio = np.zeros(audio_len) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index cfc90cdab01e4..b38ea923f0bf1 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -22,28 +22,26 @@ # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" from functools import cached_property, partial -from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, - Optional, Set, Tuple, Type, TypedDict, Union) +from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set, + Tuple, Type, TypedDict, Union) import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat from PIL import Image -from transformers.image_utils import (get_image_size, - infer_channel_dimension_format, - to_numpy_array) +from transformers import BatchFeature +from transformers.models.qwen2_vl import (Qwen2VLImageProcessor, + Qwen2VLProcessor) from transformers.models.qwen2_vl.configuration_qwen2_vl import ( Qwen2VLConfig, Qwen2VLVisionConfig) -from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( - make_batched_images, make_batched_videos, smart_resize) +from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - InputContext, token_inputs) +from vllm.inputs import InputContext from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import QuickGELU @@ -56,14 +54,14 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.image import cached_get_image_processor -from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict, - MultiModalKwargs, NestedTensors) -from vllm.multimodal.utils import cached_get_tokenizer +from vllm.multimodal.inputs import MultiModalDataDict, NestedTensors +from vllm.multimodal.processing import (BaseMultiModalProcessor, + MultiModalDataItems, ProcessorInputs, + PromptReplacement) from vllm.platforms import _Backend -from vllm.sequence import IntermediateTensors, SequenceData +from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope -from vllm.transformers_utils.processor import cached_get_processor +from vllm.utils import is_list_of from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend, @@ -159,7 +157,7 @@ class Qwen2VisionMLP(nn.Module): def __init__( self, in_features: int, - hidden_features: int = None, + hidden_features: int, act_layer: Type[nn.Module] = QuickGELU, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -644,78 +642,8 @@ def load_weights(self, weights: Iterable[Tuple[str, # === Vision input helpers === # -def get_mm_processor_kwargs( - min_pixels: Optional[int] = None, - max_pixels: Optional[int] = None) -> Dict[str, int]: - mm_processor_kwargs = {} - if min_pixels: - mm_processor_kwargs["min_pixels"] = min_pixels - if max_pixels: - mm_processor_kwargs["max_pixels"] = max_pixels - return mm_processor_kwargs - - -def mm_input_mapper_for_qwen2_vl( - ctx: InputContext, - data: MultiModalData[object], - data_type_key: str, - *, - min_pixels: Optional[int] = None, - max_pixels: Optional[int] = None, -) -> MultiModalKwargs: - """Input mapper for Qwen2-VL.""" - if data_type_key == "image" and isinstance(data, dict): - return MultiModalKwargs({ - "image_embeds": data.get("image_embeds"), - "image_grid_thw": data.get("image_grid_thw"), - }) - if data_type_key == "video" and isinstance(data, dict): - return MultiModalKwargs({ - "video_embeds": data.get("video_embeds"), - "video_grid_thw": data.get("video_grid_thw"), - }) - - model_config = ctx.model_config - # Handle mm processor kwargs; we pass these at creation time - # because preprocess() in transformers doesn't expose them - mm_processor_kwargs = get_mm_processor_kwargs(min_pixels=min_pixels, - max_pixels=max_pixels) - image_processor = cached_get_image_processor( - model_config.model, - trust_remote_code=model_config.trust_remote_code, - **mm_processor_kwargs, - ) - if image_processor is None: - raise RuntimeError("No HuggingFace processor is available " - "to process the image object") - - images = None - videos = None - if data_type_key == "image": - images = data - else: - assert data_type_key == "video" - videos = data - - try: - batch_data = image_processor \ - .preprocess(images=images, videos=videos, return_tensors="pt") \ - .data - except Exception: - logger.error("Failed to process image (%s)", data) - raise - - return MultiModalKwargs(batch_data) - - -image_input_mapper_for_qwen2_vl = partial(mm_input_mapper_for_qwen2_vl, - data_type_key="image") -video_input_mapper_for_qwen2_vl = partial(mm_input_mapper_for_qwen2_vl, - data_type_key="video") - - def _get_vision_info( - image_processor, + vision_config: Qwen2VLVisionConfig, height: int, width: int, min_pixels: int, @@ -726,12 +654,15 @@ def _get_vision_info( ): """Get information (resized height / width and number of vision tokens) of input image / video frame.""" + patch_size = vision_config.patch_size + merge_size = vision_config.spatial_merge_size + temporal_patch_size = vision_config.temporal_patch_size if do_resize: resized_height, resized_width = smart_resize( height=height, width=width, - factor=image_processor.patch_size * image_processor.merge_size, + factor=patch_size * merge_size, min_pixels=min_pixels, max_pixels=max_pixels, ) @@ -742,54 +673,41 @@ def _get_vision_info( grid_t = mm_count else: assert data_type_key == "video" - grid_t = max(mm_count // image_processor.temporal_patch_size, 1) + grid_t = max(mm_count // temporal_patch_size, 1) - grid_h = resized_height // image_processor.patch_size - grid_w = resized_width // image_processor.patch_size + grid_h = resized_height // patch_size + grid_w = resized_width // patch_size vision_tokens = grid_t * grid_h * grid_w - llm_num_vision_tokens = (vision_tokens // image_processor.merge_size // - image_processor.merge_size) + llm_num_vision_tokens = vision_tokens // (merge_size**2) return resized_height, resized_width, llm_num_vision_tokens -def _get_max_image_info( - image_processor, - data_type_key: str = "image", - mm_count: int = 1, - min_pixels: Optional[int] = None, - max_pixels: Optional[int] = None, -): - # Limit min / max pixels unless they're explicitly provided - if min_pixels is None: - min_pixels = max(image_processor.min_pixels, 28 * 28) - if max_pixels is None: - max_pixels = min(image_processor.max_pixels, 1280 * 28 * 28) - - return _get_vision_info( - image_processor, - height=9999999, - width=9999999, - min_pixels=min_pixels, - max_pixels=max_pixels, - data_type_key=data_type_key, - mm_count=mm_count, - ) +def _get_image_processor(hf_processor: Qwen2VLProcessor): + image_processor = hf_processor.image_processor # type: ignore + assert isinstance(image_processor, Qwen2VLImageProcessor) + return image_processor def get_max_qwen2_vl_mm_tokens(ctx: InputContext, data_type_key: str, *, - min_pixels=None, - max_pixels=None) -> int: - mm_processor_kwargs = get_mm_processor_kwargs(min_pixels=min_pixels, - max_pixels=max_pixels) - image_processor = cached_get_image_processor(ctx.model_config.model, - **mm_processor_kwargs) - max_resized_height, max_resized_width, max_llm_image_tokens = \ - _get_max_image_info(image_processor, data_type_key=data_type_key, - mm_count=1, min_pixels=min_pixels, - max_pixels=max_pixels) + min_pixels: Optional[int] = None, + max_pixels: Optional[int] = None) -> int: + hf_config = ctx.get_hf_config(Qwen2VLConfig) + vision_config = hf_config.vision_config + + hf_processor = ctx.get_hf_processor(Qwen2VLProcessor) + image_processor = _get_image_processor(hf_processor) + + _, _, max_llm_image_tokens = _get_vision_info( + vision_config, + height=9999999, + width=9999999, + min_pixels=min_pixels or image_processor.min_pixels, + max_pixels=max_pixels or image_processor.max_pixels, + data_type_key=data_type_key, + ) return max_llm_image_tokens @@ -799,290 +717,166 @@ def get_max_qwen2_vl_mm_tokens(ctx: InputContext, data_type_key="video") -def dummy_data_for_qwen2_vl( - ctx: InputContext, - seq_len: int, - mm_counts: Mapping[str, int], - *, - min_pixels: Optional[int] = None, - max_pixels: Optional[int] = None -) -> Tuple[SequenceData, Optional[MultiModalDataDict]]: - mm_processor_kwargs = get_mm_processor_kwargs(min_pixels=min_pixels, - max_pixels=max_pixels) - image_processor = cached_get_image_processor(ctx.model_config.model, - **mm_processor_kwargs) - - num_images = mm_counts["image"] - max_resized_height, max_resized_width, max_llm_image_tokens = \ - _get_max_image_info(image_processor, data_type_key="image", - mm_count=num_images, min_pixels=min_pixels, - max_pixels=max_pixels) - if seq_len - max_llm_image_tokens - 2 < 0: - raise RuntimeError( - f"Qwen2-VL cannot process {num_images} images in a prompt, " - "please increase max_model_len or reduce image limit by " - "--limit-mm-per-prompt.") - - # Check video counts. - num_videos = mm_counts["video"] - max_resized_height, max_resized_width, max_llm_video_tokens = \ - _get_max_image_info(image_processor, data_type_key="video", - mm_count=num_videos, min_pixels=min_pixels, - max_pixels=max_pixels) - if seq_len - max_llm_video_tokens - 2 < 0: - raise RuntimeError( - f"Qwen2-VL cannot process {num_videos} videos in a prompt, " - "please increase max_model_len or reduce video limit by " - "--limit-mm-per-prompt.") - - hf_config = ctx.get_hf_config(Qwen2VLConfig) - - dummy_seqdata = SequenceData.from_prompt_token_counts( - (hf_config.vision_start_token_id, 1), - (hf_config.image_token_id, max_llm_image_tokens), - (hf_config.vision_end_token_id, 1), - (0, seq_len - max_llm_image_tokens - 2), - ) - - dummy_image = Image.new("RGB", (max_resized_width, max_resized_height), - color=0) +class Qwen2VLMultiModalDataItems(MultiModalDataItems): - return DummyData(dummy_seqdata, { - "image": - dummy_image if num_images == 1 else [dummy_image] * num_images - }) + @staticmethod + def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems": + """ + Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`. + """ + multi_data = Qwen2VLMultiModalDataItems() + + for k, v in data.items(): + # TODO: Make a separate modality for embedding inputs + # to avoid confusion + # yapf: disable + if k == "video": + # Special case since even a single item can be a list + multi_data[k] = ( # type: ignore[index] + v if (isinstance(v, (dict, torch.Tensor)) # type: ignore[assignment] + or is_list_of(v, list)) else [v] + ) + elif k in ("image", "audio"): + multi_data[k] = ( # type: ignore[index] + v if isinstance(v, (dict, torch.Tensor, list)) else [v] + ) + else: + multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index] + # yapf: enable + return multi_data -def _get_llm_num_vision_tokens( - mm_inputs: list, - data_type_key: str, - image_processor, - min_pixels: int, - max_pixels: int, -): - """Get number of vision tokens of multimodal inputs. + def get_item_counts(self) -> Mapping[str, int]: + return { + m: ( + len(items[f"{m}_grid_thw"]) # type: ignore + if isinstance(items, dict) else len(items)) + for m, items in self.items() + } - This method is derived from `transformers.models.qwen2_vl. - image_processing_qwen2_vl.Qwen2VLImageProcessor._preprocess`. - """ - image = to_numpy_array(mm_inputs[0]) - input_data_format = infer_channel_dimension_format(image) - height, width = get_image_size(image, channel_dim=input_data_format) - - _, _, llm_num_vision_tokens = _get_vision_info( - image_processor, - height=height, - width=width, - min_pixels=min_pixels, - max_pixels=max_pixels, - do_resize=image_processor.do_resize, - data_type_key=data_type_key, - mm_count=len(mm_inputs), - ) - return llm_num_vision_tokens +class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor): -def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable, - data_type_key: str, image_processor: Any, - prompt_token_ids: List[int], min_pixels: Optional[int], - max_pixels: Optional[int]) -> List[int]: - """ - Expand pad tokens for multi-modal inputs (e.g., images or videos). - - Args: - inputs (list): The multi-modal inputs (e.g., images or videos). - token_id (int): The token ID used to represent the multi-modal input. - make_batched_fn (Callable): A function to batch the inputs. - data_type_key (str): The type of the multi-modal input. - image_processor (Any): The image processor used to process the inputs. - prompt_token_ids (List[int]): The list of token IDs in the prompt. - min_pixels (int): min pixels to used for img processing - max_pixels (int): max pixels to be used for img processing - - Returns: - List[int]: The list of token IDs for the multi-modal inputs. - """ - indices = [ - idx for idx, token in enumerate(prompt_token_ids) if token == token_id - ] - inputs = make_batched_fn(inputs) - assert len(indices) == len(inputs) - - prompt_token_ids_with_data = [] - for cnt, data in enumerate(inputs): - num_tokens = _get_llm_num_vision_tokens( - [data] if data_type_key == "image" else data, - data_type_key=data_type_key, - image_processor=image_processor, - min_pixels=min_pixels, - max_pixels=max_pixels, - ) - if cnt == 0: - end_idx = indices[cnt] - non_data_tokens = prompt_token_ids[:end_idx] - else: - non_data_tokens = prompt_token_ids[indices[cnt - 1] + - 1:indices[cnt]] - prompt_token_ids_with_data.extend(non_data_tokens) - prompt_token_ids_with_data.extend(token_id for _ in range(num_tokens)) - prompt_token_ids_with_data.extend(prompt_token_ids[indices[-1] + 1:]) - return prompt_token_ids_with_data - - -def input_processor_for_qwen2_vl( - ctx: InputContext, - inputs: DecoderOnlyInputs, - *, - min_pixels: Optional[int] = None, - max_pixels: Optional[int] = None, -) -> DecoderOnlyInputs: - multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is None: - return inputs - - image_inputs = multi_modal_data.get("image", None) - video_inputs = multi_modal_data.get("video", None) - - processor = cached_get_processor(ctx.model_config.model) - image_processor = processor.image_processor - # Apply processor kwarg overrides for image processor options - min_pixels = min_pixels if min_pixels else image_processor.min_pixels - max_pixels = max_pixels if max_pixels else image_processor.max_pixels - - model_config = ctx.model_config - hf_config = ctx.get_hf_config(Qwen2VLConfig) + def _get_mm_items( + self, + mm_data: MultiModalDataDict, + ) -> MultiModalDataItems: + return Qwen2VLMultiModalDataItems.from_dict(mm_data) - # To avoid redundant processing of vision objects (resize, rescale, etc.), - # we extract code of calculating number of vision tokens from - # `transformers.models.qwen2_vl.processing_qwen2_vl.Qwen2VLProcessor`. - # - # The following code is equivalent to: - # prompt = inputs["prompt"] - # inputs = processor(text=[prompt], - # images=image_inputs, - # videos=video_inputs, - # padding=True, - # return_tensors="pt") - # prompt_token_ids = inputs["input_ids"][0].tolist() - - tokenizer = cached_get_tokenizer( - model_config.tokenizer, - trust_remote_code=model_config.trust_remote_code) - - prompt_token_ids = inputs["prompt_token_ids"] - - # Expand image pad tokens. - - if image_inputs is not None: - if isinstance(image_inputs, dict): - prompt_token_ids_with_image = [] - image_indices = [ - idx for idx, token in enumerate(prompt_token_ids) - if token == hf_config.image_token_id - ] - - # ensure all image tokens have grid_thw - assert \ - len(image_indices) == image_inputs["image_grid_thw"].size(0), \ - "image token num does not match image_grid_thw.shape" - - image_counter = 0 - pad_token_counter = 0 - for idx, token in enumerate(prompt_token_ids): - if idx in image_indices: - grid_thw = image_inputs["image_grid_thw"][image_counter] - grid_t, grid_h, grid_w = grid_thw - num_pad_tokens = (grid_t * grid_h * grid_w // - image_processor.merge_size // - image_processor.merge_size) - prompt_token_ids_with_image.extend([token] * - num_pad_tokens) - image_counter += 1 - pad_token_counter += num_pad_tokens + def _get_hf_processor( + self, + *, + min_pixels: Optional[int] = None, + max_pixels: Optional[int] = None, + ) -> Qwen2VLProcessor: + hf_processor = self.ctx.get_hf_processor(Qwen2VLProcessor) + image_processor = _get_image_processor(hf_processor) + + if min_pixels: + image_processor.min_pixels = min_pixels + if max_pixels: + image_processor.max_pixels = max_pixels + if max_pixels or min_pixels: + image_processor.size = { + "min_pixels": image_processor.min_pixels, + "max_pixels": image_processor.max_pixels, + } + + return hf_processor + + def _get_processor_data( + self, + mm_items: MultiModalDataItems, + ) -> tuple[dict[str, Any], dict[str, Any]]: + processor_data = dict[str, Any]() + passthrough_data = dict[str, Any]() + + for k, v in mm_items.items(): + # TODO: Make a separate modality for embedding inputs + # to avoid confusion + if k in ("image", "video", "audio"): + if isinstance(v, dict): + # Pass through embedding inputs (dict) + passthrough_data.update(v) + elif isinstance(v, torch.Tensor) and v.ndim == 3: + # Pass through embedding inputs (single) + passthrough_data[f"{k}_embeds"] = [v] + elif (is_list_of(v, torch.Tensor) and len(v) > 0 + and v[0].ndim == 2): + # Pass through embedding inputs (multi) + passthrough_data[f"{k}_embeds"] = v else: - prompt_token_ids_with_image.append(token) + # Map keys to plural form, e.g.: image -> images + processor_data[f"{k}s"] = v + else: + processor_data[k] = v - # ensure all embeddings are used - assert \ - pad_token_counter == image_inputs["image_embeds"].size(0), \ - "image_embeds.shape does not match image_grid_thw" + return processor_data, passthrough_data - prompt_token_ids = prompt_token_ids_with_image - else: - prompt_token_ids = _expand_pad_tokens(image_inputs, - hf_config.image_token_id, - make_batched_images, - "image", - image_processor, - prompt_token_ids, - min_pixels=min_pixels, - max_pixels=max_pixels) - - if video_inputs is not None: - if isinstance(video_inputs, dict): - prompt_token_ids_with_video = [] - video_indices = [ - idx for idx, token in enumerate(prompt_token_ids) - if token == hf_config.video_token_id - ] - - # ensure all video tokens have grid_thw - assert \ - len(video_indices) == video_inputs["video_grid_thw"].size(0), \ - "video token num does not match video_grid_thw.shape" - - video_counter = 0 - pad_token_counter = 0 - for idx, token in enumerate(prompt_token_ids): - if idx in video_indices: - grid_thw = video_inputs["video_grid_thw"][video_counter] - grid_t, grid_h, grid_w = grid_thw - num_pad_tokens = (grid_t * grid_h * grid_w // - image_processor.merge_size // - image_processor.merge_size) - prompt_token_ids_with_video.extend([token] * - num_pad_tokens) - video_counter += 1 - pad_token_counter += num_pad_tokens - else: - prompt_token_ids_with_video.append(token) + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_inputs: BatchFeature, + mm_processor_kwargs: Mapping[str, object], + ) -> list[PromptReplacement]: + hf_processor = self._get_hf_processor() + image_processor = _get_image_processor(hf_processor) + + # NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has + # image_token and video_token registered + placeholder = { + "image": hf_processor.image_token, + "video": hf_processor.video_token, + } + merge_length = image_processor.merge_size**2 + + def get_replacement_qwen2vl(item_idx: int, modality: str): + grid_thw = hf_inputs[f"{modality}_grid_thw"][item_idx] + num_tokens = grid_thw.prod() // merge_length + return placeholder[modality] * num_tokens + + return [ + PromptReplacement( + modality=modality, + target=placeholder[modality], + replacement=partial(get_replacement_qwen2vl, + modality=modality), + ) for modality in ("image", "video") + ] - # ensure all embeddings are used - assert \ - pad_token_counter == video_inputs["video_embeds"].size(0), \ - "video_embeds.shape does not match video_grid_thw" + def _get_dummy_mm_inputs( + self, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts["image"] + hf_processor = self._get_hf_processor() + image_token: str = hf_processor.image_token + image_processor = _get_image_processor(hf_processor) + + data = {} + resized_height, resized_width = smart_resize( + height=9999999, + width=9999999, + factor=image_processor.patch_size * image_processor.merge_size, + min_pixels=image_processor.min_pixels, + max_pixels=image_processor.max_pixels, + ) - prompt_token_ids = prompt_token_ids_with_video - else: - prompt_token_ids = _expand_pad_tokens(video_inputs, - hf_config.video_token_id, - make_batched_videos, - "video", - image_processor, - prompt_token_ids, - min_pixels=min_pixels, - max_pixels=max_pixels) - - prompt = inputs.get("prompt") - if prompt is None: - prompt = tokenizer.decode(prompt_token_ids) - - return token_inputs( - prompt_token_ids=prompt_token_ids, - prompt=prompt, - multi_modal_data=multi_modal_data, - ) + dummy_image = Image.new("RGB", (resized_width, resized_height), + color=0) + data["image"] = [dummy_image] * num_images + + return ProcessorInputs( + prompt_text=image_token * num_images, + mm_data=data, + mm_processor_kwargs={}, + ) -@MULTIMODAL_REGISTRY.register_image_input_mapper( - image_input_mapper_for_qwen2_vl) -@MULTIMODAL_REGISTRY.register_input_mapper("video", - video_input_mapper_for_qwen2_vl) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_qwen2_vl_image_tokens) @MULTIMODAL_REGISTRY.register_max_multimodal_tokens( "video", get_max_qwen2_vl_video_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_vl) -@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_vl) +@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor) class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP): packed_modules_mapping = { @@ -1110,7 +904,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() - config = vllm_config.model_config.hf_config + config: Qwen2VLConfig = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index b00513e5b37cb..6baf19d675d50 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -220,15 +220,18 @@ def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems": multi_data = MultiModalDataItems() for k, v in data.items(): + # TODO: Make a separate modality for embedding inputs + # to avoid confusion # yapf: disable if k == "video": # Special case since even a single item can be a list multi_data[k] = ( # type: ignore[index] - v if is_list_of(v, (list, torch.Tensor)) else [v] + v if (isinstance(v, torch.Tensor) + or is_list_of(v, list)) else [v] ) elif k in ("image", "audio"): multi_data[k] = ( # type: ignore[index] - v if isinstance(v, (list, torch.Tensor)) else [v] + v if isinstance(v, (torch.Tensor, list)) else [v] ) else: multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index] @@ -252,6 +255,9 @@ def videos(self) -> Sequence[VideoItem]: def audios(self) -> Sequence[AudioItem]: return self.get("audio", []) + def get_item_counts(self) -> Mapping[str, int]: + return {m: len(items) for m, items in self.items()} + def get_image_size(self, item_idx: int) -> ImageSize: image = self.images[item_idx] @@ -612,6 +618,12 @@ def _get_hf_processor(self) -> ProcessorMixin: def _get_tokenizer(self) -> AnyTokenizer: return self.ctx.tokenizer + def _get_mm_items( + self, + mm_data: MultiModalDataDict, + ) -> MultiModalDataItems: + return MultiModalDataItems.from_dict(mm_data) + @abstractmethod def _get_prompt_replacements( self, @@ -778,7 +790,7 @@ def apply( 3. Extract information about the placeholder tokens from the processed token IDs. """ - mm_items = MultiModalDataItems.from_dict(mm_data) + mm_items = self._get_mm_items(mm_data) hf_inputs = self._apply_hf_processor(prompt_text, mm_items, mm_processor_kwargs) @@ -791,7 +803,7 @@ def apply( # If HF processor already inserts placeholder tokens, # there is no need for us to insert them - mm_item_counts = {m: len(items) for m, items in mm_items.items()} + mm_item_counts = mm_items.get_item_counts() all_placeholders = self._find_placeholders(all_prompt_repls, prompt_ids, mm_item_counts) From cdf22afddad7b29e8d584b77863a563a91ac09fb Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 20 Dec 2024 00:59:32 +0800 Subject: [PATCH 3/6] [Misc] Clean up and consolidate LRUCache (#11339) Signed-off-by: DarkLight1337 --- vllm/adapter_commons/models.py | 9 ++- .../tokenizer_group/tokenizer_group.py | 2 +- vllm/utils.py | 59 ++++++++----------- vllm/v1/engine/mm_input_mapper.py | 6 +- vllm/v1/utils.py | 25 -------- 5 files changed, 34 insertions(+), 67 deletions(-) diff --git a/vllm/adapter_commons/models.py b/vllm/adapter_commons/models.py index a5c04ab78fbe8..468904c90fff4 100644 --- a/vllm/adapter_commons/models.py +++ b/vllm/adapter_commons/models.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Hashable, Optional, TypeVar +from typing import Any, Callable, Dict, Optional, TypeVar from torch import nn @@ -24,14 +24,13 @@ def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs): T = TypeVar('T') -class AdapterLRUCache(LRUCache[T]): +class AdapterLRUCache(LRUCache[int, T]): - def __init__(self, capacity: int, deactivate_fn: Callable[[Hashable], - None]): + def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]): super().__init__(capacity) self.deactivate_fn = deactivate_fn - def _on_remove(self, key: Hashable, value: Optional[T]): + def _on_remove(self, key: int, value: Optional[T]): logger.debug("Removing adapter int id: %d", key) self.deactivate_fn(key) return super()._on_remove(key, value) diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py index 761b07f34d2f9..95a8f7098bbac 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -22,7 +22,7 @@ def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, self.max_input_length = max_input_length self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) max_loras = tokenizer_config.get("max_loras", 0) - self.lora_tokenizers = LRUCache[AnyTokenizer]( + self.lora_tokenizers = LRUCache[int, AnyTokenizer]( capacity=max(max_loras, max_num_seqs) if enable_lora else 0) @classmethod diff --git a/vllm/utils.py b/vllm/utils.py index ba567feb19792..3934903385ad4 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -21,14 +21,13 @@ import warnings import weakref from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task -from collections import UserDict, defaultdict +from collections import OrderedDict, UserDict, defaultdict from collections.abc import Iterable, Mapping from dataclasses import dataclass, field from functools import lru_cache, partial, wraps from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Dict, Generator, Generic, Hashable, List, Literal, - Optional, OrderedDict, Set, Tuple, Type, TypeVar, Union, - overload) + Optional, Tuple, Type, TypeVar, Union, overload) from uuid import uuid4 import numpy as np @@ -154,10 +153,12 @@ } P = ParamSpec('P') -K = TypeVar("K") T = TypeVar("T") U = TypeVar("U") +_K = TypeVar("_K", bound=Hashable) +_V = TypeVar("_V") + class _Sentinel: ... @@ -190,50 +191,48 @@ def reset(self) -> None: self.counter = 0 -class LRUCache(Generic[T]): +class LRUCache(Generic[_K, _V]): - def __init__(self, capacity: int): - self.cache: OrderedDict[Hashable, T] = OrderedDict() - self.pinned_items: Set[Hashable] = set() + def __init__(self, capacity: int) -> None: + self.cache = OrderedDict[_K, _V]() + self.pinned_items = set[_K]() self.capacity = capacity - def __contains__(self, key: Hashable) -> bool: + def __contains__(self, key: _K) -> bool: return key in self.cache def __len__(self) -> int: return len(self.cache) - def __getitem__(self, key: Hashable) -> T: + def __getitem__(self, key: _K) -> _V: value = self.cache[key] # Raise KeyError if not exists self.cache.move_to_end(key) return value - def __setitem__(self, key: Hashable, value: T) -> None: + def __setitem__(self, key: _K, value: _V) -> None: self.put(key, value) - def __delitem__(self, key: Hashable) -> None: + def __delitem__(self, key: _K) -> None: self.pop(key) - def touch(self, key: Hashable) -> None: + def touch(self, key: _K) -> None: self.cache.move_to_end(key) - def get(self, - key: Hashable, - default_value: Optional[T] = None) -> Optional[T]: - value: Optional[T] + def get(self, key: _K, default: Optional[_V] = None) -> Optional[_V]: + value: Optional[_V] if key in self.cache: value = self.cache[key] self.cache.move_to_end(key) else: - value = default_value + value = default return value - def put(self, key: Hashable, value: T) -> None: + def put(self, key: _K, value: _V) -> None: self.cache[key] = value self.cache.move_to_end(key) self._remove_old_if_needed() - def pin(self, key: Hashable) -> None: + def pin(self, key: _K) -> None: """ Pins a key in the cache preventing it from being evicted in the LRU order. @@ -242,13 +241,13 @@ def pin(self, key: Hashable) -> None: raise ValueError(f"Cannot pin key: {key} not in cache.") self.pinned_items.add(key) - def _unpin(self, key: Hashable) -> None: + def _unpin(self, key: _K) -> None: self.pinned_items.remove(key) - def _on_remove(self, key: Hashable, value: Optional[T]): + def _on_remove(self, key: _K, value: Optional[_V]) -> None: pass - def remove_oldest(self, remove_pinned=False): + def remove_oldest(self, *, remove_pinned: bool = False) -> None: if not self.cache: return @@ -262,17 +261,15 @@ def remove_oldest(self, remove_pinned=False): "cannot remove oldest from the cache.") else: lru_key = next(iter(self.cache)) - self.pop(lru_key) + self.pop(lru_key) # type: ignore def _remove_old_if_needed(self) -> None: while len(self.cache) > self.capacity: self.remove_oldest() - def pop(self, - key: Hashable, - default_value: Optional[T] = None) -> Optional[T]: + def pop(self, key: _K, default: Optional[_V] = None) -> Optional[_V]: run_on_remove = key in self.cache - value: Optional[T] = self.cache.pop(key, default_value) + value = self.cache.pop(key, default) # remove from pinned items if key in self.pinned_items: self._unpin(key) @@ -280,7 +277,7 @@ def pop(self, self._on_remove(key, value) return value - def clear(self): + def clear(self) -> None: while len(self.cache) > 0: self.remove_oldest(remove_pinned=True) self.cache.clear() @@ -843,10 +840,6 @@ def flatten_2d_lists(lists: List[List[T]]) -> List[T]: return [item for sublist in lists for item in sublist] -_K = TypeVar("_K", bound=Hashable) -_V = TypeVar("_V") - - def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]): """ Unlike :class:`itertools.groupby`, groups are not broken by diff --git a/vllm/v1/engine/mm_input_mapper.py b/vllm/v1/engine/mm_input_mapper.py index cb97f743b1d52..218724bff6bba 100644 --- a/vllm/v1/engine/mm_input_mapper.py +++ b/vllm/v1/engine/mm_input_mapper.py @@ -8,7 +8,7 @@ from vllm.logger import init_logger from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalKwargs, MultiModalRegistry) -from vllm.v1.utils import LRUDictCache +from vllm.utils import LRUCache logger = init_logger(__name__) @@ -44,7 +44,7 @@ def __init__( # Init cache self.use_cache = not model_config.disable_mm_preprocessor_cache - self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE) + self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE) # DEBUG: Set to None to disable self.mm_debug_cache_hit_ratio_steps = None @@ -120,7 +120,7 @@ class MMInputMapperServer: def __init__(self, model_config): self.use_cache = not model_config.disable_mm_preprocessor_cache - self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE) + self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE) def process_inputs( self, diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 5f327d7066830..e802c6439b740 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -1,4 +1,3 @@ -from collections import OrderedDict from collections.abc import Sequence from contextlib import contextmanager from typing import (Any, Generic, Iterator, List, Optional, TypeVar, Union, @@ -102,27 +101,3 @@ def make_zmq_socket( finally: ctx.destroy(linger=0) - - -K = TypeVar('K') -V = TypeVar('V') - - -class LRUDictCache(Generic[K, V]): - - def __init__(self, size: int): - self.cache: OrderedDict[K, V] = OrderedDict() - self.size = size - - def get(self, key: K, default=None) -> V: - if key not in self.cache: - return default - - self.cache.move_to_end(key) - return self.cache[key] - - def put(self, key: K, value: V): - self.cache[key] = value - self.cache.move_to_end(key) - if len(self.cache) > self.size: - self.cache.popitem(last=False) From 276738ce0f6aac48ace36bc79aa4a0765fccdfb2 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Fri, 20 Dec 2024 01:37:31 +0800 Subject: [PATCH 4/6] [Bugfix] Fix broken CPU compressed-tensors test (#11338) Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/layers/quantization/utils/w8a8_utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index d77722499d0e9..d89071f30a549 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -11,8 +11,7 @@ def sparse_cutlass_supported() -> bool: - # sparse cutlass is not supported on Rocm - if current_platform.is_rocm(): + if not current_platform.is_cuda(): return False capability_tuple = current_platform.get_device_capability() @@ -22,8 +21,7 @@ def sparse_cutlass_supported() -> bool: def cutlass_fp8_supported() -> bool: - # cutlass is not supported on Rocm - if current_platform.is_rocm(): + if not current_platform.is_cuda(): return False capability_tuple = current_platform.get_device_capability() From e461c262f0d4c9911f1bf75bea723f8ae17219be Mon Sep 17 00:00:00 2001 From: yangzhibin <45459326+Ghjk94522@users.noreply.github.com> Date: Fri, 20 Dec 2024 01:54:24 +0800 Subject: [PATCH 5/6] [Misc] Remove unused vllm/block.py (#11336) --- vllm/block.py | 88 -------------------------------------------- vllm/core/evictor.py | 4 +- 2 files changed, 2 insertions(+), 90 deletions(-) delete mode 100644 vllm/block.py diff --git a/vllm/block.py b/vllm/block.py deleted file mode 100644 index 47c381c19383b..0000000000000 --- a/vllm/block.py +++ /dev/null @@ -1,88 +0,0 @@ -"""Token blocks.""" -from typing import TYPE_CHECKING, Iterator, List, Optional - -from vllm.utils import Device - -DEFAULT_LAST_ACCESSED_TIME: float = -1 - - -class PhysicalTokenBlock: - """Represents the state of a block in the KV cache.""" - - def __init__( - self, - device: Device, - block_number: int, - block_size: int, - block_hash: int, - num_hashed_tokens: int, - ) -> None: - self.device = device - self.block_number = block_number - self.block_size = block_size - self.block_hash = block_hash - self.num_hashed_tokens = num_hashed_tokens - - self.ref_count = 0 - self.last_accessed = DEFAULT_LAST_ACCESSED_TIME - - self.computed = False - - def __repr__(self) -> str: - return (f'PhysicalTokenBlock(device={self.device}, ' - f'block_number={self.block_number}, ' - f'num_hashed_tokens={self.num_hashed_tokens}, ' - f'ref_count={self.ref_count}, ' - f'last_accessed={self.last_accessed}, ' - f'computed={self.computed})') - - -class BlockTable: - """Holds a list of blocks with caching of their associated block_ids - """ - - def __init__(self, blocks: Optional[List[PhysicalTokenBlock]] = None): - self._blocks: List[PhysicalTokenBlock] = [] - self._block_ids: List[int] = [] - - if blocks is not None: - for block in blocks: - self.append(block) - - def append(self, block: PhysicalTokenBlock): - self._blocks.append(block) - self._block_ids.append(block.block_number) - - def __len__(self) -> int: - return len(self._blocks) - - def __getitem__(self, key): - return self._blocks[key] - - if TYPE_CHECKING: - - def __iter__(self) -> Iterator[PhysicalTokenBlock]: - raise RuntimeError("Method should be automatically generated") - - def __setitem__(self, key, value): - if isinstance(key, slice): - blocks = value - self._blocks[key] = blocks - self._block_ids[key] = [b.block_number for b in blocks] - else: - block = value - self._blocks[key] = block - self._block_ids[key] = block.block_number - - def reset(self): - self._blocks = [] - self._block_ids = [] - - def copy(self) -> "BlockTable": - return BlockTable(self._blocks) - - def list(self) -> List[PhysicalTokenBlock]: - return self._blocks - - def ids(self) -> List[int]: - return self._block_ids diff --git a/vllm/core/evictor.py b/vllm/core/evictor.py index 44adc4158abec..c9306518223a3 100644 --- a/vllm/core/evictor.py +++ b/vllm/core/evictor.py @@ -13,7 +13,7 @@ class EvictionPolicy(enum.Enum): class Evictor(ABC): """The Evictor subclasses should be used by the BlockAllocator class to - handle eviction of freed PhysicalTokenBlocks. + handle eviction of freed Blocks. """ @abstractmethod @@ -70,7 +70,7 @@ def __init__(self, content_hash: int, num_hashed_tokens: int, class LRUEvictor(Evictor): """Evicts in a least-recently-used order using the last_accessed timestamp - that's recorded in the PhysicalTokenBlock. If there are multiple blocks with + that's recorded in the Block. If there are multiple blocks with the same last_accessed time, then the one with the largest num_hashed_tokens will be evicted. If two blocks each have the lowest last_accessed time and highest num_hashed_tokens value, then one will be chose arbitrarily From a985f7af9f7b249974b283a9d999575ac30fac3d Mon Sep 17 00:00:00 2001 From: Yuan Date: Fri, 20 Dec 2024 03:46:55 +0800 Subject: [PATCH 6/6] [CI] Adding CPU docker pipeline (#11261) Signed-off-by: Yuan Zhou Co-authored-by: Kevin H. Luu --- .buildkite/release-pipeline.yaml | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 2de6fceb0c3fe..51618a2955fb1 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -55,3 +55,18 @@ steps: password-env: DOCKERHUB_TOKEN env: DOCKER_BUILDKIT: "1" + + - block: "Build CPU release image" + key: block-cpu-release-image-build + depends_on: ~ + + - label: "Build and publish CPU release image" + depends_on: block-cpu-release-image-build + agents: + queue: cpu_queue_postmerge + commands: + - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$RELEASE_VERSION --progress plain -f Dockerfile.cpu ." + - "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$RELEASE_VERSION" + env: + DOCKER_BUILDKIT: "1"