From b98d89efd4b1a09c11c4d0cf30c9af0e93514764 Mon Sep 17 00:00:00 2001 From: Sky Lee <46676799+skylee-01@users.noreply.github.com> Date: Sun, 17 Nov 2024 00:33:01 +0800 Subject: [PATCH 1/5] [Misc] Medusa supports custom bias (#10361) --- vllm/model_executor/models/medusa.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py index de5b2d89c0962..b05360b55466b 100644 --- a/vllm/model_executor/models/medusa.py +++ b/vllm/model_executor/models/medusa.py @@ -14,11 +14,14 @@ class ResidualBlock(nn.Module): - def __init__(self, hidden_size: int, num_layers: int) -> None: + def __init__(self, config: VllmConfig, hidden_size: int, + num_layers: int) -> None: super().__init__() self.layers = nn.ModuleList([ - nn.Linear(hidden_size, hidden_size, bias=False) + nn.Linear(hidden_size, + hidden_size, + bias=getattr(config, "medusa_fc_bias", False)) for _ in range(num_layers) ]) self.act = nn.SiLU() @@ -49,7 +52,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() self.config = config self.blocks = nn.ModuleList([ - ResidualBlock(hidden_size=self.config.hidden_size, + ResidualBlock(config=config, + hidden_size=self.config.hidden_size, num_layers=self.config.num_hidden_layers) for _ in range(self.config.num_heads) ]) From 361c29e1740e0b2186f8cca3ed96ad235a8a960a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=94=B5=E8=84=91=E6=98=9F=E4=BA=BA?= Date: Sun, 17 Nov 2024 02:10:00 +0800 Subject: [PATCH 2/5] [Bugfix] Fix M-RoPE position calculation when chunked prefill is enabled (#10388) Signed-off-by: imkero --- .../vision_language/test_qwen2_vl.py | 136 +++++++++++++++++- .../model_executor/layers/rotary_embedding.py | 3 +- vllm/worker/model_runner.py | 1 + 3 files changed, 135 insertions(+), 5 deletions(-) diff --git a/tests/models/decoder_only/vision_language/test_qwen2_vl.py b/tests/models/decoder_only/vision_language/test_qwen2_vl.py index 718c675b86fb4..71b6ba4dca435 100644 --- a/tests/models/decoder_only/vision_language/test_qwen2_vl.py +++ b/tests/models/decoder_only/vision_language/test_qwen2_vl.py @@ -18,6 +18,7 @@ IMAGE_PLACEHOLDER = "<|vision_start|><|image_pad|><|vision_end|>" VIDEO_PLACEHOLDER = "<|vision_start|><|video_pad|><|vision_end|>" +MODEL_HIDDEN_SIZE = 1536 def qwen2_vl_chat_template(*query): @@ -230,7 +231,7 @@ def batch_make_video_embeddings( return result -def run_test( +def run_embedding_input_test( vllm_runner: Type[VllmRunner], inputs: List[Tuple[List[str], PromptImageInput, PromptVideoInput]], model: str, @@ -326,7 +327,7 @@ def test_qwen2_vl_image_embeddings_input(vllm_runner, image_assets, model, [], ) for image, prompt in zip(images, IMAGE_PROMPTS)] - run_test( + run_embedding_input_test( vllm_runner, inputs_per_case, model, @@ -371,7 +372,7 @@ def test_qwen2_vl_multiple_image_embeddings_input(vllm_runner, image_assets, [], )] - run_test( + run_embedding_input_test( vllm_runner, inputs_per_case, model, @@ -416,7 +417,134 @@ def test_qwen2_vl_video_embeddings_input(vllm_runner, video_assets, model, [rescale_video_size(video, factor) for factor in size_factors], ) for video, prompt in zip(sampled_vids, VIDEO_PROMPTS)] - run_test( + run_embedding_input_test( + vllm_runner, + inputs_per_case, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + mm_limit=1, + tensor_parallel_size=1, + ) + + +def run_chunked_prefill_test( + vllm_runner: Type[VllmRunner], + inputs: List[Tuple[List[str], PromptImageInput, PromptVideoInput]], + model: str, + *, + dtype: str, + max_tokens: int, + num_logprobs: int, + mm_limit: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + """Compare inference result between + chunked prefill disabled and chunked prefill enabled + """ + + # NOTE: + # max_model_len should be greater than image_feature_size + with vllm_runner(model, + task="generate", + max_model_len=4000, + max_num_seqs=4, + dtype=dtype, + limit_mm_per_prompt={ + "image": mm_limit, + "video": mm_limit + }, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend + ) as vllm_model: + + outputs_per_case = [ + vllm_model.generate_greedy_logprobs(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images or None, + videos=videos or None) + for prompts, images, videos in inputs + ] + + with vllm_runner( + model, + task="generate", + max_model_len=4000, + max_num_seqs=4, + dtype=dtype, + limit_mm_per_prompt={ + "image": mm_limit, + "video": mm_limit + }, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enable_chunked_prefill=True, + # should be small enough to ensure prefilling is chunked + max_num_batched_tokens=32, + mm_processor_kwargs={ + "max_pixels": 16 * 28 * 28, + }) as vllm_model_chunked: + outputs_per_case_chunked = [ + vllm_model_chunked.generate_greedy_logprobs( + prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images or None, + videos=videos or None) for prompts, images, videos in inputs + ] + + for outputs, \ + outputs_chunked \ + in zip(outputs_per_case, + outputs_per_case_chunked): + check_logprobs_close( + outputs_0_lst=outputs, + outputs_1_lst=outputs_chunked, + name_0="non_chunked", + name_1="chunked", + ) + + +@pytest.mark.core_model +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_tokens", [1]) +@pytest.mark.parametrize("num_logprobs", [10]) +def test_qwen2_vl_mrope_chunked_prefill(vllm_runner, example_prompts, + model: str, dtype: str, + max_tokens: int, + num_logprobs: int) -> None: + """ + Test Qwen2-VL's chunked prefill with M-RoPE + """ + prompts = [ + qwen2_vl_chat_template(IMAGE_PLACEHOLDER, prompt) + for prompt in example_prompts[:1] + ] + + # 1. Qwen2-VL's M-RoPE works only when there are some multi-modal inputs, + # so an image is included in the inputs + # 2. however, Qwen2-VL currently won't work properly + # when chunked prefill is enabled and there are some multi-modal inputs, + # here use a hacky way: provide a **zero-length** image to make it happy + # + # and finally we achieved: + # (1) chunked_prefill enabled; (2) M-RoPE works; to continue our tests + zero_len_image = { + "image_embeds": torch.empty((0, MODEL_HIDDEN_SIZE)), + "image_grid_thw": torch.tensor([[0, 0, 0]]) + } + images = [zero_len_image] * len(prompts) + + inputs_per_case: List[Tuple[List[str], PromptImageInput, + PromptVideoInput]] = [ + (prompts, images, []), + ] + + run_chunked_prefill_test( vllm_runner, inputs_per_case, model, diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 63ceec63e8317..b01e4c61fe101 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -847,6 +847,7 @@ def get_input_positions( vision_end_token_id: int, spatial_merge_size: int, context_len: int = 0, + seq_len: Optional[int] = None, ) -> Tuple[List[List[int]], int]: """Get mrope input positions and delta value.""" @@ -921,7 +922,7 @@ def get_input_positions( torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - llm_positions = llm_positions[:, context_len:] + llm_positions = llm_positions[:, context_len:seq_len] mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 042f9f07eace6..22ee3f9f863e4 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -700,6 +700,7 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, spatial_merge_size=hf_config.vision_config. spatial_merge_size, context_len=inter_data.context_lens[seq_idx], + seq_len=inter_data.seq_lens[seq_idx], ) seq_data.mrope_position_delta = mrope_position_delta From 661a34fd4fdd700a29b2db758e23e4e243e7ff18 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 16 Nov 2024 10:45:26 -0800 Subject: [PATCH 3/5] [V1] Add code owners for V1 (#10397) Signed-off-by: Woosuk Kwon --- .github/CODEOWNERS | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index cd721971d01d6..3cb91fc0f8232 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -3,13 +3,16 @@ # This lists cover the "core" components of vLLM that require careful review /vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -/vllm/core @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -/vllm/engine/llm_engine.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -/vllm/executor/executor_base.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -/vllm/worker/worker_base.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -/vllm/worker/worker.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -/vllm/model_executor/layers/sampler.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill -CMakeLists.txt @tlrmchlsmth @WoosukKwon +/vllm/core @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill +/vllm/engine/llm_engine.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill +/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill +/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill +/vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill +/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill +CMakeLists.txt @tlrmchlsmth + +# vLLM V1 +/vllm/v1 @WoosukKwon @robertgshaw2-neuralmagic @njhill @ywang96 @comaniac @alexm-neuralmagic # Test ownership /tests/async_engine @njhill @robertgshaw2-neuralmagic @simon-mo From 4fd937502827a7e06c54ded1f9d9b70ff640e222 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 16 Nov 2024 18:02:14 -0800 Subject: [PATCH 4/5] [2/N][torch.compile] make compilation cfg part of vllm cfg (#10383) Signed-off-by: youkaichao --- tests/compile/piecewise/test_simple.py | 8 +- tests/compile/piecewise/test_toy_llama.py | 22 +- tests/compile/test_basic_correctness.py | 2 +- tests/compile/test_full_graph.py | 2 +- tests/compile/test_fusion.py | 2 +- tests/compile/test_wrapper.py | 4 +- tests/compile/utils.py | 2 +- .../model_executor/test_enabled_custom_ops.py | 52 ++--- tests/tpu/test_compilation.py | 2 +- tests/tpu/test_custom_dispatcher.py | 2 +- vllm/compilation/backends.py | 20 +- vllm/compilation/config.py | 159 --------------- vllm/compilation/decorators.py | 10 +- vllm/compilation/fusion.py | 2 +- vllm/compilation/inductor_pass.py | 2 +- vllm/compilation/levels.py | 8 - vllm/compilation/wrapper.py | 11 +- vllm/config.py | 189 ++++++++++++++++++ vllm/envs.py | 13 -- vllm/model_executor/custom_op.py | 27 +-- vllm/model_executor/model_loader/loader.py | 7 +- vllm/platforms/interface.py | 20 +- vllm/platforms/tpu.py | 21 +- vllm/plugins/__init__.py | 30 ++- vllm/v1/worker/gpu_model_runner.py | 10 +- vllm/worker/model_runner.py | 7 +- vllm/worker/tpu_model_runner.py | 8 +- 27 files changed, 359 insertions(+), 283 deletions(-) delete mode 100644 vllm/compilation/config.py delete mode 100644 vllm/compilation/levels.py diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index c631850ecdedb..45f56cbbd4b16 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -11,8 +11,8 @@ from vllm.compilation.compile_context import set_compile_context from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile -from vllm.compilation.levels import CompilationLevel -from vllm.config import VllmConfig +from vllm.config import CompilationLevel, VllmConfig +from vllm.plugins import set_current_vllm_config from vllm.utils import direct_register_custom_op global_counter = 0 @@ -82,7 +82,9 @@ def test_simple_piecewise_compile(): os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE) - model = SillyModel(vllm_config=VllmConfig(), prefix='') + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + model = SillyModel(vllm_config=vllm_config, prefix='') inputs = torch.randn(100).cuda() diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index c363a587a818e..8032304e95806 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -15,12 +15,10 @@ from torch.library import Library from vllm.compilation.compile_context import set_compile_context -from vllm.compilation.config import CompilationConfig from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile -from vllm.compilation.levels import CompilationLevel -from vllm.config import VllmConfig -from vllm.plugins import set_compilation_config +from vllm.config import CompilationConfig, CompilationLevel, VllmConfig +from vllm.plugins import set_compilation_config, set_current_vllm_config from vllm.utils import direct_register_custom_op # create a library to hold the custom op @@ -272,9 +270,11 @@ def run_model(llama_config, CompilationLevel.NO_COMPILATION) set_compilation_config(None) - model = LlamaModel(config=llama_config, - vllm_config=VllmConfig(), - prefix="").eval().cuda() + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + model = LlamaModel(config=llama_config, + vllm_config=vllm_config, + prefix="").eval().cuda() B = 16 # max batch size input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() @@ -395,9 +395,11 @@ def benchmark(): else: set_compilation_config(None) - model = LlamaModel(config=llama_config, - vllm_config=VllmConfig(), - prefix="").eval().cuda().to(torch.bfloat16) + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + model = LlamaModel(config=llama_config, + vllm_config=vllm_config, + prefix="").eval().cuda().to(torch.bfloat16) B = 256 # max batch size input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 833589ba5dc9f..08747ebc58b75 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -3,7 +3,7 @@ import pytest -from vllm.compilation.levels import CompilationLevel +from vllm.config import CompilationLevel from vllm.utils import cuda_device_count_stateless from ..utils import compare_all_settings diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index f00334934cb46..4dfdfe21a67df 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -1,6 +1,6 @@ import pytest -from vllm.compilation.levels import CompilationLevel +from vllm.config import CompilationLevel from ..utils import fork_new_process_for_each_test from .utils import TEST_MODELS, check_full_graph_support diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index e4d3defafb951..4db79b070fd8d 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -3,10 +3,10 @@ from compressed_tensors.quantization import FP8_DTYPE import vllm.envs as envs -from vllm.compilation.config import CompilationConfig from vllm.compilation.fusion import (FusionPass, find_auto_fn, find_auto_fn_maybe) from vllm.compilation.reshapes import RedundantReshapesPass +from vllm.config import CompilationConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( apply_fp8_linear) diff --git a/tests/compile/test_wrapper.py b/tests/compile/test_wrapper.py index 3668c1fab6b89..74f66baaa5ea1 100644 --- a/tests/compile/test_wrapper.py +++ b/tests/compile/test_wrapper.py @@ -3,6 +3,7 @@ import torch from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher +from vllm.config import CompilationLevel class MyMod(torch.nn.Module): @@ -18,7 +19,8 @@ class MyWrapper(TorchCompileWrapperWithCustomDispatcher): def __init__(self, model): self.model = model compiled_callable = torch.compile(self.forward, backend="eager") - super().__init__(compiled_callable) + super().__init__(compiled_callable, + compilation_level=CompilationLevel.DYNAMO_ONCE) def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): # this is the function to be compiled diff --git a/tests/compile/utils.py b/tests/compile/utils.py index 222c63a342a4b..729f10676888b 100644 --- a/tests/compile/utils.py +++ b/tests/compile/utils.py @@ -4,7 +4,7 @@ from tests.quantization.utils import is_quant_method_supported from vllm import LLM, SamplingParams -from vllm.compilation.levels import CompilationLevel +from vllm.config import CompilationLevel from vllm.platforms import current_platform TEST_MODELS = [ diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index af267f804ffa7..c3219bc50646b 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -3,11 +3,13 @@ import pytest +from vllm.config import CompilationConfig, VllmConfig from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import (GeluAndMul, ReLUSquaredActivation, SiluAndMul) from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.plugins import set_current_vllm_config # Registered subclass for test @@ -51,42 +53,40 @@ class Relu3(ReLUSquaredActivation): ]) def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int], default_on: bool): - os.environ["VLLM_CUSTOM_OPS"] = env os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level) + vllm_config = VllmConfig(compilation_config=CompilationConfig( + custom_ops=env.split(","))) + with set_current_vllm_config(vllm_config): + assert CustomOp.default_on() == default_on - # Reset default_on (computed once): - CustomOp.default_on.cache_clear() + ops_enabled = [bool(x) for x in ops_enabled] - assert CustomOp.default_on() == default_on + assert RMSNorm(1024).enabled() == ops_enabled[0] + assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0] - ops_enabled = [bool(x) for x in ops_enabled] + assert SiluAndMul().enabled() == ops_enabled[1] + assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1] - assert RMSNorm(1024).enabled() == ops_enabled[0] - assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0] + assert GeluAndMul().enabled() == ops_enabled[2] + assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2] - assert SiluAndMul().enabled() == ops_enabled[1] - assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1] + # If registered, subclasses should follow their own name + assert Relu3().enabled() == ops_enabled[3] + assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3] - assert GeluAndMul().enabled() == ops_enabled[2] - assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2] + # Unregistered subclass + class SiluAndMul2(SiluAndMul): + pass - # If registered, subclasses should follow their own name - assert Relu3().enabled() == ops_enabled[3] - assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3] - - # Unregistered subclass - class SiluAndMul2(SiluAndMul): - pass - - # Subclasses should not require registration - assert SiluAndMul2().enabled() == SiluAndMul().enabled() + # Subclasses should not require registration + assert SiluAndMul2().enabled() == SiluAndMul().enabled() @pytest.mark.parametrize( "env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"]) def test_enabled_ops_invalid(env: str): - os.environ["VLLM_CUSTOM_OPS"] = env - CustomOp.default_on.cache_clear() - - with pytest.raises(AssertionError): - RMSNorm(1024).enabled() + with pytest.raises(Exception): # noqa + vllm_config = VllmConfig(compilation_config=CompilationConfig( + custom_ops=env.split(","))) + with set_current_vllm_config(vllm_config): + RMSNorm(1024).enabled() diff --git a/tests/tpu/test_compilation.py b/tests/tpu/test_compilation.py index 86d9af88e49ea..941abe17a3378 100644 --- a/tests/tpu/test_compilation.py +++ b/tests/tpu/test_compilation.py @@ -5,7 +5,7 @@ import depyf -from vllm.compilation.levels import CompilationLevel +from vllm.config import CompilationLevel # disable custom dispatcher, let Dynamo takes over # all the control diff --git a/tests/tpu/test_custom_dispatcher.py b/tests/tpu/test_custom_dispatcher.py index 923d0f1680802..53b10c06135a1 100644 --- a/tests/tpu/test_custom_dispatcher.py +++ b/tests/tpu/test_custom_dispatcher.py @@ -1,6 +1,6 @@ import os -from vllm.compilation.levels import CompilationLevel +from vllm.config import CompilationLevel from ..utils import compare_two_settings diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 5682faa158069..22c613931f082 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -10,13 +10,12 @@ import torch.fx as fx import vllm.envs as envs +from vllm.config import CompilationConfig, CompilationLevel from vllm.logger import init_logger from vllm.utils import combine_fx_passes, weak_ref_tensors -from .config import CompilationConfig from .counter import compilation_counter from .fusion import FusionPass -from .levels import CompilationLevel from .reshapes import RedundantReshapesPass logger = init_logger(__name__) @@ -392,7 +391,10 @@ class VllmBackend: sym_tensor_indices: List[int] input_buffers: List[torch.Tensor] - def __init__(self, post_grad_passes: Sequence[Callable] = ()): + def __init__( + self, + compilation_configs: CompilationConfig, + ): global global_graph_pool if global_graph_pool is None: global_graph_pool = torch.cuda.graph_pool_handle() @@ -401,11 +403,13 @@ def __init__(self, post_grad_passes: Sequence[Callable] = ()): # streams, it might not be safe to share a global pool. # only investigate this when we use multiple streams self.graph_pool = global_graph_pool - self.post_grad_passes = post_grad_passes + self.post_grad_passes = [] self.sym_tensor_indices = [] self.input_buffers = [] + self.compilation_configs = compilation_configs + # `torch.compile` is JIT compiled, so we don't need to # do anything here @@ -437,10 +441,10 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: assert not self._called, "VllmBackend can only be called once" self.graph = graph - # config is read now, because only here can + # config is updated now, because only here can # we get the sizes to capture for cudagraph # from compilation context - self.compilation_configs = CompilationConfig.select_and_init_config() + self.compilation_configs.init_during_runtime() self.add_passes_to_config() self.split_gm, self.piecewise_graphs = split_graph( @@ -688,4 +692,6 @@ def select_default_backend(level: int) -> Union[str, Callable]: return backend_str assert level == CompilationLevel.PIECEWISE - return VllmBackend() + from vllm.plugins import get_current_vllm_config + compilation_config = get_current_vllm_config().compilation_config + return VllmBackend(compilation_config) diff --git a/vllm/compilation/config.py b/vllm/compilation/config.py deleted file mode 100644 index 3e663505c627d..0000000000000 --- a/vllm/compilation/config.py +++ /dev/null @@ -1,159 +0,0 @@ -import copy -from pathlib import Path -from typing import Any, Dict, List, Optional - -from pydantic import BaseModel, Field, PrivateAttr - -import vllm.envs as envs -from vllm.logger import init_logger - -from .compile_context import get_compile_context - -logger = init_logger(__name__) - - -class CompilationConfig(BaseModel): - """ - Configuration for compilation. - It has two parts: - - CudaGraph capture: - - use_cudagraph: whether to use cudagraph inside compilation. - - False: cudagraph inside compilation is not used. - - True: cudagraph inside compilation is used. It requires - that all input buffers have fixed addresses. - Note that this is orthogonal to the cudagraph capture out - side of compilation. - TODO: move outside cudagraph logic into compilation. - torch.compile will handle cudagraph capture logic in the future. - - cudagraph_capture_sizes: sizes to capture cudagraph. - - None: capture sizes are inferred from compilation context. - - List[int]: capture sizes are specified. - - cudagraph_num_of_warmups: number of warmup runs for cudagraph. - It means the first several runs will be treated as warmup runs. - Only after that, the execution will be recorded, and the recorded - cudagraph will be used for subsequent runs. - - cudagraph_copy_inputs: whether to copy input tensors for - cudagraph. If the caller can guarantee that the same input buffers - are always used, it can set this to False. Otherwise, it should - set this to True, and the compiler will copy the input to an - internally managed buffer. Default is False. - - Inductor compilation: - - use_inductor: whether to use inductor compilation. - - False: inductor compilation is not used. graph runs in eager. - - True: inductor compilation is used. one graph for symbolic shape - is compiled. In addition, compile for different sizes specified - in inductor_compile_sizes, using configurations - in inductor_compile_config. - - inductor_compile_sizes: sizes to compile for inductor. - - inductor_specialize_for_cudagraph_no_more_than: an optional integer - to specialize inductor for cudagraph sizes no more than the - specified size. It is useful when we want to specialize inductor - with a subset of cudagraph sizes. - - inductor_compile_config: additional configurations for inductor. - - None: use default configurations. - - inductor_passes: additional passes for inductor. It is a dictionary - from pass name to pass function qualified name. We use function - name because the config uses json format. If we pass the config - from Python, functions can also be passed directly via Python object - constructor, e.g. `CompilationConfig(inductor_passes={"a": func})` - - Custom inductor passes: - - dump_graph_stages: list of stages for which we want to dump the graph. - Each pass defines its own stages (before, after, maybe in-between). - - dump_graph_dir: directory to dump the graph. Default is . - - enable_fusion: whether to enable the custom fusion pass. - TODO better pass enabling system. - - Why we have different sizes for cudagraph and inductor: - - cudagraph: a cudagraph captured for a specific size can only be used - for the same size. We need to capture all the sizes we want to use. - - inductor: a graph compiled by inductor for a general shape can be used - for different sizes. Inductor can also compile for specific sizes, - where it can have more information to optimize the graph with fully - static shapes. However, we find the general shape compilation is - sufficient for most cases. It might be beneficial to compile for - certain small batchsizes, where inductor is good at optimizing. - """ - use_inductor: bool = True - inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None - inductor_compile_sizes: Optional[List[int]] = Field(default_factory=dict) - inductor_compile_config: Dict = Field(default_factory=dict) - inductor_passes: Dict[str, str] = Field(default_factory=dict) - - use_cudagraph: bool = False - non_cudagraph_ops: List[str] = Field(default_factory=list) - cudagraph_num_of_warmups: int = 0 - cudagraph_capture_sizes: Optional[List[int]] = None - cudagraph_copy_inputs: bool = False - - dump_graph_stages: List[str] = Field(default_factory=list) - dump_graph_dir: Path = Field(default=Path(".")) - enable_fusion: bool = True - - # not configurable, computed after init - compile_sizes: List[int] = PrivateAttr - capture_sizes: List[int] = PrivateAttr - - def model_post_init(self, __context: Any) -> None: - for k, v in self.inductor_passes.items(): - if not isinstance(v, str): - assert callable(v), ( - f"pass {k} should be a function or a qualified name") - self.inductor_compile_config[k] = v - continue - - # resolve function from qualified name - names = v.split(".") - module = ".".join(names[:-1]) - func_name = names[-1] - func = __import__(module).__dict__[func_name] - self.inductor_compile_config[k] = func - - def init_during_runtime(self): - """To complete the initialization of config, - we need to know the compile context, which is only available - during the first run of the model. - """ - context = get_compile_context() - context = copy.deepcopy(context) if context is not None else [] - sizes_to_specialize: List[int] = context - if self.cudagraph_capture_sizes is None: - self.capture_sizes = sizes_to_specialize - else: - self.capture_sizes = self.cudagraph_capture_sizes - logger.info(("cudagraph sizes specified by model runner" - " %s is overridden by config %s"), - sizes_to_specialize, self.cudagraph_capture_sizes) - if self.inductor_specialize_for_cudagraph_no_more_than is not None: - assert self.inductor_compile_sizes is None, ( - "inductor_compile_sizes should be None when " - "inductor_specialize_for_cudagraph_no_more_than is not None") - self.compile_sizes = [ - x for x in self.capture_sizes - if x <= self.inductor_specialize_for_cudagraph_no_more_than - ] - else: - assert self.inductor_compile_sizes is not None, ( - "inductor_compile_sizes should not be None when " - "inductor_specialize_for_cudagraph_no_more_than is None") - self.compile_sizes = self.inductor_compile_sizes - - @staticmethod - def select_and_init_config() -> "CompilationConfig": - """The order of selecting config is: - 1. Use the config specified in environment variable. - 2. Use the config specified in plugins. - 3. Use the default config. - """ - config_path = envs.VLLM_TORCH_COMPILE_CONFIG - if config_path is not None: - with open(config_path) as json_file: - config = CompilationConfig.model_validate_json( - json_file.read()) - else: - from vllm.plugins import get_compilation_config - predefined_config = get_compilation_config() - config = predefined_config if predefined_config is not None else ( - CompilationConfig()) - - config.init_during_runtime() - return config diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index ca1e96a33c014..4b78491bc5a48 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -3,10 +3,8 @@ import torch -import vllm.envs as envs -from vllm.compilation.levels import CompilationLevel from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import VllmConfig +from vllm.config import CompilationLevel, VllmConfig from vllm.logger import init_logger from vllm.sequence import IntermediateTensors from vllm.utils import supports_dynamo @@ -126,12 +124,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner # will handle the compilation, so we don't need to do anything here. - self.do_not_compile = envs.VLLM_TORCH_COMPILE_LEVEL in [ + self.do_not_compile = \ + vllm_config.compilation_config.level in [ CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS ] or not supports_dynamo() if self.do_not_compile: return - TorchCompileWrapperWithCustomDispatcher.__init__(self) + TorchCompileWrapperWithCustomDispatcher.__init__( + self, compilation_level=vllm_config.compilation_config.level) cls.__init__ = __init__ # type: ignore diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index eb43604b1399b..e6a3afef85e1b 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -6,8 +6,8 @@ from torch._inductor.pattern_matcher import (Match, PatternMatcherPass, fwd_only, register_replacement) -from vllm.compilation.config import CompilationConfig from vllm.compilation.inductor_pass import InductorPass +from vllm.config import CompilationConfig from vllm.logger import init_logger logger = init_logger(__name__) diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index b23351fa19759..8082a08b40019 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -2,7 +2,7 @@ import torch -from vllm.compilation.config import CompilationConfig +from vllm.config import CompilationConfig # yapf: disable from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank from vllm.distributed import ( diff --git a/vllm/compilation/levels.py b/vllm/compilation/levels.py deleted file mode 100644 index 19a3a2b526870..0000000000000 --- a/vllm/compilation/levels.py +++ /dev/null @@ -1,8 +0,0 @@ -# constants for the levels of the compilation process - - -class CompilationLevel: - NO_COMPILATION = 0 - DYNAMO_AS_IS = 1 - DYNAMO_ONCE = 2 - PIECEWISE = 3 diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 7366ed4d16b0b..2a1aecc11ce26 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -8,8 +8,7 @@ import torch import vllm.envs as envs - -from .levels import CompilationLevel +from vllm.config import CompilationLevel class TorchCompileWrapperWithCustomDispatcher: @@ -25,7 +24,9 @@ class TorchCompileWrapperWithCustomDispatcher: `torch.compile` over the forward method. """ - def __init__(self, compiled_callable: Optional[Callable] = None): + def __init__(self, + compiled_callable: Optional[Callable] = None, + compilation_level: int = 0): if compiled_callable is None: # default compilation settings @@ -38,7 +39,7 @@ def __init__(self, compiled_callable: Optional[Callable] = None): backend = get_torch_compile_backend() if backend is None: from vllm.compilation.backends import select_default_backend - backend = select_default_backend(envs.VLLM_TORCH_COMPILE_LEVEL) + backend = select_default_backend(compilation_level) compiled_callable = torch.compile( self.forward, @@ -54,7 +55,7 @@ def __init__(self, compiled_callable: Optional[Callable] = None): # subclasses can use this to switch between the custom dispatcher # and the default Dynamo guard mechanism. self.use_custom_dispatcher: bool = \ - envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.DYNAMO_ONCE + compilation_level >= CompilationLevel.DYNAMO_ONCE def __call__(self, *args, **kwargs): """Implement the dispatch logic here, beyond the torch.compile level. diff --git a/vllm/config.py b/vllm/config.py index 64b2f75e092de..7e37edbe594b1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3,10 +3,12 @@ import json import warnings from dataclasses import dataclass, field, replace +from pathlib import Path from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Dict, Final, List, Literal, Mapping, Optional, Set, Tuple, Type, Union) import torch +from pydantic import BaseModel, Field, PrivateAttr from transformers import PretrainedConfig import vllm.envs as envs @@ -2052,6 +2054,185 @@ def __post_init__(self): f"installed. Original error:\n{otel_import_error_traceback}") +class CompilationLevel: + # constants for the levels of the compilation process + NO_COMPILATION = 0 + DYNAMO_AS_IS = 1 + DYNAMO_ONCE = 2 + PIECEWISE = 3 + + +class CompilationConfig(BaseModel): + """ + Configuration for compilation. + It has three parts: + - Top-level Compilation control: + - level: the level of compilation. + - 0: no compilation. + - 1: dynamo as is. + - 2: dynamo once. + - 3: piecewise compilation. + - custom_ops: fine-grained control over which custom ops to enable/disable. + Use 'all' to enable all, 'none' to disable all. + Also specify a list of custom op names to enable (prefixed with a '+'), + or disable (prefixed with a '-'). + Examples: + - 'all,-op1' to enable all except op1 + - 'none,+op1,+op2' to enable only op1 and op2 + By default, all custom ops are enabled when running without Inductor + and disabled when running with Inductor (compile_level >= Inductor). + - CudaGraph capture: + - use_cudagraph: whether to use cudagraph inside compilation. + - False: cudagraph inside compilation is not used. + - True: cudagraph inside compilation is used. It requires + that all input buffers have fixed addresses. + Note that this is orthogonal to the cudagraph capture out + side of compilation. + TODO: move outside cudagraph logic into compilation. + torch.compile will handle cudagraph capture logic in the future. + - cudagraph_capture_sizes: sizes to capture cudagraph. + - None: capture sizes are inferred from compilation context. + - List[int]: capture sizes are specified. + - cudagraph_num_of_warmups: number of warmup runs for cudagraph. + It means the first several runs will be treated as warmup runs. + Only after that, the execution will be recorded, and the recorded + cudagraph will be used for subsequent runs. + - cudagraph_copy_inputs: whether to copy input tensors for + cudagraph. If the caller can guarantee that the same input buffers + are always used, it can set this to False. Otherwise, it should + set this to True, and the compiler will copy the input to an + internally managed buffer. Default is False. + - Inductor compilation: + - use_inductor: whether to use inductor compilation. + - False: inductor compilation is not used. graph runs in eager. + - True: inductor compilation is used. one graph for symbolic shape + is compiled. In addition, compile for different sizes specified + in inductor_compile_sizes, using configurations + in inductor_compile_config. + - inductor_compile_sizes: sizes to compile for inductor. + - inductor_specialize_for_cudagraph_no_more_than: an optional integer + to specialize inductor for cudagraph sizes no more than the + specified size. It is useful when we want to specialize inductor + with a subset of cudagraph sizes. + - inductor_compile_config: additional configurations for inductor. + - None: use default configurations. + - inductor_passes: additional passes for inductor. It is a dictionary + from pass name to pass function qualified name. We use function + name because the config uses json format. If we pass the config + from Python, functions can also be passed directly via Python object + constructor, e.g. `CompilationConfig(inductor_passes={"a": func})` + - custom inductor passes: + - dump_graph_stages: list of stages for which we want to dump the graph. + Each pass defines its own stages (before, after, maybe in-between). + - dump_graph_dir: directory to dump the graph. Default is . + - enable_fusion: whether to enable the custom fusion pass. + TODO better pass enabling system. + + Why we have different sizes for cudagraph and inductor: + - cudagraph: a cudagraph captured for a specific size can only be used + for the same size. We need to capture all the sizes we want to use. + - inductor: a graph compiled by inductor for a general shape can be used + for different sizes. Inductor can also compile for specific sizes, + where it can have more information to optimize the graph with fully + static shapes. However, we find the general shape compilation is + sufficient for most cases. It might be beneficial to compile for + certain small batchsizes, where inductor is good at optimizing. + """ # noqa + level: int = 0 + custom_ops: List[str] = Field(default_factory=list) + + use_inductor: bool = True + inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None + inductor_compile_sizes: Optional[List[int]] = Field(default_factory=dict) + inductor_compile_config: Dict = Field(default_factory=dict) + inductor_passes: Dict[str, str] = Field(default_factory=dict) + + use_cudagraph: bool = False + non_cudagraph_ops: List[str] = Field(default_factory=list) + cudagraph_num_of_warmups: int = 0 + cudagraph_capture_sizes: Optional[List[int]] = None + cudagraph_copy_inputs: bool = False + + dump_graph_stages: List[str] = Field(default_factory=list) + dump_graph_dir: Path = Field(default=Path(".")) + enable_fusion: bool = True + + # not configurable, computed after init + compile_sizes: List[int] = PrivateAttr + capture_sizes: List[int] = PrivateAttr + + def model_post_init(self, __context: Any) -> None: + self.level = envs.VLLM_TORCH_COMPILE_LEVEL + + count_none = self.custom_ops.count("none") + count_all = self.custom_ops.count("all") + assert count_none + count_all <= 1, "Can only specify 'none' or 'all'" + + for k, v in self.inductor_passes.items(): + if not isinstance(v, str): + assert callable(v), ( + f"pass {k} should be a function or a qualified name") + self.inductor_compile_config[k] = v + continue + + # resolve function from qualified name + names = v.split(".") + module = ".".join(names[:-1]) + func_name = names[-1] + func = __import__(module).__dict__[func_name] + self.inductor_compile_config[k] = func + + def init_during_runtime(self): + """To complete the initialization of config, + we need to know the compile context, which is only available + during the first run of the model. + """ + from vllm.compilation.compile_context import get_compile_context + context = get_compile_context() + context = copy.deepcopy(context) if context is not None else [] + sizes_to_specialize: List[int] = context + if self.cudagraph_capture_sizes is None: + self.capture_sizes = sizes_to_specialize + else: + self.capture_sizes = self.cudagraph_capture_sizes + logger.info(("cudagraph sizes specified by model runner" + " %s is overridden by config %s"), + sizes_to_specialize, self.cudagraph_capture_sizes) + if self.inductor_specialize_for_cudagraph_no_more_than is not None: + assert self.inductor_compile_sizes is None, ( + "inductor_compile_sizes should be None when " + "inductor_specialize_for_cudagraph_no_more_than is not None") + self.compile_sizes = [ + x for x in self.capture_sizes + if x <= self.inductor_specialize_for_cudagraph_no_more_than + ] + else: + assert self.inductor_compile_sizes is not None, ( + "inductor_compile_sizes should not be None when " + "inductor_specialize_for_cudagraph_no_more_than is None") + self.compile_sizes = self.inductor_compile_sizes + + @staticmethod + def select_and_init_config() -> "CompilationConfig": + """The order of selecting config is: + 1. Use the config specified in environment variable. + 2. Use the config specified in plugins. + 3. Use the default config. + """ + config_path = envs.VLLM_TORCH_COMPILE_CONFIG + if config_path is not None: + with open(config_path) as json_file: + config = CompilationConfig.model_validate_json( + json_file.read()) + else: + from vllm.plugins import get_compilation_config + predefined_config = get_compilation_config() + config = predefined_config if predefined_config is not None else ( + CompilationConfig()) + + return config + + @dataclass class VllmConfig: """Dataclass which contains all vllm-related configuration. This @@ -2073,6 +2254,8 @@ class VllmConfig: observability_config: Optional[ObservabilityConfig] = None prompt_adapter_config: Optional[PromptAdapterConfig] = None quant_config: Optional[QuantizationConfig] = None + compilation_config: CompilationConfig = field(default=None, + init=True) # type: ignore @staticmethod def _get_quantization_config( @@ -2133,6 +2316,12 @@ def __post_init__(self): self.quant_config = VllmConfig._get_quantization_config( self.model_config, self.load_config) + if self.compilation_config is None: + self.compilation_config = CompilationConfig.select_and_init_config( + ) + + current_platform.check_and_update_config(self) + def __str__(self): return ("model=%r, speculative_config=%r, tokenizer=%r, " "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " diff --git a/vllm/envs.py b/vllm/envs.py index f320e35971f94..716e835a555f1 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -69,7 +69,6 @@ VLLM_SKIP_P2P_CHECK: bool = False VLLM_TORCH_COMPILE_LEVEL: int = 0 VLLM_TORCH_COMPILE_CONFIG: Optional[str] = None - VLLM_CUSTOM_OPS: List[str] = [] VLLM_DISABLED_KERNELS: List[str] = [] VLLM_USE_V1: bool = False VLLM_ENABLE_V1_MULTIPROCESSING: bool = False @@ -217,18 +216,6 @@ def get_default_config_root(): "VLLM_TORCH_COMPILE_CONFIG": lambda: os.environ.get("VLLM_TORCH_COMPILE_CONFIG", None), - # Fine-grained control over which custom ops to enable/disable. - # Use 'all' to enable all, 'none' to disable all. - # Also specify a list of custom op names to enable (prefixed with a '+'), - # or disable (prefixed with a '-'). - # Examples: - # - 'all,-op1' to enable all except op1 - # - 'none,+op1,+op2' to enable only op1 and op2 - # By default, all custom ops are enabled when running without Inductor - # and disabled when running with Inductor (compile_level >= Inductor). - "VLLM_CUSTOM_OPS": - lambda: os.environ.get("VLLM_CUSTOM_OPS", "").replace(" ", "").split(","), - # local rank of the process in the distributed setting, used to determine # the GPU device id "LOCAL_RANK": diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 24d75f4df4e02..6ae7d7cf6964f 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -1,12 +1,10 @@ -from functools import lru_cache from typing import Dict, Type import torch.nn as nn -import vllm.envs as envs -from vllm.compilation.levels import CompilationLevel from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.plugins import get_current_vllm_config from vllm.utils import print_warning_once logger = init_logger(__name__) @@ -87,6 +85,8 @@ def dispatch_forward(self): @classmethod def enabled(cls) -> bool: # if no name, then it was not registered + compilation_config = get_current_vllm_config().compilation_config + custom_ops = compilation_config.custom_ops if not hasattr(cls, "name"): print_warning_once( f"Custom op {cls.__name__} was not registered, " @@ -94,22 +94,25 @@ def enabled(cls) -> bool: f"It will be enabled/disabled based on the global settings.") return CustomOp.default_on() - enabled = f"+{cls.name}" in envs.VLLM_CUSTOM_OPS - disabled = f"-{cls.name}" in envs.VLLM_CUSTOM_OPS + enabled = f"+{cls.name}" in custom_ops + disabled = f"-{cls.name}" in custom_ops assert not (enabled and disabled), f"Cannot enable and disable {cls.name}" return (CustomOp.default_on() or enabled) and not disabled - # On by default if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE - # Specifying 'all' or 'none' in VLLM_CUSTOM_OPS takes precedence. @staticmethod - @lru_cache def default_on() -> bool: - count_none = envs.VLLM_CUSTOM_OPS.count("none") - count_all = envs.VLLM_CUSTOM_OPS.count("all") - assert count_none + count_all <= 1, "Can only specify 'none' or 'all'" - return envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE and \ + """ + On by default if level < CompilationLevel.PIECEWISE + Specifying 'all' or 'none' in custom_op takes precedence. + """ + from vllm.config import CompilationLevel + compilation_config = get_current_vllm_config().compilation_config + custom_ops = compilation_config.custom_ops + count_none = custom_ops.count("none") + count_all = custom_ops.count("all") + return compilation_config.level < CompilationLevel.PIECEWISE and \ not count_none > 0 or count_all > 0 # Dictionary of all custom ops (classes, indexed by registered name). diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 140b61fe6d56a..0f8b81c3ef40c 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -42,6 +42,7 @@ safetensors_weights_iterator) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform +from vllm.plugins import set_current_vllm_config from vllm.utils import is_pin_memory_available @@ -97,7 +98,8 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module: all_params = [param.name for param in signatures.parameters.values()] if "vllm_config" in all_params and "prefix" in all_params: # new-style model class - return model_class(vllm_config=vllm_config, prefix=prefix) + with set_current_vllm_config(vllm_config): + return model_class(vllm_config=vllm_config, prefix=prefix) msg = ("vLLM model class should accept `vllm_config` and `prefix` as " "input arguments. Possibly you have an old-style model class" " registered from out of tree and it is used for new vLLM version. " @@ -121,7 +123,8 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module: kwargs["lora_config"] = vllm_config.lora_config if "scheduler_config" in all_params: kwargs["scheduler_config"] = vllm_config.scheduler_config - return model_class(**kwargs) + with set_current_vllm_config(vllm_config): + return model_class(**kwargs) class BaseModelLoader(ABC): diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 81d8bdae2383c..970c0d1be617e 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -1,10 +1,15 @@ import enum import random -from typing import NamedTuple, Optional, Tuple, Union +from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union import numpy as np import torch +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = None + class PlatformEnum(enum.Enum): CUDA = enum.auto() @@ -129,6 +134,19 @@ def seed_everything(cls, seed: int) -> None: np.random.seed(seed) torch.manual_seed(seed) + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + """ + Check and update the configuration for the current platform. + + It can raise an exception if the configuration is not compatible with + the current platform, or it can update the configuration to make it + compatible with the current platform. + + The config is passed by reference, so it can be modified in place. + """ + pass + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 8d0ce47df4040..c2e22bfc09f22 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -1,18 +1,16 @@ import os +from typing import TYPE_CHECKING import torch -import vllm.envs as envs -from vllm.compilation.levels import CompilationLevel from vllm.plugins import set_torch_compile_backend from .interface import Platform, PlatformEnum -if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ: - os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.DYNAMO_ONCE) - -assert envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE,\ - "TPU does not support Inductor." +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = None set_torch_compile_backend("openxla") @@ -31,3 +29,12 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: @classmethod def inference_mode(cls): return torch.no_grad() + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + from vllm.config import CompilationLevel + compilation_config = vllm_config.compilation_config + if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ: + compilation_config.level = CompilationLevel.DYNAMO_ONCE + assert compilation_config.level < CompilationLevel.PIECEWISE,\ + "TPU does not support Inductor." diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 7b1bbb14c5302..c20b9ec891d5d 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -1,11 +1,11 @@ import logging +from contextlib import contextmanager from typing import TYPE_CHECKING, Callable, Optional, Union import vllm.envs as envs if TYPE_CHECKING: - from vllm.compilation.config import CompilationConfig - from vllm.config import VllmConfig + from vllm.config import CompilationConfig, VllmConfig else: CompilationConfig = None VllmConfig = None @@ -72,3 +72,29 @@ def set_compilation_config(config: Optional[CompilationConfig]): def get_compilation_config() -> Optional[CompilationConfig]: return _compilation_config + + +_current_vllm_config: Optional[VllmConfig] = None + + +@contextmanager +def set_current_vllm_config(vllm_config: VllmConfig): + """ + Temporarily set the current VLLM config. + Used during model initialization. + We save the current VLLM config in a global variable, + so that all modules can access it, e.g. custom ops + can access the VLLM config to determine how to dispatch. + """ + global _current_vllm_config + old_vllm_config = _current_vllm_config + try: + _current_vllm_config = vllm_config + yield + finally: + _current_vllm_config = old_vllm_config + + +def get_current_vllm_config() -> VllmConfig: + assert _current_vllm_config is not None, "Current VLLM config is not set." + return _current_vllm_config diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index eebd1de96537f..d60f93a44f6dd 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,4 +1,3 @@ -import os import time from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple @@ -8,11 +7,8 @@ import torch.distributed import torch.nn as nn -from vllm import envs from vllm.compilation.compile_context import set_compile_context -from vllm.compilation.config import CompilationConfig -from vllm.compilation.levels import CompilationLevel -from vllm.config import VllmConfig +from vllm.config import CompilationConfig, CompilationLevel, VllmConfig from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger @@ -99,7 +95,7 @@ def __init__( pin_memory=self.pin_memory, ) - self.use_cuda_graph = (envs.VLLM_TORCH_COMPILE_LEVEL + self.use_cuda_graph = (self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager) # TODO(woosuk): Provide an option to tune the max cudagraph batch size. @@ -517,9 +513,9 @@ def load_model(self) -> None: # CUDA graphs do not work properly with the custom CUDA kernels. # FIXME(woosuk): Disable inductor to reduce the compilation time # and avoid any potential issues with the inductor. - os.environ["VLLM_CUSTOM_OPS"] = "none" set_compilation_config( CompilationConfig( + custom_ops=["none"], use_cudagraph=True, non_cudagraph_ops=["vllm.unified_v1_flash_attention"], use_inductor=True, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 22ee3f9f863e4..fd89f95445565 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -19,8 +19,7 @@ from vllm.attention.backends.abstract import AttentionState from vllm.attention.backends.utils import CommonAttentionState from vllm.compilation.compile_context import set_compile_context -from vllm.compilation.levels import CompilationLevel -from vllm.config import VllmConfig +from vllm.config import CompilationLevel, VllmConfig from vllm.core.scheduler import SchedulerOutputs from vllm.distributed import get_pp_group from vllm.distributed.parallel_state import graph_capture @@ -1142,8 +1141,8 @@ def load_model(self) -> None: "provided. Defaulting to scaling factors of 1.0. " "This may lead to less accurate results!") - if envs.VLLM_TORCH_COMPILE_LEVEL == CompilationLevel.DYNAMO_AS_IS \ - and supports_dynamo(): + if self.vllm_config.compilation_config.level ==\ + CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): from vllm.plugins import get_torch_compile_backend backend = get_torch_compile_backend() or "eager" self.model = torch.compile( diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index a721186137328..d7a641857a613 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -140,7 +140,7 @@ def load_model(self) -> None: model = get_model(vllm_config=self.vllm_config) model = model.eval() xm.wait_device_ops() - self.model = ModelWrapper(model) + self.model = ModelWrapper(model, self.vllm_config) def _dummy_run( self, @@ -669,13 +669,15 @@ def execute_model( class ModelWrapper(TorchCompileWrapperWithCustomDispatcher): - def __init__(self, model: nn.Module): + def __init__(self, model: nn.Module, vllm_config: VllmConfig): self.model = model compiled_callable = torch.compile(self.forward, backend="openxla", fullgraph=True, dynamic=False) - super().__init__(compiled_callable) + super().__init__( + compiled_callable, + compilation_level=vllm_config.compilation_config.level) def __call__(self, *args, is_prompt: bool, **kwargs): if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher: From 643ecf7b11a3e74c838f438cfc1b3e59c018853b Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Sat, 16 Nov 2024 21:18:46 -0800 Subject: [PATCH 5/5] [V1] Refactor model executable interface for all text-only language models (#10374) Signed-off-by: Roger Wang --- vllm/model_executor/models/arctic.py | 16 ++++++++++++++-- vllm/model_executor/models/baichuan.py | 16 ++++++++++++++-- vllm/model_executor/models/bloom.py | 17 ++++++++++++++--- vllm/model_executor/models/commandr.py | 16 ++++++++++++++-- vllm/model_executor/models/dbrx.py | 16 ++++++++++++++-- vllm/model_executor/models/deepseek.py | 16 ++++++++++++++-- vllm/model_executor/models/deepseek_v2.py | 16 ++++++++++++++-- vllm/model_executor/models/eagle.py | 13 ++++++++++--- vllm/model_executor/models/exaone.py | 7 ++++++- vllm/model_executor/models/falcon.py | 16 ++++++++++++++-- vllm/model_executor/models/gemma.py | 7 ++++++- vllm/model_executor/models/gemma2.py | 12 ++++++++++-- vllm/model_executor/models/gpt2.py | 7 +++++-- vllm/model_executor/models/gpt_bigcode.py | 17 +++++++++++++---- vllm/model_executor/models/gpt_j.py | 16 ++++++++++++++-- vllm/model_executor/models/gpt_neox.py | 16 ++++++++++++++-- vllm/model_executor/models/granite.py | 7 ++++++- vllm/model_executor/models/granitemoe.py | 16 ++++++++++++++-- vllm/model_executor/models/internlm2.py | 9 +++++++-- vllm/model_executor/models/jais.py | 14 ++++++++++++-- vllm/model_executor/models/jamba.py | 16 ++++++++++++++-- vllm/model_executor/models/mamba.py | 15 +++++++++++++-- vllm/model_executor/models/minicpm.py | 7 ++++++- vllm/model_executor/models/mixtral.py | 16 ++++++++++++++-- vllm/model_executor/models/mixtral_quant.py | 16 ++++++++++++++-- vllm/model_executor/models/mpt.py | 16 ++++++++++++++-- vllm/model_executor/models/nemotron.py | 7 ++++++- vllm/model_executor/models/olmo.py | 19 +++++++++++++------ vllm/model_executor/models/olmoe.py | 16 ++++++++++++++-- vllm/model_executor/models/orion.py | 16 ++++++++++++++-- vllm/model_executor/models/persimmon.py | 8 +++++++- vllm/model_executor/models/phi.py | 16 ++++++++++++++-- vllm/model_executor/models/phi3_small.py | 19 +++++++++++-------- vllm/model_executor/models/phimoe.py | 16 ++++++++++++++-- vllm/model_executor/models/qwen.py | 16 ++++++++++++++-- vllm/model_executor/models/qwen2.py | 2 +- vllm/model_executor/models/qwen2_cls.py | 7 ++++++- vllm/model_executor/models/qwen2_moe.py | 16 ++++++++++++++-- vllm/model_executor/models/qwen2_rm.py | 7 ++++++- vllm/model_executor/models/solar.py | 4 +++- vllm/model_executor/models/stablelm.py | 16 ++++++++++++++-- vllm/model_executor/models/starcoder2.py | 16 ++++++++++++++-- vllm/model_executor/models/xverse.py | 16 ++++++++++++++-- 43 files changed, 483 insertions(+), 90 deletions(-) diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 9ee2a2cc09a24..d52418ee0f6f1 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -389,6 +389,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -396,9 +399,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -439,6 +446,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -446,9 +456,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index aabbd31192a40..01ce7c42cd391 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -284,6 +284,9 @@ def __init__( make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -291,9 +294,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None @@ -363,6 +370,9 @@ def __init__( self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -370,9 +380,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 84adf574af5e2..cf2eee8172769 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -251,6 +251,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.word_embeddings_layernorm(self.word_embeddings(input_ids)) + def forward( self, input_ids: torch.Tensor, @@ -258,10 +261,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.word_embeddings(input_ids) - hidden_states = self.word_embeddings_layernorm(hidden_states) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -301,6 +307,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -308,9 +317,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index cd5c1d6844716..fbb09a64cde9b 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -280,6 +280,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -287,9 +290,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None @@ -354,6 +361,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + @torch.no_grad() def forward( self, @@ -362,9 +372,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index fff8710f6b475..3952ff31e5cec 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -321,6 +321,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.d_model)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -328,9 +331,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.wte(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: assert intermediate_tensors hidden_states = intermediate_tensors["hidden_states"] @@ -376,6 +383,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -383,9 +393,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index a9bf1440c4d60..36dfea5a65656 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -353,6 +353,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -360,9 +363,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None else: hidden_states = intermediate_tensors["hidden_states"] @@ -401,6 +408,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -408,9 +418,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 4fb1eed15a2e7..1e32fe60c7a5b 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -445,6 +445,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -452,9 +455,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None @@ -495,6 +502,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -502,9 +512,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index 85c51e8404584..f138d13630263 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -78,6 +78,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def sampler(self): return self.model.sampler + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -86,11 +89,14 @@ def forward( attn_metadata: AttentionMetadata, previous_hidden_states: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - tok_embeds = self.model.model.embed_tokens(input_ids) + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) + inputs_embeds = self.fc( - torch.cat([tok_embeds, previous_hidden_states], dim=-1)) + torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) inputs_embeds[positions == 0] = 0 # masking inputs at position=0 @@ -100,7 +106,8 @@ def forward( positions=positions, kv_caches=kv_caches, attn_metadata=attn_metadata, - intermediate_tensors=intermediate_tensors) + intermediate_tensors=intermediate_tensors, + ) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index cd3e7da657e0e..52dd603ca558d 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -479,6 +479,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -486,9 +489,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: model_output = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return model_output def compute_logits( diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index b3dbf063ac298..e97abe949ccdb 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -367,6 +367,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.word_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -374,9 +377,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.word_embeddings(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: hidden_states = intermediate_tensors["hidden_states"] for i in range(self.start_layer, self.end_layer): @@ -432,6 +439,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.LongTensor, @@ -439,9 +449,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 55baba809e58f..ace13664c6ea6 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -390,6 +390,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -397,9 +400,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index eeb3fd98a7eac..a60b4e73a76d4 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -272,6 +272,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: Optional[torch.Tensor], @@ -285,7 +288,7 @@ def forward( if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.embed_tokens(input_ids) + hidden_states = self.get_input_embeddings(input_ids) hidden_states *= self.normalizer residual = None else: @@ -414,6 +417,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -421,9 +427,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index cc85693f99526..fa0fdad28d161 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -209,6 +209,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.n_embd)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -220,7 +223,7 @@ def forward( ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) + inputs_embeds = self.get_input_embeddings(input_ids) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds else: @@ -262,7 +265,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.transformer.make_empty_intermediate_tensors) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.transformer.wte(input_ids) + return self.transformer.get_input_embeddings(input_ids) def forward( self, diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index ab25c66c3a887..b2fc79d0d36dc 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -218,6 +218,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.n_embd)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -225,11 +228,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - inputs_embeds = self.wte(input_ids) - position_embeds = self.wpe(position_ids) - hidden_states = inputs_embeds + position_embeds + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) + hidden_states = inputs_embeds + self.wpe(position_ids) else: hidden_states = intermediate_tensors["hidden_states"] @@ -285,6 +289,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -292,9 +299,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index a83d03480dde1..cec3fd12a67d6 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -201,6 +201,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.n_embd)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -208,9 +211,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.wte(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: hidden_states = intermediate_tensors["hidden_states"] for i in range(self.start_layer, self.end_layer): @@ -250,6 +257,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -257,9 +267,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 794b141bfa4aa..11f286d6bcba0 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -214,6 +214,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_in(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -221,9 +224,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_in(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: hidden_states = intermediate_tensors["hidden_states"] for i in range(self.start_layer, self.end_layer): @@ -262,6 +269,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.gpt_neox.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.gpt_neox.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -269,9 +279,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.gpt_neox(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index d1e6e31f2b8d1..cb2583e69d88d 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -409,6 +409,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.lm_head = PPMissingLayer() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -416,9 +419,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return model_output def compute_logits( diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 2ed115c56af45..f437dd521a7d5 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -277,6 +277,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -284,9 +287,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) hidden_states *= self.embedding_multiplier residual = None else: @@ -366,6 +373,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.sampler = get_sampler() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -373,9 +383,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 21fa6983063b8..19bfe16e4d5fc 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -290,7 +290,7 @@ def forward( if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.tok_embeddings(input_ids) + hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None @@ -335,6 +335,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -342,9 +345,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 65800c44e5a93..ee49ffb3cd87f 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -250,6 +250,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.n_embd)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -257,9 +260,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[IntermediateTensors, torch.Tensor]: if get_pp_group().is_first_rank: - inputs_embeds = self.wte(input_ids) + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) if self.wpe is not None: position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds @@ -311,6 +316,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -318,9 +326,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[IntermediateTensors, torch.Tensor]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 88fb8d5cf555a..5612dd6886385 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -292,6 +292,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -299,8 +302,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, mamba_cache_params: MambaCacheParams, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None for i in range(len(self.layers)): layer = self.layers[i] @@ -381,12 +388,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.vocab_size) self.sampler = get_sampler() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[KVCache], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs): if self.mamba_cache is None: max_batch_size = (_get_graph_batch_size( @@ -409,7 +420,8 @@ def forward(self, mamba_cache_tensors[1], state_indices_tensor) hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, mamba_cache_params) + attn_metadata, mamba_cache_params, + inputs_embeds) return hidden_states def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 55c575e22a0f6..ac0d265a961f0 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -106,15 +106,22 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, attn_metadata: AttentionMetadata, mamba_cache_params: MambaCacheParams, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.embeddings(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None for i in range(len(self.layers)): @@ -168,12 +175,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.vocab_size) self.sampler = get_sampler() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.backbone.get_input_embeddings(input_ids) + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[KVCache], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs): if self.mamba_cache is None: max_batch_size = (_get_graph_batch_size( @@ -194,7 +205,7 @@ def forward(self, state_indices_tensor) hidden_states = self.backbone(input_ids, positions, attn_metadata, - mamba_cache_params) + mamba_cache_params, inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 2db953329fd91..6b67266c53362 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -504,6 +504,9 @@ def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = MiniCPMModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -511,9 +514,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 3eb2f60fd4fc7..eebf5bab5a288 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -281,6 +281,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -288,9 +291,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None @@ -363,6 +370,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -370,9 +380,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 95cfb6f54dc10..af2e9586988df 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -318,6 +318,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -325,9 +328,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None @@ -368,6 +375,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -375,9 +385,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index e15c0fe8db060..3c74ef2448abb 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -237,6 +237,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.d_model)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -244,9 +247,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.wte(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -283,6 +290,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -290,9 +300,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index e09d7088a69ce..eb45beae7d21a 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -440,6 +440,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -447,9 +450,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return model_output def compute_logits( diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 3467ae5896494..98d4e1ec320a4 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -248,6 +248,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -255,17 +258,16 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: """ :param input_ids: A tensor of shape `(batch_size, seq_len)`. """ if get_pp_group().is_first_rank: - # Get embeddings of input. - # shape: (batch_size, seq_len, d_model) - inputs_embeds = self.embed_tokens(input_ids) - - # embed positions - hidden_states = inputs_embeds + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -315,6 +317,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -322,6 +327,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model( input_ids=input_ids, @@ -329,6 +335,7 @@ def forward( kv_caches=kv_caches, attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, ) return hidden_states diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index 3d31919edd862..f4eebab8c98dd 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -269,6 +269,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -276,9 +279,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None @@ -326,6 +333,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -333,9 +343,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 38821c8288347..39d659c49cbcf 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -237,6 +237,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): "hidden_states", ], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -244,9 +247,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -286,6 +293,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -293,9 +303,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index 2e34a7cc30873..62c509153a111 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -235,6 +235,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -248,7 +251,7 @@ def forward( if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.embed_tokens(input_ids) + hidden_states = self.get_input_embeddings(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -282,6 +285,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 262f6996fc374..a2ab0d74c48db 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -218,6 +218,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -225,9 +228,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -303,6 +310,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -310,9 +320,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index 8a5fb6d303e60..2139cec441807 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -324,11 +324,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) def forward( self, @@ -337,9 +334,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor], ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) if (self.mup_embedding_multiplier is not None and self.mup_embedding_multiplier > 0.0): hidden_states = hidden_states * self.mup_embedding_multiplier @@ -397,8 +398,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.dummy_token_indices = None - def get_input_embeddings(self): - return self.model.embed_tokens + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) def set_input_embeddings(self, value): self.model.embed_tokens = value @@ -433,6 +434,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: output_hidden_states = self.model( input_ids=input_ids, @@ -440,6 +442,7 @@ def forward( kv_caches=kv_caches, attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, ) output_hidden_states = output_hidden_states return output_hidden_states diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 6d71a8949111b..b7e70f8fa2c6d 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -465,6 +465,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -472,9 +475,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None @@ -560,6 +567,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -567,9 +577,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 3d26ede722dd1..447632cefcd9a 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -578,6 +578,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config=quant_config) if hasattr( config, "visual") else None + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -586,6 +589,7 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], pixel_values: Optional[QwenImageInputs], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: img_pos = None # If pixel / visual embeddings are provided, this is a visual model @@ -606,6 +610,10 @@ def forward( ) if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.wte(input_ids) # Merge the image embeddings into the hidden states if actually have # visual features and the corresponding image tokens @@ -915,6 +923,9 @@ def _get_image_input_type( ) return None + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -922,7 +933,8 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - pixel_values: Optional[torch.Tensor] = None + pixel_values: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if intermediate_tensors is not None: input_ids = None @@ -932,7 +944,7 @@ def forward( hidden_states = self.transformer(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors, - pixel_values) + pixel_values, inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 431e397e1e10d..8f10df808c216 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -309,7 +309,7 @@ def forward( if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.embed_tokens(input_ids) + hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None diff --git a/vllm/model_executor/models/qwen2_cls.py b/vllm/model_executor/models/qwen2_cls.py index 120403e948686..07eb330620a43 100644 --- a/vllm/model_executor/models/qwen2_cls.py +++ b/vllm/model_executor/models/qwen2_cls.py @@ -72,6 +72,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): normalize=False, softmax=True) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -79,9 +82,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) logits, _ = self.score(hidden_states) return logits diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 51c0cd5664fd2..249d94b5d95e9 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -344,6 +344,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -351,9 +354,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None @@ -395,6 +402,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -402,9 +412,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index 55843d8325348..6db467af334f5 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -85,6 +85,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -92,9 +95,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) logits, _ = self.score(hidden_states) return logits diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index 4f03ca501fb68..affb2c975ce4a 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -456,9 +456,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return model_output def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 1125f9e9f9617..99acce596602e 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -218,6 +218,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -225,9 +228,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -265,6 +272,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -272,9 +282,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index ce7a7957f52c4..0ef940acebb93 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -221,6 +221,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -228,9 +231,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -273,6 +280,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -280,9 +290,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 153527da20d75..51172d8782a70 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -252,6 +252,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -259,9 +262,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) residual = None else: hidden_states = intermediate_tensors["hidden_states"] @@ -335,6 +342,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -342,9 +352,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits(