Skip to content

Commit

Permalink
[Bugfix][Build/CI] Fix sparse CUTLASS compilation on CUDA [12.0, 12.2)
Browse files Browse the repository at this point in the history
Signed-off-by: Tyler Michael Smith <[email protected]>
  • Loading branch information
tlrmchlsmth committed Dec 18, 2024
1 parent ca5f54a commit fbc974d
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 15 deletions.
39 changes: 30 additions & 9 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -273,15 +273,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
" in CUDA target architectures")
endif()

#
# The cutlass_scaled_mm cutlass_scaled_sparse_mm, and cutlass_compressor kernels
# For Hopper (c3x, i.e. CUTLASS 3.x) require
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now).
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
"csrc/sparse/cutlass/sparse_compressor_c3x.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
Expand All @@ -290,12 +286,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
message(STATUS "Not building cutlass_c3x kernels as CUDA Compiler version is "
message(STATUS "Not building scaled_mm_c3x as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running FP8 sparse or quantized models on "
"later if you intend on running FP8 quantized models on "
"Hopper.")
else()
message(STATUS "Not building cutlass_c3x as no compatible archs found "
message(STATUS "Not building scaled_mm_c3x as no compatible archs found "
"in CUDA target architectures")
endif()

Expand Down Expand Up @@ -329,6 +325,31 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
endif()

#
# 2:4 Sparse Kernels

# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
# require CUDA 12.2 or later (and only work on Hopper, 9.0/9.0a for now).
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
set(SRCS "csrc/sparse/cutlass/sparse_compressor_c3x.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1")
message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is "
"not >= 12.2, we recommend upgrading to CUDA 12.2 or later "
"if you intend on running FP8 sparse quantized models on Hopper.")
else()
message(STATUS "Not building sparse_scaled_mm_c3x as no compatible archs found "
"in CUDA target architectures")
endif()
endif()


#
# Machete kernels
Expand Down
2 changes: 2 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
c10::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias);

bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability);

void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& e,
torch::Tensor const& a_scales,
Expand Down
4 changes: 3 additions & 1 deletion csrc/sparse/cutlass/sparse_compressor_c3x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// clang-format off
#include <cudaTypedefs.h>

#if defined CUDA_VERSION && CUDA_VERSION >= 12020
#include "sparse_scaled_mm_c3x.cuh"

#include "cutlass/numeric_conversion.h"
Expand Down Expand Up @@ -160,4 +161,5 @@ bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta,
return cutlass_sparse_compress<int8_t, int32_t>(a_nzs, a_meta, a);
}
return false;
}
}
#endif
4 changes: 2 additions & 2 deletions csrc/sparse/cutlass/sparse_compressor_entry.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

#include "cutlass_extensions/common.hpp"

#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta,
torch::Tensor const& a);
#endif
Expand All @@ -28,7 +28,7 @@ bool cutlass_sparse_compress_entry(torch::Tensor& a_nzs, torch::Tensor& a_meta,
int32_t version_num = get_sm_version_num();

// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
if (version_num >= 90) {
return cutlass_sparse_compress_sm90(a_nzs, a_meta, a);
}
Expand Down
2 changes: 1 addition & 1 deletion csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// clang-format off
#include <cudaTypedefs.h>

#if defined CUDA_VERSION && CUDA_VERSION >= 12000
#if defined CUDA_VERSION && CUDA_VERSION >= 12020
#include "sparse_scaled_mm_c3x.cuh"
// clang-format on

Expand Down
15 changes: 13 additions & 2 deletions csrc/sparse/cutlass/sparse_scaled_mm_entry.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,18 @@

#include "cutlass_extensions/common.hpp"

#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability) {
// sparse CUTLASS kernels need at least
// CUDA 12.2 and SM90 (Hopper)

#if defined CUDA_VERSION
return CUDA_VERSION >= 12020 && cuda_device_capability >= 90;
#endif

return false;
}

#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
void cutlass_scaled_sparse_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& e,
Expand Down Expand Up @@ -43,7 +54,7 @@ void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a,
int32_t version_num = get_sm_version_num();

// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
if (version_num >= 90) {
cutlass_scaled_sparse_mm_sm90(c, a, bt_nzs, bt_meta, a_scales, b_scales,
bias);
Expand Down
7 changes: 7 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);

// Check if cutlass sparse scaled_mm is supported for CUDA devices of the
// given capability
ops.def(
"cutlass_sparse_scaled_mm_supported(int cuda_device_capability) -> bool");
ops.impl("cutlass_sparse_scaled_mm_supported",
&cutlass_sparse_scaled_mm_supported);

// CUTLASS sparse GEMM, supporting symmetric per-tensor or per-row/column
// quantization, as well as bias
ops.def(
Expand Down
5 changes: 5 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,11 @@ def cutlass_scaled_mm_azp(a: torch.Tensor,
return out


def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool:
return torch.ops._C.cutlass_sparse_scaled_mm_supported(
cuda_device_capability)


def cutlass_sparse_compress(a: torch.Tensor) \
-> Tuple[torch.Tensor, torch.Tensor]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,22 @@
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
from vllm.platforms import current_platform

__all__ = ["CompressedTensors24"]


def sparse_cutlass_supported() -> bool:
# sparse cutlass is not supported on Rocm
if current_platform.is_rocm():
return False

capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int()

return ops.cutlass_sparse_scaled_mm_supported(capability)


class CompressedTensors24(CompressedTensorsScheme):

def __init__(self,
Expand All @@ -40,6 +52,11 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):

if not sparse_cutlass_supported():
raise ValueError(
"Sparse CUTLASS not supported. vLLM must be built with"
"CUDA 12.2 or later to use this feature")

self.output_dtype = params_dtype
layer.logical_widths = output_partition_sizes
self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype)
Expand Down

0 comments on commit fbc974d

Please sign in to comment.