From 3682e33f9ff9d8baade6112a8e75a77da898f504 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 30 Dec 2024 12:24:12 +0800 Subject: [PATCH 01/10] [v1] fix compilation cache (#11598) Signed-off-by: youkaichao --- tests/compile/piecewise/test_toy_llama.py | 15 +++++++- vllm/compilation/backends.py | 22 ++++++----- vllm/config.py | 45 +++++++++++++++++++++-- vllm/v1/worker/gpu_worker.py | 1 + 4 files changed, 69 insertions(+), 14 deletions(-) diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index 07c10a3a18c55..d4ede4d2320a7 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -7,7 +7,7 @@ initialized randomly with a fixed seed. """ from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Any, List, Optional, Tuple import torch from torch import nn @@ -54,6 +54,16 @@ class LlamaConfig: tractable_init: bool = False random_seed: int = 0 + def compute_hash(self) -> str: + factors: List[Any] = [] + for k, v in self.__dict__.items(): + if k == "random_seed": + continue + factors.append((k, v)) + factors.sort() + import hashlib + return hashlib.md5(str(factors).encode()).hexdigest() + def __post_init__(self): assert self.mlp_size >= self.hidden_size @@ -263,7 +273,8 @@ def run_model(llama_config, compilation_config = CompilationConfig( level=CompilationLevel.NO_COMPILATION, ) - vllm_config = VllmConfig(compilation_config=compilation_config) + vllm_config = VllmConfig(compilation_config=compilation_config, + additional_config=llama_config) with set_current_vllm_config(vllm_config): model = LlamaModel(config=llama_config, vllm_config=vllm_config, diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 4f960b441f21d..a8dd628b9cd6f 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -619,8 +619,10 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, # the entries for different shapes that we need to either # compile or capture cudagraph self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {} - self.to_be_compiled_sizes: Set[int] = self.compile_sizes.union( - self.capture_sizes) + + # to_be_compiled_sizes tracks the remaining sizes to compile, + # and updates during the compilation process, so we need to copy it + self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy() for shape in self.compile_sizes.union(self.capture_sizes): self.concrete_size_entries[shape] = ConcreteSizeEntry( runtime_shape=shape, @@ -628,12 +630,17 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, use_cudagraph=shape in self.capture_sizes, ) + def check_for_ending_compilation(self): + if self.is_last_graph and not self.to_be_compiled_sizes: + # no specific sizes to compile + # save the hash of the inductor graph for the next run + self.compilation_config.inductor_hash_cache.save_to_file() + end_monitoring_torch_compile(self.vllm_config) + def __call__(self, *args) -> Any: if not self.first_run_finished: self.first_run_finished = True - # no specific sizes to compile - if self.is_last_graph and not self.to_be_compiled_sizes: - end_monitoring_torch_compile(self.vllm_config) + self.check_for_ending_compilation() return self.compiled_graph_for_general_shape(*args) runtime_shape = args[self.sym_shape_indices[0]] @@ -662,10 +669,7 @@ def __call__(self, *args) -> Any: # finished compilations for all required shapes if self.is_last_graph and not self.to_be_compiled_sizes: - - # save the hash of the inductor graph for the next run - self.compilation_config.inductor_hash_cache.save_to_file() - end_monitoring_torch_compile(self.vllm_config) + self.check_for_ending_compilation() if not entry.use_cudagraph: return entry.runnable(*args) diff --git a/vllm/config.py b/vllm/config.py index 8e556743c8528..765a46e6aeee3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -9,8 +9,8 @@ from dataclasses import dataclass, field, replace from pathlib import Path from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict, - Final, List, Literal, Mapping, Optional, Set, Tuple, Type, - Union) + Final, List, Literal, Mapping, Optional, Protocol, Set, + Tuple, Type, Union) import torch from pydantic import BaseModel, Field, PrivateAttr @@ -75,6 +75,12 @@ PretrainedConfig]] +class SupportsHash(Protocol): + + def compute_hash(self) -> str: + ... + + class ModelConfig: """Configuration for the model. @@ -2969,6 +2975,10 @@ class VllmConfig: init=True) # type: ignore kv_transfer_config: KVTransferConfig = field(default=None, init=True) # type: ignore + # some opaque config, only used to provide additional information + # for the hash computation, mainly used for testing and debugging. + additional_config: SupportsHash = field(default=None, + init=True) # type: ignore instance_id: str = "" def compute_hash(self) -> str: @@ -3000,33 +3010,62 @@ def compute_hash(self) -> str: vllm_factors.append(__version__) if self.model_config: vllm_factors.append(self.model_config.compute_hash()) + else: + vllm_factors.append("None") if self.cache_config: vllm_factors.append(self.cache_config.compute_hash()) + else: + vllm_factors.append("None") if self.parallel_config: vllm_factors.append(self.parallel_config.compute_hash()) + else: + vllm_factors.append("None") if self.scheduler_config: vllm_factors.append(self.scheduler_config.compute_hash()) + else: + vllm_factors.append("None") if self.device_config: vllm_factors.append(self.device_config.compute_hash()) + else: + vllm_factors.append("None") if self.load_config: vllm_factors.append(self.load_config.compute_hash()) + else: + vllm_factors.append("None") if self.lora_config: vllm_factors.append(self.lora_config.compute_hash()) + else: + vllm_factors.append("None") if self.speculative_config: vllm_factors.append(self.speculative_config.compute_hash()) + else: + vllm_factors.append("None") if self.decoding_config: vllm_factors.append(self.decoding_config.compute_hash()) + else: + vllm_factors.append("None") if self.observability_config: vllm_factors.append(self.observability_config.compute_hash()) + else: + vllm_factors.append("None") if self.prompt_adapter_config: vllm_factors.append(self.prompt_adapter_config.compute_hash()) + else: + vllm_factors.append("None") if self.quant_config: pass # should be captured by model_config.quantization if self.compilation_config: vllm_factors.append(self.compilation_config.compute_hash()) + else: + vllm_factors.append("None") if self.kv_transfer_config: vllm_factors.append(self.kv_transfer_config.compute_hash()) - + else: + vllm_factors.append("None") + if self.additional_config: + vllm_factors.append(self.additional_config.compute_hash()) + else: + vllm_factors.append("None") factors.append(vllm_factors) hash_str = hashlib.md5(str(factors).encode()).hexdigest()[:10] diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 0000b09bfaa36..af438f7d5820c 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -48,6 +48,7 @@ def __init__( self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config + self.parallel_config.rank = rank self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method From 628ec6c17b8121517e8f303b64567573036cdb38 Mon Sep 17 00:00:00 2001 From: Liangfu Chen Date: Sun, 29 Dec 2024 21:46:14 -0800 Subject: [PATCH 02/10] [Docker] bump up neuron sdk v2.21 (#11593) Signed-off-by: Liangfu Chen --- Dockerfile.neuron | 6 +++--- requirements-neuron.txt | 4 ++-- vllm/_custom_ops.py | 3 +-- vllm/triton_utils/importing.py | 1 - 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/Dockerfile.neuron b/Dockerfile.neuron index 77162bc82de62..269139fe90f0b 100644 --- a/Dockerfile.neuron +++ b/Dockerfile.neuron @@ -1,6 +1,6 @@ # default base image # https://gallery.ecr.aws/neuron/pytorch-inference-neuronx -ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.1.2-neuronx-py310-sdk2.20.2-ubuntu20.04" +ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.5.1-neuronx-py310-sdk2.21.0-ubuntu22.04" FROM $BASE_IMAGE @@ -22,9 +22,9 @@ WORKDIR ${APP_MOUNT}/vllm RUN python3 -m pip install --upgrade pip RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas -RUN python3 -m pip install sentencepiece transformers==4.36.2 -U +RUN python3 -m pip install sentencepiece transformers==4.45.2 -U RUN python3 -m pip install transformers-neuronx --extra-index-url=https://pip.repos.neuron.amazonaws.com -U -RUN python3 -m pip install --pre neuronx-cc==2.15.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U +RUN python3 -m pip install neuronx-cc==2.16.345.0 --extra-index-url=https://pip.repos.neuron.amazonaws.com -U COPY . . ARG GIT_REPO_CHECK=0 diff --git a/requirements-neuron.txt b/requirements-neuron.txt index 148fdbe0d6310..5e08d101fcd61 100644 --- a/requirements-neuron.txt +++ b/requirements-neuron.txt @@ -2,6 +2,6 @@ -r requirements-common.txt # Dependencies for Neuron devices -transformers-neuronx >= 0.12.0 -torch-neuronx >= 2.1.2 +transformers-neuronx >= 0.13.0 +torch-neuronx >= 2.5.0 neuronx-cc diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index aeacf5dda5761..eb2f69df42624 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -23,8 +23,7 @@ import vllm._moe_C # noqa: F401 supports_moe_ops = True -# neuron has torch version that doesn't even have impl_abstract -if TYPE_CHECKING or current_platform.is_neuron(): +if TYPE_CHECKING: def register_fake(fn): return lambda name: fn diff --git a/vllm/triton_utils/importing.py b/vllm/triton_utils/importing.py index 36315abcdfcda..0c96e0632f646 100644 --- a/vllm/triton_utils/importing.py +++ b/vllm/triton_utils/importing.py @@ -8,7 +8,6 @@ HAS_TRITON = ( find_spec("triton") is not None and not current_platform.is_xpu() # Not compatible - and not current_platform.is_neuron() # neuron has too old torch ) if not HAS_TRITON: From 970d6d0776076f17604077ba4d484cdadd604ceb Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 30 Dec 2024 04:22:13 -0500 Subject: [PATCH 03/10] [Build][Kernel] Update CUTLASS to v3.6.0 (#11607) Signed-off-by: Tyler Michael Smith --- CMakeLists.txt | 4 ++-- .../vllm_cutlass_library_extension.py | 18 +++++++++--------- csrc/quantization/machete/generate.py | 8 ++++---- .../machete/machete_collective_builder.cuh | 10 ++++------ csrc/quantization/machete/machete_mainloop.cuh | 11 ++++------- .../machete/machete_prepacked_layout.cuh | 5 ++--- 6 files changed, 25 insertions(+), 31 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 83c8033434f3b..3206d76125545 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -223,13 +223,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") FetchContent_Declare( cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git - GIT_TAG 8aa95dbb888be6d81c6fbf7169718c5244b53227 + GIT_TAG v3.6.0 GIT_PROGRESS TRUE # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE - GIT_SHALLOW FALSE + GIT_SHALLOW TRUE ) endif() FetchContent_MakeAvailable(cutlass) diff --git a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py index a5beea1a35e49..b401736c9824b 100644 --- a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py +++ b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py @@ -14,9 +14,9 @@ class VLLMDataType(enum.Enum): class MixedInputKernelScheduleType(enum.Enum): - TmaWarpSpecializedMixedInput = enum_auto() - TmaWarpSpecializedPingpongMixedInput = enum_auto() - TmaWarpSpecializedCooperativeMixedInput = enum_auto() + TmaWarpSpecialized = enum_auto() + TmaWarpSpecializedPingpong = enum_auto() + TmaWarpSpecializedCooperative = enum_auto() VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = { @@ -68,11 +68,11 @@ class MixedInputKernelScheduleType(enum.Enum): MixedInputKernelScheduleType, KernelScheduleType], str] = { **KernelScheduleTag, # type: ignore **{ - MixedInputKernelScheduleType.TmaWarpSpecializedMixedInput: - "cutlass::gemm::KernelTmaWarpSpecializedMixedInput", - MixedInputKernelScheduleType.TmaWarpSpecializedPingpongMixedInput: - "cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput", - MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput: - "cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput", + MixedInputKernelScheduleType.TmaWarpSpecialized: + "cutlass::gemm::KernelTmaWarpSpecialized", + MixedInputKernelScheduleType.TmaWarpSpecializedPingpong: + "cutlass::gemm::KernelTmaWarpSpecializedPingpong", + MixedInputKernelScheduleType.TmaWarpSpecializedCooperative: + "cutlass::gemm::KernelTmaWarpSpecializedCooperative", } } diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index ac63afe79a255..2df4d181902f8 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -189,7 +189,7 @@ {{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT {{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT {{DataTypeTag[t.a_token_scale]}}, // TokenScaleT - cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, + cutlass::gemm::KernelTmaWarpSpecializedCooperative, Sch>; {% for sch in schs %} @@ -223,7 +223,7 @@ {{DataTypeTag[t.convert]}}, // ElementConvert {{DataTypeTag[t.accumulator]}}, // Accumulator cutlass::layout::ColumnMajor, - cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput> + cutlass::gemm::KernelTmaWarpSpecializedCooperative> >(args.B); } {%- endfor %} @@ -239,7 +239,7 @@ }; // namespace machete """ -TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput +TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperative TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative @@ -300,7 +300,7 @@ def generate_sch_sig(schedule_config: ScheduleConfig) -> str: # mostly unique shorter sch_sig def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str: kernel_terse_names_replace = { - "KernelTmaWarpSpecializedCooperativeMixedInput_": "TmaMI_", + "KernelTmaWarpSpecializedCooperative": "TmaMI_", "TmaWarpSpecializedCooperative_": "TmaCoop_", "StreamKScheduler": "streamK", } diff --git a/csrc/quantization/machete/machete_collective_builder.cuh b/csrc/quantization/machete/machete_collective_builder.cuh index a74cf8b2dd455..ee825583dee1a 100644 --- a/csrc/quantization/machete/machete_collective_builder.cuh +++ b/csrc/quantization/machete/machete_collective_builder.cuh @@ -18,16 +18,14 @@ struct VLLMCollectiveBuilder< ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType, KernelScheduleType, cute::enable_if_t<( + cute::is_same_v || + cute::is_same_v || cute::is_same_v || - cute::is_same_v || - cute::is_same_v)>> { + KernelTmaWarpSpecializedCooperative>)>> { using CollectiveOp = machete::MacheteCollectiveMma< ElementPairA_, GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_, AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType, KernelScheduleType>; }; -}; // namespace cutlass::gemm::collective \ No newline at end of file +}; // namespace cutlass::gemm::collective diff --git a/csrc/quantization/machete/machete_mainloop.cuh b/csrc/quantization/machete/machete_mainloop.cuh index 816f33a1078e5..4071b19a3564d 100644 --- a/csrc/quantization/machete/machete_mainloop.cuh +++ b/csrc/quantization/machete/machete_mainloop.cuh @@ -66,13 +66,11 @@ struct MacheteCollectiveMma { using Schedule = KernelScheduleType; static_assert( cute::is_same_v || - cute::is_same_v || + cute::is_same_v || + cute::is_same_v || cute::is_same_v || - cute::is_same_v || cute::is_same_v || - cute::is_same_v, + cute::is_same_v, "KernelSchedule must be one of the warp specialized policies"); public: @@ -113,8 +111,7 @@ struct MacheteCollectiveMma { // For coop schedules we have two warp groups cooperatively issuing wgmma // instructions so we use 2 atoms along the M dim (one for each warpgroup) using AtomLayoutMNK = cute::conditional_t< - cute::is_same_v, + cute::is_same_v, Layout>, Layout>>; using TiledMma = decltype(cute::make_tiled_mma( diff --git a/csrc/quantization/machete/machete_prepacked_layout.cuh b/csrc/quantization/machete/machete_prepacked_layout.cuh index 680a858a893c1..81aaa6c4f3a28 100644 --- a/csrc/quantization/machete/machete_prepacked_layout.cuh +++ b/csrc/quantization/machete/machete_prepacked_layout.cuh @@ -98,8 +98,7 @@ struct PrepackedLayoutBTemplate { // For coop schedules we have two warp groups cooperatively issuing wgmma // instructions so we use 2 atoms along the M dim (one for each warpgroup) using AtomLayoutMNK = cute::conditional_t< - cute::is_same_v, + cute::is_same_v, Layout>, Layout>>; using TiledMma = decltype(cute::make_tiled_mma( @@ -247,4 +246,4 @@ struct PrepackedLayoutBTemplate { } }; -}; // namespace machete \ No newline at end of file +}; // namespace machete From 5dbf854553cb6ac97f0c633ed36ba64e0fc9bb29 Mon Sep 17 00:00:00 2001 From: "Li, Jiang" Date: Mon, 30 Dec 2024 18:17:04 +0800 Subject: [PATCH 04/10] [CI/Build][CPU] Fix CPU CI by lazy importing triton FP8 kernels (#11618) Signed-off-by: jiang1.li --- vllm/model_executor/layers/quantization/fp8.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 7f779ac8d3b3e..2fe22903a385b 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -15,8 +15,6 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - apply_w8a8_block_fp8_linear) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -337,6 +335,9 @@ def apply(self, size_k=layer.input_size_per_partition, bias=bias) + # Note: lazy import to avoid triton import error. + from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + apply_w8a8_block_fp8_linear) if self.block_quant: assert self.quant_config.weight_block_size is not None return apply_w8a8_block_fp8_linear( From b12e87f942eb7740c17ab546b964bc327afdda37 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 30 Dec 2024 20:24:45 +0800 Subject: [PATCH 05/10] [platforms] enable platform plugins (#11602) Signed-off-by: youkaichao --- .buildkite/test-pipeline.yaml | 25 +- docs/source/design/plugin_system.md | 6 +- tests/conftest.py | 2 +- tests/kernels/test_attention_selector.py | 16 +- .../plugins/vllm_add_dummy_platform/setup.py | 11 + .../vllm_add_dummy_platform/__init__.py | 5 + .../vllm_add_dummy_platform/dummy_platform.py | 5 + tests/plugins_tests/test_platform_plugins.py | 16 + vllm/config.py | 15 +- vllm/distributed/parallel_state.py | 3 +- vllm/engine/arg_utils.py | 2 +- vllm/executor/ray_utils.py | 2 +- .../guided_decoding/__init__.py | 3 +- vllm/model_executor/models/registry.py | 2 +- vllm/model_executor/utils.py | 4 +- vllm/platforms/__init__.py | 320 ++++++++++++------ vllm/plugins/__init__.py | 72 ++-- vllm/spec_decode/metrics.py | 2 +- vllm/usage/usage_lib.py | 2 +- vllm/utils.py | 8 +- vllm/worker/model_runner_base.py | 5 +- vllm/worker/multi_step_model_runner.py | 1 + vllm/worker/worker_base.py | 14 +- 23 files changed, 360 insertions(+), 181 deletions(-) create mode 100644 tests/plugins/vllm_add_dummy_platform/setup.py create mode 100644 tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py create mode 100644 tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py create mode 100644 tests/plugins_tests/test_platform_plugins.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index b563c96343f92..bee968b4d2e43 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -106,14 +106,12 @@ steps: source_file_dependencies: - vllm/ commands: - - pip install -e ./plugins/vllm_add_dummy_model - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py - pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process - pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process - pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py - - pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process - pytest -v -s entrypoints/test_chat_utils.py - pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests @@ -333,8 +331,6 @@ steps: - vllm/ - tests/models commands: - - pip install -e ./plugins/vllm_add_dummy_model - - pytest -v -s models/test_oot_registration.py # it needs a clean process - pytest -v -s models/test_registry.py - pytest -v -s models/test_initialization.py @@ -469,11 +465,28 @@ steps: - pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)' - pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)' - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py - - pip install -e ./plugins/vllm_add_dummy_model - - pytest -v -s distributed/test_distributed_oot.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/disagg_test.py +- label: Plugin Tests (2 GPUs) # 40min + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + fast_check: true + source_file_dependencies: + - vllm/plugins/ + - tests/plugins/ + commands: + # begin platform plugin tests, all the code in-between runs on dummy platform + - pip install -e ./plugins/vllm_add_dummy_platform + - pytest -v -s plugins_tests/test_platform_plugins.py + - pip uninstall vllm_add_dummy_platform -y + # end platform plugin tests + # other tests continue here: + - pip install -e ./plugins/vllm_add_dummy_model + - pytest -v -s distributed/test_distributed_oot.py + - pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process + - pytest -v -s models/test_oot_registration.py # it needs a clean process + - label: Multi-step Tests (4 GPUs) # 36min working_dir: "/vllm-workspace/tests" num_gpus: 4 diff --git a/docs/source/design/plugin_system.md b/docs/source/design/plugin_system.md index 79aff757518f2..225030885f629 100644 --- a/docs/source/design/plugin_system.md +++ b/docs/source/design/plugin_system.md @@ -41,9 +41,11 @@ Every plugin has three parts: 2. **Plugin name**: The name of the plugin. This is the value in the dictionary of the `entry_points` dictionary. In the example above, the plugin name is `register_dummy_model`. Plugins can be filtered by their names using the `VLLM_PLUGINS` environment variable. To load only a specific plugin, set `VLLM_PLUGINS` to the plugin name. 3. **Plugin value**: The fully qualified name of the function to register in the plugin system. In the example above, the plugin value is `vllm_add_dummy_model:register`, which refers to a function named `register` in the `vllm_add_dummy_model` module. -## What Can Plugins Do? +## Types of supported plugins -Currently, the primary use case for plugins is to register custom, out-of-the-tree models into vLLM. This is done by calling `ModelRegistry.register_model` to register the model. In the future, the plugin system may be extended to support more features, such as swapping in custom implementations for certain classes in vLLM. +- **General plugins** (with group name `vllm.general_plugins`): The primary use case for these plugins is to register custom, out-of-the-tree models into vLLM. This is done by calling `ModelRegistry.register_model` to register the model inside the plugin function. + +- **Platform plugins** (with group name `vllm.platform_plugins`): The primary use case for these plugins is to register custom, out-of-the-tree platforms into vLLM. The plugin function should return `None` when the platform is not supported in the current environment, or the platform class's fully qualified name when the platform is supported. ## Guidelines for Writing Plugins diff --git a/tests/conftest.py b/tests/conftest.py index 4e939221329cd..6e2f75e33654f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -31,7 +31,6 @@ to_enc_dec_tuple_list, zip_enc_dec_prompts) from vllm.logger import init_logger from vllm.outputs import RequestOutput -from vllm.platforms import current_platform from vllm.sampling_params import BeamSearchParams from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless, identity) @@ -242,6 +241,7 @@ def video_assets() -> _VideoAssets: class HfRunner: def wrap_device(self, x: _T, device: Optional[str] = None) -> _T: + from vllm.platforms import current_platform if x is None or isinstance(x, (bool, )): return x diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index d37f95d48d5b2..916cc2efa3895 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -5,7 +5,10 @@ from tests.kernels.utils import override_backend_env_variable from vllm.attention.selector import which_attn_to_use -from vllm.platforms import cpu, cuda, openvino, rocm +from vllm.platforms.cpu import CpuPlatform +from vllm.platforms.cuda import CudaPlatform +from vllm.platforms.openvino import OpenVinoPlatform +from vllm.platforms.rocm import RocmPlatform from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL @@ -20,26 +23,23 @@ def test_env(name: str, device: str, monkeypatch): override_backend_env_variable(monkeypatch, name) if device == "cpu": - with patch("vllm.attention.selector.current_platform", - cpu.CpuPlatform()): + with patch("vllm.attention.selector.current_platform", CpuPlatform()): backend = which_attn_to_use(16, torch.float16, torch.float16, 16, False) assert backend.name == "TORCH_SDPA" elif device == "hip": - with patch("vllm.attention.selector.current_platform", - rocm.RocmPlatform()): + with patch("vllm.attention.selector.current_platform", RocmPlatform()): backend = which_attn_to_use(16, torch.float16, torch.float16, 16, False) assert backend.name == "ROCM_FLASH" elif device == "openvino": with patch("vllm.attention.selector.current_platform", - openvino.OpenVinoPlatform()): + OpenVinoPlatform()): backend = which_attn_to_use(16, torch.float16, torch.float16, 16, False) assert backend.name == "OPENVINO" else: - with patch("vllm.attention.selector.current_platform", - cuda.CudaPlatform()): + with patch("vllm.attention.selector.current_platform", CudaPlatform()): backend = which_attn_to_use(16, torch.float16, torch.float16, 16, False) assert backend.name == name diff --git a/tests/plugins/vllm_add_dummy_platform/setup.py b/tests/plugins/vllm_add_dummy_platform/setup.py new file mode 100644 index 0000000000000..31639906898db --- /dev/null +++ b/tests/plugins/vllm_add_dummy_platform/setup.py @@ -0,0 +1,11 @@ +from setuptools import setup + +setup( + name='vllm_add_dummy_platform', + version='0.1', + packages=['vllm_add_dummy_platform'], + entry_points={ + 'vllm.platform_plugins': [ + "dummy_platform_plugin = vllm_add_dummy_platform:dummy_platform_plugin" # noqa + ] + }) diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py new file mode 100644 index 0000000000000..594cef520a7de --- /dev/null +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py @@ -0,0 +1,5 @@ +from typing import Optional + + +def dummy_platform_plugin() -> Optional[str]: + return "vllm_add_dummy_platform.dummy_platform.DummyPlatform" diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py new file mode 100644 index 0000000000000..fde93142f1103 --- /dev/null +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py @@ -0,0 +1,5 @@ +from vllm.platforms.cuda import CudaPlatform + + +class DummyPlatform(CudaPlatform): + device_name = "DummyDevice" diff --git a/tests/plugins_tests/test_platform_plugins.py b/tests/plugins_tests/test_platform_plugins.py new file mode 100644 index 0000000000000..0d27cf9f152e0 --- /dev/null +++ b/tests/plugins_tests/test_platform_plugins.py @@ -0,0 +1,16 @@ +def test_platform_plugins(): + # simulate workload by running an example + import runpy + current_file = __file__ + import os + example_file = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(current_file))), + "examples", "offline_inference.py") + runpy.run_path(example_file) + + # check if the plugin is loaded correctly + from vllm.platforms import _init_trace, current_platform + assert current_platform.device_name == "DummyDevice", ( + f"Expected DummyDevice, got {current_platform.device_name}, " + "possibly because current_platform is imported before the plugin" + f" is loaded. The first import:\n{_init_trace}") diff --git a/vllm/config.py b/vllm/config.py index 765a46e6aeee3..e72c53b6130d0 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -22,7 +22,7 @@ from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, get_quantization_config) from vllm.model_executor.models import ModelRegistry -from vllm.platforms import current_platform, interface +from vllm.platforms import CpuArchEnum from vllm.tracing import is_otel_available, otel_import_error_traceback from vllm.transformers_utils.config import ( ConfigFormat, get_config, get_hf_image_processor_config, @@ -349,6 +349,7 @@ def __init__(self, self.is_hybrid = self._init_is_hybrid() self.has_inner_state = self._init_has_inner_state() + from vllm.platforms import current_platform if current_platform.is_neuron(): self.override_neuron_config = override_neuron_config else: @@ -589,6 +590,7 @@ def _verify_quantization(self) -> None: raise ValueError( f"Unknown quantization method: {self.quantization}. Must " f"be one of {supported_quantization}.") + from vllm.platforms import current_platform current_platform.verify_quantization(self.quantization) if self.quantization not in optimized_quantization_methods: logger.warning( @@ -644,6 +646,7 @@ def verify_async_output_proc(self, parallel_config, speculative_config, # Reminder: Please update docs/source/usage/compatibility_matrix.md # If the feature combo become valid + from vllm.platforms import current_platform if not current_platform.is_async_output_supported(self.enforce_eager): logger.warning( "Async output processing is not supported on the " @@ -1012,6 +1015,7 @@ def _verify_args(self) -> None: raise ValueError( "GPU memory utilization must be less than 1.0. Got " f"{self.gpu_memory_utilization}.") + from vllm.platforms import current_platform if (current_platform.is_cuda() and self.block_size is not None and self.block_size > 32): raise ValueError("CUDA Paged Attention kernel only supports " @@ -1279,6 +1283,7 @@ def __post_init__(self) -> None: f"distributed executor backend " f"'{self.distributed_executor_backend}'.") ray_only_devices = ["tpu", "hpu"] + from vllm.platforms import current_platform if (current_platform.device_type in ray_only_devices and self.world_size > 1): if self.distributed_executor_backend is None: @@ -1327,7 +1332,7 @@ def use_ray(self) -> bool: def _verify_args(self) -> None: # Lazy import to avoid circular import from vllm.executor.executor_base import ExecutorBase - + from vllm.platforms import current_platform if self.distributed_executor_backend not in ( "ray", "mp", None) and not (isinstance( self.distributed_executor_backend, type) and issubclass( @@ -1528,6 +1533,7 @@ def compute_hash(self) -> str: def __init__(self, device: str = "auto") -> None: if device == "auto": # Automated device type detection + from vllm.platforms import current_platform self.device_type = current_platform.device_type if not self.device_type: raise RuntimeError("Failed to infer device type") @@ -2241,9 +2247,10 @@ def _get_and_verify_dtype( else: torch_dtype = config_dtype + from vllm.platforms import current_platform if (current_platform.is_cpu() and current_platform.get_cpu_architecture() - == interface.CpuArchEnum.POWERPC + == CpuArchEnum.POWERPC and (config_dtype == torch.float16 or config_dtype == torch.float32)): logger.info( @@ -3083,6 +3090,7 @@ def _get_quantization_config( model_config: ModelConfig, load_config: LoadConfig) -> Optional[QuantizationConfig]: """Get the quantization config.""" + from vllm.platforms import current_platform if model_config.quantization is not None: from vllm.model_executor.model_loader.weight_utils import ( get_quant_config) @@ -3145,6 +3153,7 @@ def __post_init__(self): self.quant_config = VllmConfig._get_quantization_config( self.model_config, self.load_config) + from vllm.platforms import current_platform if self.scheduler_config is not None and \ self.model_config is not None and \ self.scheduler_config.chunked_prefill_enabled and \ diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 5b9236f8c56b6..e6768467f4c27 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -39,7 +39,6 @@ import vllm.envs as envs from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger -from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op, supports_custom_op if TYPE_CHECKING: @@ -194,6 +193,7 @@ def __init__( assert self.cpu_group is not None assert self.device_group is not None + from vllm.platforms import current_platform if current_platform.is_cuda_alike(): self.device = torch.device(f"cuda:{local_rank}") else: @@ -1188,6 +1188,7 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False): import ray # Lazy import Ray ray.shutdown() gc.collect() + from vllm.platforms import current_platform if not current_platform.is_cpu(): torch.cuda.empty_cache() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 21966d003c7ef..69c7c5077fe32 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -18,7 +18,6 @@ from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS -from vllm.platforms import current_platform from vllm.transformers_utils.utils import check_gguf_file from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser, StoreBoolean @@ -1094,6 +1093,7 @@ def create_engine_config(self, use_sliding_window = (model_config.get_sliding_window() is not None) use_spec_decode = self.speculative_model is not None + from vllm.platforms import current_platform if (is_gpu and not use_sliding_window and not use_spec_decode and not self.enable_lora and not self.enable_prompt_adapter diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 426aa1b5c728f..8d766bad1a072 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -8,7 +8,6 @@ from vllm.config import ParallelConfig from vllm.executor.msgspec_utils import decode_hook, encode_hook from vllm.logger import init_logger -from vllm.platforms import current_platform from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.utils import get_ip from vllm.worker.worker_base import WorkerWrapperBase @@ -229,6 +228,7 @@ def initialize_ray_cluster( the default Ray cluster address. """ assert_ray_available() + from vllm.platforms import current_platform # Connect to a ray cluster. if current_platform.is_rocm() or current_platform.is_xpu(): diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 694c5b68b1cbd..18b435a42544a 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -6,7 +6,7 @@ from vllm.model_executor.guided_decoding.utils import ( convert_lark_to_gbnf, grammar_is_likely_lark, has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features) -from vllm.platforms import CpuArchEnum, current_platform +from vllm.platforms import CpuArchEnum if TYPE_CHECKING: from transformers import PreTrainedTokenizer @@ -39,6 +39,7 @@ def maybe_backend_fallback( if guided_params.backend == "xgrammar": # xgrammar only has x86 wheels for linux, fallback to outlines + from vllm.platforms import current_platform if current_platform.get_cpu_architecture() is not CpuArchEnum.X86: logger.warning("xgrammar is only supported on x86 CPUs. " "Falling back to use outlines instead.") diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 67268eb4bb85f..07f4b5a3b3bc8 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -18,7 +18,6 @@ import torch.nn as nn from vllm.logger import init_logger -from vllm.platforms import current_platform from .interfaces import (has_inner_state, is_attention_free, is_hybrid, supports_cross_encoding, supports_multimodal, @@ -273,6 +272,7 @@ def _try_load_model_cls( model_arch: str, model: _BaseRegisteredModel, ) -> Optional[Type[nn.Module]]: + from vllm.platforms import current_platform current_platform.verify_model_arch(model_arch) try: return model.load_model_cls() diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index 39ead08c238ce..6f1cc9d5e0c30 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -3,10 +3,9 @@ import torch -from vllm.platforms import current_platform - def set_random_seed(seed: int) -> None: + from vllm.platforms import current_platform current_platform.seed_everything(seed) @@ -38,6 +37,7 @@ def set_weight_attrs( # This sometimes causes OOM errors during model loading. To avoid this, # we sync the param tensor after its weight loader is called. # TODO(woosuk): Remove this hack once we have a better solution. + from vllm.platforms import current_platform if current_platform.is_tpu() and key == "weight_loader": value = _make_synced_weight_loader(value) setattr(weight, key, value) diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 419237c252ffd..f6ac14446c021 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -1,123 +1,223 @@ +import logging +import traceback +from itertools import chain +from typing import TYPE_CHECKING, Optional + +from vllm.plugins import load_plugins_by_group +from vllm.utils import resolve_obj_by_qualname + from .interface import _Backend # noqa: F401 -from .interface import CpuArchEnum, Platform, PlatformEnum, UnspecifiedPlatform +from .interface import CpuArchEnum, Platform, PlatformEnum -current_platform: Platform +logger = logging.getLogger(__name__) -# NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because -# they only indicate the build configuration, not the runtime environment. -# For example, people can install a cuda build of pytorch but run on tpu. -is_tpu = False -try: - # While it's technically possible to install libtpu on a non-TPU machine, - # this is a very uncommon scenario. Therefore, we assume that libtpu is - # installed if and only if the machine has TPUs. - import libtpu # noqa: F401 - is_tpu = True -except Exception: - pass +def tpu_platform_plugin() -> Optional[str]: + is_tpu = False + try: + # While it's technically possible to install libtpu on a + # non-TPU machine, this is a very uncommon scenario. Therefore, + # we assume that libtpu is installed if and only if the machine + # has TPUs. + import libtpu # noqa: F401 + is_tpu = True + except Exception: + pass + + return "vllm.platforms.tpu.TpuPlatform" if is_tpu else None -is_cuda = False -try: - import pynvml - pynvml.nvmlInit() +def cuda_platform_plugin() -> Optional[str]: + is_cuda = False + try: - if pynvml.nvmlDeviceGetCount() > 0: + import pynvml + pynvml.nvmlInit() + try: + if pynvml.nvmlDeviceGetCount() > 0: + is_cuda = True + finally: + pynvml.nvmlShutdown() + except Exception: + # CUDA is supported on Jetson, but NVML may not be. + import os + + def cuda_is_jetson() -> bool: + return os.path.isfile("/etc/nv_tegra_release") \ + or os.path.exists("/sys/class/tegra-firmware") + + if cuda_is_jetson(): is_cuda = True - finally: - pynvml.nvmlShutdown() -except Exception: - # CUDA is supported on Jetson, but NVML may not be. - import os - def cuda_is_jetson() -> bool: - return os.path.isfile("/etc/nv_tegra_release") \ - or os.path.exists("/sys/class/tegra-firmware") + return "vllm.platforms.cuda.CudaPlatform" if is_cuda else None + + +def rocm_platform_plugin() -> Optional[str]: + is_rocm = False + + try: + import amdsmi + amdsmi.amdsmi_init() + try: + if len(amdsmi.amdsmi_get_processor_handles()) > 0: + is_rocm = True + finally: + amdsmi.amdsmi_shut_down() + except Exception: + pass + + return "vllm.platforms.rocm.RocmPlatform" if is_rocm else None + + +def hpu_platform_plugin() -> Optional[str]: + is_hpu = False + try: + from importlib import util + is_hpu = util.find_spec('habana_frameworks') is not None + except Exception: + pass + + return "vllm.platforms.hpu.HpuPlatform" if is_hpu else None + + +def xpu_platform_plugin() -> Optional[str]: + is_xpu = False + + try: + # installed IPEX if the machine has XPUs. + import intel_extension_for_pytorch # noqa: F401 + import oneccl_bindings_for_pytorch # noqa: F401 + import torch + if hasattr(torch, 'xpu') and torch.xpu.is_available(): + is_xpu = True + except Exception: + pass + + return "vllm.platforms.xpu.XPUPlatform" if is_xpu else None + + +def cpu_platform_plugin() -> Optional[str]: + is_cpu = False + try: + from importlib.metadata import version + is_cpu = "cpu" in version("vllm") + except Exception: + pass + + return "vllm.platforms.cpu.CpuPlatform" if is_cpu else None + + +def neuron_platform_plugin() -> Optional[str]: + is_neuron = False + try: + import transformers_neuronx # noqa: F401 + is_neuron = True + except ImportError: + pass - if cuda_is_jetson(): - is_cuda = True + return "vllm.platforms.neuron.NeuronPlatform" if is_neuron else None -is_rocm = False -try: - import amdsmi - amdsmi.amdsmi_init() +def openvino_platform_plugin() -> Optional[str]: + is_openvino = False try: - if len(amdsmi.amdsmi_get_processor_handles()) > 0: - is_rocm = True - finally: - amdsmi.amdsmi_shut_down() -except Exception: - pass - -is_hpu = False -try: - from importlib import util - is_hpu = util.find_spec('habana_frameworks') is not None -except Exception: - pass - -is_xpu = False - -try: - # installed IPEX if the machine has XPUs. - import intel_extension_for_pytorch # noqa: F401 - import oneccl_bindings_for_pytorch # noqa: F401 - import torch - if hasattr(torch, 'xpu') and torch.xpu.is_available(): - is_xpu = True -except Exception: - pass - -is_cpu = False -try: - from importlib.metadata import version - is_cpu = "cpu" in version("vllm") -except Exception: - pass - -is_neuron = False -try: - import transformers_neuronx # noqa: F401 - is_neuron = True -except ImportError: - pass - -is_openvino = False -try: - from importlib.metadata import version - is_openvino = "openvino" in version("vllm") -except Exception: - pass - -if is_tpu: - # people might install pytorch built with cuda but run on tpu - # so we need to check tpu first - from .tpu import TpuPlatform - current_platform = TpuPlatform() -elif is_cuda: - from .cuda import CudaPlatform - current_platform = CudaPlatform() -elif is_rocm: - from .rocm import RocmPlatform - current_platform = RocmPlatform() -elif is_hpu: - from .hpu import HpuPlatform - current_platform = HpuPlatform() -elif is_xpu: - from .xpu import XPUPlatform - current_platform = XPUPlatform() -elif is_cpu: - from .cpu import CpuPlatform - current_platform = CpuPlatform() -elif is_neuron: - from .neuron import NeuronPlatform - current_platform = NeuronPlatform() -elif is_openvino: - from .openvino import OpenVinoPlatform - current_platform = OpenVinoPlatform() -else: - current_platform = UnspecifiedPlatform() - -__all__ = ['Platform', 'PlatformEnum', 'current_platform', 'CpuArchEnum'] + from importlib.metadata import version + is_openvino = "openvino" in version("vllm") + except Exception: + pass + + return "vllm.platforms.openvino.OpenVinoPlatform" if is_openvino else None + + +builtin_platform_plugins = { + 'tpu': tpu_platform_plugin, + 'cuda': cuda_platform_plugin, + 'rocm': rocm_platform_plugin, + 'hpu': hpu_platform_plugin, + 'xpu': xpu_platform_plugin, + 'cpu': cpu_platform_plugin, + 'neuron': neuron_platform_plugin, + 'openvino': openvino_platform_plugin, +} + + +def resolve_current_platform_cls_qualname() -> str: + platform_plugins = load_plugins_by_group('vllm.platform_plugins') + + activated_plugins = [] + + for name, func in chain(builtin_platform_plugins.items(), + platform_plugins.items()): + try: + assert callable(func) + platform_cls_qualname = func() + if platform_cls_qualname is not None: + activated_plugins.append(name) + except Exception: + pass + + activated_builtin_plugins = list( + set(activated_plugins) & set(builtin_platform_plugins.keys())) + activated_oot_plugins = list( + set(activated_plugins) & set(platform_plugins.keys())) + + if len(activated_oot_plugins) >= 2: + raise RuntimeError( + "Only one platform plugin can be activated, but got: " + f"{activated_oot_plugins}") + elif len(activated_oot_plugins) == 1: + platform_cls_qualname = platform_plugins[activated_oot_plugins[0]]() + logger.info("Platform plugin %s is activated", + activated_oot_plugins[0]) + elif len(activated_builtin_plugins) >= 2: + raise RuntimeError( + "Only one platform plugin can be activated, but got: " + f"{activated_builtin_plugins}") + elif len(activated_builtin_plugins) == 1: + platform_cls_qualname = builtin_platform_plugins[ + activated_builtin_plugins[0]]() + logger.info("Automatically detected platform %s.", + activated_builtin_plugins[0]) + else: + platform_cls_qualname = "vllm.interface.UnspecifiedPlatform" + logger.info( + "No platform detected, vLLM is running on UnspecifiedPlatform") + return platform_cls_qualname + + +_current_platform = None +_init_trace: str = '' + +if TYPE_CHECKING: + current_platform: Platform + + +def __getattr__(name: str): + if name == 'current_platform': + # lazy init current_platform. + # 1. out-of-tree platform plugins need `from vllm.platforms import + # Platform` so that they can inherit `Platform` class. Therefore, + # we cannot resolve `current_platform` during the import of + # `vllm.platforms`. + # 2. when users use out-of-tree platform plugins, they might run + # `import vllm`, some vllm internal code might access + # `current_platform` during the import, and we need to make sure + # `current_platform` is only resolved after the plugins are loaded + # (we have tests for this, if any developer violate this, they will + # see the test failures). + global _current_platform + if _current_platform is None: + platform_cls_qualname = resolve_current_platform_cls_qualname() + _current_platform = resolve_obj_by_qualname( + platform_cls_qualname)() + global _init_trace + _init_trace = "".join(traceback.format_stack()) + return _current_platform + else: + return globals()[name] + + +__all__ = [ + 'Platform', 'PlatformEnum', 'current_platform', 'CpuArchEnum', + "_init_trace" +] diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 17f604ea0e202..c50eb2cef4cd5 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -1,10 +1,10 @@ import logging import os +from typing import Callable, Dict import torch import vllm.envs as envs -from vllm.platforms import current_platform logger = logging.getLogger(__name__) @@ -12,6 +12,39 @@ plugins_loaded = False +def load_plugins_by_group(group: str) -> Dict[str, Callable]: + import sys + if sys.version_info < (3, 10): + from importlib_metadata import entry_points + else: + from importlib.metadata import entry_points + + allowed_plugins = envs.VLLM_PLUGINS + + discovered_plugins = entry_points(group=group) + if len(discovered_plugins) == 0: + logger.debug("No plugins for group %s found.", group) + return {} + logger.info("Available plugins for group %s:", group) + for plugin in discovered_plugins: + logger.info("name=%s, value=%s", plugin.name, plugin.value) + if allowed_plugins is None: + logger.info("all available plugins for group %s will be loaded.", + group) + logger.info("set environment variable VLLM_PLUGINS to control" + " which plugins to load.") + plugins = {} + for plugin in discovered_plugins: + if allowed_plugins is None or plugin.name in allowed_plugins: + try: + func = plugin.load() + plugins[plugin.name] = func + logger.info("plugin %s loaded.", plugin.name) + except Exception: + logger.exception("Failed to load plugin %s", plugin.name) + return plugins + + def load_general_plugins(): """WARNING: plugins can be loaded for multiple times in different processes. They should be designed in a way that they can be loaded @@ -26,6 +59,9 @@ def load_general_plugins(): os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1' # see https://github.com/vllm-project/vllm/issues/10619 torch._inductor.config.compile_threads = 1 + + from vllm.platforms import current_platform + if current_platform.is_xpu(): # see https://github.com/pytorch/pytorch/blob/8cada5cbe5450e17c26fb8b358116785324537b2/torch/_dynamo/config.py#L158 # noqa os.environ['TORCH_COMPILE_DISABLE'] = 'True' @@ -47,33 +83,7 @@ def load_general_plugins(): if plugins_loaded: return plugins_loaded = True - import sys - if sys.version_info < (3, 10): - from importlib_metadata import entry_points - else: - from importlib.metadata import entry_points - - allowed_plugins = envs.VLLM_PLUGINS - - discovered_plugins = entry_points(group='vllm.general_plugins') - if len(discovered_plugins) == 0: - logger.debug("No plugins found.") - return - logger.info("Available plugins:") - for plugin in discovered_plugins: - logger.info("name=%s, value=%s, group=%s", plugin.name, plugin.value, - plugin.group) - if allowed_plugins is None: - logger.info("all available plugins will be loaded.") - logger.info("set environment variable VLLM_PLUGINS to control" - " which plugins to load.") - else: - logger.info("plugins to load: %s", allowed_plugins) - for plugin in discovered_plugins: - if allowed_plugins is None or plugin.name in allowed_plugins: - try: - func = plugin.load() - func() - logger.info("plugin %s loaded.", plugin.name) - except Exception: - logger.exception("Failed to load plugin %s", plugin.name) + plugins = load_plugins_by_group(group='vllm.general_plugins') + # general plugins, we only need to execute the loaded functions + for func in plugins.values(): + func() diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py index 03dc46600d8a9..d678f4578499b 100644 --- a/vllm/spec_decode/metrics.py +++ b/vllm/spec_decode/metrics.py @@ -6,7 +6,6 @@ from vllm.model_executor.layers.spec_decode_base_sampler import ( SpecDecodeBaseSampler) -from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available @@ -94,6 +93,7 @@ def init_tensors(self, def maybe_collect_rejsample_metrics( self, k: int) -> Optional[SpecDecodeWorkerMetrics]: # currently using cuda.Event, skip for any non_cuda_alike platform + from vllm.platforms import current_platform if not current_platform.is_cuda_alike(): return None diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index 9ae46ff43a916..a9deee881f41a 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -17,7 +17,6 @@ import vllm.envs as envs from vllm.connections import global_http_connection -from vllm.platforms import current_platform from vllm.version import __version__ as VLLM_VERSION _config_home = envs.VLLM_CONFIG_ROOT @@ -152,6 +151,7 @@ def _report_usage_once(self, model_architecture: str, usage_context: UsageContext, extra_kvs: Dict[str, Any]) -> None: # Platform information + from vllm.platforms import current_platform if current_platform.is_cuda_alike(): device_property = torch.cuda.get_device_properties(0) self.gpu_count = torch.cuda.device_count() diff --git a/vllm/utils.py b/vllm/utils.py index 2b46c1fef0d09..8ef07d2c326a3 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -50,7 +50,6 @@ import vllm.envs as envs from vllm.logger import enable_trace_function_call, init_logger -from vllm.platforms import current_platform if TYPE_CHECKING: from vllm.config import VllmConfig @@ -609,6 +608,7 @@ def create_kv_caches_with_random_flash( seed: int = 0, device: Optional[str] = "cuda", ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + from vllm.platforms import current_platform current_platform.seed_everything(seed) torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) @@ -650,7 +650,7 @@ def create_kv_caches_with_random( raise ValueError( f"Does not support key cache of type fp8 with head_size {head_size}" ) - + from vllm.platforms import current_platform current_platform.seed_everything(seed) torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) @@ -703,6 +703,7 @@ def print_warning_once(msg: str) -> None: @lru_cache(maxsize=None) def is_pin_memory_available() -> bool: + from vllm.platforms import current_platform return current_platform.is_pin_memory_available() @@ -713,6 +714,7 @@ def __init__(self, device: Optional[torch.types.Device] = None): def current_memory_usage(self) -> float: # Return the memory usage in bytes. + from vllm.platforms import current_platform if current_platform.is_cuda_alike(): torch.cuda.reset_peak_memory_stats(self.device) mem = torch.cuda.max_memory_allocated(self.device) @@ -1066,6 +1068,7 @@ def _cuda_device_count_stateless( import torch.cuda import torch.version + from vllm.platforms import current_platform if not torch.cuda._is_compiled(): return 0 if current_platform.is_rocm(): @@ -1673,6 +1676,7 @@ def direct_register_custom_op( return if not supports_custom_op(): + from vllm.platforms import current_platform assert not current_platform.is_cuda_alike(), ( "cuda platform needs torch>=2.4 to support custom op, " "chances are you are using an old version of pytorch " diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index cd4770202a186..c7abad7e0258d 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -12,7 +12,6 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors, SequenceGroupMetadata if TYPE_CHECKING: @@ -265,13 +264,13 @@ def prepare_model_input( """ raise NotImplementedError - @current_platform.inference_mode() def execute_model( self, model_input: T, kv_caches: Optional[List[torch.Tensor]], - intermediate_tensors: Optional[IntermediateTensors], + intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, + **kwargs, ) -> Optional[List[SamplerOutput]]: """ Execute the model on the given input. diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 65d9bab0e2822..dee63a75c0605 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -544,6 +544,7 @@ def execute_model( model_input.record_step_event(current_stream) if get_pp_group().is_last_rank and self.is_driver_worker: + assert isinstance(output, list) assert len( output ) == 1, "MultiStepModelRunner requires single-step base_models" diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 3ac7fb8dfb766..249b3ed2dfd37 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -11,7 +11,6 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.platforms import current_platform from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.utils import (enable_trace_function_call_for_thread, resolve_obj_by_qualname, update_environment_variables) @@ -44,6 +43,8 @@ def __init__( self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config self.kv_transfer_config = vllm_config.kv_transfer_config + from vllm.platforms import current_platform + self.current_platform = current_platform @abstractmethod def init_device(self) -> None: @@ -74,17 +75,17 @@ def initialize_cache(self, num_gpu_blocks: int, """ raise NotImplementedError - @current_platform.inference_mode() def start_worker_execution_loop(self) -> None: """Execute model loop in parallel worker. You can stop the loop by executing a driver worker with an empty output. See `stop_remote_worker_execution_loop` for more details. """ - while True: - output = self.execute_model(execute_model_req=None) - if output is None: - return None + with self.current_platform.inference_mode(): + while True: + output = self.execute_model(execute_model_req=None) + if output is None: + return None @abstractmethod def execute_model( @@ -352,6 +353,7 @@ def execute_model( model_execute_time = time.perf_counter() - start_time if not get_pp_group().is_last_rank: # output is IntermediateTensors + assert isinstance(output, IntermediateTensors) if (self.observability_config is not None and self.observability_config.collect_model_execute_time): output.tensors["model_execute_time"] = torch.tensor( From 8d9b6721e7f5b7d191951c6f1cd12710ffd08093 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 30 Dec 2024 23:01:35 +0800 Subject: [PATCH 06/10] [VLM] Abstract out multi-modal data parsing in merged processor (#11620) Signed-off-by: DarkLight1337 --- .buildkite/test-pipeline.yaml | 4 +- vllm/model_executor/models/chatglm.py | 4 +- vllm/model_executor/models/llava.py | 18 +- vllm/model_executor/models/phi3v.py | 19 +- vllm/model_executor/models/qwen2_audio.py | 22 +- vllm/model_executor/models/qwen2_vl.py | 153 +++++----- vllm/model_executor/models/ultravox.py | 22 +- vllm/multimodal/__init__.py | 9 +- vllm/multimodal/audio.py | 4 +- vllm/multimodal/base.py | 8 +- vllm/multimodal/image.py | 4 +- vllm/multimodal/inputs.py | 195 ++++-------- vllm/multimodal/parse.py | 344 ++++++++++++++++++++++ vllm/multimodal/processing.py | 62 ++-- vllm/multimodal/video.py | 4 +- 15 files changed, 560 insertions(+), 312 deletions(-) create mode 100644 vllm/multimodal/parse.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index bee968b4d2e43..c6f8316412e2f 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -356,7 +356,7 @@ steps: - pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model' - pytest -v -s models/embedding/language -m 'not core_model' -- label: Multi-Modal Models Test (Standard) # 28min +- label: Multi-Modal Models Test (Standard) # 40min #mirror_hardwares: [amd] source_file_dependencies: - vllm/ @@ -372,7 +372,7 @@ steps: - pytest -v -s models/encoder_decoder/language -m core_model - pytest -v -s models/encoder_decoder/vision_language -m core_model -- label: Multi-Modal Models Test (Extended) 1 # 1h16m +- label: Multi-Modal Models Test (Extended) 1 # 48m optional: true source_file_dependencies: - vllm/ diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 6c50882d83c3b..ffd6891b25965 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -33,7 +33,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalData, MultiModalKwargs, +from vllm.multimodal.inputs import (ModalityData, MultiModalKwargs, NestedTensors) from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, @@ -54,7 +54,7 @@ def calculate_image_placeholder(vision_config): def mm_input_mapper_for_glmv( ctx: InputContext, - data: MultiModalData[object], + data: ModalityData[object], ) -> Dict: model_config = ctx.model_config tokenizer = cached_get_tokenizer( diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 0ecba5a1cae0f..1d6ee2a0be72e 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -20,11 +20,13 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems, - MultiModalFieldConfig, MultiModalInputsV2, - MultiModalKwargs, NestedTensors) +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalInputsV2, MultiModalKwargs, + NestedTensors) +from vllm.multimodal.parse import ImageProcessorItems from vllm.multimodal.processing import (BaseMultiModalProcessor, - ProcessorInputs, PromptReplacement, + MultiModalDataItems, ProcessorInputs, + PromptReplacement, full_groupby_modality) from vllm.sequence import IntermediateTensors @@ -179,7 +181,9 @@ def _get_prompt_replacements( assert isinstance(vision_config, PixtralVisionConfig) def get_replacement_pixtral(item_idx: int): - image_size = mm_items.get_image_size(item_idx) + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + ( num_width_tokens, num_height_tokens, @@ -591,8 +595,8 @@ def apply( result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs) - mm_items = self._get_mm_items(mm_data) - mm_item_counts = mm_items.get_item_counts() + mm_items = self._to_mm_items(mm_data) + mm_item_counts = mm_items.get_all_counts() mm_kwargs = result["mm_kwargs"] # We reimplement the functionality of MLlavaProcessor from diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index fefa9fd62d1d0..15362db6cdfbf 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -32,12 +32,13 @@ from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems, - MultiModalFieldConfig, MultiModalInputsV2, - MultiModalKwargs, NestedTensors, - PlaceholderRange) +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalInputsV2, MultiModalKwargs, + NestedTensors, PlaceholderRange) +from vllm.multimodal.parse import ImageProcessorItems from vllm.multimodal.processing import (BaseMultiModalProcessor, - ProcessorInputs, PromptReplacement, + MultiModalDataItems, ProcessorInputs, + PromptReplacement, _BoundPromptReplacement, _PlaceholderInfo) from vllm.sequence import IntermediateTensors @@ -381,7 +382,9 @@ def _get_prompt_replacements( assert isinstance(bos_token_id, int) def get_replacement_phi3v(item_idx: int): - image_size = mm_items.get_image_size(item_idx) + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + num_tokens = image_processor.calc_num_image_tokens_from_image_size( width=image_size.width, height=image_size.height, @@ -389,12 +392,14 @@ def get_replacement_phi3v(item_idx: int): return [_IMAGE_TOKEN_ID] * num_tokens + [bos_token_id] + num_images = mm_items.get_count("image", strict=False) + return [ PromptReplacement( modality="image", target=image_token, replacement=get_replacement_phi3v, - ) for image_token in image_tokens[:len(mm_items.images)] + ) for image_token in image_tokens[:num_images] ] def _apply_prompt_replacements( diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 25a351bd9c656..e3d43b017f894 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -20,8 +20,8 @@ # limitations under the License. """Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" from functools import cached_property -from typing import (Any, Iterable, List, Mapping, Optional, Set, Tuple, - TypedDict, Union) +from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, + Union) import numpy as np import torch @@ -38,10 +38,12 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataItems, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) +from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, + NestedTensors) +from vllm.multimodal.parse import MultiModalDataParser from vllm.multimodal.processing import (BaseMultiModalProcessor, - ProcessorInputs, PromptReplacement) + MultiModalDataItems, ProcessorInputs, + PromptReplacement) from vllm.sequence import IntermediateTensors from .interfaces import SupportsMultiModal, SupportsPP @@ -99,15 +101,9 @@ def _get_hf_processor( def _get_feature_extractor(self) -> WhisperFeatureExtractor: return self._get_hf_processor().feature_extractor # type: ignore - def _get_hf_mm_data( - self, - mm_items: MultiModalDataItems, - ) -> tuple[dict[str, Any], dict[str, Any]]: - # resample audio to the model's sampling rate + def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self._get_feature_extractor() - mm_items.resample_audios(feature_extractor.sampling_rate) - - return super()._get_hf_mm_data(mm_items) + return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) def _call_hf_processor( self, diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 574845ef5a525..6181fe3dd13d8 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -25,7 +25,6 @@ from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional, Set, Tuple, Type, TypedDict, Union) -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -55,15 +54,16 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems, +from vllm.multimodal.inputs import (ImageItem, ModalityData, MultiModalFieldConfig, MultiModalKwargs, - NestedTensors) + NestedTensors, VideoItem) +from vllm.multimodal.parse import ModalityDataItems, MultiModalDataParser from vllm.multimodal.processing import (BaseMultiModalProcessor, - ProcessorInputs, PromptReplacement) + MultiModalDataItems, ProcessorInputs, + PromptReplacement) from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope -from vllm.utils import is_list_of from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend, @@ -719,61 +719,81 @@ def get_max_qwen2_vl_mm_tokens(ctx: InputContext, data_type_key="video") -class Qwen2VLMultiModalDataItems(MultiModalDataItems): +class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor], + dict[str, torch.Tensor]]): - @staticmethod - def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems": - """ - Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`. - """ - multi_data = Qwen2VLMultiModalDataItems() - - for k, v in data.items(): - # TODO: Make a separate modality for embedding inputs - # to avoid confusion - # yapf: disable - if k == "video": - # Special case since even a single item can be a list - multi_data[k] = ( # type: ignore[index] - v if ( - isinstance(v, (dict, torch.Tensor)) # type: ignore[assignment] - or is_list_of(v, list) - or isinstance(v[0], (np.ndarray, torch.Tensor)) - and v[0].ndim == 4 - ) else [v] - ) - elif k in ("image", "audio"): - multi_data[k] = ( # type: ignore[index] - v if isinstance(v, (dict, torch.Tensor, list)) else [v] - ) - else: - multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index] - # yapf: enable + def __init__(self, data: dict, modality: str) -> None: + super().__init__(data) - return multi_data + self.modality = modality - def get_item_counts(self) -> Mapping[str, int]: - return { - m: ( - len(items[f"{m}_grid_thw"]) # type: ignore - if isinstance(items, dict) else len(items)) - for m, items in self.items() - } + grid_thw = data[f"{modality}_grid_thw"] + slice_idxs = [0] + grid_thw.prod(-1).cumsum_(0).tolist() + self._slices = [ + slice(slice_idxs[i], slice_idxs[i + 1]) + for i in range(len(grid_thw)) + ] - def has_embedding_inputs(self) -> bool: - return any( - isinstance(items, dict) or any( - isinstance(item, torch.Tensor) for item in items) - for items in self.values()) + def __repr__(self) -> str: + return (f"{type(self).__name__}(modality={self.modality!r})") + def get_count(self) -> int: + return len(self.data[f"{self.modality}_grid_thw"]) -class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor): + def get(self, index: int) -> dict[str, torch.Tensor]: + out = {} + for k, v in self.data.items(): + if v != f"{self.modality}_grid_thw": + v = v[self._slices[index]] + + out[k] = v + + return out + + def get_processor_data(self) -> Mapping[str, object]: + return {} + + def get_passthrough_data(self) -> Mapping[str, object]: + return self.data + + +class Qwen2ImageEmbeddingItems(Qwen2EmbeddingItems): + + def __init__(self, data: dict) -> None: + super().__init__(data, "image") + + +class Qwen2VideoEmbeddingItems(Qwen2EmbeddingItems): - def _get_mm_items( + def __init__(self, data: dict) -> None: + super().__init__(data, "video") + + +class Qwen2MultiModalDataParser(MultiModalDataParser): + + def _parse_image_data( + self, + data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], + ) -> ModalityDataItems[Any, Any]: + if isinstance(data, dict): + return Qwen2EmbeddingItems(data, modality="image") + + return super()._parse_image_data(data) + + def _parse_video_data( self, - mm_data: MultiModalDataDict, - ) -> MultiModalDataItems: - return Qwen2VLMultiModalDataItems.from_dict(mm_data) + data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]], + ) -> ModalityDataItems[Any, Any]: + if isinstance(data, dict): + return Qwen2EmbeddingItems(data, modality="video") + + return super()._parse_video_data(data) + + +class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor): + + def _get_data_parser(self) -> MultiModalDataParser: + return Qwen2MultiModalDataParser() def _get_hf_processor( self, @@ -796,35 +816,6 @@ def _get_hf_processor( return hf_processor - def _get_hf_mm_data( - self, - mm_items: MultiModalDataItems, - ) -> tuple[dict[str, Any], dict[str, Any]]: - processor_data = dict[str, Any]() - passthrough_data = dict[str, Any]() - - for k, v in mm_items.items(): - # TODO: Make a separate modality for embedding inputs - # to avoid confusion - if k in ("image", "video", "audio"): - if isinstance(v, dict): - # Pass through embedding inputs (dict) - passthrough_data.update(v) - elif isinstance(v, torch.Tensor) and v.ndim == 3: - # Pass through embedding inputs (single) - passthrough_data[f"{k}_embeds"] = [v] - elif (is_list_of(v, torch.Tensor) and len(v) > 0 - and v[0].ndim == 2): - # Pass through embedding inputs (multi) - passthrough_data[f"{k}_embeds"] = v - elif len(v) > 0: - # Map keys to plural form, e.g.: image -> images - processor_data[f"{k}s"] = v - else: - processor_data[k] = v - - return processor_data, passthrough_data - def _get_prompt_replacements( self, mm_items: MultiModalDataItems, diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 7b4aeeec5f403..7e853e5b90096 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -3,8 +3,8 @@ import math from functools import cached_property, lru_cache -from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set, - Tuple, TypedDict, Union) +from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, + TypedDict, Union) import numpy as np import torch @@ -24,10 +24,12 @@ from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataItems, MultiModalFieldConfig, - MultiModalKwargs, NestedTensors) +from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, + NestedTensors) +from vllm.multimodal.parse import MultiModalDataParser from vllm.multimodal.processing import (BaseMultiModalProcessor, - ProcessorInputs, PromptReplacement) + MultiModalDataItems, ProcessorInputs, + PromptReplacement) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.ultravox import UltravoxConfig from vllm.utils import is_list_of @@ -85,15 +87,9 @@ def _get_feature_extractor(self) -> WhisperFeatureExtractor: hf_processor = self._get_hf_processor() return hf_processor.audio_processor.feature_extractor # type: ignore - def _get_hf_mm_data( - self, - mm_items: MultiModalDataItems, - ) -> tuple[dict[str, Any], dict[str, Any]]: - # resample audio to the model's sampling rate + def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self._get_feature_extractor() - mm_items.resample_audios(feature_extractor.sampling_rate) - - return super()._get_hf_mm_data(mm_items) + return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) def _call_hf_processor( self, diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index 9255e062e4870..e58bbe81717a0 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -1,8 +1,7 @@ from .base import MultiModalPlaceholderMap, MultiModalPlugin -from .inputs import (BatchedTensorInputs, MultiModalData, - MultiModalDataBuiltins, MultiModalDataDict, - MultiModalKwargs, MultiModalPlaceholderDict, - NestedTensors) +from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins, + MultiModalDataDict, MultiModalKwargs, + MultiModalPlaceholderDict, NestedTensors) from .registry import MultiModalRegistry MULTIMODAL_REGISTRY = MultiModalRegistry() @@ -16,7 +15,7 @@ __all__ = [ "BatchedTensorInputs", - "MultiModalData", + "ModalityData", "MultiModalDataBuiltins", "MultiModalDataDict", "MultiModalKwargs", diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py index 3e09ef1fcbb56..de80f22bac2a3 100644 --- a/vllm/multimodal/audio.py +++ b/vllm/multimodal/audio.py @@ -9,7 +9,7 @@ from vllm.utils import PlaceholderModule from .base import MediaIO, MultiModalPlugin -from .inputs import AudioItem, MultiModalData, MultiModalKwargs +from .inputs import AudioItem, ModalityData, MultiModalKwargs try: import librosa @@ -31,7 +31,7 @@ def get_data_key(self) -> str: def _default_input_mapper( self, ctx: InputContext, - data: MultiModalData[AudioItem], + data: ModalityData[AudioItem], **mm_processor_kwargs, ) -> MultiModalKwargs: raise NotImplementedError("There is no default audio input mapper") diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index cdda6f8052794..7f4029e726332 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -15,12 +15,12 @@ from vllm.config import ModelConfig from vllm.sequence import SequenceGroupMetadata -from .inputs import (MultiModalData, MultiModalDataDict, MultiModalKwargs, +from .inputs import (ModalityData, MultiModalDataDict, MultiModalKwargs, PlaceholderRange) logger = init_logger(__name__) -MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]], +MultiModalInputMapper = Callable[[InputContext, ModalityData[object]], MultiModalKwargs] """ Return a dictionary to be passed as keyword arguments to @@ -69,7 +69,7 @@ def get_data_key(self) -> str: def _default_input_mapper( self, ctx: InputContext, - data: MultiModalData[Any], + data: ModalityData[Any], **mm_processor_kwargs, ) -> MultiModalKwargs: """ @@ -118,7 +118,7 @@ def wrapper(model_cls: N) -> N: def map_input( self, model_config: "ModelConfig", - data: MultiModalData[Any], + data: ModalityData[Any], mm_processor_kwargs: Optional[dict[str, Any]], ) -> MultiModalKwargs: """ diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 14c79dfadec0c..da13a381c4530 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -13,7 +13,7 @@ from vllm.utils import is_list_of from .base import MediaIO, MultiModalPlugin -from .inputs import ImageItem, MultiModalData, MultiModalKwargs +from .inputs import ImageItem, ModalityData, MultiModalKwargs if TYPE_CHECKING: from vllm.config import ModelConfig @@ -44,7 +44,7 @@ def _get_hf_image_processor( def _default_input_mapper( self, ctx: InputContext, - data: MultiModalData[ImageItem], + data: ModalityData[ImageItem], **mm_processor_kwargs, ) -> MultiModalKwargs: model_config = ctx.model_config diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 1fbda6e0b8750..db489af7ac475 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -2,53 +2,74 @@ from collections import UserDict, defaultdict from collections.abc import Mapping, Sequence from dataclasses import dataclass -from typing import (Any, Literal, NamedTuple, TypedDict, TypeVar, Union, cast, - final) +from typing import Any, Literal, TypedDict, TypeVar, Union, cast, final import numpy as np import torch import torch.types from PIL.Image import Image from transformers import BatchFeature -from typing_extensions import NotRequired, TypeAlias, assert_never +from typing_extensions import NotRequired, TypeAlias from vllm.utils import JSONTree, is_list_of, json_map_leaves _T = TypeVar("_T") -# yapf: disable -ImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor] +HfImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor] """ A :class:`transformers.image_utils.ImageInput` representing a single image item, which can be passed to a HuggingFace :code:`ImageProcessor`. """ -VideoItem: TypeAlias = Union[ - list[Image], - np.ndarray, - torch.Tensor, - list[np.ndarray], - list[torch.Tensor], -] +HfVideoItem: TypeAlias = Union[list[Image], np.ndarray, torch.Tensor, + list[np.ndarray], list[torch.Tensor]] """ A :class:`transformers.image_utils.VideoInput` representing a single video item, which can be passed to a HuggingFace :code:`VideoProcessor`. """ -AudioItem: TypeAlias = Union[ - np.ndarray, - list[float], - # `(audio, sampling_rate)`: If the audio's sampling rate is different - # from that expected by the model, we need to resample it. - tuple[np.ndarray, float], -] +HfAudioItem: TypeAlias = Union[list[float], np.ndarray, torch.Tensor] """ Represents a single audio item, which can be passed to a HuggingFace :code:`AudioProcessor`. """ -# yapf: enable -MultiModalData: TypeAlias = Union[_T, list[_T]] +ImageItem: TypeAlias = Union[HfImageItem, torch.Tensor] +""" +A :class:`transformers.image_utils.ImageInput` representing a single image +item, which can be passed to a HuggingFace :code:`ImageProcessor`. + +Alternatively, a 3-D tensor or batch of 2-D tensors, +which are treated as image embeddings; +these are directly passed to the model without HF processing. +""" + +VideoItem: TypeAlias = Union[HfVideoItem, torch.Tensor] +""" +A :class:`transformers.image_utils.VideoInput` representing a single video +item, which can be passed to a HuggingFace :code:`VideoProcessor`. + +Alternatively, a 3-D tensor or batch of 2-D tensors, +which are treated as video embeddings; +these are directly passed to the model without HF processing. +""" + +AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float], + torch.Tensor] +""" +Represents a single audio +item, which can be passed to a HuggingFace :code:`AudioProcessor`. + +Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate +is different from that expected by the model; +these are resampled to the model's sampling rate before being processed by HF. + +Alternatively, a 3-D tensor or batch of 2-D tensors, +which are treated as audio embeddings; +these are directly passed to the model without HF processing. +""" + +ModalityData: TypeAlias = Union[_T, list[_T]] """ Either a single data item, or a list of data items. @@ -61,17 +82,17 @@ class MultiModalDataBuiltins(TypedDict, total=False): """Type annotations for modality types predefined by vLLM.""" - image: MultiModalData[ImageItem] + image: ModalityData[ImageItem] """The input image(s).""" - video: MultiModalData[VideoItem] + video: ModalityData[VideoItem] """The input video(s).""" - audio: MultiModalData[AudioItem] + audio: ModalityData[AudioItem] """The input audio(s).""" -MultiModalDataDict: TypeAlias = Mapping[str, MultiModalData[Any]] +MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]] """ A dictionary containing an entry for each modality type to input. @@ -83,123 +104,6 @@ class MultiModalDataBuiltins(TypedDict, total=False): """ -class ImageSize(NamedTuple): - width: int - height: int - - -class MultiModalDataItems(UserDict[str, list[Any]]): - """ - As :class:`MultiModalDataDict`, but normalized such that each entry - corresponds to a list. - """ - - @staticmethod - def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems": - """ - Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`. - """ - multi_data = MultiModalDataItems() - - for k, v in data.items(): - # TODO: Make a separate modality for embedding inputs - # to avoid confusion - # yapf: disable - if k == "video": - # Special case since even a single item can be a list - multi_data[k] = ( # type: ignore[index] - v if ( - isinstance(v, torch.Tensor) - or is_list_of(v, list) - or isinstance(v[0], (np.ndarray, torch.Tensor)) - and v[0].ndim == 4 - ) else [v] - ) - elif k in ("image", "audio"): - multi_data[k] = ( # type: ignore[index] - v if isinstance(v, (torch.Tensor, list)) else [v] - ) - else: - multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index] - # yapf: enable - - return multi_data - - # NOTE: When a field (e.g. `images`) doesn't exist, directly appending to - # `self.images` doesn't update this dictionary, which may be confusing - # We annotate the getter methods as `Sequence` to prevent others from - # trying to update the list in this way - @property - def images(self) -> Sequence[ImageItem]: - return self.get("image", []) - - @property - def videos(self) -> Sequence[VideoItem]: - return self.get("video", []) - - @property - def audios(self) -> Sequence[AudioItem]: - return self.get("audio", []) - - def get_item_counts(self) -> Mapping[str, int]: - return {m: len(items) for m, items in self.items()} - - def has_embedding_inputs(self) -> bool: - return any( - any(isinstance(item, torch.Tensor) for item in items) - for items in self.values()) - - def get_image_size(self, item_idx: int) -> ImageSize: - image = self.images[item_idx] - - if isinstance(image, Image): - return ImageSize(*image.size) - if isinstance(image, (np.ndarray, torch.Tensor)): - _, h, w = image.shape - return ImageSize(w, h) - - assert_never(image) - - def get_audio_with_sr( - self, - item_idx: int, - *, - default_sr: float, - ) -> tuple[np.ndarray, float]: - audio = self.audios[item_idx] - - if isinstance(audio, tuple): - return audio - if isinstance(audio, list): - return np.array(audio), default_sr - if isinstance(audio, np.ndarray): - return audio, default_sr - - assert_never(audio) - - def resample_audios(self, new_sr: float, *, drop_sr: bool = True) -> None: - """ - If :code:`drop_sr=True`, the audio items in this dictionary are updated - to be NumPy arrays which implicitly means that their sampling rate is - the same as the model's expected sampling rate; otherwise, they remain - as :code:`(audio, new_sr)` tuples. - """ - # Avoid circular import - from .audio import resample_audio - - if not self.audios: - return - - new_audios = [] - for item_idx in range(len(self.audios)): - audio, sr = self.get_audio_with_sr(item_idx, default_sr=new_sr) - audio = resample_audio(audio, orig_sr=sr, target_sr=new_sr) - - new_audios.append(audio if drop_sr else (audio, new_sr)) - - self["audio"] = new_audios - - class PlaceholderRange(TypedDict): """ Placeholder location information for multi-modal data. @@ -436,7 +340,7 @@ def from_items_by_key( ) -> "MultiModalKwargs": data = { key: items[0].field.reduce(items).data - for key, items in items_by_key.items() + for key, items in items_by_key.items() if len(items) > 0 } return MultiModalKwargs(data, @@ -567,6 +471,11 @@ def get_items_by_modality( Get the keyword arguments corresponding to an item identified by its modality and index. """ + if modality not in self._keys_by_modality: + available_modalities = set(self._keys_by_modality.keys()) + raise KeyError(f"Modality {modality!r} not found. " + f"Available modalities: {available_modalities}") + keys_to_gather = self._keys_by_modality[modality] return { diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py new file mode 100644 index 0000000000000..17a795247372e --- /dev/null +++ b/vllm/multimodal/parse.py @@ -0,0 +1,344 @@ +from abc import ABC, abstractmethod +from collections import UserDict +from collections.abc import Callable, Iterator, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar + +import numpy as np +import torch +from PIL.Image import Image +from typing_extensions import TypeAlias, TypeGuard, assert_never + +from vllm.utils import is_list_of + +from .audio import resample_audio +from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem, + ImageItem, ModalityData, MultiModalDataDict, + NestedTensors, VideoItem) + +_T = TypeVar("_T") +_I = TypeVar("_I") + + +class ModalityDataItems(ABC, Generic[_T, _I]): + + def __init__(self, data: _T) -> None: + super().__init__() + + self.data = data + + def __len__(self) -> int: + return self.get_count() + + def __getitem__(self, index: int) -> _I: + return self.get(index) + + if TYPE_CHECKING: + # Auto-generated + def __iter__(self) -> Iterator[_I]: + ... + + @abstractmethod + def get_count(self) -> int: + """Get the number of data items.""" + raise NotImplementedError + + @abstractmethod + def get(self, index: int) -> _I: + """Get a data item by its index.""" + raise NotImplementedError + + def get_all(self) -> list[_I]: + """Get all data items.""" + return [self.get(idx) for idx in range(self.get_count())] + + @abstractmethod + def get_processor_data(self) -> Mapping[str, object]: + """Get the data to pass to the HF processor.""" + raise NotImplementedError + + @abstractmethod + def get_passthrough_data(self) -> Mapping[str, object]: + """Get the data to pass directly to the model.""" + raise NotImplementedError + + +class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]): + + def __init__(self, data: Sequence[_T], modality: str) -> None: + super().__init__(data) + + self.modality = modality + + def __repr__(self) -> str: + return (f"{type(self).__name__}(modality={self.modality!r})") + + def get_count(self) -> int: + return len(self.data) + + def get(self, index: int) -> _T: + return self.data[index] + + def get_processor_data(self) -> Mapping[str, object]: + return {f"{self.modality}s": self.data} + + def get_passthrough_data(self) -> Mapping[str, object]: + return {} + + +class EmbeddingItems(ModalityDataItems[NestedTensors, torch.Tensor]): + + def __init__(self, data: NestedTensors, modality: str) -> None: + super().__init__(data) + + self.modality = modality + + def __repr__(self) -> str: + return (f"{type(self).__name__}(modality={self.modality!r})") + + def get_count(self) -> int: + return len(self.data) + + def get(self, index: int) -> object: + return self.data[index] + + def get_processor_data(self) -> Mapping[str, object]: + return {} + + def get_passthrough_data(self) -> Mapping[str, object]: + return {f"{self.modality}_embeds": self.data} + + +class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]): + + def __init__(self, data: Sequence[HfAudioItem]) -> None: + super().__init__(data, "audio") + + +class AudioEmbeddingItems(EmbeddingItems): + + def __init__(self, data: NestedTensors) -> None: + super().__init__(data, "audio") + + +class ImageSize(NamedTuple): + width: int + height: int + + +class ImageProcessorItems(ProcessorBatchItems[HfImageItem]): + + def __init__(self, data: Sequence[HfImageItem]) -> None: + super().__init__(data, "image") + + def get_image_size(self, item_idx: int) -> ImageSize: + image = self.get(item_idx) + + if isinstance(image, Image): + return ImageSize(*image.size) + if isinstance(image, (np.ndarray, torch.Tensor)): + _, h, w = image.shape + return ImageSize(w, h) + + assert_never(image) + + +class ImageEmbeddingItems(EmbeddingItems): + + def __init__(self, data: NestedTensors) -> None: + super().__init__(data, "image") + + +class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]): + + def __init__(self, data: Sequence[HfVideoItem]) -> None: + super().__init__(data, "video") + + +class VideoEmbeddingItems(EmbeddingItems): + + def __init__(self, data: NestedTensors) -> None: + super().__init__(data, "video") + + +_D = TypeVar("_D", bound=ModalityDataItems[Any, Any]) + + +class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]): + """ + As :class:`MultiModalDataDict`, but normalized such that each entry + corresponds to a list. + """ + + def get_count(self, modality: str, *, strict: bool = True) -> int: + """ + Get the number of data items belonging to a modality. + + If `strict=False`, return `0` instead of raising :exc:`KeyError` + even if the modality is not found. + """ + if modality not in self: + if strict: + available_modalities = set(self.keys()) + raise KeyError(f"Modality {modality!r} not found. " + f"Available modalities: {available_modalities}") + + return 0 + + return self[modality].get_count() + + def get_all_counts(self) -> Mapping[str, int]: + """Get the number of items belonging to each modality.""" + return {m: items.get_count() for m, items in self.items()} + + def get_items( + self, + modality: str, + typ: type[_D], + ) -> _D: + """ + Get the data items belonging to a modality, + requiring that they belong to a certain type. + """ + if modality not in self: + available_modalities = set(self.keys()) + raise KeyError(f"Modality {modality!r} not found. " + f"Available modalities: {available_modalities}") + + items = self[modality] + if not isinstance(items, typ): + raise TypeError(f"Invalid type of data items for {modality=}. " + f"Expected type: {typ}, but " + f"found type: {type(items)}") + + return items + + +ModalityDataParser: TypeAlias = Callable[[ModalityData[Any]], + ModalityDataItems[Any, Any]] + + +class MultiModalDataParser: + """ + Parses :class:`MultiModalDataDict` into :class:`MultiModalDataItems`. + """ + + def __init__(self, *, target_sr: Optional[float] = None) -> None: + super().__init__() + + self.target_sr = target_sr + + def _is_embeddings(self, data: object) -> TypeGuard[NestedTensors]: + if isinstance(data, torch.Tensor): + return data.ndim == 3 + if is_list_of(data, torch.Tensor): + return len(data) == 0 or data[0].ndim == 2 + + return False + + def _get_audio_with_sr( + self, + audio: AudioItem, + ) -> tuple[np.ndarray, Optional[float]]: + if isinstance(audio, tuple): + return audio + if isinstance(audio, list): + return np.array(audio), None + if isinstance(audio, np.ndarray): + return audio, None + if isinstance(audio, torch.Tensor): + return audio.numpy(), None + + assert_never(audio) + + def _parse_audio_data( + self, + data: ModalityData[AudioItem], + ) -> ModalityDataItems[Any, Any]: + if self._is_embeddings(data): + return AudioEmbeddingItems(data) + + if (is_list_of(data, float) + or isinstance(data, + (np.ndarray, torch.Tensor)) and data.ndim == 1 + or isinstance(data, tuple)): + data_items = [data] + elif isinstance(data, (np.ndarray, torch.Tensor)): + data_items = [elem for elem in data] + else: + data_items = data + + new_audios = list[np.ndarray]() + for data_item in data_items: + audio, orig_sr = self._get_audio_with_sr(data_item) + if orig_sr is None: + new_audio = audio + else: + target_sr = self.target_sr + if target_sr is None: + raise RuntimeError( + "Audio resampling is not supported when " + "`target_sr` is not provided") + + new_audio = resample_audio(audio, + orig_sr=orig_sr, + target_sr=target_sr) + + new_audios.append(new_audio) + + return AudioProcessorItems(new_audios) + + def _parse_image_data( + self, + data: ModalityData[ImageItem], + ) -> ModalityDataItems[Any, Any]: + if self._is_embeddings(data): + return ImageEmbeddingItems(data) + + if (isinstance(data, Image) + or isinstance(data, + (np.ndarray, torch.Tensor)) and data.ndim == 3): + data_items = [data] + elif isinstance(data, (np.ndarray, torch.Tensor)): + data_items = [elem for elem in data] + else: + data_items = data + + return ImageProcessorItems(data_items) + + def _parse_video_data( + self, + data: ModalityData[VideoItem], + ) -> ModalityDataItems[Any, Any]: + if self._is_embeddings(data): + return VideoEmbeddingItems(data) + + if (is_list_of(data, Image) + or isinstance(data, + (np.ndarray, torch.Tensor)) and data.ndim == 4): + data_items = [data] + elif isinstance(data, (np.ndarray, torch.Tensor)): + data_items = [elem for elem in data] + else: + data_items = data + + return VideoProcessorItems(data_items) + + def _get_subparsers(self) -> Mapping[str, ModalityDataParser]: + return { + "audio": self._parse_audio_data, + "image": self._parse_image_data, + "video": self._parse_video_data, + } + + def parse_mm_data(self, + mm_data: MultiModalDataDict) -> MultiModalDataItems: + subparsers = self._get_subparsers() + + mm_items = MultiModalDataItems() + for k, v in mm_data.items(): + if k not in subparsers: + raise ValueError(f"Unsupported modality: {k}") + + mm_items[k] = subparsers[k](v) + + return mm_items diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 3ece0762e3228..180489166b407 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -15,11 +15,12 @@ from vllm.inputs import DummyData, InputProcessingContext from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import LRUCache, flatten_2d_lists, full_groupby, is_list_of +from vllm.utils import LRUCache, flatten_2d_lists, full_groupby -from .inputs import (MultiModalDataDict, MultiModalDataItems, - MultiModalFieldConfig, MultiModalFieldItem, - MultiModalInputsV2, MultiModalKwargs, PlaceholderRange) +from .inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalFieldItem, MultiModalInputsV2, MultiModalKwargs, + PlaceholderRange) +from .parse import MultiModalDataItems, MultiModalDataParser logger = init_logger(__name__) @@ -621,6 +622,16 @@ def __call__( ) -> MultiModalInputsV2: return self.apply(prompt, mm_data, hf_processor_mm_kwargs) + def _get_data_parser(self) -> MultiModalDataParser: + """ + Construct a data parser to preprocess multi-modal data items + before passing them to :meth:`_get_hf_mm_data`. + + You can support additional modalities by creating a subclass + of :class:`MultiModalDataParser` that has additional subparsers. + """ + return MultiModalDataParser() + def _get_hf_processor(self) -> ProcessorMixin: """ Subclasses can add keyword arguments to this method to accept @@ -631,11 +642,16 @@ def _get_hf_processor(self) -> ProcessorMixin: def _get_tokenizer(self) -> AnyTokenizer: return self.ctx.tokenizer - def _get_mm_items( + def _to_mm_items( self, mm_data: MultiModalDataDict, ) -> MultiModalDataItems: - return MultiModalDataItems.from_dict(mm_data) + """ + Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems` + before passing them to :meth:`_get_hf_mm_data`. + """ + parser = self._get_data_parser() + return parser.parse_mm_data(mm_data) @abstractmethod def _get_mm_fields_config( @@ -680,22 +696,9 @@ def _get_hf_mm_data( processor_data = dict[str, Any]() passthrough_data = dict[str, Any]() - for k, v in mm_items.items(): - # TODO: Make a separate modality for embedding inputs - # to avoid confusion - if k in ("image", "video", "audio"): - if isinstance(v, torch.Tensor) and v.ndim == 3: - # Pass through embedding inputs (single) - passthrough_data[f"{k}_embeds"] = [v] - elif (is_list_of(v, torch.Tensor) and len(v) > 0 - and v[0].ndim == 2): - # Pass through embedding inputs (multi) - passthrough_data[f"{k}_embeds"] = v - elif len(v) > 0: - # Map keys to plural form, e.g.: image -> images - processor_data[f"{k}s"] = v - else: - processor_data[k] = v + for items in mm_items.values(): + processor_data.update(items.get_processor_data()) + passthrough_data.update(items.get_passthrough_data()) return processor_data, passthrough_data @@ -756,7 +759,7 @@ def _apply_hf_processor_missing( cached items; instead, we rely on our own prompt replacement logic for the full text. """ - mm_missing_counts = mm_missing_data_items.get_item_counts() + mm_missing_counts = mm_missing_data_items.get_all_counts() prompt_ids, _ = self._apply_hf_processor( prompt_text=prompt_text, @@ -789,7 +792,8 @@ def _cached_apply_hf_processor( cache = self.cache model_id = self.ctx.model_config.model - if cache is None or mm_data_items.has_embedding_inputs(): + _, passthrough_data = self._get_hf_mm_data(mm_data_items) + if cache is None or passthrough_data: return self._apply_hf_processor( prompt_text=prompt_text, mm_items=mm_data_items, @@ -812,7 +816,7 @@ def _cached_apply_hf_processor( modality: [mm_data_items[modality][idx] for idx in idxs] for modality, idxs in mm_missing_idxs.items() } - mm_missing_data_items = self._get_mm_items(mm_missing_data) + mm_missing_data_items = self._to_mm_items(mm_missing_data) prompt_ids, mm_missing_kwargs = self._apply_hf_processor_missing( prompt_text=prompt_text, @@ -852,7 +856,7 @@ def _cached_apply_hf_processor( mm_merged_field_items[modality] = merged_modal_items_lst if self.enable_sanity_checks: - mm_missing_counts = mm_missing_data_items.get_item_counts() + mm_missing_counts = mm_missing_data_items.get_all_counts() assert all( item_count == mm_missing_counts[modality] for modality, item_count in mm_missing_next_idx.items()), dict( @@ -865,7 +869,7 @@ def _cached_apply_hf_processor( ) if self.enable_sanity_checks: - mm_item_counts = mm_data_items.get_item_counts() + mm_item_counts = mm_data_items.get_all_counts() for modality, item_count in mm_item_counts.items(): for item_idx in range(item_count): @@ -958,7 +962,7 @@ def apply( 3. Extract information about the placeholder tokens from the processed token IDs. """ - mm_items = self._get_mm_items(mm_data) + mm_items = self._to_mm_items(mm_data) prompt_ids, mm_kwargs = self._cached_apply_hf_processor( prompt_text, @@ -975,7 +979,7 @@ def apply( # If HF processor already inserts placeholder tokens, # there is no need for us to insert them - mm_item_counts = mm_items.get_item_counts() + mm_item_counts = mm_items.get_all_counts() all_placeholders = self._find_placeholders(prompt_repls, prompt_ids, mm_item_counts) diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index b7d43c830cc46..1ad1f5abc27a2 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -15,7 +15,7 @@ from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.utils import PlaceholderModule, is_list_of -from .base import MediaIO, MultiModalData +from .base import MediaIO, ModalityData from .image import ImageMediaIO, ImagePlugin from .inputs import MultiModalKwargs, VideoItem @@ -54,7 +54,7 @@ def _get_hf_video_processor( def _default_input_mapper( self, ctx: InputContext, - data: MultiModalData[VideoItem], + data: ModalityData[VideoItem], **mm_processor_kwargs, ) -> MultiModalKwargs: model_config = ctx.model_config From 5886aa496e8fa31c9180bcfc8e89faaa8899907d Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Mon, 30 Dec 2024 10:51:02 -0500 Subject: [PATCH 07/10] [V1] [6/N] API Server: Better Shutdown (#11586) --- vllm/entrypoints/openai/api_server.py | 44 ++++++++------------------- vllm/v1/engine/async_llm.py | 25 +++++++++++++-- vllm/v1/engine/core_client.py | 16 ++++------ 3 files changed, 40 insertions(+), 45 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 094cc15a317e9..bac72d87376da 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -68,7 +68,7 @@ from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path, - is_valid_ipv6_address, kill_process_tree, set_ulimit) + is_valid_ipv6_address, set_ulimit) from vllm.version import __version__ as VLLM_VERSION TIMEOUT_KEEP_ALIVE = 5 # seconds @@ -133,32 +133,21 @@ async def build_async_engine_client_from_engine_args( Returns the Client or None if the creation failed. """ - # Fall back - # TODO: fill out feature matrix. + # AsyncLLMEngine. if (MQLLMEngineClient.is_unsupported_config(engine_args) or envs.VLLM_USE_V1 or disable_frontend_multiprocessing): - engine_config = engine_args.create_engine_config( - UsageContext.OPENAI_API_SERVER) - uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config), - "uses_ray", False) - - build_engine = partial(AsyncLLMEngine.from_engine_args, - engine_args=engine_args, - engine_config=engine_config, - usage_context=UsageContext.OPENAI_API_SERVER) - if uses_ray: - # Must run in main thread with ray for its signal handlers to work - engine_client = build_engine() - else: - engine_client = await asyncio.get_running_loop().run_in_executor( - None, build_engine) - yield engine_client - if hasattr(engine_client, "shutdown"): - engine_client.shutdown() - return + engine_client: Optional[EngineClient] = None + try: + engine_client = AsyncLLMEngine.from_engine_args( + engine_args=engine_args, + usage_context=UsageContext.OPENAI_API_SERVER) + yield engine_client + finally: + if engine_client and hasattr(engine_client, "shutdown"): + engine_client.shutdown() - # Otherwise, use the multiprocessing AsyncLLMEngine. + # MQLLMEngine. else: if "PROMETHEUS_MULTIPROC_DIR" not in os.environ: # Make TemporaryDirectory for prometheus multiprocessing @@ -737,15 +726,6 @@ def signal_handler(*_) -> None: signal.signal(signal.SIGTERM, signal_handler) - # The child processes will send SIGQUIT to this process when - # any error happens. This process then clean up the whole tree. - # TODO(rob): move this into AsyncLLM.__init__ once we remove - # the context manager below. - def sigquit_handler(signum, frame): - kill_process_tree(os.getpid()) - - signal.signal(signal.SIGQUIT, sigquit_handler) - async with build_async_engine_client(args) as engine_client: app = build_app(args) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 213ddaa023dbc..3f097ca7f439c 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -1,4 +1,6 @@ import asyncio +import os +import signal from typing import AsyncGenerator, Dict, List, Mapping, Optional, Type, Union from vllm.config import ModelConfig, VllmConfig @@ -16,6 +18,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext +from vllm.utils import kill_process_tree from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.detokenizer import Detokenizer from vllm.v1.engine.processor import Processor @@ -38,6 +41,22 @@ def __init__( log_requests: bool = True, start_engine_loop: bool = True, ) -> None: + + # The child processes will send SIGQUIT when unrecoverable + # errors happen. We kill the process tree here so that the + # stack trace is very evident. + # TODO: rather than killing the main process, we should + # figure out how to raise an AsyncEngineDeadError and + # handle at the API server level so we can return a better + # error code to the clients calling VLLM. + def sigquit_handler(signum, frame): + logger.fatal( + "AsyncLLM got SIGQUIT from worker processes, shutting " + "down. See stack trace above for root cause issue.") + kill_process_tree(os.getpid()) + + signal.signal(signal.SIGQUIT, sigquit_handler) + assert start_engine_loop self.log_requests = log_requests @@ -276,9 +295,9 @@ async def _run_output_handler(self): # 4) Abort any requests that finished due to stop strings. await self.engine_core.abort_requests_async(reqs_to_abort) - except BaseException as e: - logger.error(e) - raise e + except Exception as e: + logger.exception("EngineCore output handler hit an error: %s", e) + kill_process_tree(os.getpid()) async def abort(self, request_id: str) -> None: """Abort RequestId in self, detokenizer, and engine core.""" diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index beb5d57c20c83..3293205e110af 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -6,7 +6,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.utils import get_open_zmq_ipc_path +from vllm.utils import get_open_zmq_ipc_path, make_zmq_socket from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, EngineCoreProfile, EngineCoreRequest, EngineCoreRequestType, EngineCoreRequestUnion) @@ -144,17 +144,13 @@ def __init__( else: self.ctx = zmq.Context() # type: ignore[attr-defined] - # Path for IPC. + # Paths and sockets for IPC. output_path = get_open_zmq_ipc_path() input_path = get_open_zmq_ipc_path() - - # Get output (EngineCoreOutput) from EngineCore. - self.output_socket = self.ctx.socket(zmq.constants.PULL) - self.output_socket.connect(output_path) - - # Send input (EngineCoreRequest) to EngineCore. - self.input_socket = self.ctx.socket(zmq.constants.PUSH) - self.input_socket.bind(input_path) + self.output_socket = make_zmq_socket(self.ctx, output_path, + zmq.constants.PULL) + self.input_socket = make_zmq_socket(self.ctx, input_path, + zmq.constants.PUSH) # Start EngineCore in background process. self.proc_handle: Optional[BackgroundProcHandle] From 36e76700453924c8d421db99af70a88a1df835cd Mon Sep 17 00:00:00 2001 From: whyiug Date: Tue, 31 Dec 2024 02:51:04 +0800 Subject: [PATCH 08/10] [Bugfix] Validate and concatenate image embeddings in MiniCPMVBaseModel (#11631) --- vllm/model_executor/models/minicpmv.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 1e8f9bd4cf418..712022502539b 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -487,6 +487,12 @@ def _parse_and_validate_inputs( image_embeds = kwargs.pop("image_embeds", None) if image_embeds is not None: + if not isinstance(image_embeds, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of image embeds. " + f"Got type: {type(image_embeds)}") + if isinstance(image_embeds, list): + image_embeds = torch.concat(image_embeds) + return MiniCPMVImageEmbeddingInputs( image_bounds=self._get_image_bounds(input_ids, im_start_id, im_end_id, slice_start_id, From ccb1aabccaa7aaf07b08fd8be30380e828efba0f Mon Sep 17 00:00:00 2001 From: "Kevin H. Luu" Date: Mon, 30 Dec 2024 12:27:07 -0800 Subject: [PATCH 09/10] [benchmark] Remove dependency for H100 benchmark step (#11572) --- .buildkite/nightly-benchmarks/benchmark-pipeline.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml index 708e548727cf5..868b8e95db01d 100644 --- a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml +++ b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml @@ -73,7 +73,7 @@ steps: # skip: "use this flag to conditionally skip the benchmark step, useful for PR testing" agents: queue: H100 - depends_on: block-h100 + depends_on: ~ plugins: - docker#v5.12.0: image: public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:$BUILDKITE_COMMIT From a2a40bcd0d8275e19c46e9cc06ee994d8839b98d Mon Sep 17 00:00:00 2001 From: Matthias Vogler <60004995+ayylemao@users.noreply.github.com> Date: Tue, 31 Dec 2024 02:33:06 +0100 Subject: [PATCH 10/10] [Model][LoRA]LoRA support added for MolmoForCausalLM (#11439) Signed-off-by: Matthias Vogler Signed-off-by: Jee Jee Li Co-authored-by: Matthias Vogler Co-authored-by: Jee Jee Li --- docs/source/models/supported_models.md | 2 +- vllm/model_executor/models/molmo.py | 45 ++++++++++++++++++++++++-- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 518505abeb2a9..613343281464c 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -666,7 +666,7 @@ See [this page](#generative-models) for more information on how to use generativ - Molmo - T + I - `allenai/Molmo-7B-D-0924`, `allenai/Molmo-72B-0924`, etc. - - + - ✅︎ - ✅︎ - ✅︎ * - `NVLM_D_Model` diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 5d52d2c3e6b48..cc25be9f5b6a9 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -36,6 +36,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.inputs import NestedTensors, PlaceholderRange from vllm.multimodal.utils import cached_get_tokenizer @@ -43,7 +44,7 @@ SequenceData) from vllm.transformers_utils.processor import get_processor -from .interfaces import SupportsMultiModal, SupportsPP +from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, merge_multimodal_embeddings) @@ -1161,8 +1162,8 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_molmo_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_molmo) @INPUT_REGISTRY.register_input_processor(input_processor_for_molmo) -class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): - +class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, + SupportsLoRA): hf_to_vllm_mapper = WeightsMapper( orig_to_new_substr={ # vision backbone mapping @@ -1191,6 +1192,32 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): }, ) + packed_modules_mapping = { + "qkv_proj": ["qkv_proj"], + "gate_up_proj": ["gate_up_proj"], # language model + "merged_linear": ["gate_proj", "up_proj"] # image_projector + } + + # LoRA specific attributes + supported_lora_modules = [ + # language model + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", # same name with image_projector + # vision tower + "wq", + "wk", + "wv", + "wo", + "w1", + "w2", + # image_projector + "merged_linear", + ] + embedding_modules = {} + embedding_padding_modules = [] + # BitandBytes specific attributes bitsandbytes_stacked_params_mapping = { "gate_proj": ("merged_linear", 0), @@ -1202,8 +1229,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config + lora_config = vllm_config.lora_config self.config = config self.multimodal_config = multimodal_config + self.lora_config = lora_config vision_config = VisionBackboneConfig() self.vision_backbone = MolmoVisionBackbone(config, vision_config, @@ -1377,6 +1406,16 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weights = _get_weights_with_merged_embedding(weights) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="model", + connector="vision_backbone.image_projector", + tower_model="vision_backbone", + ) + def _get_weights_with_merged_embedding( weights: Iterable[Tuple[str, torch.Tensor]]