From 60508ffda91c22e4cde3b18f149d222211db8886 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 18 Dec 2024 09:57:16 -0500 Subject: [PATCH] [Kernel]: Cutlass 2:4 Sparsity + FP8/Int8 Quant Support (#10995) Co-authored-by: Faraz Shahsavan Co-authored-by: ilmarkov Co-authored-by: Rahul Tuli Co-authored-by: rshaw@neuralmagic.com --- CMakeLists.txt | 26 +- .../cutlass_benchmarks/sparse_benchmarks.py | 384 ++++++++++++++ benchmarks/cutlass_benchmarks/utils.py | 96 ++++ .../cutlass_benchmarks/w8a8_benchmarks.py | 28 +- .../cutlass_benchmarks/weight_shapes.py | 2 +- csrc/core/math.hpp | 7 + csrc/cutlass_extensions/common.cpp | 11 + csrc/cutlass_extensions/common.hpp | 35 ++ .../epilogue/scaled_mm_epilogues_c3x.hpp | 4 +- csrc/ops.h | 9 + csrc/quantization/cutlass_w8a8/common.hpp | 27 - .../cutlass_w8a8/scaled_mm_c2x.cuh | 3 +- .../cutlass_w8a8/scaled_mm_c3x.cu | 3 +- .../cutlass_w8a8/scaled_mm_entry.cu | 12 +- csrc/sparse/cutlass/sparse_compressor_c3x.cu | 163 ++++++ .../sparse/cutlass/sparse_compressor_entry.cu | 42 ++ csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu | 303 +++++++++++ csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh | 496 ++++++++++++++++++ csrc/sparse/cutlass/sparse_scaled_mm_entry.cu | 59 +++ csrc/torch_bindings.cpp | 15 + pyproject.toml | 2 +- tests/kernels/test_semi_structured.py | 131 +++++ tests/quantization/test_compressed_tensors.py | 103 +++- tests/weight_loading/models.txt | 2 + .../run_model_weight_loading_test.sh | 4 + tests/weight_loading/test_weight_loading.py | 7 + vllm/_custom_ops.py | 103 ++++ .../compressed_tensors/compressed_tensors.py | 187 ++++++- .../compressed_tensors/schemes/__init__.py | 15 +- .../schemes/compressed_tensors_24.py | 203 +++++++ 30 files changed, 2365 insertions(+), 117 deletions(-) create mode 100644 benchmarks/cutlass_benchmarks/sparse_benchmarks.py create mode 100644 benchmarks/cutlass_benchmarks/utils.py create mode 100644 csrc/core/math.hpp create mode 100644 csrc/cutlass_extensions/common.cpp create mode 100644 csrc/cutlass_extensions/common.hpp delete mode 100644 csrc/quantization/cutlass_w8a8/common.hpp create mode 100644 csrc/sparse/cutlass/sparse_compressor_c3x.cu create mode 100644 csrc/sparse/cutlass/sparse_compressor_entry.cu create mode 100644 csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu create mode 100644 csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh create mode 100644 csrc/sparse/cutlass/sparse_scaled_mm_entry.cu create mode 100644 tests/kernels/test_semi_structured.py create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py diff --git a/CMakeLists.txt b/CMakeLists.txt index bf19b3d227171..51b49a18dddf2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -206,7 +206,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") # Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case. - set(CUTLASS_REVISION "v3.5.1" CACHE STRING "CUTLASS revision to use") + set(CUTLASS_REVISION "v3.6.0" CACHE STRING "CUTLASS revision to use") # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) @@ -223,13 +223,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") FetchContent_Declare( cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git - GIT_TAG v3.5.1 + GIT_TAG 8aa95dbb888be6d81c6fbf7169718c5244b53227 GIT_PROGRESS TRUE # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE - GIT_SHALLOW TRUE + GIT_SHALLOW FALSE ) endif() FetchContent_MakeAvailable(cutlass) @@ -241,7 +241,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/awq/gemm_kernels.cu" "csrc/custom_all_reduce.cu" "csrc/permute_cols.cu" - "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu") + "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" + "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" + "csrc/sparse/cutlass/sparse_compressor_entry.cu" + "csrc/cutlass_extensions/common.cpp") set_gencode_flags_for_srcs( SRCS "${VLLM_EXT_SRC}" @@ -271,11 +274,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() # - # The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require + # The cutlass_scaled_mm cutlass_scaled_sparse_mm, and cutlass_compressor kernels + # For Hopper (c3x, i.e. CUTLASS 3.x) require # CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now). cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu") + set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu" + "csrc/sparse/cutlass/sparse_compressor_c3x.cu" + "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_3X_ARCHS}") @@ -284,12 +290,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}") else() if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) - message(STATUS "Not building scaled_mm_c3x as CUDA Compiler version is " + message(STATUS "Not building cutlass_c3x kernels as CUDA Compiler version is " "not >= 12.0, we recommend upgrading to CUDA 12.0 or " - "later if you intend on running FP8 quantized models on " + "later if you intend on running FP8 sparse or quantized models on " "Hopper.") else() - message(STATUS "Not building scaled_mm_c3x as no compatible archs found " + message(STATUS "Not building cutlass_c3x as no compatible archs found " "in CUDA target architectures") endif() @@ -404,7 +410,7 @@ define_gpu_extension_target( SOURCES ${VLLM_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} - INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} USE_SABI 3 WITH_SOABI) diff --git a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py new file mode 100644 index 0000000000000..3d1c5e392f9e2 --- /dev/null +++ b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py @@ -0,0 +1,384 @@ +import argparse +import copy +import itertools +import pickle as pkl +import time +from typing import Callable, Iterable, List, Tuple + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from utils import make_rand_sparse_tensors +from weight_shapes import WEIGHT_SHAPES + +from vllm import _custom_ops as ops +from vllm.utils import FlexibleArgumentParser + +DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] +DEFAULT_TP_SIZES = [1] + + +# bench +def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, + **kwargs) -> TMeasurement: + min_run_time = 1 + + globals = { + "args": args, + "kwargs": kwargs, + "fn": fn, + } + return TBenchmark.Timer( + stmt="fn(*args, **kwargs)", + globals=globals, + label=label, + sub_label=sub_label, + description=description, + ).blocked_autorange(min_run_time=min_run_time) + + +def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + assert dtype == torch.int8 + b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k) + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + + out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, + torch.bfloat16) + out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) + + if not torch.allclose(out, out_ref): + print("Incorrect results") + print(out) + print(out_ref) + else: + print("Correct results") + + timers = [] + # pytorch impl - bfloat16 + timers.append( + bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, a.to(dtype=torch.bfloat16), + b.to(dtype=torch.bfloat16))) + + # pytorch impl - float16 + timers.append( + bench_fn(label, sub_label, + "pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm, + a.to(dtype=torch.float16), b.to(dtype=torch.float16))) + + # cutlass impl + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, + torch.bfloat16)) + + # cutlass with bias + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16, + bias)) + + # cutlass sparse impl + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.bfloat16)) + + # cutlass sparse with bias + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.bfloat16, bias)) + + return timers + + +def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + assert dtype == torch.float8_e4m3fn + b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, + k) + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + + out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, + torch.bfloat16) + out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) + + if not torch.allclose(out, out_ref): + print("Incorrect results") + print(out) + print(out_ref) + else: + print("Correct results") + + timers = [] + + # pytorch impl w. bf16 + timers.append( + bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, a.to(dtype=torch.bfloat16, device="cuda"), + b.to(dtype=torch.bfloat16, device="cuda"))) + + # pytorch impl: bf16 output, without fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16)) + + # pytorch impl: bf16 output, with fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16, + use_fast_accum=True)) + + # pytorch impl: fp16 output, without fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16)) + + # pytorch impl: fp16 output, with fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16, + use_fast_accum=True)) + + # cutlass impl: bf16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, + torch.bfloat16)) + + # cutlass impl: bf16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.bfloat16)) + + # cutlass impl: fp16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.float16)) + + # cutlass impl: bf16 output, with bias + timers.append( + bench_fn(label, sub_label, + "cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.bfloat16, bias)) + + # cutlass impl: fp16 output, with bias + timers.append( + bench_fn(label, sub_label, + "cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.float16, bias.to(dtype=torch.float16))) + + return timers + + +def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + if dtype == torch.int8: + return bench_int8(dtype, m, k, n, label, sub_label) + if dtype == torch.float8_e4m3fn: + return bench_fp8(dtype, m, k, n, label, sub_label) + raise ValueError("unsupported type") + + +# runner +def print_timers(timers: Iterable[TMeasurement]): + compare = TBenchmark.Compare(timers) + compare.print() + + +def run(dtype: torch.dtype, + MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: + results = [] + for m, k, n in MKNs: + timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", + f"MKN=({m}x{k}x{n})") + print_timers(timers) + results.extend(timers) + + return results + + +# output makers +def make_output(data: Iterable[TMeasurement], + MKNs: Iterable[Tuple[int, int, int]], + base_description: str, + timestamp=None): + print(f"== All Results {base_description} ====") + print_timers(data) + + # pickle all the results + timestamp = int(time.time()) if timestamp is None else timestamp + with open(f"{base_description}-{timestamp}.pkl", "wb") as f: + pkl.dump(data, f) + + +# argparse runners + + +def run_square_bench(args): + dim_sizes = list( + range(args.dim_start, args.dim_end + 1, args.dim_increment)) + MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) + data = run(args.dtype, MKNs) + + make_output(data, MKNs, f"square_bench-{args.dtype}") + + +def run_range_bench(args): + dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment)) + n = len(dim_sizes) + Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes + Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes + Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes + MKNs = list(zip(Ms, Ks, Ns)) + data = run(args.dtype, MKNs) + + make_output(data, MKNs, f"range_bench-{args.dtype}") + + +def run_model_bench(args): + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: + KNs = [] + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KNs.append(KN) + return KNs + + model_bench_data = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + Ms = args.batch_sizes + KNs = model_shapes(model, tp_size) + MKNs = [] + for m in Ms: + for k, n in KNs: + MKNs.append((m, k, n)) + + data = run(args.dtype, MKNs) + model_bench_data.append(data) + + # Print all results + for data, model_tp in zip(model_bench_data, models_tps): + model, tp_size = model_tp + print(f"== Results {args.dtype} {model}-TP{tp_size} ====") + print_timers(data) + + timestamp = int(time.time()) + + all_data = [] + for d in model_bench_data: + all_data.extend(d) + # pickle all data + with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f: + pkl.dump(all_data, f) + + +if __name__ == '__main__': + + def to_torch_dtype(dt): + if dt == "int8": + return torch.int8 + if dt == "fp8": + return torch.float8_e4m3fn + raise ValueError("unsupported dtype") + + parser = FlexibleArgumentParser( + description=""" +Benchmark Cutlass GEMM. + + To run square GEMMs: + python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 + + To run constant N and K and sweep M: + python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 + + To run dimensions from a model: + python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 + + Output: + - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. + """, # noqa: E501 + formatter_class=argparse.RawTextHelpFormatter) + + parser.add_argument("--dtype", + type=to_torch_dtype, + required=True, + help="Available options are ['int8', 'fp8']") + subparsers = parser.add_subparsers(dest="cmd") + + square_parser = subparsers.add_parser("square_bench") + square_parser.add_argument("--dim-start", type=int, required=True) + square_parser.add_argument("--dim-end", type=int, required=True) + square_parser.add_argument("--dim-increment", type=int, required=True) + square_parser.set_defaults(func=run_square_bench) + + range_parser = subparsers.add_parser("range_bench") + range_parser.add_argument("--dim-start", type=int, required=True) + range_parser.add_argument("--dim-end", type=int, required=True) + range_parser.add_argument("--dim-increment", type=int, required=True) + range_parser.add_argument("--m-constant", type=int, default=None) + range_parser.add_argument("--n-constant", type=int, default=None) + range_parser.add_argument("--k-constant", type=int, default=None) + range_parser.set_defaults(func=run_range_bench) + + model_parser = subparsers.add_parser("model_bench") + model_parser.add_argument("--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys()) + model_parser.add_argument("--tp-sizes", + nargs="+", + type=int, + default=DEFAULT_TP_SIZES) + model_parser.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + model_parser.set_defaults(func=run_model_bench) + + args = parser.parse_args() + args.func(args) diff --git a/benchmarks/cutlass_benchmarks/utils.py b/benchmarks/cutlass_benchmarks/utils.py new file mode 100644 index 0000000000000..ef06fcd6604dd --- /dev/null +++ b/benchmarks/cutlass_benchmarks/utils.py @@ -0,0 +1,96 @@ +# Cutlass bench utils +from typing import Iterable, Tuple + +import torch + +import vllm._custom_ops as ops + + +def to_fp8(tensor: torch.Tensor) -> torch.Tensor: + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + + +def to_int8(tensor: torch.Tensor) -> torch.Tensor: + return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) + + +def to_bf16(tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(dtype=torch.bfloat16) + + +def to_fp16(tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(dtype=torch.float16) + + +def make_rand_tensors(dtype: torch.dtype, m: int, n: int, + k: int) -> Tuple[torch.Tensor, torch.Tensor]: + a = torch.randn((m, k), device='cuda') * 5 + b = torch.randn((n, k), device='cuda').t() * 5 + + if dtype == torch.int8: + return to_int8(a), to_int8(b) + if dtype == torch.float8_e4m3fn: + return to_fp8(a), to_fp8(b) + + raise ValueError("unsupported dtype") + + +def prune_to_2_4(tensor): + # Reshape tensor to [N, 4] where N is number of groups of 4 + original_shape = tensor.shape + reshaped = tensor.reshape(-1, 4) + + # Get indices of top 2 absolute values in each group of 4 + _, indices = torch.topk(torch.abs(reshaped), k=2, dim=1) + + # Create binary mask + mask = torch.zeros_like(reshaped) + mask.scatter_(dim=1, + index=indices, + src=torch.ones_like(indices, dtype=mask.dtype)) + + # Apply mask and reshape back + pruned = reshaped * mask + + # Turn all -0.0 to 0.0 + pruned[pruned == -0.0] = 0.0 + + return pruned.reshape(original_shape) + + +def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int, + k: int) -> Tuple[torch.Tensor, torch.Tensor]: + a = torch.randn((m, k), device='cuda') * 5 + b = torch.randn((n, k), device='cuda').t() * 5 + + b = prune_to_2_4(b.t()).t() + + if dtype == torch.int8: + a, b = to_int8(a), to_int8(b) + elif dtype == torch.float8_e4m3fn: + a, b = to_fp8(a), to_fp8(b) + elif dtype == torch.float16: + a, b = to_fp16(a), to_fp16(b) + elif dtype == torch.bfloat16: + a, b = to_bf16(a), to_bf16(b) + else: + raise ValueError("unsupported dtype") + + b_compressed, e = ops.cutlass_sparse_compress(b.t()) + + # Compressed B, Metadata, Original A, B + return b_compressed, e, a, b + + +def make_n_rand_sparse_tensors(num_tensors: int, dtype: torch.dtype, + m: int, n: int, k: int) -> \ + Tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]: + ABs = [] + for _ in range(num_tensors): + b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k) + if b_comp is not None: + ABs.append(make_rand_sparse_tensors(dtype, m, n, k)) + BComps, Es, As, Bs = zip(*ABs) + return list(BComps), list(Es), list(As), list(Bs) diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index 63cf5d50cac75..d0353bc8cb42a 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -8,6 +8,7 @@ import torch import torch.utils.benchmark as TBenchmark from torch.utils.benchmark import Measurement as TMeasurement +from utils import make_rand_tensors from weight_shapes import WEIGHT_SHAPES from vllm import _custom_ops as ops @@ -17,31 +18,6 @@ DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] DEFAULT_TP_SIZES = [1] -# helpers - - -def to_fp8(tensor: torch.Tensor) -> torch.Tensor: - finfo = torch.finfo(torch.float8_e4m3fn) - return torch.round(tensor.clamp( - min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) - - -def to_int8(tensor: torch.Tensor) -> torch.Tensor: - return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) - - -def make_rand_tensors(dtype: torch.dtype, m: int, n: int, - k: int) -> Tuple[torch.Tensor, torch.Tensor]: - a = torch.randn((m, k), device='cuda') * 5 - b = torch.randn((n, k), device='cuda').t() * 5 - - if dtype == torch.int8: - return to_int8(a), to_int8(b) - if dtype == torch.float8_e4m3fn: - return to_fp8(a), to_fp8(b) - - raise ValueError("unsupported dtype") - # bench def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, @@ -386,4 +362,4 @@ def to_torch_dtype(dt): model_parser.set_defaults(func=run_model_bench) args = parser.parse_args() - args.func(args) + args.func(args) \ No newline at end of file diff --git a/benchmarks/cutlass_benchmarks/weight_shapes.py b/benchmarks/cutlass_benchmarks/weight_shapes.py index 25ec9d6028627..d58fb0bf86374 100644 --- a/benchmarks/cutlass_benchmarks/weight_shapes.py +++ b/benchmarks/cutlass_benchmarks/weight_shapes.py @@ -40,4 +40,4 @@ ([8192, 57344], 1), ([28672, 8192], 0), ], -} +} \ No newline at end of file diff --git a/csrc/core/math.hpp b/csrc/core/math.hpp new file mode 100644 index 0000000000000..ba9f40a230c8e --- /dev/null +++ b/csrc/core/math.hpp @@ -0,0 +1,7 @@ +#include +#include + +inline uint32_t next_pow_2(uint32_t const num) { + if (num <= 1) return num; + return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); +} \ No newline at end of file diff --git a/csrc/cutlass_extensions/common.cpp b/csrc/cutlass_extensions/common.cpp new file mode 100644 index 0000000000000..3d2093ab94297 --- /dev/null +++ b/csrc/cutlass_extensions/common.cpp @@ -0,0 +1,11 @@ +#include "cutlass_extensions/common.hpp" + +int32_t get_sm_version_num() { + int32_t major_capability, minor_capability; + cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, + 0); + cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, + 0); + int32_t version_num = major_capability * 10 + minor_capability; + return version_num; +} \ No newline at end of file diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/cutlass_extensions/common.hpp new file mode 100644 index 0000000000000..85e359aa57113 --- /dev/null +++ b/csrc/cutlass_extensions/common.hpp @@ -0,0 +1,35 @@ +#pragma once + +#include "cutlass/cutlass.h" +#include +#include "cuda_runtime.h" +#include + +/** + * Helper function for checking CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + TORCH_CHECK(error == cutlass::Status::kSuccess, \ + cutlassGetStatusString(error)); \ + } + +/** + * Panic wrapper for unwinding CUDA runtime errors + */ +#define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + TORCH_CHECK(error == cudaSuccess, cudaGetErrorString(error)); \ + } + +inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { + int max_shared_mem_per_block_opt_in = 0; + cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, + cudaDevAttrMaxSharedMemoryPerBlockOptin, + device); + return max_shared_mem_per_block_opt_in; +} + +int32_t get_sm_version_num(); diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp index 95764ecddc79f..fcc17c7727f94 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -36,13 +36,13 @@ struct ScaledEpilogueBase { // Don't want to support nullptr by default template using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T, Stride, Int<0>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; // Don't want to support nullptr by default template using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T, Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; // This utility function constructs the arguments for the load descriptors diff --git a/csrc/ops.h b/csrc/ops.h index 816b471d062d2..c145e4eda0845 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -162,6 +162,15 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& azp_adj, c10::optional const& azp, c10::optional const& bias); + +void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& e, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + c10::optional const& bias); + +bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed, + torch::Tensor& e, torch::Tensor const& a); #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, diff --git a/csrc/quantization/cutlass_w8a8/common.hpp b/csrc/quantization/cutlass_w8a8/common.hpp deleted file mode 100644 index bf04bb400790f..0000000000000 --- a/csrc/quantization/cutlass_w8a8/common.hpp +++ /dev/null @@ -1,27 +0,0 @@ -#pragma once - -#include "cutlass/cutlass.h" -#include - -/** - * Helper function for checking CUTLASS errors - */ -#define CUTLASS_CHECK(status) \ - { \ - TORCH_CHECK(status == cutlass::Status::kSuccess, \ - cutlassGetStatusString(status)) \ - } - -inline uint32_t next_pow_2(uint32_t const num) { - if (num <= 1) return num; - return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); -} - -inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { - int max_shared_mem_per_block_opt_in = 0; - cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, - cudaDevAttrMaxSharedMemoryPerBlockOptin, - device); - return max_shared_mem_per_block_opt_in; -} - diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh index d03242f44ab1d..75681f7f37820 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh @@ -21,7 +21,8 @@ #include "cutlass/epilogue/threadblock/fusion/visitors.hpp" #include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" -#include "common.hpp" +#include "core/math.hpp" +#include "cutlass_extensions/common.hpp" // clang-format on using namespace cute; diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu index 33581a63d4c3d..8190277997161 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu @@ -24,7 +24,8 @@ #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" -#include "common.hpp" +#include "core/math.hpp" +#include "cutlass_extensions/common.hpp" // clang-format on using namespace cute; diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 97a969cf5e3e0..4f7b6588ef3f7 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -3,6 +3,8 @@ #include #include +#include "cutlass_extensions/common.hpp" + void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, @@ -79,16 +81,6 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) { return false; } -int32_t get_sm_version_num() { - int32_t major_capability, minor_capability; - cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, - 0); - cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, - 0); - int32_t version_num = major_capability * 10 + minor_capability; - return version_num; -} - void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales, diff --git a/csrc/sparse/cutlass/sparse_compressor_c3x.cu b/csrc/sparse/cutlass/sparse_compressor_c3x.cu new file mode 100644 index 0000000000000..218c5317b4de6 --- /dev/null +++ b/csrc/sparse/cutlass/sparse_compressor_c3x.cu @@ -0,0 +1,163 @@ +// clang-format will break include orders +// clang-format off +#include + +#include "sparse_scaled_mm_c3x.cuh" + +#include "cutlass/numeric_conversion.h" +#include "cutlass/transform/device/transform_universal_adapter.hpp" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +// clang-format on + +using namespace cute; +using namespace vllm; + +/// Make A structured sparse by replacing elements with 0 and compress it +template +bool cutlass_sparse_compress(torch::Tensor& a_nzs, torch::Tensor& a_meta, + torch::Tensor const& a) { + // Checks for conformality + TORCH_CHECK(a.dtype() == torch::kInt8 || a.dtype() == torch::kFloat8_e4m3fn || + a.dtype() == torch::kFloat16 || a.dtype() == torch::kBFloat16); + TORCH_CHECK(a.dim() == 2) + // Check for strides and alignment + TORCH_CHECK(a.stride(0) % 4 == 0) // Required for semi-structured sparsity + TORCH_CHECK(a.stride(1) == 1) + + int m = a.size(0); + int k = a.size(1); + + // Sparse kernel setup; this kernel is not used for matmul, + // but just for setting up the compressor utility + // A matrix configuration + using ElementA = ElementA_; + using LayoutTagA = cutlass::layout::RowMajor; + constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + // B matrix configuration + using ElementB = ElementA; + using LayoutTagB = cutlass::layout::ColumnMajor; + constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + // C/D matrix configuration + using ElementC = float; + using LayoutTagC = cutlass::layout::ColumnMajor; + constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + // Core kernel configurations + using ElementAccumulator = ElementAcc_; + using TileShape = Shape<_128, _128, _128>; + using TileShapeRef = Shape<_128, _128, _64>; + using ClusterShape = Shape<_1, _2, _1>; + using KernelSchedule = typename std::conditional< + std::is_same_v, + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum, + cutlass::gemm::KernelTmaWarpSpecialized>::type; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using ProblemShape = Shape; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, + ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, ElementC, LayoutTagC, + AlignmentC, ElementC, LayoutTagC, AlignmentC, + EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, ElementA, + LayoutTagA, AlignmentA, ElementB, LayoutTagB, AlignmentB, + ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideE = StrideA; + + using StrideA = Stride, int64_t>; + + // The n (=1) dimension does not matter for the compressor + typename GemmKernel::ProblemShape prob_shape{m, 1, k, 1}; + + using LayoutA = typename GemmKernel::CollectiveMainloop::LayoutA; + using LayoutE = typename GemmKernel::CollectiveMainloop::LayoutE; + + using ElementE = typename GemmKernel::CollectiveMainloop::ElementE; + using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig; + + // Offline compressor kernel + using CompressorUtility = + cutlass::transform::kernel::StructuredSparseCompressorUtility< + ProblemShape, ElementA, LayoutTagA, SparseConfig>; + + using CompressorKernel = + cutlass::transform::kernel::StructuredSparseCompressor< + ProblemShape, ElementA, LayoutTagA, SparseConfig, + cutlass::arch::Sm90>; + + using Compressor = + cutlass::transform::device::TransformUniversalAdapter; + + auto [M, N, K, L] = prob_shape; + + StrideA stride_A; + stride_A = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + + CompressorUtility compressor_utility(prob_shape, stride_A); + + int ME = compressor_utility.get_metadata_m_physical(); + int KE = compressor_utility.get_metadata_k_physical(); + int KC = compressor_utility.get_tensorA_k_physical(); + + auto a_ptr = static_cast(a.data_ptr()); + + auto a_nzs_ptr = static_cast(a_nzs.data_ptr()); + auto a_meta_ptr = static_cast( + a_meta.data_ptr()); + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); + typename Compressor::Arguments arguments{ + prob_shape, {a_ptr, stride_A, a_nzs_ptr, a_meta_ptr}, {hw_info}}; + + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(compressor_op.can_implement(arguments)); + CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.get())); + CUTLASS_CHECK(compressor_op.run()); + CUDA_CHECK(cudaDeviceSynchronize()); + + return true; +} + +bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta, + torch::Tensor const& a) { + if (a.dtype() == torch::kBFloat16) { + return cutlass_sparse_compress(a_nzs, a_meta, + a); + } else if (a.dtype() == torch::kFloat16) { + return cutlass_sparse_compress(a_nzs, a_meta, a); + } else if (a.dtype() == torch::kFloat8_e4m3fn) { + return cutlass_sparse_compress(a_nzs, a_meta, + a); + } else if (a.dtype() == torch::kInt8) { + return cutlass_sparse_compress(a_nzs, a_meta, a); + } + return false; +} \ No newline at end of file diff --git a/csrc/sparse/cutlass/sparse_compressor_entry.cu b/csrc/sparse/cutlass/sparse_compressor_entry.cu new file mode 100644 index 0000000000000..d23d937b6ac28 --- /dev/null +++ b/csrc/sparse/cutlass/sparse_compressor_entry.cu @@ -0,0 +1,42 @@ +#include + +#include +#include + +#include "cutlass_extensions/common.hpp" + +#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X +bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta, + torch::Tensor const& a); +#endif + +bool cutlass_sparse_compress_entry(torch::Tensor& a_nzs, torch::Tensor& a_meta, + torch::Tensor const& a) { + // Checks for conformality + TORCH_CHECK(a.dim() == 2 && a_meta.dim() == 2 && a_nzs.dim() == 2); + TORCH_CHECK(a.size(0) == a_nzs.size(0) && a.size(0) == a_meta.size(0) && + a_nzs.size(1) * 2 == a.size(1) && + a_meta.size(1) * 2 * 4 == a.size(1)); + // Considering elemsPerMetaElem = 8b / 2b_per_nz = 4 + + // Check for strides and alignment + TORCH_CHECK(a.stride(1) == 1 && a_nzs.stride(1) == 1 && + a_meta.stride(1) == 1); // Row-major + TORCH_CHECK(a.stride(0) % 8 == 0); // 8 Byte Alignment for Compression + + at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); + int32_t version_num = get_sm_version_num(); + + // Guard against compilation issues for sm90 kernels +#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X + if (version_num >= 90) { + return cutlass_sparse_compress_sm90(a_nzs, a_meta, a); + } +#endif + + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled cutlass_scaled_sparse_mm for a compute capability less than " + "CUDA device capability: ", + version_num); +} diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu new file mode 100644 index 0000000000000..b50e9a3a2c240 --- /dev/null +++ b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu @@ -0,0 +1,303 @@ +// clang-format will break include orders +// clang-format off +#include + +#if defined CUDA_VERSION && CUDA_VERSION >= 12000 +#include "sparse_scaled_mm_c3x.cuh" +// clang-format on + +using namespace cute; +using namespace vllm; + +template typename Epilogue, + typename... EpilogueArgs> +void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& bt_nzs, + torch::Tensor const& bt_meta, + EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); + TORCH_CHECK(bt_nzs.dtype() == torch::kFloat8_e4m3fn); + + using Cutlass3xGemmDefault = + typename sm90_config_default::Cutlass3xGemm; + using Cutlass3xGemmM64 = + typename sm90_fp8_config_M64::Cutlass3xGemm; + using Cutlass3xGemmM128 = + typename sm90_fp8_config_M128::Cutlass3xGemm; + using Cutlass3xGemmM256 = + typename sm90_fp8_config_M256::Cutlass3xGemm; + using Cutlass3xGemmM512 = + typename sm90_fp8_config_M512::Cutlass3xGemm; + + using Cutlass3xGemm1 = + typename sm90_fp8_config_1::Cutlass3xGemm; + using Cutlass3xGemm2 = + typename sm90_fp8_config_2::Cutlass3xGemm; + using Cutlass3xGemm3 = + typename sm90_fp8_config_3::Cutlass3xGemm; + using Cutlass3xGemm4 = + typename sm90_fp8_config_4::Cutlass3xGemm; + using Cutlass3xGemm5 = + typename sm90_fp8_config_5::Cutlass3xGemm; + using Cutlass3xGemm6 = + typename sm90_fp8_config_6::Cutlass3xGemm; + using Cutlass3xGemm7 = + typename sm90_fp8_config_7::Cutlass3xGemm; + using Cutlass3xGemm8 = + typename sm90_fp8_config_8::Cutlass3xGemm; + + uint32_t const n = bt_nzs.size(0); + uint32_t const m = a.size(0); // Batch size + uint32_t const mp2 = + std::max(static_cast(64), next_pow_2(m)); // next power of 2 + + if (mp2 <= 64) { + if (n == 28672) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (n == 4096 || n == 6144) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } + } else if (mp2 <= 128) { + if (n == 4096) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (n == 28672) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (n == 6144) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } + } else if (mp2 <= 256) { + if (n == 4096) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (n == 28672) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (n == 6144) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } + } else { + if (n == 6144 || n == 28672) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (n == 4096) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } + } + + // Otherwise the default heuristic + if (mp2 <= 64) { + // n in [1, 64] + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (mp2 <= 128) { + // n in (64, 128] + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (mp2 <= 256) { + // n in (128, 256] + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else { + // n in (256, inf) + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } +} + +template typename Epilogue, + typename... EpilogueArgs> +void cutlass_gemm_sm90_fp16_dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& bt_nzs, + torch::Tensor const& bt_meta, + EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat16); + TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); + TORCH_CHECK(bt_nzs.dtype() == torch::kFloat16); + + using Cutlass3xGemmDefault = + typename sm90_config_default::Cutlass3xGemm; + + // m in (128, inf) + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); +} + +template typename Epilogue, + typename... EpilogueArgs> +void cutlass_gemm_sm90_bf16_dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& bt_nzs, + torch::Tensor const& bt_meta, + EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kBFloat16); + TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); + TORCH_CHECK(bt_nzs.dtype() == torch::kBFloat16); + + using Cutlass3xGemmDefault = + typename sm90_config_default::Cutlass3xGemm; + + // m in (128, inf) + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); +} + +template typename Epilogue, + typename... EpilogueArgs> +void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& bt_nzs, + torch::Tensor const& bt_meta, + EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kInt8); + TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); + TORCH_CHECK(bt_nzs.dtype() == torch::kInt8); + + using Cutlass3xGemmDefault = + typename sm90_config_default::Cutlass3xGemm; + using Cutlass3xGemmM128 = + typename sm90_int8_config_M128::Cutlass3xGemm; + using Cutlass3xGemmM64 = + typename sm90_int8_config_M64::Cutlass3xGemm; + using Cutlass3xGemmM32NBig = + typename sm90_int8_config_M32_NBig::Cutlass3xGemm; + using Cutlass3xGemmM32NSmall = + typename sm90_int8_config_M32_NSmall::Cutlass3xGemm; + + uint32_t const n = out.size(1); + bool const is_small_n = n < 8192; + + uint32_t const m = a.size(0); + uint32_t const mp2 = + std::max(static_cast(32), next_pow_2(m)); // next power of 2 + + if (mp2 <= 32) { + // m in [1, 32] + if (is_small_n) { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else { + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } + } else if (mp2 <= 64) { + // m in (32, 64] + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else if (mp2 <= 128) { + // m in (64, 128] + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } else { + // m in (128, inf) + return cutlass_sparse_gemm_caller( + out, a, bt_nzs, bt_meta, std::forward(args)...); + } +} + +template