Skip to content

Commit

Permalink
[Bugfix] Fix chunked prefill with model dtype float32 on Turing Devic…
Browse files Browse the repository at this point in the history
…es (vllm-project#9850)

Signed-off-by: Wallas Santos <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
  • Loading branch information
wallashss and mgoin authored Nov 25, 2024
1 parent d04b13a commit c27df94
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 13 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,5 @@ markers = [
"quant_model: run this model test under Quantized category",
"distributed_2_gpus: run this test only in distributed tests for 2 GPUs",
"skip_v1: do not run this test with v1",
"optional: optional tests that are automatically skipped, include --optional to run them",
]
19 changes: 19 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,3 +1030,22 @@ def dummy_gemma2_embedding_path():
with open(json_path, "w") as f:
json.dump(config, f)
return _dummy_gemma2_embedding_path


# Add the flag `--optional` to allow run tests
# that are marked with @pytest.mark.optional
def pytest_addoption(parser):
parser.addoption("--optional",
action="store_true",
default=False,
help="run optional test")


def pytest_collection_modifyitems(config, items):
if config.getoption("--optional"):
# --optional given in cli: do not skip optional tests
return
skip_optional = pytest.mark.skip(reason="need --optional option to run")
for item in items:
if "optional" in item.keywords:
item.add_marker(skip_optional)
63 changes: 63 additions & 0 deletions tests/kernels/test_prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ def test_contexted_kv_attention(
kv_cache_dtype: str,
device: str,
) -> None:

if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability(
89):
pytest.skip(
'Triton limitation: fp8e4nv data type is not supported on CUDA'
' arch < 89')

current_platform.seed_everything(0)
torch.set_default_device(device)

Expand Down Expand Up @@ -235,6 +242,13 @@ def test_contexted_kv_attention_alibi(
kv_cache_dtype: str,
device: str,
) -> None:

if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability(
89):
pytest.skip(
'Triton limitation: fp8e4nv data type is not supported on CUDA'
' arch < 89')

current_platform.seed_everything(0)
torch.set_default_device(device)

Expand Down Expand Up @@ -462,3 +476,52 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)


# These tests are optional to only run when explicitly invoked
#
# pytest -v -s --optional \
# tests/kernels/test_prefix_prefill.py::test_contexted_kv_attention_f32
#
# These tests are useful to test model dtype float32 on Turing devices.
# We skip them to not increase the time when running tests on CI
@pytest.mark.optional
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
@torch.inference_mode()
def test_contexted_kv_attention_f32(
num_heads: int,
num_queries_per_kv: int,
head_size: int,
sliding_window: int,
dtype: torch.dtype,
kv_cache_dtype: str,
device: str,
) -> None:
test_contexted_kv_attention(num_heads, num_queries_per_kv, head_size,
sliding_window, dtype, kv_cache_dtype, device)


@pytest.mark.optional
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_contexted_kv_attention_alibi_f32(
num_heads: int,
num_queries_per_kv: int,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str,
device: str,
) -> None:
test_contexted_kv_attention_alibi(num_heads, num_queries_per_kv, head_size,
dtype, kv_cache_dtype, device)
41 changes: 28 additions & 13 deletions vllm/attention/ops/prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@

from vllm.platforms import current_platform

# Static kernels parameters
BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64
NUM_WARPS = 8

# To check compatibility
IS_TURING = current_platform.get_device_capability() == (7, 5)

if triton.__version__ >= "2.1.0":

