Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] TPU Prototype #10241

Draft
wants to merge 34 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
f69bdea
prototype tpu on v1
robertgshaw2-neuralmagic Nov 12, 2024
1142c89
profile run complete
robertgshaw2-neuralmagic Nov 12, 2024
9cc4fbe
actually dummy run
robertgshaw2-neuralmagic Nov 12, 2024
61f7792
stash
robertgshaw2-neuralmagic Nov 12, 2024
1887d81
update workflow
robertgshaw2-neuralmagic Nov 3, 2024
b8c6444
updated
robertgshaw2-neuralmagic Nov 3, 2024
75e2e53
updated
robertgshaw2-neuralmagic Nov 12, 2024
bebabfc
more cleaning
robertgshaw2-neuralmagic Nov 12, 2024
338e11c
cleanup llmengine
robertgshaw2-neuralmagic Nov 12, 2024
db49d3b
Revert "cleanup llmengine"
robertgshaw2-neuralmagic Nov 12, 2024
4ade5b0
fixt
robertgshaw2-neuralmagic Nov 12, 2024
dc78451
warmup is working!
robertgshaw2-neuralmagic Nov 12, 2024
7f8fdee
stash
robertgshaw2-neuralmagic Nov 12, 2024
f7de1b4
stash
robertgshaw2-neuralmagic Nov 15, 2024
5de1d9f
workin for prefill, except when I compile decode cudagraphs?
robertgshaw2-neuralmagic Nov 16, 2024
15a2f74
working! It was the type of the position ids!
robertgshaw2-neuralmagic Nov 16, 2024
14b9500
forward pass
robertgshaw2-neuralmagic Nov 16, 2024
6eeecb7
correct output for single prompt with --enforce-eager
robertgshaw2-neuralmagic Nov 16, 2024
0b256c2
end to end passing working for single request with CUDAGraphs!
robertgshaw2-neuralmagic Nov 16, 2024
b44227d
yay! working with multiple requests! the issue was copy_() does not s…
robertgshaw2-neuralmagic Nov 16, 2024
451dfbf
yay! working end to end via lm eval harness!
robertgshaw2-neuralmagic Nov 16, 2024
d2ae4a5
we have end to end correctness
robertgshaw2-neuralmagic Nov 16, 2024
7dd18e0
nits
robertgshaw2-neuralmagic Nov 16, 2024
d89200d
updated
robertgshaw2-neuralmagic Nov 16, 2024
75c44b4
update to call .cpu() before slicing to avoid recompilation
robertgshaw2-neuralmagic Nov 17, 2024
58e85eb
a bit faster
robertgshaw2-neuralmagic Nov 17, 2024
fcf4681
better performance due to better input processing
robertgshaw2-neuralmagic Nov 17, 2024
d9dc36a
cleanup PR
robertgshaw2-neuralmagic Nov 17, 2024
85bc154
cleanup
robertgshaw2-neuralmagic Nov 17, 2024
25fff99
cleanup pr
robertgshaw2-neuralmagic Nov 17, 2024
5a87b99
formatting
robertgshaw2-neuralmagic Nov 17, 2024
63b301a
updated
robertgshaw2-neuralmagic Nov 17, 2024
1af03e0
updated
robertgshaw2-neuralmagic Nov 17, 2024
02ee304
fixed accuracy bug
robertgshaw2-neuralmagic Nov 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tests/entrypoints/openai/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,14 @@ def run_test(more_args):
)

