Skip to content

Commit

Permalink
Minimize includes and reformat the compressor file
Browse files Browse the repository at this point in the history
  • Loading branch information
Faraz9877 committed Dec 14, 2024
1 parent 8879323 commit 18ba3de
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 57 deletions.
10 changes: 7 additions & 3 deletions benchmarks/cutlass_benchmarks/sparse_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
print("Incorrect results")
print(out)
print(out_ref)
else:
print("Correct results")

timers = []
# pytorch impl - bfloat16
Expand Down Expand Up @@ -100,9 +102,9 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
assert dtype == torch.float8_e4m3fn
b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n,
k)
scale_a = (torch.randn((m, 1), device="cuda", dtype=torch.float32))
scale_b = (torch.randn((1, n), device="cuda", dtype=torch.float32))
bias = torch.rand((n, ), device="cuda", dtype=torch.bfloat16) * 10
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)

out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b,
torch.bfloat16)
Expand All @@ -112,6 +114,8 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
print("Incorrect results")
print(out)
print(out_ref)
else:
print("Correct results")

timers = []

Expand Down
7 changes: 2 additions & 5 deletions benchmarks/cutlass_benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,8 @@ def prune_to_2_4(tensor):

def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
# a = torch.randn((m, k), device='cuda') * 5
# b = torch.randn((n, k), device='cuda').t() * 5

a = torch.ones((m, k), device='cuda')
b = torch.ones((n, k), device='cuda').t()
a = torch.randn((m, k), device='cuda') * 5
b = torch.randn((n, k), device='cuda').t() * 5

b = prune_to_2_4(b.t()).t()

Expand Down
4 changes: 2 additions & 2 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias);

bool cutlass_sparse_compress(torch::Tensor& a_compressed, torch::Tensor& e,
torch::Tensor const& a);
bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed,
torch::Tensor& e, torch::Tensor const& a);
#endif

void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
Expand Down
100 changes: 60 additions & 40 deletions csrc/sparse/cutlass/sparse_compressor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,60 +4,89 @@

#include "sparse_scaled_mm_c3x.cuh"

#include "cute/tensor.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/numeric_conversion.h"

#include "cutlass/transform/device/transform_universal_adapter.hpp"
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"

#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"

#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
// clang-format on

using namespace cute;
using namespace vllm;

/// Make A structured sparse by replacing elements with 0 and compress it
template <typename ElementA_>
bool sparsify_and_compress(torch::Tensor& a_compressed, torch::Tensor& e,
torch::Tensor const& a) {
template <typename ElementA_, typename ElementAcc_>
bool cutlass_sparse_compress(torch::Tensor& a_compressed, torch::Tensor& e,
torch::Tensor const& a) {
// Checks for conformality
TORCH_CHECK(a.dtype() == torch::kInt8 || a.dtype() == torch::kFloat8_e4m3fn ||
a.dtype() == torch::kFloat16 || a.dtype() == torch::kBFloat16);
TORCH_CHECK(a.dim() == 2)
// Check for strides and alignment
TORCH_CHECK(a.stride(0) % 4 == 0) // Required for semi-structured sparsity
TORCH_CHECK(a.stride(1) == 1)

int m = a.size(0);
int k = a.size(1);

using ProblemShape = Shape<int, int, int, int>;
// Sparse kernel setup; this kernel is not used for matmul,
// but just for setting up the compressor utility
// A matrix configuration
using ElementA = ElementA_;
using LayoutTagA = cutlass::layout::RowMajor;
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
// B matrix configuration
using ElementB = ElementA;
using LayoutTagB = cutlass::layout::ColumnMajor;
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
// C/D matrix configuration
using ElementC = float;
using LayoutTagC = cutlass::layout::ColumnMajor;
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
// Core kernel configurations
using ElementAccumulator = ElementAcc_;
using TileShape = Shape<_128, _128, _128>;
using TileShapeRef = Shape<_128, _128, _64>;
using ClusterShape = Shape<_1, _2, _1>;
using KernelSchedule = typename std::conditional<
std::is_same_v<ElementA, cutlass::float_e4m3_t>,
cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum,
cutlass::gemm::KernelTmaWarpSpecialized>::type;

using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
using ProblemShape = Shape<int, int, int, int>;

using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator, ElementC, LayoutTagC,
AlignmentC, ElementC, LayoutTagC, AlignmentC,
EpilogueSchedule>::CollectiveOp;

using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, ElementA,
LayoutTagA, AlignmentA, ElementB, LayoutTagB, AlignmentB,
ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;

using GemmKernel =
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
CollectiveEpilogue>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

// Layouts for reference (non-sparse) tensors
using StrideA = cutlass::gemm::TagToStrideA_t<LayoutTagA>;
using StrideE = StrideA;

using Gemm = typename sm90_config_default<ElementA, cutlass::half_t,
c3x::ScaledEpilogue>::Cutlass3xGemm;

using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;

int64_t lda = a.stride(0);

using StrideA = Stride<int64_t, Int<1>, int64_t>;
using StrideB = Stride<int64_t, Int<1>, int64_t>;
using StrideC = typename Gemm::StrideC;

StrideA a_stride{lda, Int<1>{}, 0};

using GemmKernel = typename Gemm::GemmKernel;
// The n (=1) dimension does not matter for the compressor
typename GemmKernel::ProblemShape prob_shape{m, 1, k, 1};

using LayoutA = typename GemmKernel::CollectiveMainloop::LayoutA;
Expand All @@ -66,9 +95,6 @@ bool sparsify_and_compress(torch::Tensor& a_compressed, torch::Tensor& e,
using ElementE = typename GemmKernel::CollectiveMainloop::ElementE;
using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig;

LayoutA a_layout = SparseConfig::fill_layoutA(prob_shape);
LayoutE e_layout = SparseConfig::fill_layoutE(prob_shape);

// Offline compressor kernel
using CompressorUtility =
cutlass::transform::kernel::StructuredSparseCompressorUtility<
Expand All @@ -85,9 +111,6 @@ bool sparsify_and_compress(torch::Tensor& a_compressed, torch::Tensor& e,
auto [M, N, K, L] = prob_shape;

StrideA stride_A;
StrideA stride_A_compressed;
StrideE stride_E;

stride_A =
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));

Expand All @@ -103,11 +126,6 @@ bool sparsify_and_compress(torch::Tensor& a_compressed, torch::Tensor& e,
auto e_ptr =
static_cast<typename Gemm::CollectiveMainloop::ElementE*>(e.data_ptr());

stride_A_compressed =
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, KC, L));
stride_E =
cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(ME, KE, L));

cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count =
Expand All @@ -128,16 +146,18 @@ bool sparsify_and_compress(torch::Tensor& a_compressed, torch::Tensor& e,
return true;
}

bool cutlass_sparse_compress(torch::Tensor& a_compressed, torch::Tensor& e,
torch::Tensor const& a) {
bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed,
torch::Tensor& e, torch::Tensor const& a) {
if (a.dtype() == torch::kBFloat16) {
return sparsify_and_compress<cutlass::bfloat16_t>(a_compressed, e, a);
return cutlass_sparse_compress<cutlass::bfloat16_t, float>(a_compressed, e,
a);
} else if (a.dtype() == torch::kFloat16) {
return sparsify_and_compress<cutlass::half_t>(a_compressed, e, a);
return cutlass_sparse_compress<cutlass::half_t, float>(a_compressed, e, a);
} else if (a.dtype() == torch::kFloat8_e4m3fn) {
return sparsify_and_compress<cutlass::float_e4m3_t>(a_compressed, e, a);
return cutlass_sparse_compress<cutlass::float_e4m3_t, float>(a_compressed,
e, a);
} else if (a.dtype() == torch::kInt8) {
return sparsify_and_compress<int8_t>(a_compressed, e, a);
return cutlass_sparse_compress<int8_t, int32_t>(a_compressed, e, a);
}
return false;
}
1 change: 0 additions & 1 deletion csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

#if defined CUDA_VERSION && CUDA_VERSION >= 12000
#include "sparse_scaled_mm_c3x.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
// clang-format on

using namespace cute;
Expand Down
5 changes: 2 additions & 3 deletions csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@
#include "cutlass/cutlass.h"

#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"

#include "cutlass_extensions/cute_utils.cuh"
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
#include "core/math.hpp"
#include "cutlass_extensions/cute_utils.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "cutlass_extensions/common.hpp"
#include "cutlass_extensions/torch_utils.hpp"
// clang-format on
Expand Down
4 changes: 2 additions & 2 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,9 +324,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {

// CUTLASS sparse matrix compressor
ops.def(
"cutlass_sparse_compress(Tensor! a_compressed, Tensor! e,"
"cutlass_sparse_compress_entry(Tensor! a_compressed, Tensor! e,"
" Tensor a) -> bool");
ops.impl("cutlass_sparse_compress", &cutlass_sparse_compress);
ops.impl("cutlass_sparse_compress_entry", &cutlass_sparse_compress_entry);

// Mamba selective scan kernel
ops.def(
Expand Down
2 changes: 1 addition & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ def cutlass_sparse_compress(a: torch.Tensor) \
dtype=torch.uint8,
device=a.device)

if not (torch.ops._C.cutlass_sparse_compress(a_nzs, a_meta, a)):
if not (torch.ops._C.cutlass_sparse_compress_entry(a_nzs, a_meta, a)):
raise ValueError

assert (a_nzs.is_contiguous())
Expand Down

0 comments on commit 18ba3de

Please sign in to comment.