diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index abec46b2e599c..904e4ede90f8f 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -218,6 +218,12 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: max_encoder_seq_len=self.max_encoder_seq_len, cross_slot_mapping=self.cross_slot_mapping, cross_block_tables=self.cross_block_tables) + # Batch may be composed of prefill|decodes, adjust query start indices + # to refer to the start of decodes when the two are split apart. + # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. + if self._cached_decode_metadata.query_start_loc is not None: + qs = self._cached_decode_metadata.query_start_loc + self._cached_decode_metadata.query_start_loc = qs - qs[0] return self._cached_decode_metadata def advance_step(self,