measured_value = results["results"][TASK][FILTER]
print(f"{measured_value=}")
assert (measured_value - RTOL < EXPECTED_VALUE
and measured_value + RTOL > EXPECTED_VALUE
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"


@pytest.mark.skipif(not current_platform.is_cuda(),
@pytest.mark.skipif(not current_platform.is_cuda()
and not current_platform.is_tpu(),
reason="V1 currently only supported on CUDA")
def test_lm_eval_accuracy_v1_engine(monkeypatch):
"""Run with the V1 Engine."""
Expand Down
10 changes: 9 additions & 1 deletion vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class _Backend(enum.Enum):
FLASHINFER = enum.auto()
HPU_ATTN = enum.auto()
PALLAS = enum.auto()
PALLAS_VLLM_V1 = enum.auto()
IPEX = enum.auto()
NO_ATTENTION = enum.auto()

Expand Down Expand Up @@ -140,6 +141,10 @@ def _cached_get_attn_backend(
from vllm.v1.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend as FlashAttentionBackendV1)
return FlashAttentionBackendV1
if backend == _Backend.PALLAS_VLLM_V1:
from vllm.v1.attention.backends.pallas import ( # noqa: F401
PallasAttentionBackend as PallasAttentionBackendV1)
return PallasAttentionBackendV1
if backend == _Backend.XFORMERS:
logger.info("Using XFormers backend.")
from vllm.attention.backends.xformers import ( # noqa: F401
Expand Down Expand Up @@ -232,8 +237,11 @@ def which_attn_to_use(head_size: int,
return _Backend.IPEX

if current_platform.is_tpu():
if selected_backend != _Backend.PALLAS:
if (selected_backend != _Backend.PALLAS
and selected_backend != _Backend.PALLAS_VLLM_V1):
logger.info("Cannot use %s backend on TPU.", selected_backend)
if use_v1:
return _Backend.PALLAS_VLLM_V1
return _Backend.PALLAS

if current_platform.is_rocm():
Expand Down
2 changes: 2 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,6 +1225,8 @@ def __init__(self, device: str = "auto") -> None:
# Some device types require processing inputs on CPU
if self.device_type in ["neuron", "openvino"]:
self.device = torch.device("cpu")
# Device initialization should happen after initializing the
# distributed runtime.
elif self.device_type in ["tpu"]:
self.device = None
else:
Expand Down
3 changes: 3 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,9 @@ def is_pin_memory_available() -> bool:
elif current_platform.is_hpu():
print_warning_once("Pin memory is not supported on HPU.")
return False
elif current_platform.is_tpu():
print_warning_once("Pin memory is not supported on TPU.")
return False
elif current_platform.is_cpu() or current_platform.is_openvino():
return False
return True
Expand Down
3 changes: 2 additions & 1 deletion vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
AttentionMetadata, AttentionType)
from vllm.forward_context import get_forward_context
from vllm.utils import direct_register_custom_op
from vllm.vllm_flash_attn import flash_attn_varlen_func


class FlashAttentionBackend(AttentionBackend):
Expand Down Expand Up @@ -202,6 +201,8 @@ def unified_v1_flash_attention(
v_scale,
)

from vllm.vllm_flash_attn import flash_attn_varlen_func

attn_output = flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
Expand Down
298 changes: 298 additions & 0 deletions vllm/v1/attention/backends/pallas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type

import torch
import torch_xla

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType)


class PallasAttentionBackend(AttentionBackend):

@staticmethod
def get_supported_head_sizes() -> List[int]:
return [128]

@staticmethod
def get_name() -> str:
return "pallas-vllm-v1"

@staticmethod
def get_impl_cls() -> Type["PallasAttentionImpl"]:
return PallasAttentionImpl

@staticmethod
def get_metadata_cls() -> Type["PallasAttentionMetadata"]:
return PallasAttentionMetadata

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (num_kv_heads, num_blocks, block_size, head_size)


@dataclass
class PallasAttentionMetadata:

is_prompt: bool
slot_mapping: torch.Tensor
block_tables: Optional[torch.Tensor] = None
context_lens: Optional[torch.Tensor] = None


class PallasAttentionImpl(AttentionImpl):

def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
) -> None:
if blocksparse_params is not None:
raise ValueError(
"FlashAttention does not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads

if head_size % 128 != 0:
raise NotImplementedError("Head size must be a multiple of 128.")
if alibi_slopes is not None:
raise NotImplementedError("Alibi slopes is not supported.")
if sliding_window is not None:
raise NotImplementedError("Sliding window is not supported.")
if kv_cache_dtype != "auto":
raise NotImplementedError("FP8 KV cache dtype is not supported.")
if blocksparse_params is not None:
raise NotImplementedError("Blocksparse is not supported.")
if logits_soft_cap is not None:
raise NotImplementedError(
"Attention logits soft-capping is not supported.")

if torch_xla.tpu.version() < 4:
raise NotImplementedError("TPU version must be 4 or higher.")
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

support_head_sizes = PallasAttentionBackend.get_supported_head_sizes()
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PallasAttention. "
f"Supported head sizes are: {support_head_sizes}.")

self.megacore_mode = None
tpu_env = torch_xla.tpu.get_tpu_env()
tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None)
or tpu_env.get("TYPE", None)
or tpu_env.get("TPU_ACCELERATOR_TYPE", None))
assert tpu_type is not None
tpu_type = tpu_type.lower()

