Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix][Build/CI] Fix sparse CUTLASS compilation on CUDA [12.0, 12.2) #11311

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -273,15 +273,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
" in CUDA target architectures")
endif()

#
# The cutlass_scaled_mm cutlass_scaled_sparse_mm, and cutlass_compressor kernels
# For Hopper (c3x, i.e. CUTLASS 3.x) require
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now).
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
"csrc/sparse/cutlass/sparse_compressor_c3x.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
Expand All @@ -290,12 +286,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
message(STATUS "Not building cutlass_c3x kernels as CUDA Compiler version is "
message(STATUS "Not building scaled_mm_c3x as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running FP8 sparse or quantized models on "
"later if you intend on running FP8 quantized models on "
"Hopper.")
else()
message(STATUS "Not building cutlass_c3x as no compatible archs found "
message(STATUS "Not building scaled_mm_c3x as no compatible archs found "
"in CUDA target architectures")
endif()

Expand Down Expand Up @@ -329,6 +325,31 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
endif()

#
# 2:4 Sparse Kernels

# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
# require CUDA 12.2 or later (and only work on Hopper, 9.0/9.0a for now).
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
set(SRCS "csrc/sparse/cutlass/sparse_compressor_c3x.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1")
message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is "
"not >= 12.2, we recommend upgrading to CUDA 12.2 or later "
"if you intend on running FP8 sparse quantized models on Hopper.")
else()
message(STATUS "Not building sparse_scaled_mm_c3x as no compatible archs found "
"in CUDA target architectures")
endif()
endif()


#
# Machete kernels
Expand Down
2 changes: 2 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
c10::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias);

bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability);

void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& e,
torch::Tensor const& a_scales,
Expand Down
4 changes: 3 additions & 1 deletion csrc/sparse/cutlass/sparse_compressor_c3x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// clang-format off
#include <cudaTypedefs.h>

#if defined CUDA_VERSION && CUDA_VERSION >= 12020
mgoin marked this conversation as resolved.
Show resolved Hide resolved
#include "sparse_scaled_mm_c3x.cuh"

#include "cutlass/numeric_conversion.h"
Expand Down Expand Up @@ -160,4 +161,5 @@ bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta,
return cutlass_sparse_compress<int8_t, int32_t>(a_nzs, a_meta, a);
}
return false;
}
}
#endif
4 changes: 2 additions & 2 deletions csrc/sparse/cutlass/sparse_compressor_entry.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

#include "cutlass_extensions/common.hpp"

#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta,
torch::Tensor const& a);
#endif
Expand All @@ -28,7 +28,7 @@ bool cutlass_sparse_compress_entry(torch::Tensor& a_nzs, torch::Tensor& a_meta,
int32_t version_num = get_sm_version_num();

// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
if (version_num >= 90) {
return cutlass_sparse_compress_sm90(a_nzs, a_meta, a);
}
Expand Down
2 changes: 1 addition & 1 deletion csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// clang-format off
#include <cudaTypedefs.h>

#if defined CUDA_VERSION && CUDA_VERSION >= 12000
#if defined CUDA_VERSION && CUDA_VERSION >= 12020
#include "sparse_scaled_mm_c3x.cuh"
// clang-format on

Expand Down
15 changes: 13 additions & 2 deletions csrc/sparse/cutlass/sparse_scaled_mm_entry.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,18 @@

#include "cutlass_extensions/common.hpp"

#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability) {
// sparse CUTLASS kernels need at least
// CUDA 12.2 and SM90 (Hopper)

#if defined CUDA_VERSION
return CUDA_VERSION >= 12020 && cuda_device_capability >= 90;
#endif

return false;
}

#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
void cutlass_scaled_sparse_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& e,
Expand Down Expand Up @@ -43,7 +54,7 @@ void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a,
int32_t version_num = get_sm_version_num();

// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
if (version_num >= 90) {
cutlass_scaled_sparse_mm_sm90(c, a, bt_nzs, bt_meta, a_scales, b_scales,
bias);
Expand Down
7 changes: 7 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);

// Check if cutlass sparse scaled_mm is supported for CUDA devices of the
// given capability
ops.def(
"cutlass_sparse_scaled_mm_supported(int cuda_device_capability) -> bool");
ops.impl("cutlass_sparse_scaled_mm_supported",
&cutlass_sparse_scaled_mm_supported);

// CUTLASS sparse GEMM, supporting symmetric per-tensor or per-row/column
// quantization, as well as bias
ops.def(
Expand Down
5 changes: 4 additions & 1 deletion tests/kernels/test_semi_structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import torch

from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
sparse_cutlass_supported)
from vllm.platforms import current_platform

CUDA_DEVICES = [
Expand Down Expand Up @@ -102,10 +104,11 @@ def baseline_scaled_mm(a: torch.Tensor,
return output


@pytest.mark.skipif(not current_platform.has_device_capability(90),
@pytest.mark.skipif(not sparse_cutlass_supported(),
reason="Sparse FP8 is not yet supported on this GPU type.")
# Test working with a subset of A and B for sparse matmul
def test_cutlass_sparse_subset():

big_m = 1024
m, n, k = 512, 512, 512

Expand Down
8 changes: 5 additions & 3 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
sparse_cutlass_supported)
from vllm.platforms import current_platform


Expand Down Expand Up @@ -212,7 +214,7 @@ def test_compressed_tensors_kv_cache(vllm_runner):
assert output


@pytest.mark.skipif(not current_platform.has_device_capability(90),
@pytest.mark.skipif(not sparse_cutlass_supported(),
reason="Sparse FP8 is not yet supported on this GPU type.")
def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy):
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
Expand Down Expand Up @@ -254,7 +256,7 @@ def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
assert output


@pytest.mark.skipif(not current_platform.has_device_capability(90),
@pytest.mark.skipif(not sparse_cutlass_supported(),
reason="Sparse FP8 is not yet supported on this GPU type.")
@pytest.mark.parametrize("args_2of4", [
("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing",
Expand All @@ -279,7 +281,7 @@ def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
assert output


@pytest.mark.skipif(not current_platform.has_device_capability(90),
@pytest.mark.skipif(not sparse_cutlass_supported(),
reason="Sparse FP8 is not yet supported on this GPU type.")
@pytest.mark.parametrize(
"args_2of4",
Expand Down
5 changes: 5 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,11 @@ def cutlass_scaled_mm_azp(a: torch.Tensor,
return out


def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool:
return torch.ops._C.cutlass_sparse_scaled_mm_supported(
cuda_device_capability)


def cutlass_sparse_compress(a: torch.Tensor) \
-> Tuple[torch.Tensor, torch.Tensor]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise)
convert_to_channelwise, sparse_cutlass_supported)
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
ModelWeightParameter,
Expand Down Expand Up @@ -40,6 +40,11 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):

if not sparse_cutlass_supported():
raise ValueError(
"Sparse CUTLASS not supported. vLLM must be built with"
"CUDA 12.2 or later to use this feature")

self.output_dtype = params_dtype
layer.logical_widths = output_partition_sizes
self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype)
Expand Down
11 changes: 11 additions & 0 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,17 @@
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)


def sparse_cutlass_supported() -> bool:
# sparse cutlass is not supported on Rocm
if current_platform.is_rocm():
return False

capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int()

return ops.cutlass_sparse_scaled_mm_supported(capability)


def cutlass_fp8_supported() -> bool:
# cutlass is not supported on Rocm
if current_platform.is_rocm():
Expand Down
Loading