diff --git a/pyproject.toml b/pyproject.toml index 3c8c46cc8621e..253b706a774a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,4 +98,5 @@ markers = [ "quant_model: run this model test under Quantized category", "distributed_2_gpus: run this test only in distributed tests for 2 GPUs", "skip_v1: do not run this test with v1", + "optional: optional tests that are automatically skipped, include --optional to run them", ] diff --git a/tests/conftest.py b/tests/conftest.py index 29707f975e2a0..d56942d8912af 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1030,3 +1030,22 @@ def dummy_gemma2_embedding_path(): with open(json_path, "w") as f: json.dump(config, f) return _dummy_gemma2_embedding_path + + +# Add the flag `--optional` to allow run tests +# that are marked with @pytest.mark.optional +def pytest_addoption(parser): + parser.addoption("--optional", + action="store_true", + default=False, + help="run optional test") + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--optional"): + # --optional given in cli: do not skip optional tests + return + skip_optional = pytest.mark.skip(reason="need --optional option to run") + for item in items: + if "optional" in item.keywords: + item.add_marker(skip_optional) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index a8a187ebaede4..3fdb7996ba4e0 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -40,6 +40,13 @@ def test_contexted_kv_attention( kv_cache_dtype: str, device: str, ) -> None: + + if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability( + 89): + pytest.skip( + 'Triton limitation: fp8e4nv data type is not supported on CUDA' + ' arch < 89') + current_platform.seed_everything(0) torch.set_default_device(device) @@ -235,6 +242,13 @@ def test_contexted_kv_attention_alibi( kv_cache_dtype: str, device: str, ) -> None: + + if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability( + 89): + pytest.skip( + 'Triton limitation: fp8e4nv data type is not supported on CUDA' + ' arch < 89') + current_platform.seed_everything(0) torch.set_default_device(device) @@ -462,3 +476,52 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6 torch.testing.assert_close(output, output_ref, atol=atol, rtol=0) + + +# These tests are optional to only run when explicitly invoked +# +# pytest -v -s --optional \ +# tests/kernels/test_prefix_prefill.py::test_contexted_kv_attention_f32 +# +# These tests are useful to test model dtype float32 on Turing devices. +# We skip them to not increase the time when running tests on CI +@pytest.mark.optional +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", [torch.float32]) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW) +@torch.inference_mode() +def test_contexted_kv_attention_f32( + num_heads: int, + num_queries_per_kv: int, + head_size: int, + sliding_window: int, + dtype: torch.dtype, + kv_cache_dtype: str, + device: str, +) -> None: + test_contexted_kv_attention(num_heads, num_queries_per_kv, head_size, + sliding_window, dtype, kv_cache_dtype, device) + + +@pytest.mark.optional +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", [torch.float32]) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_contexted_kv_attention_alibi_f32( + num_heads: int, + num_queries_per_kv: int, + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: str, + device: str, +) -> None: + test_contexted_kv_attention_alibi(num_heads, num_queries_per_kv, head_size, + dtype, kv_cache_dtype, device) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index a2a649c8ebcfd..9c11a8df55278 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -7,6 +7,13 @@ from vllm.platforms import current_platform +# Static kernels parameters +BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64 +NUM_WARPS = 8 + +# To check compatibility +IS_TURING = current_platform.get_device_capability() == (7, 5) + if triton.__version__ >= "2.1.0": @triton.jit @@ -50,6 +57,7 @@ def _fwd_kernel( stride_v_cache_d, stride_v_cache_bl, num_queries_per_kv: int, + IN_PRECISION: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # head size BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 @@ -130,7 +138,7 @@ def _fwd_kernel( k = k_load qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N] - qk += tl.dot(q, k) + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf")) qk *= sm_scale @@ -178,7 +186,7 @@ def _fwd_kernel( v = v_load p = p.to(v.dtype) - acc += tl.dot(p, v) + acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) # # update m_i and l_i l_i = l_i_new m_i = m_i_new @@ -204,7 +212,7 @@ def _fwd_kernel( other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) qk *= sm_scale # apply causal mask qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, @@ -238,7 +246,7 @@ def _fwd_kernel( other=0.0) p = p.to(v.dtype) - acc += tl.dot(p, v) + acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) # update m_i and l_i l_i = l_i_new m_i = m_i_new @@ -485,6 +493,7 @@ def _fwd_kernel_alibi( stride_v_cache_d, stride_v_cache_bl, num_queries_per_kv: int, + IN_PRECISION: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # head size BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 @@ -560,7 +569,7 @@ def _fwd_kernel_alibi( k = k_load qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf")) qk *= sm_scale @@ -600,7 +609,7 @@ def _fwd_kernel_alibi( v = v_load p = p.to(v.dtype) - acc += tl.dot(p, v, allow_tf32=False) + acc = tl.dot(p, v, acc=acc, input_precision='ieee') # update m_i and l_i l_i = l_i_new m_i = m_i_new @@ -635,7 +644,7 @@ def _fwd_kernel_alibi( other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k, allow_tf32=False) + qk = tl.dot(q, k, acc=qk, input_precision='ieee') qk *= sm_scale qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) @@ -673,7 +682,7 @@ def _fwd_kernel_alibi( other=0.0) p = p.to(v.dtype) - acc += tl.dot(p, v, allow_tf32=False) + acc = tl.dot(p, v, acc=acc, input_precision='ieee') # update m_i and l_i l_i = l_i_new m_i = m_i_new @@ -709,13 +718,17 @@ def context_attention_fwd(q, alibi_slopes=None, sliding_window=None): - BLOCK = 128 if current_platform.has_device_capability(80) else 64 - NUM_WARPS = 8 - + q_dtype_is_f32 = q.dtype is torch.float32 # need to reduce num. blocks when using fp32 # due to increased use of GPU shared memory - if q.dtype is torch.float32: - BLOCK = BLOCK // 2 + # if q.dtype is torch.float32: + BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK + + # Turing does have tensor core for float32 multiplication + # use ieee as fallback for triton kernels work. There is also + # warning on vllm/config.py to inform users this fallback + # implementation + IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None # Conversion of FP8 Tensor from uint8 storage to # appropriate torch.dtype for interpretation by Triton @@ -799,6 +812,7 @@ def context_attention_fwd(q, v_cache.stride( 3), #[num_blocks, num_kv_heads, head_size, block_size] num_queries_per_kv=num_queries_per_kv, + IN_PRECISION=IN_PRECISION, BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, BLOCK_DMODEL_PADDED=Lk_padded, @@ -850,6 +864,7 @@ def context_attention_fwd(q, v_cache.stride( 3), #[num_blocks, num_kv_heads, head_size, block_size] num_queries_per_kv=num_queries_per_kv, + IN_PRECISION=IN_PRECISION, BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, BLOCK_DMODEL_PADDED=Lk_padded, diff --git a/vllm/config.py b/vllm/config.py index f9ecb02cd5bde..c87feaec3e5f6 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2388,6 +2388,16 @@ def __post_init__(self): self.quant_config = VllmConfig._get_quantization_config( self.model_config, self.load_config) + if self.scheduler_config is not None and \ + self.model_config is not None and \ + self.scheduler_config.chunked_prefill_enabled and \ + self.model_config.dtype == torch.float32 and \ + current_platform.get_device_capability() == (7, 5): + print_warning_once( + "Turing devices tensor cores do not support float32 matmul. " + "To workaround this limitation, vLLM will set 'ieee' input " + "precision for chunked prefill triton kernels.") + if self.compilation_config is None: self.compilation_config = CompilationConfig() if envs.VLLM_USE_V1 and not self.model_config.enforce_eager: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a43e133f21ac2..ca68c1d57151c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1055,6 +1055,7 @@ def create_engine_config(self) -> VllmConfig: msg = "Chunked prefill is not supported for embedding models" raise ValueError(msg) + speculative_config = SpeculativeConfig.maybe_create_spec_config( target_model_config=model_config, target_parallel_config=parallel_config,