From fbc974d34f1dde32d021dd16cbb64a5ac039c9ad Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Wed, 18 Dec 2024 22:48:44 +0000 Subject: [PATCH 1/2] [Bugfix][Build/CI] Fix sparse CUTLASS compilation on CUDA [12.0, 12.2) Signed-off-by: Tyler Michael Smith --- CMakeLists.txt | 39 ++++++++++++++----- csrc/ops.h | 2 + csrc/sparse/cutlass/sparse_compressor_c3x.cu | 4 +- .../sparse/cutlass/sparse_compressor_entry.cu | 4 +- csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu | 2 +- csrc/sparse/cutlass/sparse_scaled_mm_entry.cu | 15 ++++++- csrc/torch_bindings.cpp | 7 ++++ vllm/_custom_ops.py | 5 +++ .../schemes/compressed_tensors_24.py | 17 ++++++++ 9 files changed, 80 insertions(+), 15 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 51b49a18dddf2..83c8033434f3b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -273,15 +273,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") " in CUDA target architectures") endif() - # - # The cutlass_scaled_mm cutlass_scaled_sparse_mm, and cutlass_compressor kernels - # For Hopper (c3x, i.e. CUTLASS 3.x) require + # The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require # CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now). cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu" - "csrc/sparse/cutlass/sparse_compressor_c3x.cu" - "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu") + set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_3X_ARCHS}") @@ -290,12 +286,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}") else() if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) - message(STATUS "Not building cutlass_c3x kernels as CUDA Compiler version is " + message(STATUS "Not building scaled_mm_c3x as CUDA Compiler version is " "not >= 12.0, we recommend upgrading to CUDA 12.0 or " - "later if you intend on running FP8 sparse or quantized models on " + "later if you intend on running FP8 quantized models on " "Hopper.") else() - message(STATUS "Not building cutlass_c3x as no compatible archs found " + message(STATUS "Not building scaled_mm_c3x as no compatible archs found " "in CUDA target architectures") endif() @@ -329,6 +325,31 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() endif() + # + # 2:4 Sparse Kernels + + # The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor + # require CUDA 12.2 or later (and only work on Hopper, 9.0/9.0a for now). + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS) + set(SRCS "csrc/sparse/cutlass/sparse_compressor_c3x.cu" + "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_3X_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1") + message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS) + message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is " + "not >= 12.2, we recommend upgrading to CUDA 12.2 or later " + "if you intend on running FP8 sparse quantized models on Hopper.") + else() + message(STATUS "Not building sparse_scaled_mm_c3x as no compatible archs found " + "in CUDA target architectures") + endif() + endif() + # # Machete kernels diff --git a/csrc/ops.h b/csrc/ops.h index c145e4eda0845..347c502845d8f 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -163,6 +163,8 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, c10::optional const& azp, c10::optional const& bias); +bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability); + void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& e, torch::Tensor const& a_scales, diff --git a/csrc/sparse/cutlass/sparse_compressor_c3x.cu b/csrc/sparse/cutlass/sparse_compressor_c3x.cu index 218c5317b4de6..bd53695503241 100644 --- a/csrc/sparse/cutlass/sparse_compressor_c3x.cu +++ b/csrc/sparse/cutlass/sparse_compressor_c3x.cu @@ -2,6 +2,7 @@ // clang-format off #include +#if defined CUDA_VERSION && CUDA_VERSION >= 12020 #include "sparse_scaled_mm_c3x.cuh" #include "cutlass/numeric_conversion.h" @@ -160,4 +161,5 @@ bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta, return cutlass_sparse_compress(a_nzs, a_meta, a); } return false; -} \ No newline at end of file +} +#endif diff --git a/csrc/sparse/cutlass/sparse_compressor_entry.cu b/csrc/sparse/cutlass/sparse_compressor_entry.cu index d23d937b6ac28..3401761c1b703 100644 --- a/csrc/sparse/cutlass/sparse_compressor_entry.cu +++ b/csrc/sparse/cutlass/sparse_compressor_entry.cu @@ -5,7 +5,7 @@ #include "cutlass_extensions/common.hpp" -#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X +#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta, torch::Tensor const& a); #endif @@ -28,7 +28,7 @@ bool cutlass_sparse_compress_entry(torch::Tensor& a_nzs, torch::Tensor& a_meta, int32_t version_num = get_sm_version_num(); // Guard against compilation issues for sm90 kernels -#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X +#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X if (version_num >= 90) { return cutlass_sparse_compress_sm90(a_nzs, a_meta, a); } diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu index b50e9a3a2c240..6223dc8cca704 100644 --- a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu +++ b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu @@ -2,7 +2,7 @@ // clang-format off #include -#if defined CUDA_VERSION && CUDA_VERSION >= 12000 +#if defined CUDA_VERSION && CUDA_VERSION >= 12020 #include "sparse_scaled_mm_c3x.cuh" // clang-format on diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_entry.cu b/csrc/sparse/cutlass/sparse_scaled_mm_entry.cu index 4c930b603c9e4..d464b045b895f 100644 --- a/csrc/sparse/cutlass/sparse_scaled_mm_entry.cu +++ b/csrc/sparse/cutlass/sparse_scaled_mm_entry.cu @@ -5,7 +5,18 @@ #include "cutlass_extensions/common.hpp" -#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X +bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability) { + // sparse CUTLASS kernels need at least + // CUDA 12.2 and SM90 (Hopper) + +#if defined CUDA_VERSION + return CUDA_VERSION >= 12020 && cuda_device_capability >= 90; +#endif + + return false; +} + +#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X void cutlass_scaled_sparse_mm_sm90(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& e, @@ -43,7 +54,7 @@ void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a, int32_t version_num = get_sm_version_num(); // Guard against compilation issues for sm90 kernels -#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X +#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X if (version_num >= 90) { cutlass_scaled_sparse_mm_sm90(c, a, bt_nzs, bt_meta, a_scales, b_scales, bias); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 88a4e60c75cbe..956258c1001d3 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -321,6 +321,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool"); ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8); + // Check if cutlass sparse scaled_mm is supported for CUDA devices of the + // given capability + ops.def( + "cutlass_sparse_scaled_mm_supported(int cuda_device_capability) -> bool"); + ops.impl("cutlass_sparse_scaled_mm_supported", + &cutlass_sparse_scaled_mm_supported); + // CUTLASS sparse GEMM, supporting symmetric per-tensor or per-row/column // quantization, as well as bias ops.def( diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index f6b5514f8987d..19f31b8ec419d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -552,6 +552,11 @@ def cutlass_scaled_mm_azp(a: torch.Tensor, return out +def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool: + return torch.ops._C.cutlass_sparse_scaled_mm_supported( + cuda_device_capability) + + def cutlass_sparse_compress(a: torch.Tensor) \ -> Tuple[torch.Tensor, torch.Tensor]: """ diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index af266769aef89..d0390b2bff15a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -14,10 +14,22 @@ ChannelQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) +from vllm.platforms import current_platform __all__ = ["CompressedTensors24"] +def sparse_cutlass_supported() -> bool: + # sparse cutlass is not supported on Rocm + if current_platform.is_rocm(): + return False + + capability_tuple = current_platform.get_device_capability() + capability = -1 if capability_tuple is None else capability_tuple.to_int() + + return ops.cutlass_sparse_scaled_mm_supported(capability) + + class CompressedTensors24(CompressedTensorsScheme): def __init__(self, @@ -40,6 +52,11 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): + if not sparse_cutlass_supported(): + raise ValueError( + "Sparse CUTLASS not supported. vLLM must be built with" + "CUDA 12.2 or later to use this feature") + self.output_dtype = params_dtype layer.logical_widths = output_partition_sizes self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype) From 117c978a79683f2aee952b6506a852a46b3ec4d2 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Wed, 18 Dec 2024 23:38:47 +0000 Subject: [PATCH 2/2] skip tests Signed-off-by: Tyler Michael Smith --- tests/kernels/test_semi_structured.py | 5 ++++- tests/quantization/test_compressed_tensors.py | 8 +++++--- .../schemes/compressed_tensors_24.py | 14 +------------- .../layers/quantization/utils/w8a8_utils.py | 11 +++++++++++ 4 files changed, 21 insertions(+), 17 deletions(-) diff --git a/tests/kernels/test_semi_structured.py b/tests/kernels/test_semi_structured.py index 34244a8fe4ca7..4316d6ab30e33 100644 --- a/tests/kernels/test_semi_structured.py +++ b/tests/kernels/test_semi_structured.py @@ -8,6 +8,8 @@ import torch from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + sparse_cutlass_supported) from vllm.platforms import current_platform CUDA_DEVICES = [ @@ -102,10 +104,11 @@ def baseline_scaled_mm(a: torch.Tensor, return output -@pytest.mark.skipif(not current_platform.has_device_capability(90), +@pytest.mark.skipif(not sparse_cutlass_supported(), reason="Sparse FP8 is not yet supported on this GPU type.") # Test working with a subset of A and B for sparse matmul def test_cutlass_sparse_subset(): + big_m = 1024 m, n, k = 512, 512, 512 diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 21fec990aa873..38e02f6018dee 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -14,6 +14,8 @@ CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + sparse_cutlass_supported) from vllm.platforms import current_platform @@ -212,7 +214,7 @@ def test_compressed_tensors_kv_cache(vllm_runner): assert output -@pytest.mark.skipif(not current_platform.has_device_capability(90), +@pytest.mark.skipif(not sparse_cutlass_supported(), reason="Sparse FP8 is not yet supported on this GPU type.") def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy): assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) @@ -254,7 +256,7 @@ def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4): assert output -@pytest.mark.skipif(not current_platform.has_device_capability(90), +@pytest.mark.skipif(not sparse_cutlass_supported(), reason="Sparse FP8 is not yet supported on this GPU type.") @pytest.mark.parametrize("args_2of4", [ ("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing", @@ -279,7 +281,7 @@ def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4): assert output -@pytest.mark.skipif(not current_platform.has_device_capability(90), +@pytest.mark.skipif(not sparse_cutlass_supported(), reason="Sparse FP8 is not yet supported on this GPU type.") @pytest.mark.parametrize( "args_2of4", diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index d0390b2bff15a..bc697ef93b34b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -9,27 +9,15 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - convert_to_channelwise) + convert_to_channelwise, sparse_cutlass_supported) from vllm.model_executor.parameter import (BasevLLMParameter, ChannelQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) -from vllm.platforms import current_platform __all__ = ["CompressedTensors24"] -def sparse_cutlass_supported() -> bool: - # sparse cutlass is not supported on Rocm - if current_platform.is_rocm(): - return False - - capability_tuple = current_platform.get_device_capability() - capability = -1 if capability_tuple is None else capability_tuple.to_int() - - return ops.cutlass_sparse_scaled_mm_supported(capability) - - class CompressedTensors24(CompressedTensorsScheme): def __init__(self, diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 4037bcb963b25..d77722499d0e9 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -10,6 +10,17 @@ TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) +def sparse_cutlass_supported() -> bool: + # sparse cutlass is not supported on Rocm + if current_platform.is_rocm(): + return False + + capability_tuple = current_platform.get_device_capability() + capability = -1 if capability_tuple is None else capability_tuple.to_int() + + return ops.cutlass_sparse_scaled_mm_supported(capability) + + def cutlass_fp8_supported() -> bool: # cutlass is not supported on Rocm if current_platform.is_rocm():