@triton.jit
Expand Down Expand Up @@ -50,6 +57,7 @@ def _fwd_kernel(
stride_v_cache_d,
stride_v_cache_bl,
num_queries_per_kv: int,
IN_PRECISION: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, # head size
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
Expand Down Expand Up @@ -130,7 +138,7 @@ def _fwd_kernel(
k = k_load

qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N]
qk += tl.dot(q, k)
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
float("-inf"))
qk *= sm_scale
Expand Down Expand Up @@ -178,7 +186,7 @@ def _fwd_kernel(
v = v_load
p = p.to(v.dtype)

acc += tl.dot(p, v)
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
# # update m_i and l_i
l_i = l_i_new
m_i = m_i_new
Expand All @@ -204,7 +212,7 @@ def _fwd_kernel(
other=0.0)

qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
qk *= sm_scale
# apply causal mask
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
Expand Down Expand Up @@ -238,7 +246,7 @@ def _fwd_kernel(
other=0.0)
p = p.to(v.dtype)

acc += tl.dot(p, v)
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
Expand Down Expand Up @@ -485,6 +493,7 @@ def _fwd_kernel_alibi(
stride_v_cache_d,
stride_v_cache_bl,
num_queries_per_kv: int,
IN_PRECISION: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, # head size
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
Expand Down Expand Up @@ -560,7 +569,7 @@ def _fwd_kernel_alibi(
k = k_load

qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
float("-inf"))
qk *= sm_scale
Expand Down Expand Up @@ -600,7 +609,7 @@ def _fwd_kernel_alibi(
v = v_load
p = p.to(v.dtype)

acc += tl.dot(p, v, allow_tf32=False)
acc = tl.dot(p, v, acc=acc, input_precision='ieee')
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
Expand Down Expand Up @@ -635,7 +644,7 @@ def _fwd_kernel_alibi(
other=0.0)

qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k, allow_tf32=False)
qk = tl.dot(q, k, acc=qk, input_precision='ieee')
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
float("-inf"))
Expand Down Expand Up @@ -673,7 +682,7 @@ def _fwd_kernel_alibi(
other=0.0)
p = p.to(v.dtype)

acc += tl.dot(p, v, allow_tf32=False)
acc = tl.dot(p, v, acc=acc, input_precision='ieee')
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
Expand Down Expand Up @@ -709,13 +718,17 @@ def context_attention_fwd(q,
alibi_slopes=None,
sliding_window=None):

BLOCK = 128 if current_platform.has_device_capability(80) else 64
NUM_WARPS = 8

q_dtype_is_f32 = q.dtype is torch.float32
# need to reduce num. blocks when using fp32
# due to increased use of GPU shared memory
if q.dtype is torch.float32:
BLOCK = BLOCK // 2
# if q.dtype is torch.float32:
BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK

# Turing does have tensor core for float32 multiplication
# use ieee as fallback for triton kernels work. There is also
# warning on vllm/config.py to inform users this fallback
# implementation
IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None

# Conversion of FP8 Tensor from uint8 storage to
# appropriate torch.dtype for interpretation by Triton
Expand Down Expand Up @@ -799,6 +812,7 @@ def context_attention_fwd(q,
v_cache.stride(
3), #[num_blocks, num_kv_heads, head_size, block_size]
num_queries_per_kv=num_queries_per_kv,
IN_PRECISION=IN_PRECISION,
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_DMODEL_PADDED=Lk_padded,
Expand Down Expand Up @@ -850,6 +864,7 @@ def context_attention_fwd(q,
v_cache.stride(
3), #[num_blocks, num_kv_heads, head_size, block_size]
num_queries_per_kv=num_queries_per_kv,
IN_PRECISION=IN_PRECISION,
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_DMODEL_PADDED=Lk_padded,
Expand Down
10 changes: 10 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2388,6 +2388,16 @@ def __post_init__(self):
self.quant_config = VllmConfig._get_quantization_config(
self.model_config, self.load_config)

if self.scheduler_config is not None and \
self.model_config is not None and \
self.scheduler_config.chunked_prefill_enabled and \
self.model_config.dtype == torch.float32 and \
current_platform.get_device_capability() == (7, 5):
print_warning_once(
"Turing devices tensor cores do not support float32 matmul. "
"To workaround this limitation, vLLM will set 'ieee' input "
"precision for chunked prefill triton kernels.")

if self.compilation_config is None:
self.compilation_config = CompilationConfig()
if envs.VLLM_USE_V1 and not self.model_config.enforce_eager:
Expand Down
1 change: 1 addition & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,6 +1055,7 @@ def create_engine_config(self) -> VllmConfig:
msg = "Chunked prefill is not supported for embedding models"
raise ValueError(msg)


speculative_config = SpeculativeConfig.maybe_create_spec_config(
target_model_config=model_config,
target_parallel_config=parallel_config,
Expand Down

0 comments on commit c27df94

Please sign in to comment.