Skip to content

Commit

Permalink
Fix formatting issues
Browse files Browse the repository at this point in the history
Signed-off-by: Pavani Majety <[email protected]>
  • Loading branch information
pavanimajety committed Nov 19, 2024
1 parent 59e3786 commit 9ae1e95
Showing 1 changed file with 17 additions and 27 deletions.
44 changes: 17 additions & 27 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,34 +22,21 @@

import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionState,
AttentionType,
)
from vllm.attention.backends.utils import (
PAD_SLOT_ID,
compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty,
)
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionState, AttentionType)
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.forward_context import get_forward_context
from vllm.utils import (
async_tensor_h2d,
direct_register_custom_op,
get_kv_cache_torch_dtype,
make_tensor_with_pad,
)
from vllm.utils import (async_tensor_h2d, direct_register_custom_op,
get_kv_cache_torch_dtype, make_tensor_with_pad)

if TYPE_CHECKING:
from vllm.worker.model_runner import (
ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata,
)
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)


class FlashInferBackend(AttentionBackend):
Expand Down Expand Up @@ -878,7 +865,10 @@ def unified_flash_infer(

assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
assert window_size.le

window_left = -1
if window_size is not None:
window_left = window_size[0]

prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -911,7 +901,7 @@ def unified_flash_infer(
causal=True,
k_scale=k_scale,
v_scale=v_scale,
window_left=window_size[0])
window_left=window_left)
if decode_meta := attn_metadata.decode_metadata:
assert attn_metadata.decode_metadata is not None
assert attn_metadata.decode_metadata.decode_wrapper is not None
Expand All @@ -922,7 +912,7 @@ def unified_flash_infer(
logits_soft_cap=logits_soft_cap,
k_scale=k_scale,
v_scale=v_scale,
window_left=window_size[0])
window_left=window_left)

if prefill_output is None and decode_output is not None:
# Decode only batch.
Expand Down

0 comments on commit 9ae1e95

Please sign in to comment.