From 18ba3de38e15cd4edb7579c3d89a50cc1d72c60e Mon Sep 17 00:00:00 2001 From: Faraz Shahsavan Date: Sat, 14 Dec 2024 06:28:59 +0000 Subject: [PATCH] Minimize includes and reformat the compressor file --- .../cutlass_benchmarks/sparse_benchmarks.py | 10 +- benchmarks/cutlass_benchmarks/utils.py | 7 +- csrc/ops.h | 4 +- csrc/sparse/cutlass/sparse_compressor.cu | 100 +++++++++++------- csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu | 1 - csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh | 5 +- csrc/torch_bindings.cpp | 4 +- vllm/_custom_ops.py | 2 +- 8 files changed, 76 insertions(+), 57 deletions(-) diff --git a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py index 2bf65182c76ec..3d1c5e392f9e2 100644 --- a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py @@ -54,6 +54,8 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, print("Incorrect results") print(out) print(out_ref) + else: + print("Correct results") timers = [] # pytorch impl - bfloat16 @@ -100,9 +102,9 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, assert dtype == torch.float8_e4m3fn b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k) - scale_a = (torch.randn((m, 1), device="cuda", dtype=torch.float32)) - scale_b = (torch.randn((1, n), device="cuda", dtype=torch.float32)) - bias = torch.rand((n, ), device="cuda", dtype=torch.bfloat16) * 10 + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, torch.bfloat16) @@ -112,6 +114,8 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, print("Incorrect results") print(out) print(out_ref) + else: + print("Correct results") timers = [] diff --git a/benchmarks/cutlass_benchmarks/utils.py b/benchmarks/cutlass_benchmarks/utils.py index c53cee52642f4..ef06fcd6604dd 100644 --- a/benchmarks/cutlass_benchmarks/utils.py +++ b/benchmarks/cutlass_benchmarks/utils.py @@ -62,11 +62,8 @@ def prune_to_2_4(tensor): def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int, k: int) -> Tuple[torch.Tensor, torch.Tensor]: - # a = torch.randn((m, k), device='cuda') * 5 - # b = torch.randn((n, k), device='cuda').t() * 5 - - a = torch.ones((m, k), device='cuda') - b = torch.ones((n, k), device='cuda').t() + a = torch.randn((m, k), device='cuda') * 5 + b = torch.randn((n, k), device='cuda').t() * 5 b = prune_to_2_4(b.t()).t() diff --git a/csrc/ops.h b/csrc/ops.h index d1b2e212f8a44..211ddd690d7dc 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -161,8 +161,8 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b_scales, c10::optional const& bias); -bool cutlass_sparse_compress(torch::Tensor& a_compressed, torch::Tensor& e, - torch::Tensor const& a); +bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed, + torch::Tensor& e, torch::Tensor const& a); #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, diff --git a/csrc/sparse/cutlass/sparse_compressor.cu b/csrc/sparse/cutlass/sparse_compressor.cu index d8f1e5e852a40..aa1dee73a70ce 100644 --- a/csrc/sparse/cutlass/sparse_compressor.cu +++ b/csrc/sparse/cutlass/sparse_compressor.cu @@ -4,60 +4,89 @@ #include "sparse_scaled_mm_c3x.cuh" -#include "cute/tensor.hpp" -#include "cutlass/numeric_types.h" #include "cutlass/numeric_conversion.h" - #include "cutlass/transform/device/transform_universal_adapter.hpp" #include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" #include "cutlass/epilogue/collective/default_epilogue.hpp" #include "cutlass/util/host_tensor.h" #include "cutlass/util/packed_stride.hpp" - -#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" // clang-format on using namespace cute; using namespace vllm; /// Make A structured sparse by replacing elements with 0 and compress it -template -bool sparsify_and_compress(torch::Tensor& a_compressed, torch::Tensor& e, - torch::Tensor const& a) { +template +bool cutlass_sparse_compress(torch::Tensor& a_compressed, torch::Tensor& e, + torch::Tensor const& a) { // Checks for conformality TORCH_CHECK(a.dtype() == torch::kInt8 || a.dtype() == torch::kFloat8_e4m3fn || a.dtype() == torch::kFloat16 || a.dtype() == torch::kBFloat16); TORCH_CHECK(a.dim() == 2) // Check for strides and alignment + TORCH_CHECK(a.stride(0) % 4 == 0) // Required for semi-structured sparsity TORCH_CHECK(a.stride(1) == 1) int m = a.size(0); int k = a.size(1); - using ProblemShape = Shape; + // Sparse kernel setup; this kernel is not used for matmul, + // but just for setting up the compressor utility + // A matrix configuration using ElementA = ElementA_; using LayoutTagA = cutlass::layout::RowMajor; + constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + // B matrix configuration + using ElementB = ElementA; + using LayoutTagB = cutlass::layout::ColumnMajor; + constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + // C/D matrix configuration + using ElementC = float; + using LayoutTagC = cutlass::layout::ColumnMajor; + constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + // Core kernel configurations + using ElementAccumulator = ElementAcc_; + using TileShape = Shape<_128, _128, _128>; + using TileShapeRef = Shape<_128, _128, _64>; + using ClusterShape = Shape<_1, _2, _1>; + using KernelSchedule = typename std::conditional< + std::is_same_v, + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum, + cutlass::gemm::KernelTmaWarpSpecialized>::type; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using ProblemShape = Shape; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, + ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, ElementC, LayoutTagC, + AlignmentC, ElementC, LayoutTagC, AlignmentC, + EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, ElementA, + LayoutTagA, AlignmentA, ElementB, LayoutTagB, AlignmentB, + ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - // Layouts for reference (non-sparse) tensors using StrideA = cutlass::gemm::TagToStrideA_t; using StrideE = StrideA; - using Gemm = typename sm90_config_default::Cutlass3xGemm; - - using ElementAB = typename Gemm::ElementAB; - using ElementD = typename Gemm::ElementD; - - int64_t lda = a.stride(0); - using StrideA = Stride, int64_t>; - using StrideB = Stride, int64_t>; - using StrideC = typename Gemm::StrideC; - StrideA a_stride{lda, Int<1>{}, 0}; - - using GemmKernel = typename Gemm::GemmKernel; + // The n (=1) dimension does not matter for the compressor typename GemmKernel::ProblemShape prob_shape{m, 1, k, 1}; using LayoutA = typename GemmKernel::CollectiveMainloop::LayoutA; @@ -66,9 +95,6 @@ bool sparsify_and_compress(torch::Tensor& a_compressed, torch::Tensor& e, using ElementE = typename GemmKernel::CollectiveMainloop::ElementE; using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig; - LayoutA a_layout = SparseConfig::fill_layoutA(prob_shape); - LayoutE e_layout = SparseConfig::fill_layoutE(prob_shape); - // Offline compressor kernel using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< @@ -85,9 +111,6 @@ bool sparsify_and_compress(torch::Tensor& a_compressed, torch::Tensor& e, auto [M, N, K, L] = prob_shape; StrideA stride_A; - StrideA stride_A_compressed; - StrideE stride_E; - stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); @@ -103,11 +126,6 @@ bool sparsify_and_compress(torch::Tensor& a_compressed, torch::Tensor& e, auto e_ptr = static_cast(e.data_ptr()); - stride_A_compressed = - cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, KC, L)); - stride_E = - cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(ME, KE, L)); - cutlass::KernelHardwareInfo hw_info; hw_info.device_id = 0; hw_info.sm_count = @@ -128,16 +146,18 @@ bool sparsify_and_compress(torch::Tensor& a_compressed, torch::Tensor& e, return true; } -bool cutlass_sparse_compress(torch::Tensor& a_compressed, torch::Tensor& e, - torch::Tensor const& a) { +bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed, + torch::Tensor& e, torch::Tensor const& a) { if (a.dtype() == torch::kBFloat16) { - return sparsify_and_compress(a_compressed, e, a); + return cutlass_sparse_compress(a_compressed, e, + a); } else if (a.dtype() == torch::kFloat16) { - return sparsify_and_compress(a_compressed, e, a); + return cutlass_sparse_compress(a_compressed, e, a); } else if (a.dtype() == torch::kFloat8_e4m3fn) { - return sparsify_and_compress(a_compressed, e, a); + return cutlass_sparse_compress(a_compressed, + e, a); } else if (a.dtype() == torch::kInt8) { - return sparsify_and_compress(a_compressed, e, a); + return cutlass_sparse_compress(a_compressed, e, a); } return false; } \ No newline at end of file diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu index 5aa8ade0df376..b50e9a3a2c240 100644 --- a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu +++ b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu @@ -4,7 +4,6 @@ #if defined CUDA_VERSION && CUDA_VERSION >= 12000 #include "sparse_scaled_mm_c3x.cuh" -#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" // clang-format on using namespace cute; diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh index 9267b87cd3cc7..81a8819bde60a 100644 --- a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh +++ b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh @@ -9,13 +9,12 @@ #include "cutlass/cutlass.h" #include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/kernel/gemm_universal.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" -#include "cutlass_extensions/cute_utils.cuh" -#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp" #include "core/math.hpp" +#include "cutlass_extensions/cute_utils.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" #include "cutlass_extensions/common.hpp" #include "cutlass_extensions/torch_utils.hpp" // clang-format on diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 546d01e0d9025..99f0a16e8b0f2 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -324,9 +324,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // CUTLASS sparse matrix compressor ops.def( - "cutlass_sparse_compress(Tensor! a_compressed, Tensor! e," + "cutlass_sparse_compress_entry(Tensor! a_compressed, Tensor! e," " Tensor a) -> bool"); - ops.impl("cutlass_sparse_compress", &cutlass_sparse_compress); + ops.impl("cutlass_sparse_compress_entry", &cutlass_sparse_compress_entry); // Mamba selective scan kernel ops.def( diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index dc22d90bd0a5c..accb71c93dbd3 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -580,7 +580,7 @@ def cutlass_sparse_compress(a: torch.Tensor) \ dtype=torch.uint8, device=a.device) - if not (torch.ops._C.cutlass_sparse_compress(a_nzs, a_meta, a)): + if not (torch.ops._C.cutlass_sparse_compress_entry(a_nzs, a_meta, a)): raise ValueError assert (a_nzs.is_contiguous())