diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp index c74d2aacb0beb..fcc17c7727f94 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -1,4 +1,4 @@ -#include "broadcast_load_epilogue_c3x.hpp" +#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp" /* This file defines custom epilogues for fusing channel scales, token scales, diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index aa7f517ce6ff0..afe53797322f9 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -2,7 +2,7 @@ Run `pytest tests/kernels/test_cutlass.py`. """ -from typing import Optional, Tuple, Type +from typing import Optional, Type import pytest import torch @@ -55,62 +55,6 @@ def rand_int8(shape: tuple, device: str = "cuda"): return to_int8(torch.rand(shape, device=device) * 255 - 128) -def to_bf16(tensor: torch.Tensor) -> torch.Tensor: - return tensor.to(dtype=torch.bfloat16) - - -def to_fp16(tensor: torch.Tensor) -> torch.Tensor: - return tensor.to(dtype=torch.float16) - - -def prune_to_2_4(tensor): - # Reshape tensor to [N, 4] where N is number of groups of 4 - original_shape = tensor.shape - reshaped = tensor.reshape(-1, 4) - - # Get indices of top 2 absolute values in each group of 4 - _, indices = torch.topk(torch.abs(reshaped), k=2, dim=1) - - # Create binary mask - mask = torch.zeros_like(reshaped) - mask.scatter_(dim=1, - index=indices, - src=torch.ones_like(indices, dtype=mask.dtype)) - - # Apply mask and reshape back - pruned = reshaped * mask - - # Turn all -0.0 to 0.0 - pruned[pruned == -0.0] = 0.0 - - return pruned.reshape(original_shape) - - -def make_rand_sparse_tensors( - dtype: torch.dtype, m: int, n: int, k: int -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - a = torch.randn((m, k), device='cuda') * 5 - b = torch.randn((n, k), device='cuda').t() * 5 - - b = prune_to_2_4(b.t()).t() - - if dtype == torch.int8: - a, b = to_int8(a), to_int8(b) - elif dtype == torch.float8_e4m3fn: - a, b = to_fp8(a), to_fp8(b) - elif dtype == torch.float16: - a, b = to_fp16(a), to_fp16(b) - elif dtype == torch.bfloat16: - a, b = to_bf16(a), to_bf16(b) - else: - raise ValueError("unsupported dtype") - - b_compressed, e = ops.cutlass_compress_entry(b.t()) - - # Compressed B, Metadata, Original A, B - return b_compressed, e, a, b - - def baseline_scaled_mm(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, @@ -459,35 +403,6 @@ def test_cutlass_subset(): torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) -# Test working with a subset of A and B for sparse matmul -def test_cutlass_sparse_subset(): - big_m = 1024 - m, n, k = 512, 512, 512 - - # Create tensors - b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, - big_m, n, k) - a = whole_a[0:m, 0:k] - scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 - scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 - - print("in test") - - out = ops.cutlass_scaled_sparse_mm(a, - b_comp, - e, - scale_a, - scale_b, - out_dtype=torch.bfloat16) - baseline = baseline_scaled_mm(a, - b, - scale_a, - scale_b, - out_dtype=torch.bfloat16) - - torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) - - # Test to make sure cuda graphs work class CutlassLayer(torch.nn.Module): diff --git a/tests/kernels/test_semi_structured.py b/tests/kernels/test_semi_structured.py index 783ae12ba6cc7..3d8560238de3d 100644 --- a/tests/kernels/test_semi_structured.py +++ b/tests/kernels/test_semi_structured.py @@ -1,292 +1,132 @@ -import pytest +"""Tests for sparse cutlass kernels + +Run `pytest tests/kernels/test_semi_structured.py`. +""" +from typing import Optional, Tuple, Type + import torch -from tests.quantization.utils import is_quant_method_supported -from vllm.model_executor.layers.sparsity.utils.cusparse_2_4_utils import ( - clear_cache, compress_to_torch_sparse_semi_structured_mat, - decompress_torch_sparse_semi_structured_mat, dense_matmul, - generate_pruned_semi_structured_mat, get_random_mat, - is_semi_structured_supported, semi_structured_dense_sparse_T_gemm, - semi_structured_dense_sparse_T_gemm_scaled, - semi_structured_sparse_dense_gemm, - semi_structured_sparse_dense_gemm_scaled) - -DTYPES = [torch.float16, torch.bfloat16, torch.int8] -SIZES = [(128, 128), (1024, 8192)] -SIZES_FP8 = [(32, 64), (1024, 1024)] -MNK = [(128, 128, 128), (128, 512, 1024), (512, 512, 512), (1024, 2048, 4096)] - - -# From pytorch test -def to_float8(x, dtype=torch.float8_e4m3fn): - finfo = torch.finfo(dtype) - # Calculate the scale as dtype max divided by absmax - scale = finfo.max / x.abs().max().clamp(min=1e-12) - # scale and clamp the tensor to bring it to - # the representative range of float8 data type - # (as default cast is unsaturated) - x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) - # Return both float8 data and the inverse scale (as float), - # as both required as inputs to torch._scaled_mm - return x_scl_sat.to(dtype), scale.float().reciprocal() - - -@pytest.mark.skipif( - not is_semi_structured_supported(), - reason="Semi structured matmul is not supported on this GPU type.") -@pytest.mark.parametrize("size", SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -def test_semi_structured_compress(size, dtype): - input_pruned = generate_pruned_semi_structured_mat(*size, dtype) - output_pruned = decompress_torch_sparse_semi_structured_mat( - compress_to_torch_sparse_semi_structured_mat(input_pruned)) - torch.testing.assert_close(input_pruned, output_pruned) - - -# TODO modelopt config has to be replaced with corresponding fp8_24 config -@pytest.mark.skipif( - not is_semi_structured_supported() - or not is_quant_method_supported("modelopt"), - reason="Semi structured fp8 matmul is not supported on this GPU type.") -@pytest.mark.parametrize("size", SIZES_FP8) -def test_semi_structured_fp8_compress(size): - dtype = torch.float8_e4m3fn - input_pruned = generate_pruned_semi_structured_mat(*size, dtype) - output_pruned = decompress_torch_sparse_semi_structured_mat( - compress_to_torch_sparse_semi_structured_mat(input_pruned)) - torch.testing.assert_close(input_pruned.to(torch.float32), - output_pruned.to(torch.float32), - rtol=1e-1, - atol=1e-1) - - -@pytest.mark.skipif( - not is_semi_structured_supported(), - reason="Semi structured matmul is not supported on this GPU type.") -@pytest.mark.parametrize("mnk", MNK) -@pytest.mark.parametrize("dtype", DTYPES) -def test_torch_semi_structured_sparse_dense_matmul(mnk, dtype): - M, N, K = mnk - A_pruned = generate_pruned_semi_structured_mat(M, K, dtype) - A = compress_to_torch_sparse_semi_structured_mat(A_pruned) - B = get_random_mat(K, N, dtype) - if dtype is torch.int8: - with pytest.raises(ValueError): - C_sparse = semi_structured_sparse_dense_gemm(A, B) +from vllm import _custom_ops as ops +from vllm.platforms import current_platform + + +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + +capability = current_platform.get_device_capability() +capability = capability[0] * 10 + capability[1] + + +def to_fp8(tensor: torch.Tensor): + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + + +def to_int8(tensor: torch.Tensor): + return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) + + +def rand_int8(shape: tuple, device: str = "cuda"): + return to_int8(torch.rand(shape, device=device) * 255 - 128) + + +def to_bf16(tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(dtype=torch.bfloat16) + + +def to_fp16(tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(dtype=torch.float16) + + +def prune_to_2_4(tensor): + # Reshape tensor to [N, 4] where N is number of groups of 4 + original_shape = tensor.shape + reshaped = tensor.reshape(-1, 4) + + # Get indices of top 2 absolute values in each group of 4 + _, indices = torch.topk(torch.abs(reshaped), k=2, dim=1) + + # Create binary mask + mask = torch.zeros_like(reshaped) + mask.scatter_(dim=1, + index=indices, + src=torch.ones_like(indices, dtype=mask.dtype)) + + # Apply mask and reshape back + pruned = reshaped * mask + + # Turn all -0.0 to 0.0 + pruned[pruned == -0.0] = 0.0 + + return pruned.reshape(original_shape) + + +def make_rand_sparse_tensors( + dtype: torch.dtype, m: int, n: int, k: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + a = torch.randn((m, k), device='cuda') * 5 + b = torch.randn((n, k), device='cuda').t() * 5 + + b = prune_to_2_4(b.t()).t() + + if dtype == torch.int8: + a, b = to_int8(a), to_int8(b) + elif dtype == torch.float8_e4m3fn: + a, b = to_fp8(a), to_fp8(b) + elif dtype == torch.float16: + a, b = to_fp16(a), to_fp16(b) + elif dtype == torch.bfloat16: + a, b = to_bf16(a), to_bf16(b) else: - C_sparse = semi_structured_sparse_dense_gemm(A, B) - C = dense_matmul(A_pruned, B, dtype) - torch.testing.assert_close(C, C_sparse) - - # Verify cache - B = get_random_mat(K, N, dtype) - C = dense_matmul(A_pruned, B, dtype) - C_sparse = semi_structured_sparse_dense_gemm(A, B) - torch.testing.assert_close(C, C_sparse) - - C_sparse = semi_structured_sparse_dense_gemm(A, B, cached=False) - torch.testing.assert_close(C, C_sparse) - clear_cache() - - -@pytest.mark.skipif( - not is_semi_structured_supported(), - reason="Semi structured matmul is not supported on this GPU type.") -@pytest.mark.parametrize("mnk", MNK) -@pytest.mark.parametrize("dtype", DTYPES) -def test_torch_semi_structured_sparse_dense_T_matmul(mnk, dtype): - M, N, K = mnk - A_pruned = generate_pruned_semi_structured_mat(M, K, dtype) - A = compress_to_torch_sparse_semi_structured_mat(A_pruned) - B = get_random_mat(N, K, dtype) - - C_sparse = semi_structured_sparse_dense_gemm(A, B.t()) - C = dense_matmul(A_pruned, B.t(), dtype) - torch.testing.assert_close(C, C_sparse) - - # Verify cache - B = get_random_mat(N, K, dtype) - C = dense_matmul(A_pruned, B.t(), dtype) - C_sparse = semi_structured_sparse_dense_gemm(A, B.t()) - torch.testing.assert_close(C, C_sparse) - - C_sparse = semi_structured_sparse_dense_gemm(A, B.t(), cached=False) - torch.testing.assert_close(C, C_sparse) - clear_cache() - - -# TODO modelopt config has to be replaced with corresponding fp8_24 config -@pytest.mark.skipif( - not is_semi_structured_supported() - or not is_quant_method_supported("modelopt"), - reason="Semi structured fp8 matmul is not supported on this GPU type.") -def test_torch_semi_structured_sparse_dense_T_fp8_matmul(): - M, N, K = (32, 64, 32) - dtype = torch.float8_e4m3fn - A_pruned = generate_pruned_semi_structured_mat(M, N, dtype=dtype) - A = compress_to_torch_sparse_semi_structured_mat(A_pruned) - B = torch.full((K, N), .25, device='cuda', dtype=dtype).t() - - C = dense_matmul(A_pruned, B, dtype=dtype).to(torch.float32) - C_sparse = semi_structured_sparse_dense_gemm(A, B).to(torch.float32) - torch.testing.assert_close(C, C_sparse, rtol=1e-1, atol=1e-1) - - # Cached version - B = torch.full((K, N), .25, device='cuda', dtype=dtype).t() - C = dense_matmul(A_pruned, B, dtype=dtype).to(torch.float32) - C_sparse = semi_structured_sparse_dense_gemm(A, B).to(torch.float32) - torch.testing.assert_close(C, C_sparse, rtol=1e-1, atol=1e-1) - - # Noncached version - C_sparse = semi_structured_sparse_dense_gemm(A, B, cached=False).to( - torch.float32) - torch.testing.assert_close(C, C_sparse, rtol=1e-1, atol=1e-1) - clear_cache() - - -@pytest.mark.skipif( - not is_semi_structured_supported(), - reason="Semi structured matmul is not supported on this GPU type.") -@pytest.mark.parametrize("mnk", MNK) -@pytest.mark.parametrize("dtype", DTYPES) -def test_torch_semi_structured_dense_sparse_T_matmul(mnk, dtype): - M, N, K = mnk - B_T_pruned = generate_pruned_semi_structured_mat(N, K, dtype) - B_T = compress_to_torch_sparse_semi_structured_mat(B_T_pruned) - A = get_random_mat(M, K, dtype) - - C_sparse = semi_structured_dense_sparse_T_gemm(A, B_T) - C = dense_matmul(A, B_T_pruned.t(), dtype) - torch.testing.assert_close(C, C_sparse) - - C_sparse = semi_structured_dense_sparse_T_gemm(A, B_T, cached=False) - C = dense_matmul(A, B_T_pruned.t(), dtype) - torch.testing.assert_close(C, C_sparse) - clear_cache() - - -# TODO modelopt config has to be replaced with corresponding fp8_24 config -@pytest.mark.skipif( - not is_semi_structured_supported() - or not is_quant_method_supported("modelopt"), - reason="Semi structured fp8 matmul is not supported on this GPU type.") -def test_torch_semi_structured_dense_sparse_T_fp8_matmul(): - M, N, K = (32, 64, 32) - dtype = torch.float8_e4m3fn - B_T_pruned = generate_pruned_semi_structured_mat(N, K, dtype=dtype) - B_T = compress_to_torch_sparse_semi_structured_mat(B_T_pruned) - A = torch.full((M, K), .25, device='cuda', dtype=dtype) - - C_sparse = semi_structured_dense_sparse_T_gemm(A, B_T).to(torch.float32) - C = dense_matmul(A, B_T_pruned.t(), dtype=dtype).to(torch.float32) - torch.testing.assert_close(C, C_sparse, rtol=1e-1, atol=1e-1) - - C_sparse = semi_structured_dense_sparse_T_gemm(A, B_T).to(torch.float32) - C = dense_matmul(A, B_T_pruned.t(), dtype=dtype).to(torch.float32) - torch.testing.assert_close(C, C_sparse, rtol=1e-1, atol=1e-1) - clear_cache() - - -@pytest.mark.skipif( - not is_semi_structured_supported() - or not is_quant_method_supported("modelopt"), - reason="Semi structured fp8 matmul is not supported on this GPU type.") -def test_torch_semi_structured_sparse_dense_T_fp8_scaled_matmul(): - M, N, K = (32, 64, 32) - A_pruned = generate_pruned_semi_structured_mat(M, N, dtype=torch.float16) - A_pruned_fp8, scale_A = to_float8(A_pruned) - B = torch.rand((K, N), device='cuda').to(torch.float16).t() - B_fp8, scale_B = to_float8(B) - - A_fp8_sparse = compress_to_torch_sparse_semi_structured_mat(A_pruned_fp8) - - C = torch._scaled_mm(A_pruned_fp8, - B_fp8, - scale_a=scale_A, - scale_b=scale_B, - out_dtype=torch.float32) - C_sparse = semi_structured_sparse_dense_gemm_scaled(A_fp8_sparse, - B_fp8, - scale_a=scale_A, - scale_b=scale_B).to( - torch.float32) - torch.testing.assert_close(C, C_sparse, rtol=7e-2, atol=7e-2) - - # cached - B = torch.rand((K, N), device='cuda').to(torch.float16).t() - B_fp8, scale_B = to_float8(B) - - C = torch._scaled_mm(A_pruned_fp8, - B_fp8, - scale_a=scale_A, - scale_b=scale_B, - out_dtype=torch.float32) - C_sparse = semi_structured_sparse_dense_gemm_scaled(A_fp8_sparse, - B_fp8, - scale_a=scale_A, - scale_b=scale_B).to( - torch.float32) - torch.testing.assert_close(C, C_sparse, rtol=7e-2, atol=7e-2) - - # noncached - C_sparse = semi_structured_sparse_dense_gemm_scaled(A_fp8_sparse, - B_fp8, - scale_a=scale_A, - scale_b=scale_B, - cached=False).to( - torch.float32) - torch.testing.assert_close(C, C_sparse, rtol=7e-2, atol=7e-2) - clear_cache() - - -@pytest.mark.skipif( - not is_semi_structured_supported() - or not is_quant_method_supported("modelopt"), - reason="Semi structured fp8 matmul is not supported on this GPU type.") -def test_torch_semi_structured_dense_sparse_T_fp8_scaled_matmul(): - M, N, K = (32, 64, 32) - A = torch.rand((M, K), device='cuda', dtype=torch.float16) - A_fp8, scale_a = to_float8(A) - B_T_pruned = generate_pruned_semi_structured_mat(N, K, dtype=torch.float16) - B_T_pruned_fp8, scale_b = to_float8(B_T_pruned) - B_T_packed = compress_to_torch_sparse_semi_structured_mat(B_T_pruned_fp8) - - C_sparse = semi_structured_dense_sparse_T_gemm_scaled(A_fp8, - B_T_packed, - scale_a=scale_a, - scale_b=scale_b).to( - torch.float32) - C = torch._scaled_mm(B_T_pruned_fp8, - A_fp8.t(), - scale_a=scale_b, - scale_b=scale_a, - out_dtype=torch.float32).t() - torch.testing.assert_close(C, C_sparse, rtol=7e-2, atol=7e-2) - clear_cache() - - -@pytest.mark.skipif( - not is_semi_structured_supported(), - reason="Semi structured matmul is not supported on this GPU type.") -def test_torch_semi_structured_sparse_dense_t_int8_scaled_matmul(): - dtype = torch.int8 - M, N, K = (32, 64, 32) - A_pruned = generate_pruned_semi_structured_mat(M, K, dtype) - A = compress_to_torch_sparse_semi_structured_mat(A_pruned) - B = get_random_mat(N, K, dtype) - - scale_a = torch.tensor(2.0, dtype=torch.float32, device='cuda') - scale_b = torch.tensor(2.0, dtype=torch.float32, device='cuda') - - C = dense_matmul(A_pruned, - B.t(), - dtype=dtype, - scale_a=scale_a, - scale_b=scale_b).to(torch.float32) - C_sparse = semi_structured_sparse_dense_gemm_scaled(A, - B.t(), - scale_a=scale_a, - scale_b=scale_b).to( - torch.float32) - torch.testing.assert_close(C, C_sparse, rtol=1e-1, atol=1e-1) - clear_cache() + raise ValueError("unsupported dtype") + + b_compressed, e = ops.cutlass_compress_entry(b.t()) + + # Compressed B, Metadata, Original A, B + return b_compressed, e, a, b + + +def baseline_scaled_mm(a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: Type[torch.dtype], + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + output = (scale_a * (scale_b * (torch.mm( + a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype) + if bias is not None: + output = output + bias + + return output + + +# Test working with a subset of A and B for sparse matmul +def test_cutlass_sparse_subset(): + big_m = 1024 + m, n, k = 512, 512, 512 + + # Create tensors + b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, + big_m, n, k) + a = whole_a[0:m, 0:k] + scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 + scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 + + print("in test") + + out = ops.cutlass_scaled_sparse_mm(a, + b_comp, + e, + scale_a, + scale_b, + out_dtype=torch.bfloat16) + baseline = baseline_scaled_mm(a, + b, + scale_a, + scale_b, + out_dtype=torch.bfloat16) + + torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) + diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index d99d2340275af..72e744fc9ba0d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -537,8 +537,7 @@ def cutlass_compress_entry(a: torch.Tensor) \ assert (a.dtype is torch.int8 or a.dtype is torch.float8_e4m3fn or \ a.dtype is torch.bfloat16 or a.dtype is torch.float16) - # Not exactly sure what the right value would be based on cutlass definitions - # Let's assume e.dtype: torch.uint8 so elemsPerElemE = 8b / 2b_per_nz = 4 + # e.dtype: torch.uint8 so elemsPerElemE = 8b / 2b_per_nz = 4 elemsPerElemE = 4 m = a.shape[0]