From 970d6d0776076f17604077ba4d484cdadd604ceb Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 30 Dec 2024 04:22:13 -0500 Subject: [PATCH] [Build][Kernel] Update CUTLASS to v3.6.0 (#11607) Signed-off-by: Tyler Michael Smith --- CMakeLists.txt | 4 ++-- .../vllm_cutlass_library_extension.py | 18 +++++++++--------- csrc/quantization/machete/generate.py | 8 ++++---- .../machete/machete_collective_builder.cuh | 10 ++++------ csrc/quantization/machete/machete_mainloop.cuh | 11 ++++------- .../machete/machete_prepacked_layout.cuh | 5 ++--- 6 files changed, 25 insertions(+), 31 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 83c8033434f3b..3206d76125545 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -223,13 +223,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") FetchContent_Declare( cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git - GIT_TAG 8aa95dbb888be6d81c6fbf7169718c5244b53227 + GIT_TAG v3.6.0 GIT_PROGRESS TRUE # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE - GIT_SHALLOW FALSE + GIT_SHALLOW TRUE ) endif() FetchContent_MakeAvailable(cutlass) diff --git a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py index a5beea1a35e49..b401736c9824b 100644 --- a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py +++ b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py @@ -14,9 +14,9 @@ class VLLMDataType(enum.Enum): class MixedInputKernelScheduleType(enum.Enum): - TmaWarpSpecializedMixedInput = enum_auto() - TmaWarpSpecializedPingpongMixedInput = enum_auto() - TmaWarpSpecializedCooperativeMixedInput = enum_auto() + TmaWarpSpecialized = enum_auto() + TmaWarpSpecializedPingpong = enum_auto() + TmaWarpSpecializedCooperative = enum_auto() VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = { @@ -68,11 +68,11 @@ class MixedInputKernelScheduleType(enum.Enum): MixedInputKernelScheduleType, KernelScheduleType], str] = { **KernelScheduleTag, # type: ignore **{ - MixedInputKernelScheduleType.TmaWarpSpecializedMixedInput: - "cutlass::gemm::KernelTmaWarpSpecializedMixedInput", - MixedInputKernelScheduleType.TmaWarpSpecializedPingpongMixedInput: - "cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput", - MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput: - "cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput", + MixedInputKernelScheduleType.TmaWarpSpecialized: + "cutlass::gemm::KernelTmaWarpSpecialized", + MixedInputKernelScheduleType.TmaWarpSpecializedPingpong: + "cutlass::gemm::KernelTmaWarpSpecializedPingpong", + MixedInputKernelScheduleType.TmaWarpSpecializedCooperative: + "cutlass::gemm::KernelTmaWarpSpecializedCooperative", } } diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index ac63afe79a255..2df4d181902f8 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -189,7 +189,7 @@ {{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT {{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT {{DataTypeTag[t.a_token_scale]}}, // TokenScaleT - cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, + cutlass::gemm::KernelTmaWarpSpecializedCooperative, Sch>; {% for sch in schs %} @@ -223,7 +223,7 @@ {{DataTypeTag[t.convert]}}, // ElementConvert {{DataTypeTag[t.accumulator]}}, // Accumulator cutlass::layout::ColumnMajor, - cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput> + cutlass::gemm::KernelTmaWarpSpecializedCooperative> >(args.B); } {%- endfor %} @@ -239,7 +239,7 @@ }; // namespace machete """ -TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput +TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperative TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative @@ -300,7 +300,7 @@ def generate_sch_sig(schedule_config: ScheduleConfig) -> str: # mostly unique shorter sch_sig def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str: kernel_terse_names_replace = { - "KernelTmaWarpSpecializedCooperativeMixedInput_": "TmaMI_", + "KernelTmaWarpSpecializedCooperative": "TmaMI_", "TmaWarpSpecializedCooperative_": "TmaCoop_", "StreamKScheduler": "streamK", } diff --git a/csrc/quantization/machete/machete_collective_builder.cuh b/csrc/quantization/machete/machete_collective_builder.cuh index a74cf8b2dd455..ee825583dee1a 100644 --- a/csrc/quantization/machete/machete_collective_builder.cuh +++ b/csrc/quantization/machete/machete_collective_builder.cuh @@ -18,16 +18,14 @@ struct VLLMCollectiveBuilder< ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType, KernelScheduleType, cute::enable_if_t<( + cute::is_same_v || + cute::is_same_v || cute::is_same_v || - cute::is_same_v || - cute::is_same_v)>> { + KernelTmaWarpSpecializedCooperative>)>> { using CollectiveOp = machete::MacheteCollectiveMma< ElementPairA_, GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_, AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType, KernelScheduleType>; }; -}; // namespace cutlass::gemm::collective \ No newline at end of file +}; // namespace cutlass::gemm::collective diff --git a/csrc/quantization/machete/machete_mainloop.cuh b/csrc/quantization/machete/machete_mainloop.cuh index 816f33a1078e5..4071b19a3564d 100644 --- a/csrc/quantization/machete/machete_mainloop.cuh +++ b/csrc/quantization/machete/machete_mainloop.cuh @@ -66,13 +66,11 @@ struct MacheteCollectiveMma { using Schedule = KernelScheduleType; static_assert( cute::is_same_v || - cute::is_same_v || + cute::is_same_v || + cute::is_same_v || cute::is_same_v || - cute::is_same_v || cute::is_same_v || - cute::is_same_v, + cute::is_same_v, "KernelSchedule must be one of the warp specialized policies"); public: @@ -113,8 +111,7 @@ struct MacheteCollectiveMma { // For coop schedules we have two warp groups cooperatively issuing wgmma // instructions so we use 2 atoms along the M dim (one for each warpgroup) using AtomLayoutMNK = cute::conditional_t< - cute::is_same_v, + cute::is_same_v, Layout>, Layout>>; using TiledMma = decltype(cute::make_tiled_mma( diff --git a/csrc/quantization/machete/machete_prepacked_layout.cuh b/csrc/quantization/machete/machete_prepacked_layout.cuh index 680a858a893c1..81aaa6c4f3a28 100644 --- a/csrc/quantization/machete/machete_prepacked_layout.cuh +++ b/csrc/quantization/machete/machete_prepacked_layout.cuh @@ -98,8 +98,7 @@ struct PrepackedLayoutBTemplate { // For coop schedules we have two warp groups cooperatively issuing wgmma // instructions so we use 2 atoms along the M dim (one for each warpgroup) using AtomLayoutMNK = cute::conditional_t< - cute::is_same_v, + cute::is_same_v, Layout>, Layout>>; using TiledMma = decltype(cute::make_tiled_mma( @@ -247,4 +246,4 @@ struct PrepackedLayoutBTemplate { } }; -}; // namespace machete \ No newline at end of file +}; // namespace machete