From d08b78b50c94239beca3701d286c6d6202b44bd9 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Wed, 18 Dec 2024 17:13:18 -0500 Subject: [PATCH] Properly initializing the new field in the attn metadata (#337) --- tests/kernels/utils.py | 2 ++ tests/worker/test_model_input.py | 3 +++ vllm/attention/backends/abstract.py | 5 ++--- vllm/attention/backends/blocksparse_attn.py | 2 ++ vllm/attention/backends/flash_attn.py | 3 +++ vllm/attention/backends/flashinfer.py | 2 ++ vllm/attention/backends/placeholder_attn.py | 3 +++ vllm/attention/backends/rocm_flash_attn.py | 2 ++ vllm/attention/backends/torch_sdpa.py | 1 + vllm/attention/backends/utils.py | 2 ++ vllm/attention/backends/xformers.py | 2 ++ vllm/worker/hpu_model_runner.py | 7 +++++-- vllm/worker/model_runner.py | 2 -- vllm/worker/model_runner_base.py | 3 +-- vllm/worker/openvino_model_runner.py | 1 + vllm/worker/tpu_model_runner.py | 5 +++++ vllm/worker/xpu_model_runner.py | 2 ++ 17 files changed, 38 insertions(+), 9 deletions(-) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index eb4afc8b3ae36..0fb22638374a4 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -914,6 +914,7 @@ def make_test_metadata( num_prefills=num_prefills, slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping), multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, @@ -963,6 +964,7 @@ def make_test_metadata( num_prefills=num_prefills, slot_mapping=kv_mmap.slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py index 309854e6babf3..57f1fd47a600f 100644 --- a/tests/worker/test_model_input.py +++ b/tests/worker/test_model_input.py @@ -74,6 +74,7 @@ def test_model_runner_input(): num_decode_tokens=3, slot_mapping=torch.zeros(1), multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, ) model_input = ModelInputForGPUWithSamplingMetadata( input_tokens=torch.ones(10), @@ -126,6 +127,7 @@ def test_embedding_model_runner_input(): num_decode_tokens=3, slot_mapping=torch.zeros(1), multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, ) model_input = ModelInputForGPUWithPoolingMetadata( input_tokens=torch.ones(10), @@ -177,6 +179,7 @@ def test_multi_step_model_runner_input(): num_decode_tokens=3, slot_mapping=torch.zeros(1), multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, ) frozen_model_input = ModelInputForGPUWithSamplingMetadata( input_tokens=torch.ones(10), diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 0af04224d853d..3127285df6c63 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from contextlib import contextmanager -from dataclasses import dataclass, field, fields +from dataclasses import dataclass, fields from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set, Tuple, Type, TypeVar) @@ -126,8 +126,7 @@ class AttentionMetadata: # Enable/disable KV scales calculation. This is so that we can disable the # calculation until after prefill and cuda graph capture. - enable_kv_scales_calculation: bool = field(init=False, - default_factory=lambda: True) + enable_kv_scales_calculation: bool @property @abstractmethod diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 6c34a0327c2d4..1869dbab0cbf5 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -222,6 +222,7 @@ def prefill_metadata( slot_mapping=self.slot_mapping[:self.num_prefill_tokens], multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, @@ -251,6 +252,7 @@ def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]: num_decode_tokens=self.num_decode_tokens, slot_mapping=self.slot_mapping[self.num_prefill_tokens:], multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], max_query_len=None, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 9c89088795207..c640de998149e 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -224,6 +224,7 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_query_len=self.max_query_len, @@ -268,6 +269,7 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: num_decode_tokens=self.num_decode_tokens, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, seq_lens=None, seq_lens_tensor=seq_lens_tensor, max_decode_query_len=self.max_decode_query_len, @@ -550,6 +552,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=True, seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, max_decode_query_len=max_decode_query_len, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 91b2e6dfc070b..2285f20c1c8af 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -218,6 +218,7 @@ def graph_capture_get_metadata_for_batch( num_prefills=0, slot_mapping=self._graph_slot_mapping[:batch_size], multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, num_prefill_tokens=0, num_decode_tokens=batch_size, max_prefill_seq_len=0, @@ -711,6 +712,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, max_prefill_seq_len=max_prefill_seq_len, diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index 534f79b3a60bf..d2dc0d6cf0a5f 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -140,6 +140,7 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], max_decode_query_len=0, @@ -173,6 +174,7 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: num_decode_tokens=self.num_decode_tokens, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], max_decode_query_len=self.max_decode_query_len, @@ -378,6 +380,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_prefills=self.num_prefills, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=True, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 228f9c9c3dd65..efaa74f67bafd 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -165,6 +165,7 @@ def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: slot_mapping=self.slot_mapping[:self.num_prefill_tokens], multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, @@ -202,6 +203,7 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: num_decode_tokens=self.num_decode_tokens, slot_mapping=self.slot_mapping[self.num_prefill_tokens:], multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], max_query_len=None, diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index d361ff8cdff82..8e586c00024d5 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -372,6 +372,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], prefill_block_tables=prefill_block_tables, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, ) return attn_metadata diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 03d5d84141d35..d7d4a5166975c 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -274,6 +274,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=True, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, @@ -326,6 +327,7 @@ def graph_capture_get_metadata_for_batch( num_decode_tokens=batch_size, slot_mapping=self._graph_slot_mapping[:batch_size], multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, seq_lens=None, seq_lens_tensor=self._graph_seq_lens[:batch_size], max_query_len=1, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 5192c07d3e495..dcd60da43d520 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -217,6 +217,7 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_query_len=self.max_query_len, @@ -261,6 +262,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: num_decode_tokens=self.num_decode_tokens, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, seq_lens_tensor=seq_lens_tensor, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 9d479f412af46..9d49c6d0a0983 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -892,7 +892,8 @@ def _prepare_prompt( num_decode_tokens=0, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps= - None # FIXME(kzawora): mutli-modality will not work here + None, # FIXME(kzawora): mutli-modality will not work here + enable_kv_scales_calculation=False, ) multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) @@ -1046,7 +1047,9 @@ def _prepare_decode( num_prefill_tokens=0, num_decode_tokens=num_decode_tokens, slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None) + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + ) return PrepareDecodeMetadata(input_tokens=input_tokens, input_positions=input_positions, attn_metadata=attn_metadata, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e33a51887d827..ed93a87a0697d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -174,8 +174,6 @@ def from_broadcasted_tensor_dict( if attn_backend is not None: tensor_dict = _init_attn_metadata_from_tensor_dict( attn_backend, tensor_dict) - if "enable_kv_scales_calculation" in tensor_dict: - tensor_dict.pop("enable_kv_scales_calculation") return cls(**tensor_dict) diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 4a3f711e3a9ed..cd4770202a186 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -47,8 +47,7 @@ def _init_attn_metadata_from_tensor_dict( # Extract the fields used to create AttentionMetadata. valid_attn_kwargs = {} for field in dataclasses.fields(attn_backend.get_metadata_cls()): - if field.name in tensor_dict and field.name != \ - 'enable_kv_scales_calculation': + if field.name in tensor_dict: valid_attn_kwargs[field.name] = tensor_dict.pop(field.name) attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index 6000e5dfe4e30..23a48e29731a7 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -278,6 +278,7 @@ def _prepare_model_input( block_indices_begins=block_indices_begins_tensor, max_context_len=max_context_len_tensor, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, ) multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 9a054eb8a4cf7..d4a11b1ee220c 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -184,6 +184,7 @@ def _dummy_run( num_decode_tokens=0, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, block_tables=None, context_lens=None, effective_query_lens=None, @@ -202,6 +203,7 @@ def _dummy_run( num_decode_tokens=0, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, block_tables=block_tables, context_lens=context_lens, effective_query_lens=effective_query_lens, @@ -233,6 +235,7 @@ def _dummy_run( num_decode_tokens=batch_size * seq_len, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, block_tables=block_tables, context_lens=context_lens, ) @@ -418,6 +421,7 @@ def _prepare_prompt( num_decode_tokens=0, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, block_tables=block_tables, context_lens=context_lens, effective_query_lens=prompt_lens, @@ -489,6 +493,7 @@ def _prepare_decode( num_decode_tokens=batch_size, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, block_tables=block_tables, context_lens=context_lens, ) diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 9cf25387560da..a4a03050a46c3 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -257,6 +257,7 @@ def _prepare_prompt( is_prompt=True, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, seq_lens=seq_lens, seqlen_q=seqlen_q, max_seqlen=max_seqlen, @@ -341,6 +342,7 @@ def _prepare_decode( is_prompt=False, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, seq_lens=seq_lens, seqlen_q=torch.tensor([]), max_seqlen=0,