From 9ae1e959896c8acc574943498bf00142994a927f Mon Sep 17 00:00:00 2001 From: Pavani Majety Date: Tue, 19 Nov 2024 11:44:53 -0800 Subject: [PATCH] Fix formatting issues Signed-off-by: Pavani Majety --- vllm/attention/backends/flashinfer.py | 44 +++++++++++---------------- 1 file changed, 17 insertions(+), 27 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index ff6186bcdd542..bd14e66ad8a91 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -22,34 +22,21 @@ import vllm.envs as envs from vllm import _custom_ops as ops -from vllm.attention.backends.abstract import ( - AttentionBackend, - AttentionImpl, - AttentionMetadata, - AttentionMetadataBuilder, - AttentionState, - AttentionType, -) -from vllm.attention.backends.utils import ( - PAD_SLOT_ID, - compute_slot_mapping, - compute_slot_mapping_start_idx, - is_block_tables_empty, -) +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionState, AttentionType) +from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, + compute_slot_mapping_start_idx, + is_block_tables_empty) from vllm.attention.ops.paged_attn import PagedAttention from vllm.forward_context import get_forward_context -from vllm.utils import ( - async_tensor_h2d, - direct_register_custom_op, - get_kv_cache_torch_dtype, - make_tensor_with_pad, -) +from vllm.utils import (async_tensor_h2d, direct_register_custom_op, + get_kv_cache_torch_dtype, make_tensor_with_pad) if TYPE_CHECKING: - from vllm.worker.model_runner import ( - ModelInputForGPUBuilder, - ModelInputForGPUWithSamplingMetadata, - ) + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) class FlashInferBackend(AttentionBackend): @@ -878,7 +865,10 @@ def unified_flash_infer( assert query.shape[0] == num_prefill_tokens assert decode_query.shape[0] == num_decode_tokens - assert window_size.le + + window_left = -1 + if window_size is not None: + window_left = window_size[0] prefill_output: Optional[torch.Tensor] = None decode_output: Optional[torch.Tensor] = None @@ -911,7 +901,7 @@ def unified_flash_infer( causal=True, k_scale=k_scale, v_scale=v_scale, - window_left=window_size[0]) + window_left=window_left) if decode_meta := attn_metadata.decode_metadata: assert attn_metadata.decode_metadata is not None assert attn_metadata.decode_metadata.decode_wrapper is not None @@ -922,7 +912,7 @@ def unified_flash_infer( logits_soft_cap=logits_soft_cap, k_scale=k_scale, v_scale=v_scale, - window_left=window_size[0]) + window_left=window_left) if prefill_output is None and decode_output is not None: # Decode only batch.