From af7c4a92e654684066e61518d6ed90feda983635 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Sun, 8 Dec 2024 22:29:16 -0800 Subject: [PATCH 01/18] [Doc][V1] Add V1 support column for multimodal models (#10998) Signed-off-by: Roger Wang --- docs/source/models/supported_models.rst | 26 ++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index c9b3fa8485ff1..4e5b10967e3bb 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -495,7 +495,7 @@ Text Generation --------------- .. list-table:: - :widths: 25 25 15 25 5 5 + :widths: 25 25 15 20 5 5 5 :header-rows: 1 * - Architecture @@ -504,47 +504,55 @@ Text Generation - Example HF Models - :ref:`LoRA ` - :ref:`PP ` + - V1 * - :code:`AriaForConditionalGeneration` - Aria - T + I - :code:`rhymes-ai/Aria` - - ✅︎ + - * - :code:`Blip2ForConditionalGeneration` - BLIP-2 - T + I\ :sup:`E` - :code:`Salesforce/blip2-opt-2.7b`, :code:`Salesforce/blip2-opt-6.7b`, etc. - - ✅︎ + - * - :code:`ChameleonForConditionalGeneration` - Chameleon - T + I - :code:`facebook/chameleon-7b` etc. - - ✅︎ + - * - :code:`FuyuForCausalLM` - Fuyu - T + I - :code:`adept/fuyu-8b` etc. - - ✅︎ + - * - :code:`ChatGLMModel` - GLM-4V - T + I - :code:`THUDM/glm-4v-9b` etc. - ✅︎ - ✅︎ + - * - :code:`H2OVLChatModel` - H2OVL - T + I\ :sup:`E+` - :code:`h2oai/h2ovl-mississippi-800m`, :code:`h2oai/h2ovl-mississippi-2b`, etc. - - ✅︎ + - * - :code:`Idefics3ForConditionalGeneration` - Idefics3 - T + I - :code:`HuggingFaceM4/Idefics3-8B-Llama3` etc. - ✅︎ + - - * - :code:`InternVLChatModel` - InternVL 2.5, Mono-InternVL, InternVL 2.0 @@ -552,96 +560,112 @@ Text Generation - :code:`OpenGVLab/InternVL2_5-4B`, :code:`OpenGVLab/Mono-InternVL-2B`, :code:`OpenGVLab/InternVL2-4B`, etc. - - ✅︎ + - ✅︎ * - :code:`LlavaForConditionalGeneration` - LLaVA-1.5 - T + I\ :sup:`E+` - :code:`llava-hf/llava-1.5-7b-hf`, :code:`TIGER-Lab/Mantis-8B-siglip-llama3` (see note), etc. - - ✅︎ + - ✅︎ * - :code:`LlavaNextForConditionalGeneration` - LLaVA-NeXT - T + I\ :sup:`E+` - :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc. - - ✅︎ + - * - :code:`LlavaNextVideoForConditionalGeneration` - LLaVA-NeXT-Video - T + V - :code:`llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. - - ✅︎ + - * - :code:`LlavaOnevisionForConditionalGeneration` - LLaVA-Onevision - T + I\ :sup:`+` + V\ :sup:`+` - :code:`llava-hf/llava-onevision-qwen2-7b-ov-hf`, :code:`llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. - - ✅︎ + - * - :code:`MiniCPMV` - MiniCPM-V - T + I\ :sup:`E+` - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc. - ✅︎ - ✅︎ + - * - :code:`MllamaForConditionalGeneration` - Llama 3.2 - T + I\ :sup:`+` - :code:`meta-llama/Llama-3.2-90B-Vision-Instruct`, :code:`meta-llama/Llama-3.2-11B-Vision`, etc. - - + - * - :code:`MolmoForCausalLM` - Molmo - T + I - :code:`allenai/Molmo-7B-D-0924`, :code:`allenai/Molmo-72B-0924`, etc. - - ✅︎ + - ✅︎ * - :code:`NVLM_D_Model` - NVLM-D 1.0 - T + I\ :sup:`E+` - :code:`nvidia/NVLM-D-72B`, etc. - - ✅︎ + - ✅︎ * - :code:`PaliGemmaForConditionalGeneration` - PaliGemma - T + I\ :sup:`E` - :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, etc. - - ✅︎ + - * - :code:`Phi3VForCausalLM` - Phi-3-Vision, Phi-3.5-Vision - T + I\ :sup:`E+` - :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc. - - ✅︎ + - ✅︎ * - :code:`PixtralForConditionalGeneration` - Pixtral - T + I\ :sup:`+` - :code:`mistralai/Pixtral-12B-2409`, :code:`mistral-community/pixtral-12b` etc. - - ✅︎ + - ✅︎ * - :code:`QWenLMHeadModel` - Qwen-VL - T + I\ :sup:`E+` - :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc. - ✅︎ - ✅︎ + - * - :code:`Qwen2AudioForConditionalGeneration` - Qwen2-Audio - T + A\ :sup:`+` - :code:`Qwen/Qwen2-Audio-7B-Instruct` - - ✅︎ + - * - :code:`Qwen2VLForConditionalGeneration` - Qwen2-VL - T + I\ :sup:`E+` + V\ :sup:`E+` - :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc. - ✅︎ - ✅︎ + - * - :code:`UltravoxModel` - Ultravox - T + A\ :sup:`E+` - :code:`fixie-ai/ultravox-v0_3` - - ✅︎ + - | :sup:`E` Pre-computed embeddings can be inputted for this modality. | :sup:`+` Multiple items can be inputted per text prompt for this modality. From d1c2e15eb31ef12e688ce0cb71895f88eaf4cd4f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 8 Dec 2024 23:09:04 -0800 Subject: [PATCH 02/18] [torch.compile] add dynamo time tracking (#11005) Signed-off-by: youkaichao --- vllm/compilation/backends.py | 6 ++++++ vllm/compilation/decorators.py | 6 +++--- vllm/compilation/monitor.py | 9 +++++++-- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 1206424ae1e3f..f002a8ff905b1 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -265,7 +265,13 @@ def configure_post_pass(self): def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: + # when dynamo calls the backend, it means the bytecode + # transform and analysis are done compilation_counter.num_graphs_seen += 1 + from .monitor import torch_compile_start_time + dynamo_time = time.time() - torch_compile_start_time + logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time) + self.compilation_configs.compilation_time += dynamo_time # we control the compilation process, each instance can only be # called once diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index a32dced57e5b3..938430fe2a501 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -145,6 +145,7 @@ def _support_torch_compile( def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) + self.vllm_config = vllm_config # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner # will handle the compilation, so we don't need to do anything here. self.do_not_compile = \ @@ -157,9 +158,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): TorchCompileWrapperWithCustomDispatcher.__init__( self, compilation_level=vllm_config.compilation_config.level) - if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE: - start_monitoring_torch_compile(vllm_config.compilation_config) - cls.__init__ = __init__ def __call__(self, *args, **kwargs): @@ -186,6 +184,8 @@ def __call__(self, *args, **kwargs): raise ValueError( "Unsupported dynamic dimensions" f" {dims} for argument {k} with type {type(arg)}.") + # here, it is the starting point of the `torch.compile` process + start_monitoring_torch_compile(self.vllm_config.compilation_config) # if we don't use custom dispatcher, we can directly call the # compiled function and let torch.compile handle the dispatching, diff --git a/vllm/compilation/monitor.py b/vllm/compilation/monitor.py index f718e46423212..3348674b09af2 100644 --- a/vllm/compilation/monitor.py +++ b/vllm/compilation/monitor.py @@ -1,14 +1,19 @@ +import time + from vllm.config import CompilationConfig, CompilationLevel from vllm.logger import init_logger logger = init_logger(__name__) +torch_compile_start_time: float = 0.0 + def start_monitoring_torch_compile(compilation_config: CompilationConfig): - pass + global torch_compile_start_time + torch_compile_start_time = time.time() def end_monitoring_torch_compile(compilation_config: CompilationConfig): if compilation_config.level == CompilationLevel.PIECEWISE: - logger.info("graph compilation takes %.2f s in total", + logger.info("torch.compile takes %.2f s in total", compilation_config.compilation_time) From c690357928fd2812f450bfb0c3629a816f5e9a55 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Mon, 9 Dec 2024 08:27:10 -0800 Subject: [PATCH 03/18] [V1] Fix Detokenizer loading in `AsyncLLM` (#10997) Signed-off-by: Roger Wang --- vllm/v1/engine/async_llm.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 4ef372fd8464b..0bcccda2bf329 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -65,7 +65,12 @@ def __init__( input_registry) # Detokenizer (converts EngineCoreOutputs --> RequestOutput). - self.detokenizer = Detokenizer(vllm_config.model_config.tokenizer) + self.detokenizer = Detokenizer( + tokenizer_name=vllm_config.model_config.tokenizer, + tokenizer_mode=vllm_config.model_config.tokenizer_mode, + trust_remote_code=vllm_config.model_config.trust_remote_code, + revision=vllm_config.model_config.tokenizer_revision, + ) # EngineCore (starts the engine in background process). self.engine_core = EngineCoreClient.make_client( From e691b26f6fae5a3a1c220d15f20de83c7d78ed51 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Mon, 9 Dec 2024 11:44:27 -0500 Subject: [PATCH 04/18] [Core] Require xgrammar >= 0.1.6 (#11021) Signed-off-by: Russell Bryant --- requirements-common.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-common.txt b/requirements-common.txt index 72fb020a82c4e..112528880c0ac 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -19,7 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer >= 0.10.9, < 0.11 outlines >= 0.0.43, < 0.1 -xgrammar >= 0.1.5; platform_machine == "x86_64" +xgrammar >= 0.1.6; platform_machine == "x86_64" typing_extensions >= 4.10 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 partial-json-parser # used for parsing partial JSON outputs From aea2fc38c3b31b9a8ea7d1cffb8f37a2da6f6075 Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Tue, 10 Dec 2024 01:24:46 +0800 Subject: [PATCH 05/18] [Platform] Move `async output` check to platform (#10768) Signed-off-by: wangxiyuan --- vllm/config.py | 17 +++-------------- vllm/platforms/cpu.py | 6 +++++- vllm/platforms/cuda.py | 12 +++++++++++- vllm/platforms/hpu.py | 6 +++++- vllm/platforms/interface.py | 11 +++++++++++ vllm/platforms/neuron.py | 6 +++++- vllm/platforms/openvino.py | 6 +++++- vllm/platforms/rocm.py | 12 +++++++++++- vllm/platforms/tpu.py | 6 +++++- vllm/platforms/xpu.py | 6 +++++- 10 files changed, 66 insertions(+), 22 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 7fbe04eaaf4f8..29f0839dcabba 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -513,11 +513,10 @@ def verify_async_output_proc(self, parallel_config, speculative_config, # Reminder: Please update docs/source/usage/compatibility_matrix.rst # If the feature combo become valid - if device_config.device_type not in ("cuda", "tpu", "xpu", "hpu"): + if not current_platform.is_async_output_supported(self.enforce_eager): logger.warning( - "Async output processing is only supported for CUDA, TPU, XPU " - "and HPU." - "Disabling it for other platforms.") + "Async output processing is not supported on the " + "current platform type %s.", current_platform.device_type) self.use_async_output_proc = False return @@ -527,16 +526,6 @@ def verify_async_output_proc(self, parallel_config, speculative_config, self.use_async_output_proc = False return - # Reminder: Please update docs/source/usage/compatibility_matrix.rst - # If the feature combo become valid - if device_config.device_type == "cuda" and self.enforce_eager: - logger.warning( - "To see benefits of async output processing, enable CUDA " - "graph. Since, enforce-eager is enabled, async output " - "processor cannot be used") - self.use_async_output_proc = not self.enforce_eager - return - # Async postprocessor is not necessary with embedding mode # since there is no token generation if self.task == "embedding": diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 680ee74129739..e5142b985d1f2 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import psutil import torch @@ -37,6 +37,10 @@ def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: def get_device_total_memory(cls, device_id: int = 0) -> int: return psutil.virtual_memory().total + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + return False + @classmethod def inference_mode(cls): return torch.no_grad() diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 846a1869da228..edaf377b501df 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -4,7 +4,7 @@ import os from functools import lru_cache, wraps -from typing import TYPE_CHECKING, Callable, List, TypeVar +from typing import TYPE_CHECKING, Callable, List, Optional, TypeVar import pynvml import torch @@ -88,6 +88,16 @@ def get_device_name(cls, device_id: int = 0) -> str: def get_device_total_memory(cls, device_id: int = 0) -> int: raise NotImplementedError + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + if enforce_eager: + logger.warning( + "To see benefits of async output processing, enable CUDA " + "graph. Since, enforce-eager is enabled, async output " + "processor cannot be used") + return False + return True + @classmethod def is_full_nvlink(cls, device_ids: List[int]) -> bool: raise NotImplementedError diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py index 10aaa6d54962c..7f22bee3eaa74 100644 --- a/vllm/platforms/hpu.py +++ b/vllm/platforms/hpu.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -20,6 +20,10 @@ class HpuPlatform(Platform): def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: return _Backend.HPU_ATTN + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + return True + @staticmethod def inference_mode(): return torch.no_grad() diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 0be7df7941b8b..db06d2c18e681 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -6,11 +6,15 @@ import numpy as np import torch +from vllm.logger import init_logger + if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None +logger = init_logger(__name__) + class _Backend(enum.Enum): FLASH_ATTN = enum.auto() @@ -147,6 +151,13 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: """Get the total memory of a device in bytes.""" raise NotImplementedError + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + """ + Check if the current platform supports async output. + """ + raise NotImplementedError + @classmethod def inference_mode(cls): """A device-specific wrapper of `torch.inference_mode`. diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index 87655ea198303..1e5c4bddfa24f 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from .interface import Platform, PlatformEnum @@ -18,6 +18,10 @@ class NeuronPlatform(Platform): def get_device_name(cls, device_id: int = 0) -> str: return "neuron" + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + return False + @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config = vllm_config.parallel_config diff --git a/vllm/platforms/openvino.py b/vllm/platforms/openvino.py index 29b61e955d9ab..e0f8e8b4b49fe 100644 --- a/vllm/platforms/openvino.py +++ b/vllm/platforms/openvino.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -37,6 +37,10 @@ def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: def get_device_name(self, device_id: int = 0) -> str: return "openvino" + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + return False + @classmethod def inference_mode(self): return torch.inference_mode(mode=True) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 3c14fbc179f69..66674e3ebe91f 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -1,6 +1,6 @@ import os from functools import lru_cache -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -72,6 +72,16 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: device_props = torch.cuda.get_device_properties(device_id) return device_props.total_memory + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + if enforce_eager: + logger.warning( + "To see benefits of async output processing, enable CUDA " + "graph. Since, enforce-eager is enabled, async output " + "processor cannot be used") + return False + return True + @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config = vllm_config.parallel_config diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index b138f7e1c54c5..10d874349f36b 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -35,6 +35,10 @@ def get_device_name(cls, device_id: int = 0) -> str: def get_device_total_memory(cls, device_id: int = 0) -> int: raise NotImplementedError + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + return True + @classmethod def inference_mode(cls): return torch.no_grad() diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 9665786f4c499..11dbd04d55671 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -41,6 +41,10 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: device_props = torch.xpu.get_device_properties(device_id) return device_props.total_memory + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + return True + @staticmethod def inference_mode(): return torch.no_grad() From 25b79d9fd38e2c53ce281be23241d8939ec7320c Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Mon, 9 Dec 2024 12:33:41 -0500 Subject: [PATCH 06/18] [V1] Input Batch Relocation (#10962) Signed-off-by: Varun Sundar Rabindranath Co-authored-by: Varun Sundar Rabindranath --- vllm/v1/worker/gpu_input_batch.py | 280 +++++++++++++++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 273 +--------------------------- 2 files changed, 283 insertions(+), 270 deletions(-) create mode 100644 vllm/v1/worker/gpu_input_batch.py diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py new file mode 100644 index 0000000000000..457784bb0287c --- /dev/null +++ b/vllm/v1/worker/gpu_input_batch.py @@ -0,0 +1,280 @@ +# Datastructures defining an input batch + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Set + +import numpy as np +import torch + +from vllm.multimodal import MultiModalKwargs +from vllm.sampling_params import SamplingParams, SamplingType +from vllm.v1.sample.metadata import SamplingMetadata + +if TYPE_CHECKING: + from vllm.multimodal.inputs import PlaceholderRange + + +@dataclass +class CachedRequestState: + + req_id: str + prompt_token_ids: List[int] + prompt: Optional[str] + mm_inputs: List[MultiModalKwargs] + mm_positions: List["PlaceholderRange"] + sampling_params: SamplingParams + generator: Optional[torch.Generator] + + block_ids: List[int] + num_computed_tokens: int + output_token_ids: List[int] + + @property + def num_tokens(self) -> int: + return len(self.prompt_token_ids) + len(self.output_token_ids) + + +class InputBatch: + + def __init__( + self, + max_num_reqs: int, + max_model_len: int, + max_num_blocks_per_req: int, + device: torch.device, + pin_memory: bool, + ): + self.max_num_reqs = max_num_reqs + self.max_model_len = max_model_len + self.max_num_blocks_per_req = max_num_blocks_per_req + self.device = device + self.pin_memory = pin_memory + + self.req_ids: List[Optional[str]] = [None] * max_num_reqs + self.req_id_to_index: Dict[str, int] = {} + + self.token_ids_cpu = np.empty((max_num_reqs, max_model_len), + dtype=np.int32) + self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) + + # Attention-related. + self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req), + device=self.device, + dtype=torch.int32) + self.block_table_cpu_tensor = torch.zeros( + (max_num_reqs, max_num_blocks_per_req), + device="cpu", + dtype=torch.int32, + pin_memory=pin_memory, + ) + self.block_table_cpu = self.block_table_cpu_tensor.numpy() + + # Sampling-related. + self.temperature = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.temperature_cpu = self.temperature_cpu_tensor.numpy() + self.greedy_reqs: Set[str] = set() + self.random_reqs: Set[str] = set() + + self.top_p = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.top_p_cpu = self.top_p_cpu_tensor.numpy() + self.top_p_reqs: Set[str] = set() + + self.top_k = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device=device) + self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device="cpu", + pin_memory=pin_memory) + self.top_k_cpu = self.top_k_cpu_tensor.numpy() + self.top_k_reqs: Set[str] = set() + + # req_index -> generator + self.generators: Dict[int, torch.Generator] = {} + + self.num_logprobs: Dict[str, int] = {} + self.prompt_logprob_reqs: Set[str] = set() + + def add_request( + self, + request: "CachedRequestState", + req_index: Optional[int] = None, + ) -> None: + if req_index is None: + req_index = self.num_reqs + assert req_index < self.max_num_reqs + + req_id = request.req_id + self.req_ids[req_index] = req_id + self.req_id_to_index[req_id] = req_index + + # Copy the prompt token ids and output token ids. + num_prompt_tokens = len(request.prompt_token_ids) + self.token_ids_cpu[ + req_index, :num_prompt_tokens] = request.prompt_token_ids + start_idx = num_prompt_tokens + end_idx = start_idx + len(request.output_token_ids) + self.token_ids_cpu[req_index, + start_idx:end_idx] = request.output_token_ids + + self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens + num_blocks = len(request.block_ids) + self.block_table_cpu[req_index, :num_blocks] = request.block_ids + + sampling_params = request.sampling_params + self.temperature_cpu[req_index] = sampling_params.temperature + if sampling_params.sampling_type == SamplingType.GREEDY: + self.greedy_reqs.add(req_id) + else: + self.random_reqs.add(req_id) + + self.top_p_cpu[req_index] = sampling_params.top_p + if sampling_params.top_p < 1: + self.top_p_reqs.add(req_id) + self.top_k_cpu[req_index] = sampling_params.top_k + if sampling_params.top_k > 0: + self.top_k_reqs.add(req_id) + + self.generators[req_index] = request.generator + + num_logprobs = sampling_params.logprobs + if num_logprobs is not None and num_logprobs > 0: + self.num_logprobs[req_id] = num_logprobs + if sampling_params.prompt_logprobs: + self.prompt_logprob_reqs.add(req_id) + + def remove_request(self, req_id: str) -> Optional[int]: + req_index = self.req_id_to_index.pop(req_id, None) + if req_index is None: + return None + self.req_ids[req_index] = None + + self.greedy_reqs.discard(req_id) + self.random_reqs.discard(req_id) + self.top_p_reqs.discard(req_id) + self.top_k_reqs.discard(req_id) + self.generators.pop(req_index, None) + self.num_logprobs.pop(req_id, None) + self.prompt_logprob_reqs.discard(req_id) + return req_index + + def clear(self) -> None: + self.req_ids = [None] * self.max_num_reqs + self.req_id_to_index.clear() + self.greedy_reqs.clear() + self.random_reqs.clear() + self.top_p_reqs.clear() + self.top_k_reqs.clear() + self.generators.clear() + self.num_logprobs.clear() + self.prompt_logprob_reqs.clear() + + def condense(self, empty_req_indices: List[int]) -> None: + if self.num_reqs == 0: + # The batched states are empty. + return + + # NOTE(woosuk): This function assumes that the empty_req_indices + # is sorted in descending order. + last_req_index = self.num_reqs + len(empty_req_indices) - 1 + while empty_req_indices: + # Find the largest non-empty index. + while last_req_index in empty_req_indices: + last_req_index -= 1 + + # Find the smallest empty index. + empty_index = empty_req_indices.pop() + if empty_index >= last_req_index: + break + + # Swap the states. + req_id = self.req_ids[last_req_index] + self.req_ids[empty_index] = req_id + self.req_ids[last_req_index] = None + self.req_id_to_index[req_id] = empty_index + + # TODO(woosuk): Optimize the copy of token_ids_cpu and + # block_table_cpu. + self.token_ids_cpu[empty_index] = self.token_ids_cpu[ + last_req_index] + self.num_computed_tokens_cpu[ + empty_index] = self.num_computed_tokens_cpu[last_req_index] + self.block_table_cpu[empty_index] = self.block_table_cpu[ + last_req_index] + self.temperature_cpu[empty_index] = self.temperature_cpu[ + last_req_index] + self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] + self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] + generator = self.generators.pop(last_req_index, None) + if generator is not None: + self.generators[empty_index] = generator + + # Decrement last_req_index since it is now empty. + last_req_index -= 1 + + def make_sampling_metadata( + self, + skip_copy: bool = False, + ) -> SamplingMetadata: + if not skip_copy: + self.temperature[:self.num_reqs].copy_( + self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) + self.top_p[:self.num_reqs].copy_( + self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) + self.top_k[:self.num_reqs].copy_( + self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) + return SamplingMetadata( + temperature=self.temperature[:self.num_reqs], + all_greedy=self.all_greedy, + all_random=self.all_random, + top_p=self.top_p[:self.num_reqs], + top_k=self.top_k[:self.num_reqs], + no_top_p=self.no_top_p, + no_top_k=self.no_top_k, + generators=self.generators, + max_num_logprobs=self.max_num_logprobs, + ) + + @property + def num_reqs(self) -> int: + return len(self.req_id_to_index) + + @property + def all_greedy(self) -> bool: + return len(self.random_reqs) == 0 + + @property + def all_random(self) -> bool: + return len(self.greedy_reqs) == 0 + + @property + def no_top_p(self) -> bool: + return len(self.top_p_reqs) == 0 + + @property + def no_top_k(self) -> bool: + return len(self.top_k_reqs) == 0 + + @property + def max_num_logprobs(self) -> int: + return max(self.num_logprobs.values()) if self.num_logprobs else 0 + + @property + def no_logprob(self) -> bool: + return len(self.num_logprobs) == 0 + + @property + def no_prompt_logprob(self) -> bool: + return len(self.prompt_logprob_reqs) == 0 diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e8d964a722f60..7f95be06188e3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,7 +1,6 @@ import gc import time -from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple import numpy as np import torch @@ -15,16 +14,16 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.multimodal import MultiModalKwargs -from vllm.sampling_params import SamplingParams, SamplingType +from vllm.sampling_params import SamplingType from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv, is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, FlashAttentionMetadata) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch if TYPE_CHECKING: - from vllm.multimodal.inputs import PlaceholderRange from vllm.v1.core.scheduler import SchedulerOutput logger = init_logger(__name__) @@ -609,269 +608,3 @@ def _get_padded_batch_size(self, batch_size: int) -> Optional[int]: if batch_size <= size: return size return None - - -@dataclass -class CachedRequestState: - - req_id: str - prompt_token_ids: List[int] - prompt: Optional[str] - mm_inputs: List[MultiModalKwargs] - mm_positions: List["PlaceholderRange"] - sampling_params: SamplingParams - generator: Optional[torch.Generator] - - block_ids: List[int] - num_computed_tokens: int - output_token_ids: List[int] - - @property - def num_tokens(self) -> int: - return len(self.prompt_token_ids) + len(self.output_token_ids) - - -class InputBatch: - - def __init__( - self, - max_num_reqs: int, - max_model_len: int, - max_num_blocks_per_req: int, - device: torch.device, - pin_memory: bool, - ): - self.max_num_reqs = max_num_reqs - self.max_model_len = max_model_len - self.max_num_blocks_per_req = max_num_blocks_per_req - self.device = device - self.pin_memory = pin_memory - - self.req_ids: List[Optional[str]] = [None] * max_num_reqs - self.req_id_to_index: Dict[str, int] = {} - - self.token_ids_cpu = np.empty((max_num_reqs, max_model_len), - dtype=np.int32) - self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) - - # Attention-related. - self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req), - device=self.device, - dtype=torch.int32) - self.block_table_cpu_tensor = torch.zeros( - (max_num_reqs, max_num_blocks_per_req), - device="cpu", - dtype=torch.int32, - pin_memory=pin_memory, - ) - self.block_table_cpu = self.block_table_cpu_tensor.numpy() - - # Sampling-related. - self.temperature = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) - self.temperature_cpu = self.temperature_cpu_tensor.numpy() - self.greedy_reqs: Set[str] = set() - self.random_reqs: Set[str] = set() - - self.top_p = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) - self.top_p_cpu = self.top_p_cpu_tensor.numpy() - self.top_p_reqs: Set[str] = set() - - self.top_k = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device=device) - self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device="cpu", - pin_memory=pin_memory) - self.top_k_cpu = self.top_k_cpu_tensor.numpy() - self.top_k_reqs: Set[str] = set() - - # req_index -> generator - self.generators: Dict[int, torch.Generator] = {} - - self.num_logprobs: Dict[str, int] = {} - self.prompt_logprob_reqs: Set[str] = set() - - def add_request( - self, - request: "CachedRequestState", - req_index: Optional[int] = None, - ) -> None: - if req_index is None: - req_index = self.num_reqs - assert req_index < self.max_num_reqs - - req_id = request.req_id - self.req_ids[req_index] = req_id - self.req_id_to_index[req_id] = req_index - - # Copy the prompt token ids and output token ids. - num_prompt_tokens = len(request.prompt_token_ids) - self.token_ids_cpu[ - req_index, :num_prompt_tokens] = request.prompt_token_ids - start_idx = num_prompt_tokens - end_idx = start_idx + len(request.output_token_ids) - self.token_ids_cpu[req_index, - start_idx:end_idx] = request.output_token_ids - - self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens - num_blocks = len(request.block_ids) - self.block_table_cpu[req_index, :num_blocks] = request.block_ids - - sampling_params = request.sampling_params - self.temperature_cpu[req_index] = sampling_params.temperature - if sampling_params.sampling_type == SamplingType.GREEDY: - self.greedy_reqs.add(req_id) - else: - self.random_reqs.add(req_id) - - self.top_p_cpu[req_index] = sampling_params.top_p - if sampling_params.top_p < 1: - self.top_p_reqs.add(req_id) - self.top_k_cpu[req_index] = sampling_params.top_k - if sampling_params.top_k > 0: - self.top_k_reqs.add(req_id) - - self.generators[req_index] = request.generator - - num_logprobs = sampling_params.logprobs - if num_logprobs is not None and num_logprobs > 0: - self.num_logprobs[req_id] = num_logprobs - if sampling_params.prompt_logprobs: - self.prompt_logprob_reqs.add(req_id) - - def remove_request(self, req_id: str) -> Optional[int]: - req_index = self.req_id_to_index.pop(req_id, None) - if req_index is None: - return None - self.req_ids[req_index] = None - - self.greedy_reqs.discard(req_id) - self.random_reqs.discard(req_id) - self.top_p_reqs.discard(req_id) - self.top_k_reqs.discard(req_id) - self.generators.pop(req_index, None) - self.num_logprobs.pop(req_id, None) - self.prompt_logprob_reqs.discard(req_id) - return req_index - - def clear(self) -> None: - self.req_ids = [None] * self.max_num_reqs - self.req_id_to_index.clear() - self.greedy_reqs.clear() - self.random_reqs.clear() - self.top_p_reqs.clear() - self.top_k_reqs.clear() - self.generators.clear() - self.num_logprobs.clear() - self.prompt_logprob_reqs.clear() - - def condense(self, empty_req_indices: List[int]) -> None: - if self.num_reqs == 0: - # The batched states are empty. - return - - # NOTE(woosuk): This function assumes that the empty_req_indices - # is sorted in descending order. - last_req_index = self.num_reqs + len(empty_req_indices) - 1 - while empty_req_indices: - # Find the largest non-empty index. - while last_req_index in empty_req_indices: - last_req_index -= 1 - - # Find the smallest empty index. - empty_index = empty_req_indices.pop() - if empty_index >= last_req_index: - break - - # Swap the states. - req_id = self.req_ids[last_req_index] - self.req_ids[empty_index] = req_id - self.req_ids[last_req_index] = None - self.req_id_to_index[req_id] = empty_index - - # TODO(woosuk): Optimize the copy of token_ids_cpu and - # block_table_cpu. - self.token_ids_cpu[empty_index] = self.token_ids_cpu[ - last_req_index] - self.num_computed_tokens_cpu[ - empty_index] = self.num_computed_tokens_cpu[last_req_index] - self.block_table_cpu[empty_index] = self.block_table_cpu[ - last_req_index] - self.temperature_cpu[empty_index] = self.temperature_cpu[ - last_req_index] - self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] - self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] - generator = self.generators.pop(last_req_index, None) - if generator is not None: - self.generators[empty_index] = generator - - # Decrement last_req_index since it is now empty. - last_req_index -= 1 - - def make_sampling_metadata( - self, - skip_copy: bool = False, - ) -> SamplingMetadata: - if not skip_copy: - self.temperature[:self.num_reqs].copy_( - self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) - self.top_p[:self.num_reqs].copy_( - self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) - self.top_k[:self.num_reqs].copy_( - self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) - return SamplingMetadata( - temperature=self.temperature[:self.num_reqs], - all_greedy=self.all_greedy, - all_random=self.all_random, - top_p=self.top_p[:self.num_reqs], - top_k=self.top_k[:self.num_reqs], - no_top_p=self.no_top_p, - no_top_k=self.no_top_k, - generators=self.generators, - max_num_logprobs=self.max_num_logprobs, - ) - - @property - def num_reqs(self) -> int: - return len(self.req_id_to_index) - - @property - def all_greedy(self) -> bool: - return len(self.random_reqs) == 0 - - @property - def all_random(self) -> bool: - return len(self.greedy_reqs) == 0 - - @property - def no_top_p(self) -> bool: - return len(self.top_p_reqs) == 0 - - @property - def no_top_k(self) -> bool: - return len(self.top_k_reqs) == 0 - - @property - def max_num_logprobs(self) -> int: - return max(self.num_logprobs.values()) if self.num_logprobs else 0 - - @property - def no_logprob(self) -> bool: - return len(self.num_logprobs) == 0 - - @property - def no_prompt_logprob(self) -> bool: - return len(self.prompt_logprob_reqs) == 0 From edc4fa31888b4a41060acb7b16250540f051ad59 Mon Sep 17 00:00:00 2001 From: "Kevin H. Luu" Date: Mon, 9 Dec 2024 11:46:58 -0800 Subject: [PATCH 07/18] [ci/build] Recompile CI dependencies list with Python 3.12 (#11013) Signed-off-by: kevin --- requirements-test.txt | 25 ++----------------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/requirements-test.txt b/requirements-test.txt index 19369254dbe26..38a064bca449a 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,8 +1,8 @@ # -# This file is autogenerated by pip-compile with Python 3.9 +# This file is autogenerated by pip-compile with Python 3.12 # by the following command: # -# pip-compile requirements-test.in +# python3.12 -m piptools compile requirements-test.in -o requirements-test.txt # absl-py==2.1.0 # via rouge-score @@ -27,10 +27,6 @@ anyio==4.6.2.post1 # via httpx argcomplete==3.5.1 # via datamodel-code-generator -async-timeout==4.0.3 - # via - # aiohttp - # redis attrs==24.2.0 # via # aiohttp @@ -111,10 +107,6 @@ email-validator==2.2.0 # via pydantic evaluate==0.4.3 # via lm-eval -exceptiongroup==1.2.2 - # via - # anyio - # pytest fastrlock==0.8.2 # via cupy-cuda12x filelock==3.16.1 @@ -165,8 +157,6 @@ idna==3.10 # httpx # requests # yarl -importlib-resources==6.4.5 - # via matplotlib inflect==5.6.2 # via datamodel-code-generator iniconfig==2.0.0 @@ -518,12 +508,6 @@ timm==1.0.11 # via -r requirements-test.in tokenizers==0.20.3 # via transformers -toml==0.10.2 - # via datamodel-code-generator -tomli==2.0.2 - # via - # black - # pytest torch==2.5.1 # via # -r requirements-test.in @@ -567,12 +551,9 @@ typepy[datetime]==1.3.2 # tabledata typing-extensions==4.12.2 # via - # anyio - # black # huggingface-hub # librosa # mistral-common - # multidict # pydantic # pydantic-core # torch @@ -590,8 +571,6 @@ xxhash==3.5.0 # evaluate yarl==1.17.1 # via aiohttp -zipp==3.20.2 - # via importlib-resources zstandard==0.23.0 # via lm-eval From 3b61cb450d899dc423feb264c297d4d18d701678 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 9 Dec 2024 12:38:46 -0800 Subject: [PATCH 08/18] [V1] Further reduce CPU overheads in flash-attn (#10989) Signed-off-by: Woosuk Kwon --- csrc/cache_kernels.cu | 14 ++++++++++++-- vllm/v1/attention/backends/flash_attn.py | 21 ++++++++++++++++----- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 1be806bbfa43c..8a95279f9a25a 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -307,10 +307,20 @@ void reshape_and_cache_flash( torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size] torch::Tensor& value_cache, // [num_blocks, block_size, num_heads, head_size] - torch::Tensor& slot_mapping, // [num_tokens] + torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] const std::string& kv_cache_dtype, const double k_scale, const double v_scale) { - int num_tokens = key.size(0); + // NOTE(woosuk): In vLLM V1, key.size(0) can be different from + // slot_mapping.size(0) because of padding for CUDA graphs. + // In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because + // both include padding. + // In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0) + // since key includes padding for CUDA graphs, while slot_mapping does not. + // In this case, slot_mapping.size(0) represents the actual number of tokens + // before padding. + // For compatibility with both cases, we use slot_mapping.size(0) as the + // number of tokens. + int num_tokens = slot_mapping.size(0); int num_heads = key.size(1); int head_size = key.size(2); int block_size = key_cache.size(1); diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index d37989055c2e5..251a103e60f06 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -138,14 +138,25 @@ def forward( # Profiling run. return output - num_actual_tokens = attn_metadata.num_actual_tokens + # IMPORTANT! + # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in + # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead + # in this method. For example, `view` and `slice` (or `[:n]`) operations + # are surprisingly slow even in the case they do not invoke any GPU ops. + # Minimize the PyTorch ops in this method as much as possible. + # Whenever making a change in this method, please benchmark the + # performance to make sure it does not introduce any overhead. + num_actual_tokens = attn_metadata.num_actual_tokens # Reshape the input keys and values and store them in the cache. - key_cache = kv_cache[0] - value_cache = kv_cache[1] + # NOTE(woosuk): Here, key and value are padded while slot_mapping is + # not padded. However, we don't need to do key[:num_actual_tokens] and + # value[:num_actual_tokens] because the reshape_and_cache_flash op uses + # the slot_mapping's shape to determine the number of actual tokens. + key_cache, value_cache = kv_cache.unbind(0) torch.ops._C_cache_ops.reshape_and_cache_flash( - key[:num_actual_tokens], - value[:num_actual_tokens], + key, + value, key_cache, value_cache, attn_metadata.slot_mapping, From ca871491edb0fba11fe9aa94300bd8d282fa29e1 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 10 Dec 2024 04:54:44 +0800 Subject: [PATCH 09/18] [Misc][LoRA] Abstract PunicaWrapper (#10955) Signed-off-by: Jee Jee Li --- tests/lora/test_layers.py | 49 +- vllm/lora/layers.py | 7 +- vllm/lora/models.py | 8 +- vllm/lora/punica.py | 725 -------------------- vllm/lora/punica_wrapper/__init__.py | 7 + vllm/lora/punica_wrapper/punica_base.py | 480 +++++++++++++ vllm/lora/punica_wrapper/punica_gpu.py | 358 ++++++++++ vllm/lora/punica_wrapper/punica_selector.py | 14 + vllm/lora/punica_wrapper/utils.py | 159 +++++ 9 files changed, 1058 insertions(+), 749 deletions(-) delete mode 100644 vllm/lora/punica.py create mode 100644 vllm/lora/punica_wrapper/__init__.py create mode 100644 vllm/lora/punica_wrapper/punica_base.py create mode 100644 vllm/lora/punica_wrapper/punica_gpu.py create mode 100644 vllm/lora/punica_wrapper/punica_selector.py create mode 100644 vllm/lora/punica_wrapper/utils.py diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index a113e3f7abc1e..fb8c0b2a7ba26 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -28,7 +28,7 @@ # yapf: enable from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights, PackedLoRALayerWeights) -from vllm.lora.punica import PunicaWrapper +from vllm.lora.punica_wrapper import get_punica_wrapper from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, @@ -48,11 +48,12 @@ torch.float32: (5e-3, 5e-3), torch.bfloat16: (3e-2, 2e-2), } -CUDA_DEVICES = [ +# TODO: Modify this based on platform +DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] -# We will launch different triton kernels between the prefill and decode +#For GPU, we will launch different triton kernels between the prefill and decode # stages, so we need to verify this. prefill stage(True) or decode stage(False) STAGES = [True, False] @@ -192,9 +193,18 @@ def create_random_inputs( return inputs, index_mapping, prompt_mapping +def check_punica_wrapper(punica_wrapper) -> bool: + if current_platform.is_cuda_alike(): + from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU + + return type(punica_wrapper) is PunicaWrapperGPU + else: + return False + + @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("stage", STAGES) def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: @@ -205,7 +215,8 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: torch.set_default_device(device) max_loras = 8 - punica_wrapper = PunicaWrapper(8192, 256, device) + punica_wrapper = get_punica_wrapper(8192, 256, device) + assert check_punica_wrapper(punica_wrapper) lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16) @@ -296,7 +307,7 @@ def create_random_embedding_layer(): # @pytest.mark.skip( # reason="Fails when loras are in any slot other than the first.") @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("stage", STAGES) def test_embeddings_with_new_embeddings(dist_init, num_loras, device, @@ -305,7 +316,8 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, torch.cuda.set_device(device) torch.set_default_device(device) max_loras = 8 - punica_wrapper = PunicaWrapper(8192, 256, device) + punica_wrapper = get_punica_wrapper(8192, 256, device) + assert check_punica_wrapper(punica_wrapper) lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16) @@ -432,7 +444,7 @@ def create_random_embedding_layer(): @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512]) @pytest.mark.parametrize("stage", STAGES) def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, @@ -441,7 +453,8 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, torch.cuda.set_device(device) torch.set_default_device(device) max_loras = 8 - punica_wrapper = PunicaWrapper(8192, 256, device) + punica_wrapper = get_punica_wrapper(8192, 256, device) + assert check_punica_wrapper(punica_wrapper) lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16) @@ -563,7 +576,7 @@ def _pretest(): @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) @pytest.mark.parametrize("bias_enabled", [True, False]) def test_linear_replicated(dist_init, num_loras, device, stage, @@ -571,7 +584,8 @@ def test_linear_replicated(dist_init, num_loras, device, stage, torch.cuda.set_device(device) torch.set_default_device(device) - punica_wrapper = PunicaWrapper(8192, 256, device) + punica_wrapper = get_punica_wrapper(8192, 256, device) + assert check_punica_wrapper(punica_wrapper) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, @@ -675,7 +689,7 @@ def create_random_linear_replicated_layer(): @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("orientation", ["row", "column"]) @pytest.mark.parametrize("fully_shard", [True, False]) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) @pytest.mark.parametrize("bias_enabled", [True, False]) def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, @@ -683,7 +697,8 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, torch.cuda.set_device(device) torch.set_default_device(device) - punica_wrapper = PunicaWrapper(8192, 256, device) + punica_wrapper = get_punica_wrapper(8192, 256, device) + assert check_punica_wrapper(punica_wrapper) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, @@ -797,7 +812,7 @@ def create_random_linear_parallel_layer(): @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("repeats", [1, 2, 3]) @pytest.mark.parametrize("fully_shard", [True, False]) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) @pytest.mark.parametrize("bias_enabled", [True, False]) def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, @@ -805,7 +820,8 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, torch.cuda.set_device(device) torch.set_default_device(device) - punica_wrapper = PunicaWrapper(8192, 256, device) + punica_wrapper = get_punica_wrapper(8192, 256, device) + assert check_punica_wrapper(punica_wrapper) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, @@ -963,7 +979,8 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, seed = 0 current_platform.seed_everything(seed) torch.set_default_device(device) - punica_wrapper = PunicaWrapper(8192, 256, device) + punica_wrapper = get_punica_wrapper(8192, 256, device) + assert check_punica_wrapper(punica_wrapper) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 3e9c2ceb83eac..38cb846578d5c 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -17,7 +17,6 @@ tensor_model_parallel_all_reduce, tensor_model_parallel_gather) from vllm.distributed.utils import divide -from vllm.lora.punica import PunicaWrapper # yapf: disable from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, @@ -33,7 +32,7 @@ VocabParallelEmbedding) if TYPE_CHECKING: - pass + from vllm.lora.punica_wrapper import PunicaWrapperBase def _get_lora_device(base_layer: nn.Module) -> torch.device: @@ -115,9 +114,9 @@ def set_lora( def set_mapping( self, - punica_wrapper: PunicaWrapper, + punica_wrapper, ): - self.punica_wrapper: PunicaWrapper = punica_wrapper + self.punica_wrapper: PunicaWrapperBase = punica_wrapper @classmethod def can_replace_layer( diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 9855b57d0c9c9..49cd9f0c236ad 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -21,7 +21,7 @@ LinearScalingRotaryEmbeddingWithLora, LoRAMapping) from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights -from vllm.lora.punica import PunicaWrapper +from vllm.lora.punica_wrapper import get_punica_wrapper from vllm.lora.utils import (from_layer, from_layer_logits_processor, is_regex_target_modules, parse_fine_tuned_lora_name, replace_submodule) @@ -331,9 +331,9 @@ def __init__( self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots self.vocab_size = vocab_size self.long_lora_context: Optional[LongContextLoRAContext] = None - self.punica_wrapper = PunicaWrapper(max_num_batched_tokens, - max_batches=self.max_num_seqs, - device=self.device) + self.punica_wrapper = get_punica_wrapper(max_num_batched_tokens, + max_batches=self.max_num_seqs, + device=self.device) # Scaling factor -> offset to the sin_cos_cache to it. # Used for long context lora. self.scaling_factor_to_offset: Dict[float, int] = {} diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py deleted file mode 100644 index 563d1181d6fcb..0000000000000 --- a/vllm/lora/punica.py +++ /dev/null @@ -1,725 +0,0 @@ -""" -Based on: -Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). -Punica: Multi-Tenant LoRA Serving. -https://arxiv.org/abs/2310.18547 -""" - -from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union - -import torch - -from vllm.triton_utils import HAS_TRITON - -if HAS_TRITON: - from vllm.lora.ops.bgmv_expand import bgmv_expand - from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice - from vllm.lora.ops.bgmv_shrink import bgmv_shrink - from vllm.lora.ops.sgmv_expand import sgmv_expand - from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice - from vllm.lora.ops.sgmv_shrink import sgmv_shrink - -if TYPE_CHECKING: - # avoid circuit import - from vllm.lora.layers import LoRAMapping - from vllm.lora.models import LongContextLoRAContext - - -def compute_meta( - token_lora_tensor: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]: - """ - Get the information required for the sgmv kernel. With the features: - 1. If consecutive requests in the batch use the same LoRA, this function - will combine them into a single request, improving sgmv kernel inference - performance. - 2. At the beginning of each prefill stage inference, recalculations are - needed based on the input, but only once. - """ - - lora_indices_tensor, seq_length_tensor = torch.unique_consecutive( - token_lora_tensor, return_counts=True) - cum_result = torch.cumsum(seq_length_tensor, dim=0) - b_seq_start_tensor = torch.zeros_like(seq_length_tensor) - b_seq_start_tensor[1:].copy_(cum_result[:-1]) - max_length = seq_length_tensor.max().item() - token_nums = seq_length_tensor.sum().item() - batch_size = lora_indices_tensor.size(0) - no_lora = False - # -1 means no lora should be applied. Use `no_lora` to determine whether - # the current step requires LoRA. If LoRA is not needed, the prefill stage - # does not need to launch the triton kernel, which can improve performance - if batch_size == 1 and lora_indices_tensor == -1: - no_lora = True - return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, - batch_size, max_length, token_nums, no_lora) - - -# TODO see if this can be vectorized -def convert_mapping( - mapping: "LoRAMapping", - lora_index_to_id: List[Optional[int]], - max_loras: int, - vocab_size: int, - extra_vocab_size: int, - device: torch.device, - long_lora_context: Optional["LongContextLoRAContext"] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, - Optional[torch.Tensor], List[int]]: - """Converts LoRAMapping to index tensors. - - Args: - mapping: LoRAMapping mapping rows in a batch to LoRA ids. - lora_index_to_id: List mapping LoRA ids to LoRA indices. - max_loras: Maximum number of LoRAs. - vocab_size: Model vocab size. - extra_vocab_size: Extra vocab size each LoRA can have. - long_lora_context: Passed if there are long context lora in a batch. - - Returns: - A tuple of tensors: - base_indices: Tensor of shape [batch_size] mapping batch rows to - LoRA indices. - sampler_indices: Tensor of shape [batch_size] mapping requests to - LoRA indices for sampler. For generation, this will be the - same as base_indicies. For prefill, this will map requests - to LoRA indices. - sampler_indices_padded: Tensor of shape [batch_size] mapping - requests to LoRA indices for sampler with padding. - Same as sampler_indicies, but -1 is replaced with - max_loras. - embeddings_indices: Tensor of shape [2, batch_size] mapping - requests to embedding indices. First row is for embeddings - added by the LoRAs, second row is for the LoRA.lora_a - embeddings. - long_lora_indices: Tensor of shape [batch_size] mapping - requests to RoPE offsets and rot dims for long LoRAs. - None if long context lora doesn't exist. - indices_len: List of lengths of the above tensors. It contains - (base_indices, sampler_indices, sampler_indices_padded, - embeddings_indices, long_lora_indices). - """ - index_mapping_indices: List[int] = list(mapping.index_mapping).copy() - embedding_indices = index_mapping_indices.copy() - lora_indices = index_mapping_indices.copy() - long_lora_offsets: Optional[torch.Tensor] = None - if long_lora_context: - long_lora_offsets = torch.zeros(len(index_mapping_indices), - device=device, - dtype=torch.long) - prompt_mapping: List[int] = [ - lora_index_to_id.index(x) if x > 0 else -1 - for x in mapping.prompt_mapping - ] - lora_idx = None - for i in range(len(index_mapping_indices)): - # TODO index can be slow. optimize - lora_idx = (lora_index_to_id.index(index_mapping_indices[i]) - if index_mapping_indices[i] > 0 else -1) - embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 - lora_indices[i] = lora_idx - if long_lora_context: - assert long_lora_offsets is not None - lora_offset: int = long_lora_context.offsets_by_lora_id.get( - index_mapping_indices[i], 0) - long_lora_offsets[i] = lora_offset - - indices_list: List[Union[List[int], torch.Tensor]] = [ - index_mapping_indices, - lora_indices, - embedding_indices, - ] - if long_lora_context: - assert long_lora_offsets is not None - indices_list.append(long_lora_offsets) - indices = torch.tensor(indices_list, dtype=torch.long, device=device) - prompt_mapping_tensor = torch.tensor(prompt_mapping, - dtype=torch.long, - device=device) - embeddings_indices = torch.stack([ - indices[2] * extra_vocab_size, - indices[2] * (vocab_size + extra_vocab_size), - ]) - embeddings_indices[embeddings_indices == -1] = max_loras - 1 - base_indices = indices[1] - sampler_indices = prompt_mapping_tensor - sampler_indices_padded = sampler_indices.clone() - sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 - sampler_indices_padded = torch.arange( - 0, len(sampler_indices_padded), device=device, dtype=torch.long) + ( - sampler_indices_padded * len(sampler_indices_padded)) - long_lora_indices = None - long_lora_indices_len: Optional[int] = None - if long_lora_context: - long_lora_indices = indices[3] - long_lora_indices_len = long_lora_indices.shape[-1] - # Contain length of indices tensors. Used to index into each tensor. - indices_len = [ - base_indices.shape[-1], - sampler_indices.shape[-1], - sampler_indices_padded.shape[-1], - embeddings_indices.shape[-1], - ] - if long_lora_indices_len is not None: - indices_len.append(long_lora_indices_len) - else: - # If long_lora doesn't exist,append None - indices_len.append(None) - - return ( - base_indices, - sampler_indices, - sampler_indices_padded, - embeddings_indices, - long_lora_indices, - indices_len, - ) - - -class PunicaWrapper: - """ - PunicaWrapper is designed to manage and provide metadata for the punica - kernel. The main function is to maintain the state information for - Multi-LoRA, and to provide the interface for the punica kernel. - """ - - def __init__(self, max_num_batched_tokens: int, max_batches: int, - device: Union[torch.device, str]): - self._token_lora_indices = torch.empty(max_num_batched_tokens, - dtype=torch.long, - device=device) - self._sampler_indices = torch.empty(max_num_batched_tokens, - dtype=torch.long, - device=device) - self._sampler_indices_padded = torch.empty(max_num_batched_tokens, - dtype=torch.long, - device=device) - self._embeddings_indices = torch.empty(2, - max_num_batched_tokens, - dtype=torch.long, - device=device) - self._long_lora_indices = torch.empty(max_num_batched_tokens, - dtype=torch.long, - device=device) - - # 5 is the number of indicies tensors. - # base_indices, sampler_indices, sampler_indices_padded, - # embeddings_indices,long_lora_indices - self.indices_len: List[Optional[int]] = [None] * 5 - # these attributes are the information required for sgmv kernel - self._seq_start_locs = torch.empty(max_batches, - dtype=torch.long, - device=device) - self._seq_lengths = torch.empty(max_batches, - dtype=torch.long, - device=device) - self._lora_indices_per_batch = torch.empty(max_batches, - dtype=torch.long, - device=device) - self.device: torch.device = device - self.max_length: int = 0 - self.token_nums: int = 0 - self.batch_size: int = -1 - self.is_prefill = False - self.no_lora = False - - def update_metadata( - self, - mapping: "LoRAMapping", - lora_index_to_id: List[Optional[int]], - max_loras: int, - vocab_size: int, - extra_vocab_size: int, - long_lora_context: Optional["LongContextLoRAContext"] = None, - ): - - self._update_base_metadata(mapping, lora_index_to_id, max_loras, - vocab_size, extra_vocab_size, - long_lora_context) - if mapping.is_prefill: - # Update metadata required for prefill-related operators. - self._update_prefill_metada(self.token_lora_indices) - self.is_prefill = True - else: - self.is_prefill = False - - def _update_base_metadata( - self, - mapping: "LoRAMapping", - lora_index_to_id: List[Optional[int]], - max_loras: int, - vocab_size: int, - extra_vocab_size: int, - long_lora_context: Optional["LongContextLoRAContext"] = None, - ): - ( - base_indices, - sampler_indices, - sampler_indices_padded, - embeddings_indices, - long_lora_offsets_tensor, - indices_len, - ) = convert_mapping( - mapping, - lora_index_to_id, - max_loras, - vocab_size, - extra_vocab_size, - self.device, - long_lora_context, - ) - self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) - self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) - self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( - sampler_indices_padded) - self._embeddings_indices[:embeddings_indices. - shape[0], :embeddings_indices.shape[1]].copy_( - embeddings_indices) - if long_lora_offsets_tensor is not None: - self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( - long_lora_offsets_tensor) - else: - self._long_lora_indices.zero_() - self.indices_len[:] = indices_len - - def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: - - (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, - batch_size, max_length, token_nums, - no_lora) = compute_meta(token_lora_tensor) - - self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_( - b_seq_start_tensor) - self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor) - self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_( - lora_indices_tensor) - self.batch_size = batch_size - self.max_length = max_length - self.token_nums = token_nums - self.no_lora = no_lora - - @property - def prefill_metadata( - self - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]: - """ - This property provides a convenient way to access the necessary - metadata for prefill-related kernel computations. - 1. seq_start_locs: Tensor of sequence start positions. - 2. seq_lengths: Tensor of sequence lengths. - 3. lora_indices_per_batch: Tensor of lora indices, and an index of - -1 means no lora should be applied. - 4. batch_size: Batch size after clustering identical lora indices. - 5. max_length: The maximum sequence length in the batch. - 6. token_nums: The token numbers in the batch. - """ - return (self._seq_start_locs[:self.batch_size], - self._seq_lengths[:self.batch_size], - self._lora_indices_per_batch[:self.batch_size], - self.batch_size, self.max_length, self.token_nums) - - @property - def token_lora_indices(self) -> torch.Tensor: - """ - This property provides the lora indices corresponding to each token - in the batch. An index of -1 means no lora should be applied. - """ - token_lora_len = self.indices_len[0] - return self._token_lora_indices[:token_lora_len] - - @property - def sampler_indices(self) -> torch.Tensor: - """ - This property is used to access the lora indices specifically for - LogitsProcessorWithLoRA. - """ - sampler_indices_len = self.indices_len[1] - return self._sampler_indices[:sampler_indices_len] - - @property - def sampler_indices_padded(self) -> torch.Tensor: - """ - This property provides access to padded sampler indices. - """ - indices_padded_len = self.indices_len[2] - return self._sampler_indices_padded[:indices_padded_len] - - @property - def embeddings_indices(self) -> torch.Tensor: - """ - This property provides access to the indices used for lora embeddings, - specifically for VocabParallelEmbeddingWithLoRA. - """ - embeddings_indices_len = self.indices_len[3] - return self._embeddings_indices[:, :embeddings_indices_len] - - @property - def long_lora_indices(self) -> torch.Tensor: - """ - This property provides access to the indices used for long context - lora, specifically for LinearScalingRotaryEmbeddingWithLora. - """ - long_lora_len = self.indices_len[4] - return self._long_lora_indices[:long_lora_len] - - def _shrink_prefill( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - scale: float, - ): - #No LoRA request, so return directly - if self.no_lora: - return - sgmv_shrink( - x, - w_t_all, - y, - *self.prefill_metadata, - scale, - ) - - def _shrink_decode( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - scale: float, - ): - bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) - - def _expand_prefill( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - add_input: bool, - ): - #No LoRA request, so return directly - if self.no_lora: - return - sgmv_expand( - x, - w_t_all, - y, - *self.prefill_metadata, - add_input, - ) - - def _expand_decode( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - add_input: bool, - ): - bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input) - - def _expand_slice_prefill( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - y_offset: Optional[int], - y_slice_size: Optional[int], - add_input: bool, - ): - #No LoRA request, so return directly - if self.no_lora: - return - sgmv_expand_slice( - x, - w_t_all, - y, - *self.prefill_metadata, - y_offset, - y_slice_size, - add_input, - ) - - def _expand_slice_decode( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - y_offset: Optional[int], - y_slice_size: Optional[int], - add_input: bool, - ): - bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, - y_slice_size, add_input) - - def _apply_expand(self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - y_offset: Optional[int], - y_slice_size: Optional[int], - add_input: bool = True): - """ - Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all` - computation, which is suitable for the - GEMM of lora'b. - """ - - expand_slice_fun: Callable = (self._expand_slice_prefill - if self.is_prefill else - self._expand_slice_decode) - expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input) - - def _apply_bias( - self, - indices: torch.Tensor, - output: torch.Tensor, - output_slices: Tuple[int, ...], - lora_bias_stacked: Tuple[Optional[torch.Tensor], ...], - ): - """Applies bias to output - - Input shapes: - lora_bias_stacked: 3 element tuple of (num_loras, output_dim) - indices: (batch_size) - output: (batch_size, q_slice_size + 2*kv_slice_size) - output_slices: n-1 element tuple of (slice_size...), - where n is number of slices - """ - org_output = output - output = output.view(-1, output.shape[-1]) - indices = indices.view(-1) - - offset_left = 0 - for slice_idx, slice in enumerate(output_slices): - bias = lora_bias_stacked[slice_idx] - if bias is not None: - bias = bias.view(-1, bias.shape[-1]) - bias = bias[indices] - bias[indices == -1] = 0 - output[:, offset_left:offset_left + slice] += bias - offset_left += slice - - return output.view_as(org_output) - - def _apply_shrink( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - scale: float, - ): - """ - Perform the ` y+=x@w_t_all` computation, which is suitable for the - GEMM of lora'a. - When `is_prefill is` true, it indicates that it is currently the - prefill stage, and the `_shrink_prefill` function should be called. - Otherwise, it is the decode stage, and the _shrink_decode function - should be called. - """ - y_org = y - y = y.view(-1, y.shape[-1]) - shrink_fun: Callable = (self._shrink_prefill - if self.is_prefill else self._shrink_decode) - shrink_fun(y, x, w_t_all, scale) - y = y.view_as(y_org) - - def add_shrink( - self, - y: Union[Tuple[torch.Tensor, ...], torch.Tensor], - x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, ...], - scale: float, - ): - """ - Performs GEMM for multiple slices of lora_a. - When `is_prefill is` true, it indicates that it is currently the - prefill stage, and the `_shrink_prefill` function should be called. - Otherwise, it is the decode stage, and the _shrink_decode function - should be called. - - Semantics: - for i in range(len(lora_a_stacked)): - y[i] += (x @ lora_a_stacked[i]) * scale - - Args: - y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors - x (torch.Tensor): Input tensor - lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights - scale (float): Scaling factor for the operation - """ - - x = x.view(-1, x.shape[-1]) - # TODO fuse these kernels - for slice_idx in range(len(lora_a_stacked)): - self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], - scale) - - def add_expand( - self, - y: torch.Tensor, - x: Union[Tuple[torch.Tensor, ...], torch.Tensor], - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], - output_slices: Tuple[int, ...], - offset_start: int = 0, - add_input=True, - ) -> None: - """ - Performs GEMM and bias addition for multiple slices of lora_b. - - Semantics: - for i in range(len(lora_b_stacked)): - slice = output_slices[i] - y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + - lora_bias_stacked[i] - offset += slice - - Args: - y (torch.Tensor): Output tensor. - x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors - lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight - lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): - bias's weight - output_slices (Tuple[int, ...]): Every slice's size - add_input (bool): Defaults to True. - """ - y_org = y - y = y.view(-1, y.shape[-1]) - offset_left = offset_start - if lora_bias_stacked is not None: - self._apply_bias(self.token_lora_indices, y, output_slices, - lora_bias_stacked) - for slice_idx in range(len(lora_b_stacked)): - self._apply_expand( - y, - x[slice_idx], - lora_b_stacked[slice_idx], - offset_left, - output_slices[slice_idx], - add_input=add_input, - ) - offset_left += output_slices[slice_idx] - y = y.view_as(y_org) - - def add_lora_embedding( - self, - y: torch.Tensor, - x: torch.Tensor, - lora_b_stacked: torch.Tensor, - add_input: bool = True, - ): - """ - Applies lora specifically for VocabParallelEmbeddingWithLoRA. - - Semantics: - y += x @ lora_b_stacked - - Args: - y (torch.Tensor): Output tensor. - x (torch.Tensor): Input tensor. - lora_b_stacked (torch.Tensor): lora_b's weights. - add_input (bool): Default to True. - - """ - - # Embedding layer only need expand op - expand_fun: Callable = (self._expand_prefill - if self.is_prefill else self._expand_decode) - expand_fun(y, x, lora_b_stacked, add_input) - - def add_lora_linear( - self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, ...], - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], - scale: float, - output_slices: Tuple[int, ...], - *, - buffer: Optional[Tuple[torch.Tensor, ...]] = None) -> None: - """ - Applicable to linear-related lora. - - Semantics: - for i in range(len(lora_a_stacked)): - y[i] += ( - x[i].unsqueeze(0) - @ lora_a_stacked[indices[i], layer_idx, :, :] - @ lora_b_stacked[indices[i], layer_idx, :, :] - * scale - ).squeeze(0)+lora_bias_stacked[i] - - Args: - y (torch.Tensor): Output tensor. Will be changed in-place. - x (torch.Tensor): Input tensor - lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. - lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. - lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. - scale (float): Scaling factor. - output_slices (Tuple[int, ...]): Every slice's size. - buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. - """ - - assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) - if lora_bias_stacked is not None: - assert len(lora_bias_stacked) == len(output_slices) - y = self._apply_bias(self.token_lora_indices, y, output_slices, - lora_bias_stacked) - - if buffer is None: - r = lora_b_stacked[0].size(-1) - # We set the buffer to be float32 by default ,refer to: - # https://github.com/triton-lang/triton/issues/1387 - buffer = tuple( - torch.zeros( - (x.size(0), r), dtype=torch.float32, device=x.device) - for _ in range(len(output_slices))) - self.add_shrink(buffer, x, lora_a_stacked, scale) - self.add_expand(y, - buffer, - lora_b_stacked, - None, - output_slices, - add_input=True) - - def add_lora_logits(self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: torch.Tensor, - lora_b_stacked: torch.Tensor, - scale, - *, - buffer: Optional[torch.Tensor] = None) -> None: - """ - Applies lora specifically for LogitsProcessorWithLoRA. - - Semantics: - buffer = (x @ lora_a_stacked) * scale - y += buffer @ lora_b_stacked - - Args: - y (torch.Tensor): Output tensor. - x (torch.Tensor): Input tensor. - lora_a_stacked (torch.Tensor): lora_a's weights. - lora_b_stacked (torch.Tensor):lora_b's weights. - scale (float): Scaling factor. - buffer (Optional[torch.Tensor]):Default to None. - """ - y_org = y - y = y.view(-1, y.shape[-1]) - x = x.view(-1, x.shape[-1]) - r = lora_b_stacked.size(-1) - if buffer is None: - # We set the buffer to be float32 by default ,refer to: - # https://github.com/triton-lang/triton/issues/1387 - buffer = torch.zeros((x.size(0), r), - dtype=torch.float32, - device=x.device) - # LogitsProcessorWithLoRA always using bgmv. - bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) - bgmv_expand(buffer, - lora_b_stacked, - y, - self.sampler_indices, - add_inputs=True) - y = y.view_as(y_org) diff --git a/vllm/lora/punica_wrapper/__init__.py b/vllm/lora/punica_wrapper/__init__.py new file mode 100644 index 0000000000000..48ada3926ea46 --- /dev/null +++ b/vllm/lora/punica_wrapper/__init__.py @@ -0,0 +1,7 @@ +from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase +from vllm.lora.punica_wrapper.punica_selector import get_punica_wrapper + +__all__ = [ + "PunicaWrapperBase", + "get_punica_wrapper", +] diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py new file mode 100644 index 0000000000000..0a5a84bdd8deb --- /dev/null +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -0,0 +1,480 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import torch + +from .utils import compute_meta, convert_mapping + +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.layers import LoRAMapping + from vllm.lora.models import LongContextLoRAContext + + +class PunicaWrapperABC(ABC): + """ + PunicaWrapper ABC. + """ + + @abstractmethod + def update_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + **kwargs, + ) -> None: + """ + Update the lora-related metadata + """ + raise NotImplementedError + + @abstractmethod + def add_shrink( + self, + y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, + **kwargs, + ) -> None: + """ + Performs GEMM for multiple slices of lora_a. + """ + + raise NotImplementedError + + @abstractmethod + def add_expand( + self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_input=True, + **kwargs, + ) -> None: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + """ + raise NotImplementedError + + @abstractmethod + def add_lora_embedding( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_input: bool = True, + **kwargs, + ) -> None: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA, + and this layer only requires the expand operation. + """ + raise NotImplementedError + + @abstractmethod + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None, + **kwargs) -> None: + """ + Applicable to linear-related lora. + """ + + raise NotImplementedError + + @abstractmethod + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + """ + raise NotImplementedError + + +class PunicaWrapperBase(PunicaWrapperABC): + """ + PunicaWrapperBase is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the punica. + """ + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): + self._token_lora_indices = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + self._sampler_indices = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + self._sampler_indices_padded = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + self._embeddings_indices = torch.empty(2, + max_num_batched_tokens, + dtype=torch.long, + device=device) + self._long_lora_indices = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + + # 5 is the number of indicies tensors. + # base_indices, sampler_indices, sampler_indices_padded, + # embeddings_indices,long_lora_indices + self.indices_len: List[Optional[int]] = [None] * 5 + # these attributes are the information required for sgmv kernel + self._seq_start_locs = torch.empty(max_batches, + dtype=torch.long, + device=device) + self._seq_lengths = torch.empty(max_batches, + dtype=torch.long, + device=device) + self._lora_indices_per_batch = torch.empty(max_batches, + dtype=torch.long, + device=device) + self.device: torch.device = device + self.max_length: int = 0 + self.token_nums: int = 0 + self.batch_size: int = -1 + self.is_prefill = False + self.no_lora = False + + def _update_base_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + ): + ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_offsets_tensor, + indices_len, + ) = convert_mapping( + mapping, + lora_index_to_id, + max_loras, + vocab_size, + extra_vocab_size, + self.device, + long_lora_context, + ) + self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) + self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) + self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( + sampler_indices_padded) + self._embeddings_indices[:embeddings_indices. + shape[0], :embeddings_indices.shape[1]].copy_( + embeddings_indices) + if long_lora_offsets_tensor is not None: + self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( + long_lora_offsets_tensor) + else: + self._long_lora_indices.zero_() + self.indices_len[:] = indices_len + + def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: + + (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, + batch_size, max_length, token_nums, + no_lora) = compute_meta(token_lora_tensor) + + self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_( + b_seq_start_tensor) + self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor) + self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_( + lora_indices_tensor) + self.batch_size = batch_size + self.max_length = max_length + self.token_nums = token_nums + self.no_lora = no_lora + + def _apply_bias( + self, + indices: torch.Tensor, + output: torch.Tensor, + output_slices: Tuple[int, ...], + lora_bias_stacked: Tuple[Optional[torch.Tensor], ...], + ): + """Applies bias to output + + Input shapes: + lora_bias_stacked: 3 element tuple of (num_loras, output_dim) + indices: (batch_size) + output: (batch_size, q_slice_size + 2*kv_slice_size) + output_slices: n-1 element tuple of (slice_size...), + where n is number of slices + """ + org_output = output + output = output.view(-1, output.shape[-1]) + indices = indices.view(-1) + + offset_left = 0 + for slice_idx, slice in enumerate(output_slices): + bias = lora_bias_stacked[slice_idx] + if bias is not None: + bias = bias.view(-1, bias.shape[-1]) + bias = bias[indices] + bias[indices == -1] = 0 + output[:, offset_left:offset_left + slice] += bias + offset_left += slice + + return output.view_as(org_output) + + @property + def prefill_metadata( + self + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]: + """ + This property provides a convenient way to access the necessary + metadata for prefill-related kernel computations. + 1. seq_start_locs: Tensor of sequence start positions. + 2. seq_lengths: Tensor of sequence lengths. + 3. lora_indices_per_batch: Tensor of lora indices, and an index of + -1 means no lora should be applied. + 4. batch_size: Batch size after clustering identical lora indices. + 5. max_length: The maximum sequence length in the batch. + 6. token_nums: The token numbers in the batch. + """ + return (self._seq_start_locs[:self.batch_size], + self._seq_lengths[:self.batch_size], + self._lora_indices_per_batch[:self.batch_size], + self.batch_size, self.max_length, self.token_nums) + + @property + def token_lora_indices(self) -> torch.Tensor: + """ + This property provides the lora indices corresponding to each token + in the batch. An index of -1 means no lora should be applied. + """ + token_lora_len = self.indices_len[0] + return self._token_lora_indices[:token_lora_len] + + @property + def sampler_indices(self) -> torch.Tensor: + """ + This property is used to access the lora indices specifically for + LogitsProcessorWithLoRA. + """ + sampler_indices_len = self.indices_len[1] + return self._sampler_indices[:sampler_indices_len] + + @property + def sampler_indices_padded(self) -> torch.Tensor: + """ + This property provides access to padded sampler indices. + """ + indices_padded_len = self.indices_len[2] + return self._sampler_indices_padded[:indices_padded_len] + + @property + def embeddings_indices(self) -> torch.Tensor: + """ + This property provides access to the indices used for lora embeddings, + specifically for VocabParallelEmbeddingWithLoRA. + """ + embeddings_indices_len = self.indices_len[3] + return self._embeddings_indices[:, :embeddings_indices_len] + + @property + def long_lora_indices(self) -> torch.Tensor: + """ + This property provides access to the indices used for long context + lora, specifically for LinearScalingRotaryEmbeddingWithLora. + """ + long_lora_len = self.indices_len[4] + return self._long_lora_indices[:long_lora_len] + + def update_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + **kwargs): + + self._update_base_metadata(mapping, lora_index_to_id, max_loras, + vocab_size, extra_vocab_size, + long_lora_context) + if mapping.is_prefill: + # Update metadata required for prefill-related operators. + self._update_prefill_metada(self.token_lora_indices) + self.is_prefill = True + else: + self.is_prefill = False + + @abstractmethod + def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, **kwargs) -> None: + """ + Performs GEMM for multiple slices of lora_a. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += (x @ lora_a_stacked[i]) * scale + + Args: + y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + scale (float): Scaling factor for the operation + + """ + # TODO: implement it based on torch ops + raise NotImplementedError + + @abstractmethod + def add_expand(self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_input=True, + **kwargs) -> None: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + + Semantics: + for i in range(len(lora_b_stacked)): + slice = output_slices[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] + offset += slice + + Args: + y (torch.Tensor): Output tensor. + x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + bias's weight + output_slices (Tuple[int, ...]): Every slice's size + add_input (bool): Defaults to True. + + """ + # TODO: implement it based on torch ops + raise NotImplementedError + + @abstractmethod + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_input: bool = True, + **kwargs) -> None: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA. + and this layer only requires the expand operation. + Semantics: + y += x @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_b_stacked (torch.Tensor): lora_b's weights. + add_input (bool): Default to True. + """ + # TODO: implement it based on torch ops + raise NotImplementedError + + @abstractmethod + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None, + **kwargs) -> None: + """ + Applicable to linear-related lora. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += ( + x[i].unsqueeze(0) + @ lora_a_stacked[indices[i], layer_idx, :, :] + @ lora_b_stacked[indices[i], layer_idx, :, :] + * scale + ).squeeze(0)+lora_bias_stacked[i] + + Args: + y (torch.Tensor): Output tensor. Will be changed in-place. + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. + scale (float): Scaling factor. + output_slices (Tuple[int, ...]): Every slice's size. + buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. + """ + # TODO: implement it based on torch ops + raise NotImplementedError + + @abstractmethod + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + + Semantics: + buffer = (x @ lora_a_stacked) * scale + y += buffer @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_a_stacked (torch.Tensor): lora_a's weights. + lora_b_stacked (torch.Tensor):lora_b's weights. + scale (float): Scaling factor. + buffer (Optional[torch.Tensor]):Default to None. + """ + # TODO: implement it based on torch ops + raise NotImplementedError diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py new file mode 100644 index 0000000000000..b2af29de129ce --- /dev/null +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -0,0 +1,358 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from typing import Callable, Optional, Tuple, Union, final + +import torch + +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + from vllm.lora.ops.bgmv_expand import bgmv_expand + from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice + from vllm.lora.ops.bgmv_shrink import bgmv_shrink + from vllm.lora.ops.sgmv_expand import sgmv_expand + from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice + from vllm.lora.ops.sgmv_shrink import sgmv_shrink + +from .punica_base import PunicaWrapperBase + + +@final +class PunicaWrapperGPU(PunicaWrapperBase): + """ + PunicaWrapperGPU is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the punica triton kernel. + """ + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, + device) + + def _shrink_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_shrink( + x, + w_t_all, + y, + *self.prefill_metadata, + scale, + ) + + def _shrink_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) + + def _expand_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_input: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand( + x, + w_t_all, + y, + *self.prefill_metadata, + add_input, + ) + + def _expand_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_input: bool, + ): + bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input) + + def _expand_slice_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_input: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand_slice( + x, + w_t_all, + y, + *self.prefill_metadata, + y_offset, + y_slice_size, + add_input, + ) + + def _expand_slice_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_input: bool, + ): + bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, + y_slice_size, add_input) + + def _apply_expand( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_input: bool = True, + ): + """ + Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all` + computation, which is suitable for the + GEMM of lora'b. + """ + + expand_slice_fun: Callable = (self._expand_slice_prefill + if self.is_prefill else + self._expand_slice_decode) + expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input) + + def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, + w_t_all: torch.Tensor, scale: float): + """ + Perform the ` y+=x@w_t_all` computation, which is suitable for the + GEMM of lora'a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + shrink_fun: Callable = (self._shrink_prefill + if self.is_prefill else self._shrink_decode) + shrink_fun(y, x, w_t_all, scale) + y = y.view_as(y_org) + + def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, **kwargs): + """ + Performs GEMM for multiple slices of lora_a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += (x @ lora_a_stacked[i]) * scale + + Args: + y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + scale (float): Scaling factor for the operation + """ + + x = x.view(-1, x.shape[-1]) + # TODO fuse these kernels + for slice_idx in range(len(lora_a_stacked)): + self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], + scale) + + def add_expand(self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_input=True, + **kwargs) -> None: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + + Semantics: + for i in range(len(lora_b_stacked)): + slice = output_slices[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] + offset += slice + + Args: + y (torch.Tensor): Output tensor. + x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + bias's weight + output_slices (Tuple[int, ...]): Every slice's size + add_input (bool): Defaults to True. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + offset_left = offset_start + if lora_bias_stacked is not None: + self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) + for slice_idx in range(len(lora_b_stacked)): + self._apply_expand( + y, + x[slice_idx], + lora_b_stacked[slice_idx], + offset_left, + output_slices[slice_idx], + add_input=add_input, + ) + offset_left += output_slices[slice_idx] + y = y.view_as(y_org) + + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_input: bool = True, + **kwargs) -> None: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA. + + Semantics: + y += x @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_b_stacked (torch.Tensor): lora_b's weights. + add_input (bool): Default to True. + """ + + # Embedding layer only need expand op + expand_fun: Callable = (self._expand_prefill + if self.is_prefill else self._expand_decode) + expand_fun(y, x, lora_b_stacked, add_input) + + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None, + **kwargs) -> None: + """ + Applicable to linear-related lora. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += ( + x[i].unsqueeze(0) + @ lora_a_stacked[indices[i], layer_idx, :, :] + @ lora_b_stacked[indices[i], layer_idx, :, :] + * scale + ).squeeze(0)+lora_bias_stacked[i] + + Args: + y (torch.Tensor): Output tensor. Will be changed in-place. + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. + scale (float): Scaling factor. + output_slices (Tuple[int, ...]): Every slice's size. + buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. + """ + + assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) + if lora_bias_stacked is not None: + assert len(lora_bias_stacked) == len(output_slices) + y = self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) + + if buffer is None: + r = lora_b_stacked[0].size(-1) + # We set the buffer to be float32 by default ,refer to: + # https://github.com/triton-lang/triton/issues/1387 + buffer = tuple( + torch.zeros( + (x.size(0), r), dtype=torch.float32, device=x.device) + for _ in range(len(output_slices))) + self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) + self.add_expand(y, + buffer, + lora_b_stacked, + None, + output_slices, + add_input=True, + **kwargs) + + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + + Semantics: + buffer = (x @ lora_a_stacked) * scale + y += buffer @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_a_stacked (torch.Tensor): lora_a's weights. + lora_b_stacked (torch.Tensor):lora_b's weights. + scale (float): Scaling factor. + buffer (Optional[torch.Tensor]):Default to None. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + r = lora_b_stacked.size(-1) + if buffer is None: + # We set the buffer to be float32 by default ,refer to: + # https://github.com/triton-lang/triton/issues/1387 + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + # LogitsProcessorWithLoRA always using bgmv. + bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) + bgmv_expand(buffer, + lora_b_stacked, + y, + self.sampler_indices, + add_inputs=True) + y = y.view_as(y_org) diff --git a/vllm/lora/punica_wrapper/punica_selector.py b/vllm/lora/punica_wrapper/punica_selector.py new file mode 100644 index 0000000000000..df6c1bdc7dd71 --- /dev/null +++ b/vllm/lora/punica_wrapper/punica_selector.py @@ -0,0 +1,14 @@ +from vllm.platforms import current_platform +from vllm.utils import print_info_once + +from .punica_base import PunicaWrapperBase + + +def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase: + if current_platform.is_cuda_alike(): + # Lazy import to avoid ImportError + from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU + print_info_once("Using PunicaWrapperGPU.") + return PunicaWrapperGPU(*args, **kwargs) + else: + raise NotImplementedError diff --git a/vllm/lora/punica_wrapper/utils.py b/vllm/lora/punica_wrapper/utils.py new file mode 100644 index 0000000000000..7360c8c09e3ac --- /dev/null +++ b/vllm/lora/punica_wrapper/utils.py @@ -0,0 +1,159 @@ +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import torch + +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.layers import LoRAMapping + from vllm.lora.models import LongContextLoRAContext + + +def compute_meta( + token_lora_tensor: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]: + """ + Get the information required for the sgmv kernel. With the features: + 1. If consecutive requests in the batch use the same LoRA, this function + will combine them into a single request, improving sgmv kernel inference + performance. + 2. At the beginning of each prefill stage inference, recalculations are + needed based on the input, but only once. + """ + + lora_indices_tensor, seq_length_tensor = torch.unique_consecutive( + token_lora_tensor, return_counts=True) + cum_result = torch.cumsum(seq_length_tensor, dim=0) + b_seq_start_tensor = torch.zeros_like(seq_length_tensor) + b_seq_start_tensor[1:].copy_(cum_result[:-1]) + max_length = seq_length_tensor.max().item() + token_nums = seq_length_tensor.sum().item() + batch_size = lora_indices_tensor.size(0) + no_lora = False + # -1 means no lora should be applied. Use `no_lora` to determine whether + # the current step requires LoRA. If LoRA is not needed, the prefill stage + # does not need to launch the triton kernel, which can improve performance + if batch_size == 1 and lora_indices_tensor == -1: + no_lora = True + return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, + batch_size, max_length, token_nums, no_lora) + + +# TODO see if this can be vectorized +def convert_mapping( + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + device: torch.device, + long_lora_context: Optional["LongContextLoRAContext"] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor], List[int]]: + """Converts LoRAMapping to index tensors. + + Args: + mapping: LoRAMapping mapping rows in a batch to LoRA ids. + lora_index_to_id: List mapping LoRA ids to LoRA indices. + max_loras: Maximum number of LoRAs. + vocab_size: Model vocab size. + extra_vocab_size: Extra vocab size each LoRA can have. + long_lora_context: Passed if there are long context lora in a batch. + + Returns: + A tuple of tensors: + base_indices: Tensor of shape [batch_size] mapping batch rows to + LoRA indices. + sampler_indices: Tensor of shape [batch_size] mapping requests to + LoRA indices for sampler. For generation, this will be the + same as base_indicies. For prefill, this will map requests + to LoRA indices. + sampler_indices_padded: Tensor of shape [batch_size] mapping + requests to LoRA indices for sampler with padding. + Same as sampler_indicies, but -1 is replaced with + max_loras. + embeddings_indices: Tensor of shape [2, batch_size] mapping + requests to embedding indices. First row is for embeddings + added by the LoRAs, second row is for the LoRA.lora_a + embeddings. + long_lora_indices: Tensor of shape [batch_size] mapping + requests to RoPE offsets and rot dims for long LoRAs. + None if long context lora doesn't exist. + indices_len: List of lengths of the above tensors. It contains + (base_indices, sampler_indices, sampler_indices_padded, + embeddings_indices, long_lora_indices). + """ + index_mapping_indices: List[int] = list(mapping.index_mapping).copy() + embedding_indices = index_mapping_indices.copy() + lora_indices = index_mapping_indices.copy() + long_lora_offsets: Optional[torch.Tensor] = None + if long_lora_context: + long_lora_offsets = torch.zeros(len(index_mapping_indices), + device=device, + dtype=torch.long) + prompt_mapping: List[int] = [ + lora_index_to_id.index(x) if x > 0 else -1 + for x in mapping.prompt_mapping + ] + lora_idx = None + for i in range(len(index_mapping_indices)): + # TODO index can be slow. optimize + lora_idx = (lora_index_to_id.index(index_mapping_indices[i]) + if index_mapping_indices[i] > 0 else -1) + embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 + lora_indices[i] = lora_idx + if long_lora_context: + assert long_lora_offsets is not None + lora_offset: int = long_lora_context.offsets_by_lora_id.get( + index_mapping_indices[i], 0) + long_lora_offsets[i] = lora_offset + + indices_list: List[Union[List[int], torch.Tensor]] = [ + index_mapping_indices, + lora_indices, + embedding_indices, + ] + if long_lora_context: + assert long_lora_offsets is not None + indices_list.append(long_lora_offsets) + indices = torch.tensor(indices_list, dtype=torch.long, device=device) + prompt_mapping_tensor = torch.tensor(prompt_mapping, + dtype=torch.long, + device=device) + embeddings_indices = torch.stack([ + indices[2] * extra_vocab_size, + indices[2] * (vocab_size + extra_vocab_size), + ]) + embeddings_indices[embeddings_indices == -1] = max_loras - 1 + base_indices = indices[1] + sampler_indices = prompt_mapping_tensor + sampler_indices_padded = sampler_indices.clone() + sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 + sampler_indices_padded = torch.arange( + 0, len(sampler_indices_padded), device=device, dtype=torch.long) + ( + sampler_indices_padded * len(sampler_indices_padded)) + long_lora_indices = None + long_lora_indices_len: Optional[int] = None + if long_lora_context: + long_lora_indices = indices[3] + long_lora_indices_len = long_lora_indices.shape[-1] + # Contain length of indices tensors. Used to index into each tensor. + indices_len = [ + base_indices.shape[-1], + sampler_indices.shape[-1], + sampler_indices_padded.shape[-1], + embeddings_indices.shape[-1], + ] + if long_lora_indices_len is not None: + indices_len.append(long_lora_indices_len) + else: + # If long_lora doesn't exist,append None + indices_len.append(None) + + return ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_indices, + indices_len, + ) From a811dd660856a5c222a1447fe1d93deccbc162fd Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 10 Dec 2024 04:55:10 +0800 Subject: [PATCH 10/18] [Model] merged input processor for Phi-3-Vision models (#10977) Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Cyrus Leung --- tests/entrypoints/openai/test_vision.py | 4 +- .../openai/test_vision_embedding.py | 4 +- .../mm_processor_kwargs/test_phi3v.py | 136 ++------ tests/multimodal/test_processor_kwargs.py | 169 +++++----- vllm/inputs/registry.py | 4 +- vllm/model_executor/models/phi3v.py | 298 +++++------------- vllm/multimodal/processing.py | 29 +- 7 files changed, 235 insertions(+), 409 deletions(-) diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index 157d873a75b4d..a0b6edd566561 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -89,7 +89,7 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI, choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=772, total_tokens=782) + completion_tokens=10, prompt_tokens=775, total_tokens=785) message = choice.message message = chat_completion.choices[0].message @@ -181,7 +181,7 @@ async def test_single_chat_session_image_base64encoded( choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=772, total_tokens=782) + completion_tokens=10, prompt_tokens=775, total_tokens=785) message = choice.message message = chat_completion.choices[0].message diff --git a/tests/entrypoints/openai/test_vision_embedding.py b/tests/entrypoints/openai/test_vision_embedding.py index d0c43b47bf0af..425f2a10ec855 100644 --- a/tests/entrypoints/openai/test_vision_embedding.py +++ b/tests/entrypoints/openai/test_vision_embedding.py @@ -95,5 +95,5 @@ async def test_image_embedding(server: RemoteOpenAIServer, model_name: str, assert len(embeddings["data"]) == 1 assert len(embeddings["data"][0]["embedding"]) == 3072 assert embeddings["usage"]["completion_tokens"] == 0 - assert embeddings["usage"]["prompt_tokens"] == 762 - assert embeddings["usage"]["total_tokens"] == 762 + assert embeddings["usage"]["prompt_tokens"] == 765 + assert embeddings["usage"]["total_tokens"] == 765 diff --git a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py index 60a8f63eb5faa..c16192a1e1438 100644 --- a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py +++ b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py @@ -2,12 +2,10 @@ from typing import Optional import pytest -import torch -from transformers import AutoImageProcessor, AutoTokenizer +from transformers import AutoTokenizer -from vllm.inputs import InputContext, token_inputs +from vllm.inputs import InputContext, InputProcessingContext from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID -from vllm.multimodal import MultiModalRegistry from .....conftest import _ImageAssets from ....utils import build_model_context @@ -17,15 +15,9 @@ # Wrap lazy imports to avoid initializing CUDA during test collection @pytest.fixture() -def input_processor_for_phi3v(): - from vllm.model_executor.models.phi3v import input_processor_for_phi3v - return input_processor_for_phi3v - - -@pytest.fixture() -def dummy_data_for_phi3v(): - from vllm.model_executor.models.phi3v import dummy_data_for_phi3v - return dummy_data_for_phi3v +def processor_for_phi3v(): + from vllm.model_executor.models.phi3v import Phi3VProcessor + return Phi3VProcessor @pytest.fixture() @@ -34,53 +26,6 @@ def get_max_phi3v_image_tokens(): return get_max_phi3v_image_tokens -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("num_crops", [4, 16, None]) -def test_input_mapper_override(model: str, image_assets: _ImageAssets, - num_crops: Optional[int]): - """Ensure that the [default] input mapper handles num_crops properly.""" - # We pass the processor kwargs here since for this model, we fall back to - # the default mapper; this will fall back to the HF mapper and forward - # mm_processor_kwargs to it. - mm_processor_kwargs = { - "num_crops": num_crops - } if num_crops is not None else {} - ctx = build_model_context( - model_name=model, - tokenizer_name=model, - trust_remote_code=True, - mm_processor_kwargs=mm_processor_kwargs, - ) - - hf_processor = AutoImageProcessor.from_pretrained(model, - trust_remote_code=True, - **mm_processor_kwargs) - - mm_registry = MultiModalRegistry() - mm_registry.init_mm_limits_per_prompt(ctx.model_config) - - image = image_assets[0].pil_image - hf_result = hf_processor.preprocess( - image, - return_tensors="pt", - ) - - vllm_result = mm_registry.map_input( - ctx.model_config, - {"image": image}, - ) - - assert torch.all(hf_result["image_sizes"] == vllm_result["image_sizes"]) - assert torch.all( - hf_result["num_img_tokens"] == vllm_result["num_img_tokens"]) - - # For pixel values, the second axis should be the num_crops + 1 - # for the rescaled original image. The default value in VLLM falls - # back to the HF config, which is why we compare to the processor num_crops - assert torch.all(hf_result["pixel_values"] == vllm_result["pixel_values"]) - assert vllm_result["pixel_values"].shape[1] == hf_processor.num_crops + 1 - - @pytest.mark.parametrize("model", models) @pytest.mark.parametrize("num_crops,expected_max_tokens", [ (4, 781), @@ -112,48 +57,20 @@ def test_max_tokens_override(get_max_phi3v_image_tokens, model: str, @pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("num_crops,toks_per_img,num_imgs", [ - (4, 781, 1), - (4, 781, 2), - (16, 2653, 1), - (16, 2653, 2), -]) -def test_dummy_data_override(dummy_data_for_phi3v, model: str, num_crops: int, - toks_per_img: int, num_imgs: int): - """Ensure dummy_data_for_phi3v handles num_crops properly.""" - # Same as the previous test - don't initialize mm_processor_kwargs - # in this test and assume that the kwargs will be correctly expanded by - # the partial when calling the dummy data func. - ctx = build_model_context( - model_name=model, - tokenizer_name=model, - trust_remote_code=True, - mm_processor_kwargs=None, - ) - - dummy_data = dummy_data_for_phi3v( - ctx=ctx, - seq_len=8192, # Should be bigger than num_imgs * toks_per_img - mm_counts={"image": num_imgs}, - num_crops=num_crops, - ) - sequence_data = dummy_data.seq_data - # Ensure we have the right number of placeholders per num_crops size - img_tok_count = sequence_data.get_token_ids().count(_IMAGE_TOKEN_ID) - assert img_tok_count == toks_per_img * num_imgs - - -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("num_crops,expected_toks_per_img,num_imgs", [ - (4, 757, 1), - (4, 757, 2), - (16, 1921, 1), - (16, 1921, 2), -]) -def test_input_processor_override(input_processor_for_phi3v, - image_assets: _ImageAssets, model: str, - num_crops: int, expected_toks_per_img: int, - num_imgs: int): +@pytest.mark.parametrize( + "num_crops,expected_toks_per_img,num_imgs", + [ + (4, 757, 1), + (4, 757, 2), + (16, 1921, 1), + (16, 1921, 2), + # the default num_crops of phi-3.5-vision is 4 + (None, 757, 2), + (None, 757, 2), + ]) +def test_processor_override(processor_for_phi3v, image_assets: _ImageAssets, + model: str, num_crops: Optional[int], + expected_toks_per_img: int, num_imgs: int): """Ensure input_processor_for_phi3v handles num_crops properly.""" # Same as the previous test - don't initialize mm_processor_kwargs # in this test and assume that the kwargs will be correctly expanded by @@ -163,19 +80,20 @@ def test_input_processor_override(input_processor_for_phi3v, tokenizer_name=model, trust_remote_code=True, ) - tokenizer = AutoTokenizer.from_pretrained(model) + tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) + ctx = InputProcessingContext(ctx.model_config, tokenizer) # Build the image str / prompt based on the number of images we pass img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)]) prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n" images = [image_assets[0].pil_image] * num_imgs - inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt), - prompt=prompt, - multi_modal_data={"image": images}) + mm_data = {"image": images} + mm_processor_kwargs = {} + if num_crops is not None: + mm_processor_kwargs = {"num_crops": num_crops} - processed_inputs = input_processor_for_phi3v(ctx, - inputs, - num_crops=num_crops) + processor = processor_for_phi3v(ctx) + processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) # Ensure we have the right number of placeholders per num_crops size img_tok_count = processed_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID) diff --git a/tests/multimodal/test_processor_kwargs.py b/tests/multimodal/test_processor_kwargs.py index e6c8793989e13..d141cdf1f083b 100644 --- a/tests/multimodal/test_processor_kwargs.py +++ b/tests/multimodal/test_processor_kwargs.py @@ -15,13 +15,13 @@ # Used for fast tests where the model doesn't matter DUMMY_MODEL_ID = "facebook/opt-125m" # Used for tests that need a multimodal model -MULTIMODAL_MODEL_ID = "microsoft/Phi-3.5-vision-instruct" +MULTIMODAL_MODEL_ID = "OpenGVLab/InternVL2-2B" # For mm_processor_kwargs - we test overrides by defining mocks for each place # it is used, and ensuring that we can pass processor kwargs an override value # to receive the intended result for things like sequence length etc. -DEFAULT_NUM_CROPS = 4 -NUM_CROPS_OVERRIDE = 16 +DEFAULT_MAX_DYNAMIC_PATCH = 6 +MAX_DYNAMIC_PATCH_OVERRIDE = 4 # Mocks for all of the places that we use the mm_processor_kwargs @@ -33,10 +33,11 @@ def use_processor_mock(): def custom_processor(ctx: InputContext, inputs: DecoderOnlyInputs, *, - num_crops=DEFAULT_NUM_CROPS): + max_dynamic_patch=DEFAULT_MAX_DYNAMIC_PATCH): # For testing purposes, we don't worry about the prompt - return token_inputs(prompt_token_ids=[], - mm_processor_kwargs={"num_crops": num_crops}) + return token_inputs( + prompt_token_ids=[], + mm_processor_kwargs={"max_dynamic_patch": max_dynamic_patch}) with patch("vllm.inputs.registry.InputRegistry._get_model_input_processor", return_value=custom_processor): @@ -52,9 +53,9 @@ def custom_dummy_data_factory(self, seq_len: int, mm_counts: Mapping[str, int], *, - num_crops=DEFAULT_NUM_CROPS): + max_dynamic_patch=DEFAULT_MAX_DYNAMIC_PATCH): seq_data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * num_crops)) + array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * max_dynamic_patch)) return DummyData(seq_data, None) with patch( @@ -65,15 +66,15 @@ def custom_dummy_data_factory(self, # Lazy import to avoid CUDA reinitialization error def mm_model_cls(): - from vllm.model_executor.models.phi3v import Phi3VForCausalLM + from vllm.model_executor.models.internvl import InternVLChatModel - return Phi3VForCausalLM + return InternVLChatModel # lambda whose signature matches max token calcs extra & mapper + extra kwargs -get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops -custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: { - "pixel_values": torch.zeros(size=(1, num_crops + 1, 3, 336, 336)) +get_max_dynamic_patch = lambda ctx, *, max_dynamic_patch=DEFAULT_MAX_DYNAMIC_PATCH: max_dynamic_patch # noqa: E501 +custom_mapper = lambda ctx, data, *, max_dynamic_patch=DEFAULT_MAX_DYNAMIC_PATCH: { # noqa: E501 + "pixel_values": torch.zeros(size=(1, max_dynamic_patch + 1, 3, 448, 448)) } @@ -88,27 +89,28 @@ def test_default_processor_is_a_noop(): assert proc_inputs is proc_outputs -def _get_num_crops_info(init_num_crops: int, inference_num_crops: int): - """Get the init / inference kwargs and expected num_crops for this test.""" - # If we have a value for num_crops, pass the override value and make +def _get_max_dynamic_patch_info(init_max_dynamic_patch: int, + inference_max_dynamic_patch: int): + """Get the init / inference kwargs and expected max_dynamic_patch.""" + # If we have a value for max_dynamic_patch, pass the override value and make # sure we get that value as a return-value from out mock processor, # otherwise fall back to the default value - init_kwargs = None if init_num_crops is None else { - "num_crops": init_num_crops + init_kwargs = None if init_max_dynamic_patch is None else { + "max_dynamic_patch": init_max_dynamic_patch } - inference_kwargs = None if inference_num_crops is None else { - "num_crops": inference_num_crops + inference_kwargs = None if inference_max_dynamic_patch is None else { + "max_dynamic_patch": inference_max_dynamic_patch } - if inference_num_crops is not None: - expected_seq_count = inference_num_crops - elif init_num_crops is not None: - expected_seq_count = init_num_crops + if inference_max_dynamic_patch is not None: + expected_seq_count = inference_max_dynamic_patch + elif init_max_dynamic_patch is not None: + expected_seq_count = init_max_dynamic_patch else: - expected_seq_count = DEFAULT_NUM_CROPS + expected_seq_count = DEFAULT_MAX_DYNAMIC_PATCH return init_kwargs, inference_kwargs, expected_seq_count -def _get_processed_num_crops( +def _get_processed_max_dynamic_patch( processor: Callable[[ProcessorInputs], ProcessorInputs], inference_kwargs: Optional[Dict[str, int]], ) -> int: @@ -120,27 +122,30 @@ def _get_processed_num_crops( assert "type" in processed_inputs assert processed_inputs["type"] == "token" assert "mm_processor_kwargs" in processed_inputs - return processed_inputs["mm_processor_kwargs"]["num_crops"] + return processed_inputs["mm_processor_kwargs"]["max_dynamic_patch"] -@pytest.mark.parametrize("init_num_crops,inference_num_crops", [ - (None, None), - (NUM_CROPS_OVERRIDE, None), - (DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE), -]) -def test_input_processor_kwargs(use_processor_mock, init_num_crops, - inference_num_crops): +@pytest.mark.parametrize( + "init_max_dynamic_patch,inference_max_dynamic_patch", [ + (None, None), + (MAX_DYNAMIC_PATCH_OVERRIDE, None), + (DEFAULT_MAX_DYNAMIC_PATCH, MAX_DYNAMIC_PATCH_OVERRIDE), + ]) +def test_input_processor_kwargs(use_processor_mock, init_max_dynamic_patch, + inference_max_dynamic_patch): """Ensure input processors can use processor kwargs.""" dummy_registry = InputRegistry() - init_kwargs, inference_kwargs, expected_seq_count = _get_num_crops_info( - init_num_crops, inference_num_crops) + (init_kwargs, inference_kwargs, + expected_seq_count) = _get_max_dynamic_patch_info( + init_max_dynamic_patch, inference_max_dynamic_patch) ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs) processor = dummy_registry.create_input_processor(ctx.model_config) - num_crops_val = _get_processed_num_crops(processor, inference_kwargs) + max_dynamic_patch_val = _get_processed_max_dynamic_patch( + processor, inference_kwargs) - assert num_crops_val == expected_seq_count + assert max_dynamic_patch_val == expected_seq_count @pytest.mark.parametrize( @@ -165,18 +170,21 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock, processor = dummy_registry.create_input_processor(ctx.model_config) # Should filter out the inference time kwargs - num_crops_val = _get_processed_num_crops(processor, mm_processor_kwargs) - assert num_crops_val == DEFAULT_NUM_CROPS + max_dynamic_patch_val = _get_processed_max_dynamic_patch( + processor, mm_processor_kwargs) + assert max_dynamic_patch_val == DEFAULT_MAX_DYNAMIC_PATCH ### Test overrides for the dummy data -@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) -def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops): +@pytest.mark.parametrize("max_dynamic_patch", + [None, MAX_DYNAMIC_PATCH_OVERRIDE]) +def test_dummy_data_kwarg_overrides(use_dummy_data_mock, max_dynamic_patch): """Ensure dummy data factories can use processor kwargs.""" - mm_processor_kwargs = None if num_crops is None else { - "num_crops": num_crops + mm_processor_kwargs = None if max_dynamic_patch is None else { + "max_dynamic_patch": max_dynamic_patch } - expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops + expected_seq_count = (DEFAULT_MAX_DYNAMIC_PATCH + if max_dynamic_patch is None else max_dynamic_patch) dummy_registry = InputRegistry() ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=mm_processor_kwargs) @@ -217,17 +225,20 @@ def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock, # len is solely dependent on the value of the mm_processor_kwargs. dummy_data = dummy_registry.dummy_data_for_profiling( ctx.model_config, seq_len=-1, mm_registry=mm_registry) - assert len(dummy_data.seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS + assert len( + dummy_data.seq_data.prompt_token_ids) == DEFAULT_MAX_DYNAMIC_PATCH ### Test overrides for the max token count per multimodal instance -@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) -def test_max_tokens_kwarg_overrides(num_crops): +@pytest.mark.parametrize("max_dynamic_patch", + [None, MAX_DYNAMIC_PATCH_OVERRIDE]) +def test_max_tokens_kwarg_overrides(max_dynamic_patch): """Ensure max token calcs can use processor kwargs.""" - mm_processor_kwargs = None if num_crops is None else { - "num_crops": num_crops + mm_processor_kwargs = None if max_dynamic_patch is None else { + "max_dynamic_patch": max_dynamic_patch } - expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops + expected_seq_count = (DEFAULT_MAX_DYNAMIC_PATCH + if max_dynamic_patch is None else max_dynamic_patch) ctx = build_model_context(MULTIMODAL_MODEL_ID, task="generate", @@ -239,11 +250,11 @@ def test_max_tokens_kwarg_overrides(num_crops): mm_registry.init_mm_limits_per_prompt(ctx.model_config) # Patch the image registry for phi3v with our lambda that is compatible # with overrides, then ensure that calling the method correctly echos - # our num_crops value back from the mm_processor_kwargs. + # our max_dynamic_patch value back from the mm_processor_kwargs. with patch.object( mm_registry._get_plugin("image"), "_max_mm_tokens", - {mm_model_cls(): get_num_crops}, + {mm_model_cls(): get_max_dynamic_patch}, ): max_multimodal_tokens = mm_registry.get_max_multimodal_tokens( ctx.model_config) @@ -279,26 +290,29 @@ def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs): with patch.object( mm_registry._get_plugin("image"), "_max_mm_tokens", - {mm_model_cls(): get_num_crops}, + {mm_model_cls(): get_max_dynamic_patch}, ): max_multimodal_tokens = mm_registry.get_max_multimodal_tokens( ctx.model_config) - assert max_multimodal_tokens == DEFAULT_NUM_CROPS + assert max_multimodal_tokens == DEFAULT_MAX_DYNAMIC_PATCH ### Test overrides for the mapper -@pytest.mark.parametrize("num_crops", [DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE]) -def test_default_mapper_with_processor_kwargs(image_assets, num_crops): +@pytest.mark.parametrize( + "max_dynamic_patch", + [DEFAULT_MAX_DYNAMIC_PATCH, MAX_DYNAMIC_PATCH_OVERRIDE]) +def test_default_mapper_with_processor_kwargs(image_assets, max_dynamic_patch): """Ensure that the mapper processor kwargs can fall back to HF models.""" # NOTE - we don't validate bad inputs for the default mapper, because it's # through the automodel interface in transformers, so we can't easily # inspect what kwargs are or are not allowed. - ctx = build_model_context(MULTIMODAL_MODEL_ID, - task="generate", - trust_remote_code=True, - mm_processor_kwargs={"num_crops": num_crops}, - limit_mm_per_prompt={"image": 1}) + ctx = build_model_context( + MULTIMODAL_MODEL_ID, + task="generate", + trust_remote_code=True, + mm_processor_kwargs={"max_dynamic_patch": max_dynamic_patch}, + limit_mm_per_prompt={"image": 1}) mm_registry = MultiModalRegistry() mm_registry.init_mm_limits_per_prompt(ctx.model_config) @@ -307,20 +321,22 @@ def test_default_mapper_with_processor_kwargs(image_assets, num_crops): mm_inputs = {"image": image} mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs) - # Phi3v pixel vals should have shape: [batch, num_crops+1, 3, 336, 336] - assert mapped_inputs["pixel_values"].shape[1] == num_crops + 1 + # pixel vals should have shape: [batch, max_dynamic_patch+1, ...] + assert mapped_inputs["pixel_values"].shape[1] == max_dynamic_patch + 1 -@pytest.mark.parametrize("init_num_crops,inference_num_crops", [ - (None, None), - (NUM_CROPS_OVERRIDE, None), - (DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE), -]) -def test_custom_mapper_kwarg_overrides(image_assets, init_num_crops, - inference_num_crops): +@pytest.mark.parametrize( + "init_max_dynamic_patch,inference_max_dynamic_patch", [ + (None, None), + (MAX_DYNAMIC_PATCH_OVERRIDE, None), + (DEFAULT_MAX_DYNAMIC_PATCH, MAX_DYNAMIC_PATCH_OVERRIDE), + ]) +def test_custom_mapper_kwarg_overrides(image_assets, init_max_dynamic_patch, + inference_max_dynamic_patch): """Ensure custom mappers can use processor kwargs.""" - init_kwargs, inference_kwargs, expected_seq_count = _get_num_crops_info( - init_num_crops, inference_num_crops) + (init_kwargs, inference_kwargs, + expected_seq_count) = _get_max_dynamic_patch_info( + init_max_dynamic_patch, inference_max_dynamic_patch) ctx = build_model_context(MULTIMODAL_MODEL_ID, task="generate", @@ -335,7 +351,7 @@ def test_custom_mapper_kwarg_overrides(image_assets, init_num_crops, # Patch the image registry for phi3v with our lambda that is compatible # with overrides, then ensure that calling the method correctly echos - # our num_crops value back from the mm_processor_kwargs. + # our max_dynamic_patch value back from the mm_processor_kwargs. mm_registry._get_plugin("image").register_input_mapper(custom_mapper)( mm_model_cls()) mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs, @@ -373,11 +389,12 @@ def test_custom_mapper_with_sad_kwarg_overrides(image_assets, # Patch the image registry for phi3v with our lambda that is compatible # with overrides, then ensure that calling the method correctly echos - # our num_crops value back from the mm_processor_kwargs. + # our max_dynamic_patch value back from the mm_processor_kwargs. mm_registry._get_plugin("image").register_input_mapper(custom_mapper)( mm_model_cls()) # Should filter out the inference time kwargs mapped_inputs = mm_registry.map_input( ctx.model_config, mm_inputs, mm_processor_kwargs=mm_processor_kwargs) - assert mapped_inputs["pixel_values"].shape[1] == DEFAULT_NUM_CROPS + 1 + assert mapped_inputs["pixel_values"].shape[1] == ( + DEFAULT_MAX_DYNAMIC_PATCH + 1) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 646554c72481a..0dfed3b7e61bf 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -69,12 +69,12 @@ class InputProcessingContext(InputContext): tokenizer: AnyTokenizer """The tokenizer used to tokenize the inputs.""" - def get_hf_processor(self) -> ProcessorMixin: + def get_hf_processor(self, **kwargs) -> ProcessorMixin: return cached_get_processor( self.model_config.tokenizer, tokenizer=self.tokenizer, # Override the tokenizer with ours trust_remote_code=self.model_config.trust_remote_code, - ) + **kwargs) N = TypeVar("N", bound=Type[nn.Module]) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index eef23029a2aca..3c7854ce388ab 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -12,22 +12,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import itertools -import re -from functools import cached_property, lru_cache -from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set, - Tuple, TypedDict, Union) +from functools import cached_property +from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, + TypedDict, Union) -import numpy as np import torch import torch.nn as nn -from PIL import Image -from transformers import CLIPVisionConfig, PretrainedConfig +from transformers import (BatchFeature, CLIPVisionConfig, PretrainedConfig, + ProcessorMixin) from vllm.attention import AttentionMetadata -from vllm.config import ModelConfig, VllmConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - InputContext, token_inputs) +from vllm.config import VllmConfig +from vllm.inputs import InputContext from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler @@ -36,12 +32,18 @@ from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import NestedTensors, PlaceholderRange -from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token +from vllm.multimodal.image import cached_get_image_processor +from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors +from vllm.multimodal.processing import (BaseMultiModalProcessor, + InputProcessingContext, + ModalityProcessingMetadata, + MultiModalDataDict, + MultiModalProcessingMetadata, + PromptReplacement) from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of -from .clip import dummy_image_for_clip, dummy_seq_data_for_clip +from .clip import dummy_image_for_clip from .interfaces import SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, init_vllm_registered_model, maybe_prefix, @@ -303,231 +305,99 @@ def add_image_newline(self, image_features_hd): return image_features_hd_newline -# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57 -def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336): - target_height = int(np.ceil(height / padding_unit) * padding_unit) - top_padding = int((target_height - height) / 2) - bottom_padding = target_height - height - top_padding - padded_width = width - padded_height = height + top_padding + bottom_padding - return padded_width, padded_height - - -# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90 -def _calc_hd_transform_size(*, width: int, height: int, hd_num: int): - transposed = False - if width < height: - width, height = height, width - transposed = True - - ratio = width / height - scale = 1 - while scale * np.ceil(scale / ratio) <= hd_num: - scale += 1 - scale -= 1 - - new_width = int(scale * 336) - new_height = int(new_width / ratio) - - padded_width, padded_height = _calc_padded_size(width=new_width, - height=new_height) - - if transposed: - padded_width, padded_height = padded_height, padded_width - - return padded_width, padded_height - - -# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L181 -def get_phi3v_image_feature_size( - hf_config: Dict[str, Any], - *, - input_height: int, - input_width: int, - num_crops: int, -) -> int: - if num_crops is None: - num_crops = hf_config.get("num_crops", 16) - new_width, new_height = _calc_hd_transform_size(width=input_width, - height=input_height, - hd_num=num_crops) - - return (new_height // 336 * new_width // 336 + 1) * 144 + 1 \ - + (new_height // 336 + 1) * 12 - - def get_max_phi3v_image_tokens(ctx: InputContext, *, num_crops: Optional[int] = None): + mm_processor_kwargs = {} + if num_crops is not None: + mm_processor_kwargs["num_crops"] = num_crops - return get_phi3v_image_feature_size( - ctx.get_hf_image_processor_config(), - input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, - input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, - num_crops=num_crops, + model_config = ctx.model_config + image_processor = cached_get_image_processor( + model_config.model, + trust_remote_code=model_config.trust_remote_code, + **mm_processor_kwargs, + ) + + num_tokens = image_processor.calc_num_image_tokens_from_image_size( + width=MAX_IMAGE_FEATURE_SIZE_WIDTH, + height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, ) + return num_tokens -def dummy_data_for_phi3v(ctx: InputContext, - seq_len: int, - mm_counts: Mapping[str, int], - *, - num_crops: Optional[int] = None): +def dummy_mm_kwargs_for_phi3v(ctx: InputProcessingContext, + mm_counts: Mapping[str, int]): num_images = mm_counts["image"] - image_feature_size = get_max_phi3v_image_tokens(ctx, num_crops=num_crops) - - seq_data, ranges = dummy_seq_data_for_clip( - CLIP_VIT_LARGE_PATCH14_336_CONFIG, - seq_len, - num_images, - image_token_id=_IMAGE_TOKEN_ID, - image_feature_size_override=image_feature_size, - ) - mm_data = dummy_image_for_clip( + data = dummy_image_for_clip( CLIP_VIT_LARGE_PATCH14_336_CONFIG, num_images, image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH, image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, ) - return DummyData(seq_data, mm_data, ranges) - + hf_processor = ctx.get_hf_processor() + image_processor = hf_processor.image_processor # type: ignore + hf_inputs = image_processor.preprocess(data['image'], return_tensors="pt") -@lru_cache -def _get_image_placeholder_token_id_candidates( - model_config: ModelConfig, - idx: int, -) -> List[List[int]]: - assert idx > 0 + return MultiModalKwargs(**hf_inputs) - tokenizer = cached_get_tokenizer(model_config.tokenizer) - # This is used when the image token is at the start of the string - start_candidate = tokenizer.encode(f"<|image_{idx}|>", - add_special_tokens=False) +def create_metadata_for_phi3v( + ctx: InputProcessingContext) -> MultiModalProcessingMetadata: + return { + "image": + ModalityProcessingMetadata(prompt_repls=[ + PromptReplacement(target=[_IMAGE_TOKEN_ID], + repl_unit=[_IMAGE_TOKEN_ID], + repl_count=get_max_phi3v_image_tokens(ctx)), + ]), + } - # This is used when the image token is in the middle of the string - # We need to get the token for "<", not "▁<" - # https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/raw/main/tokenizer.json - a_token_id, = tokenizer.encode("a", add_special_tokens=False) - a_token_id_, *middle_candidate = tokenizer.encode(f"a<|image_{idx}|>", - add_special_tokens=False) - assert a_token_id == a_token_id_ - return [start_candidate, middle_candidate] +class Phi3VProcessor(BaseMultiModalProcessor): + def __init__(self, ctx: InputProcessingContext) -> None: + super().__init__( + ctx=ctx, + metadata=create_metadata_for_phi3v(ctx), + ) -def input_processor_for_phi3v(ctx: InputContext, - inputs: DecoderOnlyInputs, - *, - num_crops: Optional[int] = None): - multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: - return inputs - - model_config = ctx.model_config - hf_config = ctx.get_hf_image_processor_config() - - image_data = multi_modal_data["image"] - if isinstance(image_data, Image.Image): - w, h = image_data.size - image_feature_size = [ - get_phi3v_image_feature_size(hf_config, - input_width=w, - input_height=h, - num_crops=num_crops) - ] - image_data = [image_data] - elif is_list_of(image_data, Image.Image): - image_feature_size = [] - for image in image_data: - w, h = image.size - image_feature_size.append( - get_phi3v_image_feature_size(hf_config, - input_width=w, - input_height=h, - num_crops=num_crops)) - elif isinstance(image_data, torch.Tensor): - image_feature_size = [image_data.shape[0]] - image_data = [image_data] - elif is_list_of(image_data, torch.Tensor): - image_feature_size = [item.shape[0] for item in image_data] - else: - raise TypeError(f"Invalid image type: {type(image_data)}") - - prompt = inputs.get("prompt") - if prompt is None: - # for async server request, we assume prompt and its token_ids is always - # in correct format. And num_image_tags == len(image_data) always True. - image_idx = range(1, len(image_data) + 1) - new_prompt = None - else: - image_idx = sorted(map(int, re.findall(r"<\|image_(\d+)\|>+", prompt))) - if prompt.count("<|image|>") > 0: - logger.warning("Please follow the prompt format that is " - "documented on HuggingFace which does not involve " - "repeating <|image|> tokens.") - elif (num_image_tags := len(image_idx)) > 1: - assert num_image_tags == len( - image_data), "The count of image_placeholder not match image's" - new_prompt = prompt - - prompt_token_ids = inputs["prompt_token_ids"].copy() - - # masked placeholder with image token id - for idx in image_idx: - candidates = _get_image_placeholder_token_id_candidates(model_config, - idx=idx) - - for candidate in candidates: - for i in range(len(prompt_token_ids) - len(candidate) + 1): - if prompt_token_ids[i:i + len(candidate)] == candidate: - prompt_token_ids[i:i + - len(candidate)] = ([_IMAGE_TOKEN_ID] * - len(candidate)) - break - - # merge consecutive tag ids - merged_token_ids: List[int] = [] - for is_placeholder, token_ids in itertools.groupby( - prompt_token_ids, lambda x: x == _IMAGE_TOKEN_ID): - if is_placeholder: - merged_token_ids.append(_IMAGE_TOKEN_ID) - else: - merged_token_ids.extend(list(token_ids)) - - # TODO: Move this to utils or integrate with clip. - new_token_ids: List[int] = [] - placeholder_ranges: List[PlaceholderRange] = [] - placeholder_idx = 0 - while merged_token_ids: - token_id = merged_token_ids.pop(0) - if token_id == _IMAGE_TOKEN_ID: - replacement_ids = repeat_and_pad_token( - _IMAGE_TOKEN_ID, - repeat_count=image_feature_size[placeholder_idx], - ) - placeholder_ranges.append({ - "offset": len(new_token_ids), - "length": len(replacement_ids) - }) - new_token_ids.extend(replacement_ids) - placeholder_idx += 1 - else: - new_token_ids.append(token_id) - - # NOTE: Create a defensive copy of the original inputs - return token_inputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data, - multi_modal_placeholders={"image": placeholder_ranges}) + def _get_hf_processor( + self, + *, + num_crops: Optional[int] = None, + ) -> ProcessorMixin: + if num_crops is not None: + return self.ctx.get_hf_processor(num_crops=num_crops) + return self.ctx.get_hf_processor() + + def _apply_hf_processor( + self, + prompt: str, + mm_data: MultiModalDataDict, + mm_processor_kwargs: Mapping[str, object], + ) -> BatchFeature: + processed_outputs = super()._apply_hf_processor( + prompt, mm_data, mm_processor_kwargs) + # Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids, + # which will cause OverflowError when decoding the prompt_ids. + # Therefore, we need to do an early replacement here + token_ids = processed_outputs['input_ids'] + token_ids[token_ids < 0] = _IMAGE_TOKEN_ID + processed_outputs['input_ids'] = token_ids + return processed_outputs + + def _get_dummy_mm_kwargs( + self, + mm_counts: Mapping[str, int], + ) -> MultiModalKwargs: + return dummy_mm_kwargs_for_phi3v(self.ctx, mm_counts) -@MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v) -@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v) +@MULTIMODAL_REGISTRY.register_processor(Phi3VProcessor) class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index c3a95d60e6fe6..922c83b6fd8a9 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -3,7 +3,8 @@ from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence from dataclasses import dataclass from functools import lru_cache -from typing import Any, Generic, NamedTuple, Optional, Protocol, TypeVar, Union +from typing import (Any, Dict, Generic, NamedTuple, Optional, Protocol, + TypeVar, Union, cast) import torch from transformers import BatchFeature, ProcessorMixin @@ -11,7 +12,8 @@ from vllm.inputs import DummyData, InputProcessingContext from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import flatten_2d_lists, full_groupby, is_list_of +from vllm.utils import (flatten_2d_lists, full_groupby, is_list_of, + resolve_mm_processor_kwargs) from .inputs import (AudioItem, ImageItem, MultiModalDataDict, MultiModalInputsV2, MultiModalKwargs, PlaceholderRange, @@ -543,8 +545,14 @@ def __init__( self.ctx = ctx self.metadata = metadata + self.init_mm_processor_kwargs = (ctx.model_config.mm_processor_kwargs + or {}) - def _get_hf_processor(self) -> ProcessorMixin: + def _get_hf_processor( + self, + **mm_processor_kwargs: Mapping[str, object], + ) -> ProcessorMixin: + # by default, we won't pass any kwargs to the processor initialization return self.ctx.get_hf_processor() def _get_tokenizer(self) -> AnyTokenizer: @@ -581,7 +589,13 @@ def _apply_hf_processor( mm_data: MultiModalDataDict, mm_processor_kwargs: Mapping[str, object], ) -> BatchFeature: - hf_processor = self._get_hf_processor() + # some mm_processor_kwargs may be used in processor initialization + # instead of processor call + processor_init_kwargs = { + **self.init_mm_processor_kwargs, + **mm_processor_kwargs, + } + hf_processor = self._get_hf_processor(**processor_init_kwargs) processor_data = dict[str, Any]() passthrough_data = dict[str, Any]() @@ -601,6 +615,13 @@ def _apply_hf_processor( else: processor_data[k] = v + # filter mm_processor_kwargs used in processor call + mm_processor_kwargs = resolve_mm_processor_kwargs( + self.init_mm_processor_kwargs, + cast(Dict[str, Any], mm_processor_kwargs), + hf_processor, + ) + try: hf_inputs = hf_processor( text=prompt, # type: ignore From cbcbdb1ceb9c219d13b2386e101992c399410551 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Mon, 9 Dec 2024 22:21:06 +0100 Subject: [PATCH 11/18] [Bugfix][Hardware][Gaudi] Bump vllm_hpu_extension version (#11028) Signed-off-by: Konrad Zawora --- requirements-hpu.txt | 2 +- vllm/attention/backends/hpu_attn.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/requirements-hpu.txt b/requirements-hpu.txt index 4674efb812cfd..17d40d0ee131a 100644 --- a/requirements-hpu.txt +++ b/requirements-hpu.txt @@ -8,4 +8,4 @@ pandas tabulate setuptools>=61 setuptools-scm>=8 -vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@fd7f2e6 +vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@e096d6f diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 2c62e565c04c7..f90d15d4207e7 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -111,8 +111,16 @@ def __init__( self.matmul_qk = Matmul() self.softmax = Softmax() self.matmul_av = Matmul() + self.batch2block_matmul = Matmul() + self.block2batch_matmul = Matmul() + # NOTE(kzawora): Contiguous PA is off until model runner supports it self.k_cache = VLLMKVCache() + self.k_cache.use_contiguous_pa = False self.v_cache = VLLMKVCache() + self.v_cache.use_contiguous_pa = False + # NOTE(kzawora): Pipelined PA is off until model runner supports it + ops.pa_impl = ops.pa + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.sliding_window = sliding_window self.alibi_slopes = alibi_slopes @@ -228,9 +236,12 @@ def forward( block_mapping=attn_metadata.block_mapping, block_bias=attn_metadata.attn_bias, block_scales=attn_metadata.block_scales, + block_groups=None, scale=self.scale, matmul_qk_op=self.matmul_qk, matmul_av_op=self.matmul_av, + batch2block_matmul_op=self.batch2block_matmul, + block2batch_matmul_op=self.block2batch_matmul, keys_fetch_func=self.k_cache.fetch_from_cache, values_fetch_func=self.v_cache.fetch_from_cache) # Reshape the output tensor. From 1a2f8fb828f0444705db319786b2e901159f184e Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 9 Dec 2024 13:47:24 -0800 Subject: [PATCH 12/18] [v1] fix use compile sizes (#11000) Signed-off-by: youkaichao --- vllm/config.py | 1 + vllm/v1/worker/gpu_model_runner.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/vllm/config.py b/vllm/config.py index 29f0839dcabba..5fb9563fcf3a3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2522,6 +2522,7 @@ def __post_init__(self): self.compilation_config.custom_ops = ["none"] self.compilation_config.use_cudagraph = True self.compilation_config.use_inductor = True + self.compilation_config.cudagraph_num_of_warmups = 1 self.compilation_config.pass_config.enable_fusion = False self.compilation_config.pass_config.enable_reshape = False self.compilation_config.level = CompilationLevel.PIECEWISE diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7f95be06188e3..c601aca13feaf 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -582,6 +582,9 @@ def capture_model(self) -> None: # can reuse the memory pool allocated for the large shapes. with graph_capture(): for num_tokens in reversed(self.cudagraph_batch_sizes): + for _ in range(self.vllm_config.compilation_config. + cudagraph_num_of_warmups): + self._dummy_run(self.model, num_tokens, self.kv_caches) self._dummy_run(self.model, num_tokens, self.kv_caches) end_time = time.perf_counter() From 9c6459e4cb020ec1ad9ea08cac9309b83d432fc8 Mon Sep 17 00:00:00 2001 From: xendo Date: Mon, 9 Dec 2024 22:53:24 +0100 Subject: [PATCH 13/18] [Neuron] Upgrade neuron to 2.20.2 (#11016) Signed-off-by: Jerzy Zagorski Co-authored-by: Jerzy Zagorski --- Dockerfile.neuron | 3 ++- vllm/utils.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/Dockerfile.neuron b/Dockerfile.neuron index 76dbd4c04d3f3..77162bc82de62 100644 --- a/Dockerfile.neuron +++ b/Dockerfile.neuron @@ -1,5 +1,6 @@ # default base image -ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.1.2-neuronx-py310-sdk2.20.0-ubuntu20.04" +# https://gallery.ecr.aws/neuron/pytorch-inference-neuronx +ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.1.2-neuronx-py310-sdk2.20.2-ubuntu20.04" FROM $BASE_IMAGE diff --git a/vllm/utils.py b/vllm/utils.py index 1f19d9eacd16d..2bb1fb2af40f4 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1628,7 +1628,7 @@ def direct_register_custom_op( library object. If you want to bind the operator to a different library, make sure the library object is alive when the operator is used. """ - if is_in_doc_build(): + if is_in_doc_build() or not supports_custom_op(): return import torch.library if hasattr(torch.library, "infer_schema"): From b63ba848323efd88207b12d7582501d525503b8a Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Mon, 9 Dec 2024 17:00:29 -0500 Subject: [PATCH 14/18] [ROCm][bugfix] scpecilative decoding worker class (#11035) Signed-off-by: Gregory Shtrasberg --- vllm/platforms/rocm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 66674e3ebe91f..0133f26a0b1bc 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -93,6 +93,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: elif vllm_config.speculative_config: parallel_config.worker_cls = \ "vllm.spec_decode.spec_decode_worker.create_spec_worker" + parallel_config.sd_worker_cls = \ + "vllm.worker.worker.Worker" else: parallel_config.worker_cls = "vllm.worker.worker.Worker" From 5ed5d5f128d26a48c1b1db16c319fcb96c93799d Mon Sep 17 00:00:00 2001 From: Richard Liu <39319471+richardsliu@users.noreply.github.com> Date: Mon, 9 Dec 2024 15:07:48 -0800 Subject: [PATCH 15/18] Build tpu image in release pipeline (#10936) Signed-off-by: Richard Liu Co-authored-by: Kevin H. Luu --- .buildkite/release-pipeline.yaml | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 93e118fb3eab8..2de6fceb0c3fe 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -39,3 +39,19 @@ steps: - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.1.0 --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain ." - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT" + + - label: "Build and publish TPU release image" + depends_on: ~ + if: build.env("NIGHTLY") == "1" + agents: + queue: tpu_queue_postmerge + commands: + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --tag vllm/vllm-tpu:nightly --tag vllm/vllm-tpu:$BUILDKITE_COMMIT --progress plain -f Dockerfile.tpu ." + - "docker push vllm/vllm-tpu:nightly" + - "docker push vllm/vllm-tpu:$BUILDKITE_COMMIT" + plugins: + - docker-login#v3.0.0: + username: vllm + password-env: DOCKERHUB_TOKEN + env: + DOCKER_BUILDKIT: "1" From 6faec545057e6152e92e8ab619fc018e20864943 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 9 Dec 2024 15:08:19 -0800 Subject: [PATCH 16/18] [V1] Do not store `None` in self.generators (#11038) Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_input_batch.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 457784bb0287c..25d95ac6e26af 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -102,6 +102,8 @@ def __init__( self.top_k_reqs: Set[str] = set() # req_index -> generator + # NOTE(woosuk): The indices of the requests that do not have their own + # generator should not be included in the dictionary. self.generators: Dict[int, torch.Generator] = {} self.num_logprobs: Dict[str, int] = {} @@ -147,7 +149,10 @@ def add_request( if sampling_params.top_k > 0: self.top_k_reqs.add(req_id) - self.generators[req_index] = request.generator + # NOTE(woosuk): self.generators should not include the requests that + # do not have their own generator. + if request.generator is not None: + self.generators[req_index] = request.generator num_logprobs = sampling_params.logprobs if num_logprobs is not None and num_logprobs > 0: From 6d525288c1a40ee70f9cff2fe08657f23bae88dc Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Mon, 9 Dec 2024 20:15:34 -0500 Subject: [PATCH 17/18] [Docs] Add dedicated tool calling page to docs (#10554) Signed-off-by: mgoin Co-authored-by: Tyler Michael Smith --- docs/source/index.rst | 1 + .../serving/openai_compatible_server.md | 217 ------------- docs/source/usage/tool_calling.md | 287 ++++++++++++++++++ 3 files changed, 288 insertions(+), 217 deletions(-) create mode 100644 docs/source/usage/tool_calling.md diff --git a/docs/source/index.rst b/docs/source/index.rst index 86b1eed2d26ba..c45c941b00e20 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -102,6 +102,7 @@ Documentation usage/lora usage/multimodal_inputs + usage/tool_calling usage/structured_outputs usage/spec_decode usage/compatibility_matrix diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index d75e90807ca1d..f75653106cf66 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -361,220 +361,3 @@ $ vllm serve SOME_MODEL --config config.yaml **NOTE** In case an argument is supplied simultaneously using command line and the config file, the value from the commandline will take precedence. The order of priorities is `command line > config file values > defaults`. - ---- - -## Tool calling in the chat completion API -vLLM currently supports named function calling, as well as the `auto` and `none` options for the `tool_choice` field in the chat completion API. The `tool_choice` option `required` is **not yet supported** but on the roadmap. - -It is the callers responsibility to prompt the model with the tool information, vLLM will not automatically manipulate the prompt. -Please see below for recommended configuration and chat templates to use when function calling is to be used with the different models. - - -### Named Function Calling -vLLM supports named function calling in the chat completion API by default. It does so using Outlines, so this is -enabled by default, and will work with any supported model. You are guaranteed a validly-parsable function call - not a -high-quality one. - -vLLM will use guided decoding to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter. - -To use a named function, you need to define the functions in the `tools` parameter of the chat completion request, and -specify the `name` of one of the tools in the `tool_choice` parameter of the chat completion request. - - -### Automatic Function Calling -To enable this feature, you should set the following flags: -* `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it -deems appropriate. -* `--tool-call-parser` -- select the tool parser to use (listed below). Additional tool parsers -will continue to be added in the future, and also can register your own tool parsers in the `--tool-parser-plugin`. -* `--tool-parser-plugin` -- **optional** tool parser plugin used to register user defined tool parsers into vllm, the registered tool parser name can be specified in `--tool-call-parser`. -* `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages -that contain previously generated tool calls. Hermes, Mistral and Llama models have tool-compatible chat templates in their -`tokenizer_config.json` files, but you can specify a custom template. This argument can be set to `tool_use` if your model has a tool use-specific chat -template configured in the `tokenizer_config.json`. In this case, it will be used per the `transformers` specification. More on this [here](https://huggingface.co/docs/transformers/en/chat_templating#why-do-some-models-have-multiple-templates) -from HuggingFace; and you can find an example of this in a `tokenizer_config.json` [here](https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/blob/main/tokenizer_config.json) - -If your favorite tool-calling model is not supported, please feel free to contribute a parser & tool use chat template! - - -#### Hermes Models (`hermes`) - -All Nous Research Hermes-series models newer than Hermes 2 Pro should be supported. -* `NousResearch/Hermes-2-Pro-*` -* `NousResearch/Hermes-2-Theta-*` -* `NousResearch/Hermes-3-*` - - -_Note that the Hermes 2 **Theta** models are known to have degraded tool call quality & capabilities due to the merge -step in their creation_. - -Flags: `--tool-call-parser hermes` - - -#### Mistral Models (`mistral`) - -Supported models: -* `mistralai/Mistral-7B-Instruct-v0.3` (confirmed) -* Additional mistral function-calling models are compatible as well. - -Known issues: -1. Mistral 7B struggles to generate parallel tool calls correctly. -2. Mistral's `tokenizer_config.json` chat template requires tool call IDs that are exactly 9 digits, which is -much shorter than what vLLM generates. Since an exception is thrown when this condition -is not met, the following additional chat templates are provided: - -* `examples/tool_chat_template_mistral.jinja` - this is the "official" Mistral chat template, but tweaked so that -it works with vLLM's tool call IDs (provided `tool_call_id` fields are truncated to the last 9 digits) -* `examples/tool_chat_template_mistral_parallel.jinja` - this is a "better" version that adds a tool-use system prompt -when tools are provided, that results in much better reliability when working with parallel tool calling. - - -Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja` - - -#### Llama Models (`llama3_json`) - -Supported models: -* `meta-llama/Meta-Llama-3.1-8B-Instruct` -* `meta-llama/Meta-Llama-3.1-70B-Instruct` -* `meta-llama/Meta-Llama-3.1-405B-Instruct` -* `meta-llama/Meta-Llama-3.1-405B-Instruct-FP8` - -The tool calling that is supported is the [JSON based tool calling](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling). For [pythonic tool calling](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#zero-shot-function-calling) in Llama-3.2 models, see the `pythonic` tool parser below. -Other tool calling formats like the built in python tool calling or custom tool calling are not supported. - -Known issues: -1. Parallel tool calls are not supported. -2. The model can generate parameters with a wrong format, such as generating - an array serialized as string instead of an array. - -The `tool_chat_template_llama3_json.jinja` file contains the "official" Llama chat template, but tweaked so that -it works better with vLLM. - -Recommended flags: `--tool-call-parser llama3_json --chat-template examples/tool_chat_template_llama3_json.jinja` - -#### IBM Granite - -Supported models: -* `ibm-granite/granite-3.0-8b-instruct` - -Recommended flags: `--tool-call-parser granite --chat-template examples/tool_chat_template_granite.jinja` - -`examples/tool_chat_template_granite.jinja`: this is a modified chat template from the original on Huggingface. Parallel function calls are supported. - -* `ibm-granite/granite-20b-functioncalling` - -Recommended flags: `--tool-call-parser granite-20b-fc --chat-template examples/tool_chat_template_granite_20b_fc.jinja` - -`examples/tool_chat_template_granite_20b_fc.jinja`: this is a modified chat template from the original on Huggingface, which is not vLLM compatible. It blends function description elements from the Hermes template and follows the same system prompt as "Response Generation" mode from [the paper](https://arxiv.org/abs/2407.00121). Parallel function calls are supported. - - -#### InternLM Models (`internlm`) - -Supported models: -* `internlm/internlm2_5-7b-chat` (confirmed) -* Additional internlm2.5 function-calling models are compatible as well - -Known issues: -* Although this implementation also supports InternLM2, the tool call results are not stable when testing with the `internlm/internlm2-chat-7b` model. - -Recommended flags: `--tool-call-parser internlm --chat-template examples/tool_chat_template_internlm2_tool.jinja` - - -#### Jamba Models (`jamba`) -AI21's Jamba-1.5 models are supported. -* `ai21labs/AI21-Jamba-1.5-Mini` -* `ai21labs/AI21-Jamba-1.5-Large` - - -Flags: `--tool-call-parser jamba` - - -#### Models with Pythonic Tool Calls (`pythonic`) - -A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models. - -As a concrete example, these models may look up the weather in San Francisco and Seattle by generating: -```python -[get_weather(city='San Francisco', metric='celsius'), get_weather(city='Seattle', metric='celsius')] -``` - -Limitations: -* The model must not generate both text and tool calls in the same generation. This may not be hard to change for a specific model, but the community currently lacks consensus on which tokens to emit when starting and ending tool calls. (In particular, the Llama 3.2 models emit no such tokens.) -* Llama's smaller models struggle to use tools effectively. - -Example supported models: -* `meta-llama/Llama-3.2-1B-Instruct`\* (use with `examples/tool_chat_template_llama3.2_pythonic.jinja`) -* `meta-llama/Llama-3.2-3B-Instruct`\* (use with `examples/tool_chat_template_llama3.2_pythonic.jinja`) -* `Team-ACE/ToolACE-8B` (use with `examples/tool_chat_template_toolace.jinja`) -* `fixie-ai/ultravox-v0_4-ToolACE-8B` (use with `examples/tool_chat_template_toolace.jinja`) - -Flags: `--tool-call-parser pythonic --chat-template {see_above}` - ---- -**WARNING** -Llama's smaller models frequently fail to emit tool calls in the correct format. Your mileage may vary. - ---- - - -### How to write a tool parser plugin - -A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py. - -Here is a summary of a plugin file: - -```python - -# import the required packages - -# define a tool parser and register it to vllm -# the name list in register_module can be used -# in --tool-call-parser. you can define as many -# tool parsers as you want here. -@ToolParserManager.register_module(["example"]) -class ExampleToolParser(ToolParser): - def __init__(self, tokenizer: AnyTokenizer): - super().__init__(tokenizer) - - # adjust request. e.g.: set skip special tokens - # to False for tool call output. - def adjust_request( - self, request: ChatCompletionRequest) -> ChatCompletionRequest: - return request - - # implement the tool call parse for stream call - def extract_tool_calls_streaming( - self, - previous_text: str, - current_text: str, - delta_text: str, - previous_token_ids: Sequence[int], - current_token_ids: Sequence[int], - delta_token_ids: Sequence[int], - request: ChatCompletionRequest, - ) -> Union[DeltaMessage, None]: - return delta - - # implement the tool parse for non-stream call - def extract_tool_calls( - self, - model_output: str, - request: ChatCompletionRequest, - ) -> ExtractedToolCallInformation: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=text) - - -``` - -Then you can use this plugin in the command line like this. -``` - --enable-auto-tool-choice \ - --tool-parser-plugin - --tool-call-parser example \ - --chat-template \ -``` - diff --git a/docs/source/usage/tool_calling.md b/docs/source/usage/tool_calling.md new file mode 100644 index 0000000000000..f8be023307b0c --- /dev/null +++ b/docs/source/usage/tool_calling.md @@ -0,0 +1,287 @@ +# Tool Calling + +vLLM currently supports named function calling, as well as the `auto` and `none` options for the `tool_choice` field in the chat completion API. The `tool_choice` option `required` is **not yet supported** but on the roadmap. + +## Quickstart + +Start the server with tool calling enabled. This example uses Meta's Llama 3.1 8B model, so we need to use the llama3 tool calling chat template from the vLLM examples directory: + +```bash +vllm serve meta-llama/Llama-3.1-8B-Instruct \ + --enable-auto-tool-choice \ + --tool-call-parser llama3_json \ + --chat-template examples/tool_chat_template_llama3_json.jinja +``` + +Next, make a request to the model that should result in it using the available tools: + +```python +from openai import OpenAI +import json + +client = OpenAI(base_url="http://localhost:8000/v1", api_key="dummy") + +def get_weather(location: str, unit: str): + return f"Getting the weather for {location} in {unit}..." +tool_functions = {"get_weather": get_weather} + +tools = [{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "City and state, e.g., 'San Francisco, CA'"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]} + }, + "required": ["location", "unit"] + } + } +}] + +response = client.chat.completions.create( + model=client.models.list().data[0].id, + messages=[{"role": "user", "content": "What's the weather like in San Francisco?"}], + tools=tools, + tool_choice="auto" +) + +tool_call = response.choices[0].message.tool_calls[0].function +print(f"Function called: {tool_call.name}") +print(f"Arguments: {tool_call.arguments}") +print(f"Result: {get_weather(**json.loads(tool_call.arguments))}") +``` + +Example output: +``` +Function called: get_weather +Arguments: {"location": "San Francisco, CA", "unit": "fahrenheit"} +Result: Getting the weather for San Francisco, CA in fahrenheit... +``` + +This example demonstrates: +- Setting up the server with tool calling enabled +- Defining an actual function to handle tool calls +- Making a request with `tool_choice="auto"` +- Handling the structured response and executing the corresponding function + +You can also specify a particular function using named function calling by setting `tool_choice={"type": "function", "function": {"name": "get_weather"}}`. Note that this will use the guided decoding backend - so the first time this is used, there will be several seconds of latency (or more) as the FSM is compiled for the first time before it is cached for subsequent requests. + +Remember that it's the callers responsibility to: +1. Define appropriate tools in the request +2. Include relevant context in the chat messages +3. Handle the tool calls in your application logic + +For more advanced usage, including parallel tool calls and different model-specific parsers, see the sections below. + +## Named Function Calling +vLLM supports named function calling in the chat completion API by default. It does so using Outlines through guided decoding, so this is +enabled by default, and will work with any supported model. You are guaranteed a validly-parsable function call - not a +high-quality one. + +vLLM will use guided decoding to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter. +For best results, we recommend ensuring that the expected output format / schema is specified in the prompt to ensure that the model's intended generation is aligned with the schema that it's being forced to generate by the guided decoding backend. + +To use a named function, you need to define the functions in the `tools` parameter of the chat completion request, and +specify the `name` of one of the tools in the `tool_choice` parameter of the chat completion request. + + +## Automatic Function Calling + +To enable this feature, you should set the following flags: +* `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it +deems appropriate. +* `--tool-call-parser` -- select the tool parser to use (listed below). Additional tool parsers +will continue to be added in the future, and also can register your own tool parsers in the `--tool-parser-plugin`. +* `--tool-parser-plugin` -- **optional** tool parser plugin used to register user defined tool parsers into vllm, the registered tool parser name can be specified in `--tool-call-parser`. +* `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages +that contain previously generated tool calls. Hermes, Mistral and Llama models have tool-compatible chat templates in their +`tokenizer_config.json` files, but you can specify a custom template. This argument can be set to `tool_use` if your model has a tool use-specific chat +template configured in the `tokenizer_config.json`. In this case, it will be used per the `transformers` specification. More on this [here](https://huggingface.co/docs/transformers/en/chat_templating#why-do-some-models-have-multiple-templates) +from HuggingFace; and you can find an example of this in a `tokenizer_config.json` [here](https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/blob/main/tokenizer_config.json) + +If your favorite tool-calling model is not supported, please feel free to contribute a parser & tool use chat template! + + +### Hermes Models (`hermes`) + +All Nous Research Hermes-series models newer than Hermes 2 Pro should be supported. +* `NousResearch/Hermes-2-Pro-*` +* `NousResearch/Hermes-2-Theta-*` +* `NousResearch/Hermes-3-*` + + +_Note that the Hermes 2 **Theta** models are known to have degraded tool call quality & capabilities due to the merge +step in their creation_. + +Flags: `--tool-call-parser hermes` + + +### Mistral Models (`mistral`) + +Supported models: +* `mistralai/Mistral-7B-Instruct-v0.3` (confirmed) +* Additional mistral function-calling models are compatible as well. + +Known issues: +1. Mistral 7B struggles to generate parallel tool calls correctly. +2. Mistral's `tokenizer_config.json` chat template requires tool call IDs that are exactly 9 digits, which is +much shorter than what vLLM generates. Since an exception is thrown when this condition +is not met, the following additional chat templates are provided: + +* `examples/tool_chat_template_mistral.jinja` - this is the "official" Mistral chat template, but tweaked so that +it works with vLLM's tool call IDs (provided `tool_call_id` fields are truncated to the last 9 digits) +* `examples/tool_chat_template_mistral_parallel.jinja` - this is a "better" version that adds a tool-use system prompt +when tools are provided, that results in much better reliability when working with parallel tool calling. + + +Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja` + + +### Llama Models (`llama3_json`) + +Supported models: +* `meta-llama/Meta-Llama-3.1-8B-Instruct` +* `meta-llama/Meta-Llama-3.1-70B-Instruct` +* `meta-llama/Meta-Llama-3.1-405B-Instruct` +* `meta-llama/Meta-Llama-3.1-405B-Instruct-FP8` + +The tool calling that is supported is the [JSON based tool calling](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling). For [pythonic tool calling](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#zero-shot-function-calling) in Llama-3.2 models, see the `pythonic` tool parser below. +Other tool calling formats like the built in python tool calling or custom tool calling are not supported. + +Known issues: +1. Parallel tool calls are not supported. +2. The model can generate parameters with a wrong format, such as generating + an array serialized as string instead of an array. + +The `tool_chat_template_llama3_json.jinja` file contains the "official" Llama chat template, but tweaked so that +it works better with vLLM. + +Recommended flags: `--tool-call-parser llama3_json --chat-template examples/tool_chat_template_llama3_json.jinja` + +#### IBM Granite + +Supported models: +* `ibm-granite/granite-3.0-8b-instruct` + +Recommended flags: `--tool-call-parser granite --chat-template examples/tool_chat_template_granite.jinja` + +`examples/tool_chat_template_granite.jinja`: this is a modified chat template from the original on Huggingface. Parallel function calls are supported. + +* `ibm-granite/granite-20b-functioncalling` + +Recommended flags: `--tool-call-parser granite-20b-fc --chat-template examples/tool_chat_template_granite_20b_fc.jinja` + +`examples/tool_chat_template_granite_20b_fc.jinja`: this is a modified chat template from the original on Huggingface, which is not vLLM compatible. It blends function description elements from the Hermes template and follows the same system prompt as "Response Generation" mode from [the paper](https://arxiv.org/abs/2407.00121). Parallel function calls are supported. + + +### InternLM Models (`internlm`) + +Supported models: +* `internlm/internlm2_5-7b-chat` (confirmed) +* Additional internlm2.5 function-calling models are compatible as well + +Known issues: +* Although this implementation also supports InternLM2, the tool call results are not stable when testing with the `internlm/internlm2-chat-7b` model. + +Recommended flags: `--tool-call-parser internlm --chat-template examples/tool_chat_template_internlm2_tool.jinja` + + +### Jamba Models (`jamba`) +AI21's Jamba-1.5 models are supported. +* `ai21labs/AI21-Jamba-1.5-Mini` +* `ai21labs/AI21-Jamba-1.5-Large` + + +Flags: `--tool-call-parser jamba` + + +### Models with Pythonic Tool Calls (`pythonic`) + +A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models. + +As a concrete example, these models may look up the weather in San Francisco and Seattle by generating: +```python +[get_weather(city='San Francisco', metric='celsius'), get_weather(city='Seattle', metric='celsius')] +``` + +Limitations: +* The model must not generate both text and tool calls in the same generation. This may not be hard to change for a specific model, but the community currently lacks consensus on which tokens to emit when starting and ending tool calls. (In particular, the Llama 3.2 models emit no such tokens.) +* Llama's smaller models struggle to use tools effectively. + +Example supported models: +* `meta-llama/Llama-3.2-1B-Instruct`\* (use with `examples/tool_chat_template_llama3.2_pythonic.jinja`) +* `meta-llama/Llama-3.2-3B-Instruct`\* (use with `examples/tool_chat_template_llama3.2_pythonic.jinja`) +* `Team-ACE/ToolACE-8B` (use with `examples/tool_chat_template_toolace.jinja`) +* `fixie-ai/ultravox-v0_4-ToolACE-8B` (use with `examples/tool_chat_template_toolace.jinja`) + +Flags: `--tool-call-parser pythonic --chat-template {see_above}` + +--- +**WARNING** +Llama's smaller models frequently fail to emit tool calls in the correct format. Your mileage may vary. + +--- + + +## How to write a tool parser plugin + +A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py. + +Here is a summary of a plugin file: + +```python + +# import the required packages + +# define a tool parser and register it to vllm +# the name list in register_module can be used +# in --tool-call-parser. you can define as many +# tool parsers as you want here. +@ToolParserManager.register_module(["example"]) +class ExampleToolParser(ToolParser): + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + # adjust request. e.g.: set skip special tokens + # to False for tool call output. + def adjust_request( + self, request: ChatCompletionRequest) -> ChatCompletionRequest: + return request + + # implement the tool call parse for stream call + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + return delta + + # implement the tool parse for non-stream call + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=text) + + +``` + +Then you can use this plugin in the command line like this. +``` + --enable-auto-tool-choice \ + --tool-parser-plugin + --tool-call-parser example \ + --chat-template \ +``` + From d1f6d1c8af892c7269f113711783374eebb52511 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 10 Dec 2024 10:23:07 +0800 Subject: [PATCH 18/18] [Model] Add has_weight to RMSNorm and re-enable weights loading tracker for Mamba (#10739) Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/layers/layernorm.py | 11 ++++++-- .../layers/mamba/mamba_mixer.py | 26 +++++++++++++------ vllm/model_executor/models/mamba.py | 9 +++++-- 3 files changed, 34 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 345919c5d1636..43ea4eb5a4d1a 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -20,6 +20,7 @@ def __init__( hidden_size: int, eps: float = 1e-6, var_hidden_size: Optional[int] = None, + has_weight: bool = True, ) -> None: super().__init__() @@ -27,7 +28,11 @@ def __init__( self.variance_epsilon = eps self.variance_size_override = (None if var_hidden_size == hidden_size else var_hidden_size) - self.weight = nn.Parameter(torch.ones(hidden_size)) + self.has_weight = has_weight + + self.weight = torch.ones(hidden_size) + if self.has_weight: + self.weight = nn.Parameter(self.weight) def forward_native( self, @@ -59,7 +64,9 @@ def forward_native( variance = x_var.pow(2).mean(dim=-1, keepdim=True) x = x * torch.rsqrt(variance + self.variance_epsilon) - x = x.to(orig_dtype) * self.weight + x = x.to(orig_dtype) + if self.has_weight: + x = x * self.weight if residual is None: return x else: diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 8ef0a6cdf2c52..10bec75f49fdf 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -40,6 +40,7 @@ def __init__(self, use_conv_bias: bool, use_bias: bool, use_rms_norm: bool, + rms_norm_has_weight: bool = True, rms_norm_eps: float = 1e-5, activation="silu"): super().__init__() @@ -105,14 +106,23 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): input_is_parallel=True, ) - self.dt_layernorm = RMSNorm(time_step_rank, - eps=rms_norm_eps) if use_rms_norm else None - - self.b_layernorm = RMSNorm(ssm_state_size, - eps=rms_norm_eps) if use_rms_norm else None - - self.c_layernorm = RMSNorm(ssm_state_size, - eps=rms_norm_eps) if use_rms_norm else None + self.dt_layernorm = RMSNorm( + time_step_rank, + eps=rms_norm_eps, + has_weight=rms_norm_has_weight, + ) if use_rms_norm else None + + self.b_layernorm = RMSNorm( + ssm_state_size, + eps=rms_norm_eps, + has_weight=rms_norm_has_weight, + ) if use_rms_norm else None + + self.c_layernorm = RMSNorm( + ssm_state_size, + eps=rms_norm_eps, + has_weight=rms_norm_has_weight, + ) if use_rms_norm else None def forward_native(self, hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index b32032e411b0a..8bdcd2c5aad1f 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -1,5 +1,5 @@ """PyTorch MAMBA model.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Set, Tuple import torch from torch import nn @@ -47,6 +47,7 @@ def __init__(self, use_conv_bias=config.use_conv_bias, use_bias=config.use_bias, use_rms_norm=self.is_falcon_mamba, + rms_norm_has_weight=not self.is_falcon_mamba, rms_norm_eps=mixer_rms_eps, activation=config.hidden_act) @@ -241,8 +242,10 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "A_log" in name: name = name.replace("A_log", "A") @@ -254,3 +257,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params