Skip to content

Commit

Permalink
Add the heuristics for fp8
Browse files Browse the repository at this point in the history
  • Loading branch information
Faraz9877 authored and robertgshaw2-neuralmagic committed Nov 19, 2024
1 parent abfd85d commit 540d0ce
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 21 deletions.
4 changes: 2 additions & 2 deletions benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)

out = ops.cutlass_scaled_sparse_mm(b_compressed, e, aT, scale_b, scale_a, torch.float16)
out_ref = ops.cutlass_scaled_mm(a, bT, scale_a, scale_b, torch.float16)
out = ops.cutlass_scaled_sparse_mm(b_compressed, e, aT, scale_b, scale_a, torch.bfloat16)
out_ref = ops.cutlass_scaled_mm(a, bT, scale_a, scale_b, torch.bfloat16)

if not torch.allclose(out.t(), out_ref):
print("Incorrect result")
Expand Down
22 changes: 10 additions & 12 deletions csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -57,24 +57,22 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
uint32_t const mp2 =
std::max(static_cast<uint32_t>(64), next_pow_2(n)); // next power of 2

// if (mp2 <= 64) {
// // n in [1, 64]
// return cutlass_sparse_gemm_caller<Cutlass3xGemmM64>(
// out, a, e, b, std::forward<EpilogueArgs>(args)...);
// } else if (mp2 <= 128) {
if (mp2 <= 128) {
if (mp2 <= 64) {
// n in [1, 64]
return cutlass_sparse_gemm_caller<Cutlass3xGemmM64>(
out, a, e, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// n in (64, 128]
return cutlass_sparse_gemm_caller<Cutlass3xGemmM128>(
out, a, e, b, std::forward<EpilogueArgs>(args)...);
// } else if (mp2 <= 256) {
} else {
} else if (mp2 <= 256) {
// n in (128, 256]
return cutlass_sparse_gemm_caller<Cutlass3xGemmM256>(
out, a, e, b, std::forward<EpilogueArgs>(args)...);
// } else {
// // n in (256, inf)
// return cutlass_sparse_gemm_caller<Cutlass3xGemmM512>(
// out, a, e, b, std::forward<EpilogueArgs>(args)...);
} else {
// n in (256, inf)
return cutlass_sparse_gemm_caller<Cutlass3xGemmM512>(
out, a, e, b, std::forward<EpilogueArgs>(args)...);
}
}

Expand Down
14 changes: 7 additions & 7 deletions csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -533,9 +533,9 @@ struct sm90_fp8_config_M64 {
// M in [1, 64]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecializedCooperative;
using TileShape = Shape<_128, _64, _256>;
using TileShape = Shape<_64, _64, _256>;
using ClusterShape = Shape<_1, _1, _1>;

using TileSchedule = cutlass::gemm::PersistentScheduler;
Expand All @@ -552,9 +552,9 @@ struct sm90_fp8_config_M128 {
// M in (64, 128]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecializedCooperative;
using TileShape = Shape<_128, _64, _256>;
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _128, _256>;
using ClusterShape = Shape<_1, _1, _1>;

using TileSchedule = cutlass::gemm::PersistentScheduler;
Expand Down Expand Up @@ -590,9 +590,9 @@ struct sm90_fp8_config_M512 {
// M in (256, ]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecializedCooperative;
using TileShape = Shape<_256, _128, _128>;
using TileShape = Shape<_128, _128, _256>;
using ClusterShape = Shape<_1, _1, _1>;

using TileSchedule = cutlass::gemm::PersistentScheduler;
Expand Down

0 comments on commit 540d0ce

Please sign in to comment.