From 9242621cea291883ee71be072ecf672d1fc5b510 Mon Sep 17 00:00:00 2001 From: Hui Liu <96135754+hliuca@users.noreply.github.com> Date: Mon, 2 Dec 2024 14:14:25 -0800 Subject: [PATCH] restore accidental deletion --- vllm/attention/backends/rocm_flash_attn.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index abec46b2e599c..904e4ede90f8f 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -218,6 +218,12 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: max_encoder_seq_len=self.max_encoder_seq_len, cross_slot_mapping=self.cross_slot_mapping, cross_block_tables=self.cross_block_tables) + # Batch may be composed of prefill|decodes, adjust query start indices + # to refer to the start of decodes when the two are split apart. + # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. + if self._cached_decode_metadata.query_start_loc is not None: + qs = self._cached_decode_metadata.query_start_loc + self._cached_decode_metadata.query_start_loc = qs - qs[0] return self._cached_decode_metadata def advance_step(self,