if (("lite" not in tpu_type) and ("v6" not in tpu_type)):
if self.num_kv_heads % 2 == 0:
self.megacore_mode = "kv_head"
else:
# NOTE(woosuk): If the batch size is not a multiple of 2, the
# megacore mode will be None.
self.megacore_mode = "batch"

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
attn_metadata: PallasAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor:
"""Forward pass with FlashAttention.

Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size]
kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size]
NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor
with shape [0] for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""

assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in PallasAttentionImpl.")

if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionImpl")

# Unpack
batch_size, seq_len, hidden_size = query.shape
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
value = value.view(batch_size, seq_len, self.num_kv_heads,
self.head_size)

# Write to KV cache.
if kv_cache[0].numel() > 0:
slot_mapping = attn_metadata.slot_mapping
key_cache = kv_cache[0]
value_cache = kv_cache[1]
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)

query = query * self.scale
if attn_metadata.is_prompt:
assert seq_len % 16 == 0, (
"Pallas FlashAttention kernel requires seq_len to be a "
f"multiple of 16 but got {seq_len}")

# Handle GQA/MQA.
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=-2)
key = key.view(batch_size, seq_len, self.num_heads,
self.head_size)
value = value.repeat_interleave(self.num_queries_per_kv,
dim=-2)
value = value.view(batch_size, seq_len, self.num_heads,
self.head_size)
# FlashAttention requires [batch_size, num_heads, seq_len, d_model]
# while the input is [batch_size, seq_len, num_heads, d_model].
# Permute the input to match the required format.
output = torch.ops.xla.flash_attention(
query.permute(0, 2, 1, 3),
key.permute(0, 2, 1, 3),
value.permute(0, 2, 1, 3),
True,
)
output = output.permute(0, 2, 1, 3)
else:
# Decoding run.
assert kv_cache[0].numel() > 0
query = query.squeeze(dim=1)
pages_per_compute_block = 16 # TODO(woosuk): Tune this value.

assert attn_metadata.block_tables is not None
assert attn_metadata.context_lens is not None
# NOTE(woosuk): The PagedAttention Pallas kernel stores the entire
# block table in SMEM. Therefore, if the block table is too large,
# the kernel compilation will fail. To avoid this, we split the
# batch dimension into smaller chunks and run the kernel multiple
# times.
MAX_SMEM_USAGE = 512 * 1024
size_per_seq = 4 * attn_metadata.block_tables.shape[1]
max_num_seq = MAX_SMEM_USAGE // size_per_seq

if batch_size <= max_num_seq:
output = paged_attention(
query,
key_cache,
value_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
pages_per_compute_block,
self.megacore_mode,
)
else:
chunk_size = max_num_seq
# Make sure the chunk size is a multiple of 2.
chunk_size = chunk_size // 2 * 2
num_chunks = (batch_size + chunk_size - 1) // chunk_size

output = torch.empty_like(query)
for chunk_idx in range(num_chunks):
chunk_start = chunk_idx * chunk_size
chunk_end = chunk_start + chunk_size
# NOTE(woosuk): We skip this line because it causes Dynamo
# compilation error. Instead, we rely on the slice operation
# to handle the out-of-bound case.
# chunk_end = min(chunk_end, batch_size)
chunk_output = paged_attention(
query[chunk_start:chunk_end],
key_cache,
value_cache,
attn_metadata.context_lens[chunk_start:chunk_end],
attn_metadata.block_tables[chunk_start:chunk_end],
pages_per_compute_block,
self.megacore_mode,
)
output[chunk_start:chunk_end] = chunk_output

# Reshape the output tensor.
return output.reshape(batch_size, seq_len, hidden_size)


def write_to_kv_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)

key = key.flatten(0, 2)
value = value.flatten(0, 2)
key_cache = key_cache.flatten(0, 2)
value_cache = value_cache.flatten(0, 2)
key_cache.index_copy_(0, slot_mapping, key)
value_cache.index_copy_(0, slot_mapping, value)


def paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
context_lens: torch.Tensor,
block_tables: torch.Tensor,
pages_per_compute_block: int,
megacore_mode: Optional[str],
) -> torch.Tensor:
batch_size = query.shape[0]
if megacore_mode == "batch" and batch_size % 2 != 0:
megacore_mode = None
else:
megacore_mode = megacore_mode

# NOTE(woosuk): A temporary workaround to avoid the error:
# "xla::paged_attention() Expected a value of type 'str' for
# argument 'megacore_mode' but instead found type 'NoneType'."
if megacore_mode is not None:
output = torch.ops.xla.paged_attention(
query,
key_cache,
value_cache,
context_lens,
block_tables,
pages_per_compute_block,
megacore_mode=megacore_mode,
)
else:
output = torch.ops.xla.paged_attention(
query,
key_cache,
value_cache,
context_lens,
block_tables,
pages_per_compute_block,
)
return output
Loading