From 8d7b0df243770d92f5ee7ae58bdda319564072ec Mon Sep 17 00:00:00 2001 From: Faraz Shahsavan Date: Mon, 18 Nov 2024 17:41:13 +0000 Subject: [PATCH] Enable other datatypes --- .../cutlass_benchmarks/sparse_mm/bench_v1.py | 96 +++++++++++++------ csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu | 2 +- vllm/_custom_ops.py | 4 +- 3 files changed, 69 insertions(+), 33 deletions(-) diff --git a/benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py b/benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py index 76be1cd395bfd..c54b2bf264db9 100644 --- a/benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py +++ b/benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py @@ -5,7 +5,7 @@ import torch import torch.utils.benchmark as TBenchmark from torch.utils.benchmark import Measurement as TMeasurement -from utils import make_rand_sparse_tensors, to_fp16 +from utils import make_rand_sparse_tensors, to_fp16, to_bf16 import vllm._custom_ops as ops @@ -32,12 +32,22 @@ def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str) -> Iterable[TMeasurement]: assert dtype == torch.int8 - a_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k) + # Create tensors + b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k) + aT = a.t() + bT = b.t() scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + out = ops.cutlass_scaled_sparse_mm(b_compressed, e, aT, scale_b, scale_a, torch.bfloat16) + out_ref = ops.cutlass_scaled_mm(a, bT, scale_a, scale_b, torch.bfloat16) + + if not torch.allclose(out.t(), out_ref): + print("Incorrect result") + exit() + timers = [] # pytorch impl - bfloat16 @@ -55,25 +65,25 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, # cutlass impl: bf16 output timers.append( bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm", - ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, + ops.cutlass_scaled_sparse_mm, b_compressed, e, aT, scale_b, scale_a, torch.bfloat16)) # cutlass with bias: bf16 output timers.append( bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm_bias", - ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, torch.bfloat16, + ops.cutlass_scaled_sparse_mm, b_compressed, e, aT, scale_b, scale_a, torch.bfloat16, bias)) # cutlass impl: fp16 output timers.append( bench_fn(label, sub_label, "cutlass_i8_i8_fp16_scaled_sparse_mm", - ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, + ops.cutlass_scaled_sparse_mm, b_compressed, e, aT, scale_b, scale_a, torch.float16)) # cutlass with bias: fp16 output timers.append( bench_fn(label, sub_label, "cutlass_i8_i8_fp16_scaled_sparse_mm_bias", - ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, torch.float16, + ops.cutlass_scaled_sparse_mm, b_compressed, e, aT, scale_b, scale_a, torch.float16, bias.to(dtype=torch.float16))) return timers @@ -172,12 +182,28 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, def bench_fp16(dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str) -> Iterable[TMeasurement]: assert dtype == torch.float16 - a_compressed, e, a, b = make_rand_sparse_tensors(torch.float16, m, n, k) + m, k, n = 1, 128, 256 + + # Create tensors + b_compressed, e, a, b = make_rand_sparse_tensors(torch.float16, m, n, k) + aT = a.t() + bT = b.t() scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + out = ops.cutlass_scaled_sparse_mm(b_compressed, e, aT, scale_b, scale_a, torch.bfloat16) + out_ref = to_bf16(a@bT) + + if not torch.allclose(out.t(), out_ref, rtol=1e-2, atol=1e-2): + print("Incorrect result") + print(out.t()) + print(out_ref) + exit() + else: + print("Correct result") + timers = [] # # pytorch impl w. bf16 @@ -213,25 +239,25 @@ def bench_fp16(dtype: torch.dtype, m: int, k: int, n: int, label: str, # cutlass impl: bf16 output timers.append( bench_fn(label, sub_label, "cutlass_fp16_fp16_bf16_scaled_sparse_mm", - ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, + ops.cutlass_scaled_sparse_mm, b_compressed, e, aT, scale_b, scale_a, torch.bfloat16)) # cutlass impl: fp16 output timers.append( bench_fn(label, sub_label, "cutlass_fp16_fp16_fp16_scaled_sparse_mm", - ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, torch.float16)) + ops.cutlass_scaled_sparse_mm, b_compressed, e, aT, scale_b, scale_a, torch.float16)) - # cutlass impl: bf16 output, with bias - timers.append( - bench_fn(label, sub_label, "cutlass_fp16_fp16_bf16_scaled_sparse_mm_bias", - ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, torch.bfloat16, - bias)) + # # cutlass impl: bf16 output, with bias + # timers.append( + # bench_fn(label, sub_label, "cutlass_fp16_fp16_bf16_scaled_sparse_mm_bias", + # ops.cutlass_scaled_sparse_mm, b_compressed, e, aT, scale_b, scale_a, torch.bfloat16, + # bias)) - # cutlass impl: fp16 output, with bias - timers.append( - bench_fn(label, sub_label, "cutlass_fp16_fp16_fp16_scaled_sparse_mm_bias", - ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, torch.float16, - bias.to(dtype=torch.float16))) + # # cutlass impl: fp16 output, with bias + # timers.append( + # bench_fn(label, sub_label, "cutlass_fp16_fp16_fp16_scaled_sparse_mm_bias", + # ops.cutlass_scaled_sparse_mm, b_compressed, e, aT, scale_b, scale_a, torch.float16, + # bias.to(dtype=torch.float16))) return timers @@ -239,12 +265,22 @@ def bench_fp16(dtype: torch.dtype, m: int, k: int, n: int, label: str, def bench_bf16(dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str) -> Iterable[TMeasurement]: assert dtype == torch.bfloat16 - a_compressed, e, a, b = make_rand_sparse_tensors(torch.bfloat16, m, n, k) + # Create tensors + b_compressed, e, a, b = make_rand_sparse_tensors(torch.bfloat16, m, n, k) + aT = a.t() + bT = b.t() scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + out = ops.cutlass_scaled_sparse_mm(b_compressed, e, aT, scale_b, scale_a, torch.bfloat16) + out_ref = to_bf16(a@bT) + + if not torch.allclose(out.t(), out_ref): + print("Incorrect result") + exit() + timers = [] # # pytorch impl w. bf16 @@ -280,24 +316,24 @@ def bench_bf16(dtype: torch.dtype, m: int, k: int, n: int, label: str, # cutlass impl: bf16 output timers.append( bench_fn(label, sub_label, "cutlass_bf16_bf16_bf16_scaled_sparse_mm", - ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, + ops.cutlass_scaled_sparse_mm, b_compressed, e, aT, scale_b, scale_a, torch.bfloat16)) # cutlass impl: fp16 output timers.append( bench_fn(label, sub_label, "cutlass_bf16_bf16_fp16_scaled_sparse_mm", - ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, torch.float16)) + ops.cutlass_scaled_sparse_mm, b_compressed, e, aT, scale_b, scale_a, torch.float16)) # cutlass impl: bf16 output, with bias timers.append( bench_fn(label, sub_label, "cutlass_bf16_bf16_bf16_scaled_sparse_mm_bias", - ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, torch.bfloat16, + ops.cutlass_scaled_sparse_mm, b_compressed, e, aT, scale_b, scale_a, torch.bfloat16, bias)) # cutlass impl: fp16 output, with bias timers.append( bench_fn(label, sub_label, "cutlass_bf16_bf16_fp16_scaled_sparse_mm_bias", - ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, torch.float16, + ops.cutlass_scaled_sparse_mm, b_compressed, e, aT, scale_b, scale_a, torch.float16, bias.to(dtype=torch.float16))) return timers @@ -305,12 +341,12 @@ def bench_bf16(dtype: torch.dtype, m: int, k: int, n: int, label: str, def bench_v1(dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str) -> Iterable[TMeasurement]: - # if dtype == torch.int8: - # return bench_int8(dtype, m, k, n, label, sub_label) + if dtype == torch.int8: + return bench_int8(dtype, m, k, n, label, sub_label) if dtype == torch.float8_e4m3fn: return bench_fp8(dtype, m, k, n, label, sub_label) - # if dtype == torch.float16: - # return bench_fp16(dtype, m, k, n, label, sub_label) - # if dtype == torch.bfloat16: - # return bench_bf16(dtype, m, k, n, label, sub_label) + if dtype == torch.float16: + return bench_fp16(dtype, m, k, n, label, sub_label) + if dtype == torch.bfloat16: + return bench_bf16(dtype, m, k, n, label, sub_label) raise ValueError("unsupported type") diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu index 4f8359db1518d..461fb30b80347 100644 --- a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu +++ b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu @@ -226,7 +226,7 @@ void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor co out, a, e, b, std::forward(epilogue_args)...); } } - else { + else { // a.dtype() == torch::kBFloat16 TORCH_CHECK(a.dtype() == torch::kBFloat16); TORCH_CHECK(e.dtype() == torch::kUInt8); TORCH_CHECK(b.dtype() == torch::kBFloat16); diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 01a93a76275cd..0759a92b969b8 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -542,8 +542,8 @@ def cutlass_scaled_sparse_mm(a: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: # assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) - assert bias is None or bias.shape[0] == b.shape[ - 1] and bias.dtype == out_dtype + assert bias is None or bias.shape[0] == a.shape[0] \ + and bias.dtype == out_dtype m = a.shape[0] n = b.shape[1]