Skip to content

Commit

Permalink
Clean up code
Browse files Browse the repository at this point in the history
  • Loading branch information
Faraz9877 committed Dec 12, 2024
1 parent b039820 commit 67aae3e
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 27 deletions.
6 changes: 3 additions & 3 deletions csrc/sparse/cutlass/sparse_compressor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ bool sparsify_and_compress(torch::Tensor& a_compressed, torch::Tensor& e,
c3x::ScaledEpilogue>::Cutlass3xGemm,
typename std::conditional<
std::is_same_v<ElementA, cutlass::half_t>,
typename sm90_fp16_config_default<cutlass::half_t,
cutlass::half_t,
c3x::ScaledEpilogue>::Cutlass3xGemm,
typename sm90_fp16_config_default<
cutlass::half_t, cutlass::half_t,
c3x::ScaledEpilogue>::Cutlass3xGemm,
typename sm90_bf16_config_default<
cutlass::bfloat16_t, cutlass::half_t,
c3x::ScaledEpilogue>::Cutlass3xGemm>::type>::type>::type;
Expand Down
2 changes: 1 addition & 1 deletion csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#include "cutlass_extensions/common.hpp"
// clang-format on

#include "sparse_scaled_mm_c3x.cuh"
#include "sparse_scaled_mm_c3x.cuh"

using namespace cute;
using namespace vllm;
Expand Down
48 changes: 26 additions & 22 deletions csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
using namespace cute;

/*
This file defines sparse quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm90a (Hopper) or later.
This file defines sparse quantized GEMM operations using the CUTLASS 3.x API,
for NVIDIA GPUs with sm90a (Hopper) or later.
Epilogue functions can be defined to post-process the output before it is
written to GPU memory.
Expand Down Expand Up @@ -192,7 +192,7 @@ struct sm90_fp16_config_default {
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>;
KernelSchedule, EpilogueSchedule, float>;
};

template <typename InType, typename OutType,
Expand All @@ -206,7 +206,7 @@ struct sm90_bf16_config_default {
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>;
KernelSchedule, EpilogueSchedule, float>;
};

//////////////////////// Cherry-Picking Kernels ////////////////////////
Expand All @@ -220,7 +220,7 @@ struct sm90_fp8_config_1 {
using ClusterShape = Shape<_8, _1, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>;
KernelSchedule, EpilogueSchedule, float>;
};

template <typename InType, typename OutType,
Expand All @@ -235,7 +235,7 @@ struct sm90_fp8_config_2 {
using ClusterShape = Shape<_8, _1, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>;
KernelSchedule, EpilogueSchedule, float>;
};

template <typename InType, typename OutType,
Expand All @@ -248,7 +248,7 @@ struct sm90_fp8_config_3 {
using ClusterShape = Shape<_1, _2, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>;
KernelSchedule, EpilogueSchedule, float>;
};

template <typename InType, typename OutType,
Expand All @@ -262,7 +262,7 @@ struct sm90_fp8_config_4 {
using ClusterShape = Shape<_8, _1, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>;
KernelSchedule, EpilogueSchedule, float>;
};

template <typename InType, typename OutType,
Expand All @@ -276,7 +276,7 @@ struct sm90_fp8_config_5 {
using ClusterShape = Shape<_8, _1, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>;
KernelSchedule, EpilogueSchedule, float>;
};

template <typename InType, typename OutType,
Expand All @@ -289,7 +289,7 @@ struct sm90_fp8_config_6 {
using ClusterShape = Shape<_1, _2, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>;
KernelSchedule, EpilogueSchedule, float>;
};

template <typename InType, typename OutType,
Expand All @@ -304,7 +304,7 @@ struct sm90_fp8_config_7 {
using ClusterShape = Shape<_1, _1, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>;
KernelSchedule, EpilogueSchedule, float>;
};

template <typename InType, typename OutType,
Expand All @@ -319,7 +319,7 @@ struct sm90_fp8_config_8 {
using ClusterShape = Shape<_8, _1, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>;
KernelSchedule, EpilogueSchedule, float>;
};
////////////////////////////////////////////////////////////////////////

Expand All @@ -334,7 +334,7 @@ struct sm90_fp8_config_default {
using ClusterShape = Shape<_1, _2, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>;
KernelSchedule, EpilogueSchedule, float>;
};

template <typename InType, typename OutType,
Expand All @@ -352,7 +352,8 @@ struct sm90_fp8_config_M64 {

using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float, TileSchedule>;
KernelSchedule, EpilogueSchedule, float,
TileSchedule>;
};

template <typename InType, typename OutType,
Expand All @@ -370,7 +371,8 @@ struct sm90_fp8_config_M128 {

using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float, TileSchedule>;
KernelSchedule, EpilogueSchedule, float,
TileSchedule>;
};

template <typename InType, typename OutType,
Expand All @@ -389,7 +391,8 @@ struct sm90_fp8_config_M256 {

using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float, TileSchedule>;
KernelSchedule, EpilogueSchedule, float,
TileSchedule>;
};

template <typename InType, typename OutType,
Expand All @@ -408,7 +411,8 @@ struct sm90_fp8_config_M512 {

using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float, TileSchedule>;
KernelSchedule, EpilogueSchedule, float,
TileSchedule>;
};

template <typename InType, typename OutType,
Expand All @@ -423,7 +427,7 @@ struct sm90_int8_config_default {
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, int32_t>;
KernelSchedule, EpilogueSchedule, int32_t>;
};

template <typename InType, typename OutType,
Expand All @@ -438,7 +442,7 @@ struct sm90_int8_config_M128 {
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, int32_t>;
KernelSchedule, EpilogueSchedule, int32_t>;
};

template <typename InType, typename OutType,
Expand All @@ -452,7 +456,7 @@ struct sm90_int8_config_M64 {
using ClusterShape = Shape<_1, _1, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, int32_t>;
KernelSchedule, EpilogueSchedule, int32_t>;
};

template <typename InType, typename OutType,
Expand All @@ -466,7 +470,7 @@ struct sm90_int8_config_M32_NBig {
using ClusterShape = Shape<_1, _4, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, int32_t>;
KernelSchedule, EpilogueSchedule, int32_t>;
};

template <typename InType, typename OutType,
Expand All @@ -480,7 +484,7 @@ struct sm90_int8_config_M32_NSmall {
using ClusterShape = Shape<_1, _8, _1>;
using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, int32_t>;
KernelSchedule, EpilogueSchedule, int32_t>;
};

} // namespace
4 changes: 3 additions & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,9 @@ def cutlass_scaled_mm_azp(a: torch.Tensor,

def cutlass_compress_entry(a: torch.Tensor) \
-> Tuple[torch.Tensor, torch.Tensor]:
assert (a.dtype is [torch.int8, torch.float8_e4m3fn, torch.bfloat16, torch.float16])
assert (a.dtype in [
torch.int8, torch.float8_e4m3fn, torch.bfloat16, torch.float16
])

# e.dtype: torch.uint8 so elemsPerElemE = 8b / 2b_per_nz = 4
elemsPerElemE = 4
Expand Down

0 comments on commit 67aae3e

Please sign in to comment.