From 8571ac4672c8b599338cb95e23dfd624016aab36 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 5 Aug 2024 15:13:43 -0400 Subject: [PATCH] [Kernel] Update CUTLASS to 3.5.1 (#7085) --- CMakeLists.txt | 6 +- .../broadcast_load_epilogue_c3x.hpp | 192 ++++++++++-------- .../cutlass_w8a8/scaled_mm_c2x.cuh | 8 +- .../cutlass_w8a8/scaled_mm_c3x.cu | 30 +-- 4 files changed, 129 insertions(+), 107 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 922613ec5ddaa..e5ac5516c2e46 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -193,8 +193,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") FetchContent_Declare( cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git - # CUTLASS 3.5.0 - GIT_TAG 7d49e6c7e2f8896c47f586706e67e1fb215529dc + # CUTLASS 3.5.1 + GIT_TAG 06b21349bcf6ddf6a1686a47a137ad1446579db9 # Shallow clone with depth 1 GIT_SHALLOW TRUE GIT_PROGRESS TRUE @@ -237,7 +237,7 @@ define_gpu_extension_target( SOURCES ${VLLM_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} - INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} USE_SABI 3 WITH_SOABI) diff --git a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp b/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp index e4bc9752ed7db..58b1e8ff159fb 100644 --- a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp +++ b/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp @@ -64,8 +64,6 @@ using namespace detail; // Row vector broadcast template< - // Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least - // ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races int Stages, class CtaTileShapeMNK, class Element, @@ -73,14 +71,12 @@ template< int Alignment = 128 / sizeof_bits_v > struct Sm90RowOrScalarBroadcast { - static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); - static_assert( - (cute::is_same_v>) || // row vector broadcast, e.g. per-col alpha/bias - (cute::is_same_v>)); // batched row vector broadcast + static_assert(Stages == 0, "Row broadcast doesn't support smem usage"); + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{}); - // Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem - struct SharedStorage { - alignas(16) array_aligned(CtaTileShapeMNK{}) * Stages> smem_row; + struct SharedStorage { + array_aligned(CtaTileShapeMNK{})> smem; }; // This struct has been modified to have a bool indicating that ptr_row is a @@ -100,6 +96,12 @@ struct Sm90RowOrScalarBroadcast { return args; } + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { @@ -118,15 +120,15 @@ struct Sm90RowOrScalarBroadcast { CUTLASS_HOST_DEVICE Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) - : params(params), - smem_row(const_cast(shared_storage.smem_row.data())) { } + : params(params) + , smem(const_cast(shared_storage.smem.data())) { } Params params; - Element* smem_row; + Element *smem = nullptr; CUTLASS_DEVICE bool is_producer_load_needed() const { - return true; + return false; } CUTLASS_DEVICE bool @@ -139,78 +141,76 @@ struct Sm90RowOrScalarBroadcast { return (!params.row_broadcast && *(params.ptr_row) == Element(0)); } - template - struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks { - CUTLASS_DEVICE - ProducerLoadCallbacks(GTensor&& gRow, STensor&& sRow, Params const& params) - : gRow(cute::forward(gRow)), - sRow(cute::forward(sRow)), - params(params) {} - - GTensor gRow; // (CTA_M,CTA_N) - STensor sRow; // (CTA_M,CTA_N,PIPE) - Params const& params; - - CUTLASS_DEVICE void - begin(uint64_t* full_mbarrier_ptr, int load_iteration, bool issue_tma_load) { - if (!params.row_broadcast) { - return; - } - - if (issue_tma_load) { - // Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size - constexpr uint32_t copy_bytes = size<1>(CtaTileShapeMNK{}) * sizeof_bits_v / 8; - cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes); - // Issue the TMA bulk copy - auto bulk_copy = Copy_Atom{}.with(*full_mbarrier_ptr); - // Filter so we don't issue redundant copies over stride-0 modes - int bcast_pipe_index = (load_iteration / EpiTiles) % Stages; - copy(bulk_copy, filter(gRow), filter(sRow(_,_,bcast_pipe_index))); - } - } - }; - template CUTLASS_DEVICE auto get_producer_load_callbacks(ProducerLoadArgs const& args) { - - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); - Tensor gRow = local_tile(mRow, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) - Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE) - make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), - make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); - - constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; - return ProducerLoadCallbacks( - cute::move(gRow), cute::move(sRow), params); + return EmptyProducerLoadCallbacks{}; } - template + template struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { CUTLASS_DEVICE - ConsumerStoreCallbacks(RTensor&& tCrRow, STensor&& tCsRow, Params const& params) - : tCrRow(cute::forward(tCrRow)), - tCsRow(cute::forward(tCsRow)), - params(params) {} - - RTensor tCrRow; // (CPY,CPY_M,CPY_N) - STensor tCsRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) + ConsumerStoreCallbacks( + GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_, + GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_, + SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_, + CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const& params_) + : tGS_gRow(tGS_gRow_) + , tGS_sRow(tGS_sRow_) + , tGS_cRow(tGS_cRow_) + , tiled_G2S(tiled_g2s_) + , tSR_sRow(tSR_sRow_) + , tSR_rRow(tSR_rRow_) + , tCcRow(tCcRow_) + , residue_tCcRow(residue_tCcRow_) + , params(params_) {} + + GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N) + GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N) + GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N) + Tiled_G2S tiled_G2S; + + SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ThrResidue residue_tCcRow; // (m, n) + ThrNum thr_num; Params const& params; CUTLASS_DEVICE void - previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { + begin() { if (!params.row_broadcast) { - fill(tCrRow, *(params.ptr_row)); + fill(tSR_rRow, *(params.ptr_row)); return; } + auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); + Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); + Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); + + for (int i = 0; i < size(tGS_gRow_flt); ++i) { + if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { + continue; // OOB of SMEM, + } + if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) { + tGS_sRow_flt(i) = tGS_gRow_flt(i); + } + else { + tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds. + } + } + synchronize(); + } + + CUTLASS_DEVICE void + begin_loop(int epi_m, int epi_n) { if (epi_m == 0) { // Assumes M-major subtile loop - // Filter so we don't issue redundant copies over stride-0 modes - // (only works if 0-strides are in same location, which is by construction) - int bcast_pipe_index = (load_iteration / EpiTiles) % Stages; - copy_aligned(filter(tCsRow(_,_,_,epi_m,epi_n,bcast_pipe_index)), filter(tCrRow)); + if (!params.row_broadcast) return; // Do not issue LDS when row is scalar + Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); + Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); + copy(tSR_sRow_flt, tSR_rRow_flt); } } @@ -221,7 +221,7 @@ struct Sm90RowOrScalarBroadcast { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < FragmentSize; ++i) { - frg_row[i] = tCrRow(epi_v * FragmentSize + i); + frg_row[i] = tSR_rRow(epi_v * FragmentSize + i); } return frg_row; @@ -234,17 +234,41 @@ struct Sm90RowOrScalarBroadcast { > CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + using ThreadCount = decltype(size(args.tiled_copy)); - Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE) - make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), - make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); - Tensor tCsRow = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) - sRow, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCrRow = make_tensor_like(take<0,3>(tCsRow)); // (CPY,CPY_M,CPY_N) - - constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; - return ConsumerStoreCallbacks( - cute::move(tCrRow), cute::move(tCsRow), params); + Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); + Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) + Tensor sRow = make_tensor(make_smem_ptr(smem), + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) + //// G2S: Gmem to Smem + auto tiled_g2s = make_tiled_copy(Copy_Atom{}, + Layout< Shape<_1, ThreadCount>, + Stride<_0, _1>>{}, + Layout<_1>{}); + auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); + Tensor tGS_gRow = thr_g2s.partition_S(gRow); + Tensor tGS_sRow = thr_g2s.partition_D(sRow); + + //// G2S: Coord + auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}))); + Tensor tGS_cRow = thr_g2s.partition_S(cRow); + + //// S2R: Smem to Reg + Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) + + return ConsumerStoreCallbacks( + tGS_gRow, + tGS_sRow, + tGS_cRow, tiled_g2s, + tSR_sRow, + tSR_rRow, + args.tCcD, + args.residue_cD, + ThreadCount{}, + params); } }; @@ -285,6 +309,12 @@ struct Sm90ColOrScalarBroadcast { return args; } + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh index ba620e85117b5..be8a5c0e54e8e 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh @@ -10,8 +10,6 @@ #include "cute/atom/mma_atom.hpp" #include "cutlass/numeric_types.h" -#include "cutlass/util/device_memory.h" - #include "cutlass/cutlass.h" #include "cutlass/gemm_coord.h" #include "cutlass/arch/mma_sm75.h" @@ -301,12 +299,14 @@ inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, // Launch the CUTLASS GEMM kernel. typename Gemm::Op gemm_op; size_t workspace_size = gemm_op.get_workspace_size(args); - cutlass::device_memory::allocation workspace(workspace_size); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); CUTLASS_CHECK(gemm_op.can_implement(args)); - cutlass::Status status = gemm_op(args, workspace.get(), stream); + cutlass::Status status = gemm_op(args, workspace.data_ptr(), stream); CUTLASS_CHECK(status); } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu index b3f5b62086609..088185188770d 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu @@ -18,8 +18,6 @@ #include "cute/atom/mma_atom.hpp" #include "cutlass/numeric_types.h" -#include "cutlass/util/device_memory.h" - #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" @@ -72,13 +70,9 @@ struct ScaledEpilogueBase { 0 /*Stages*/, typename EpilogueDescriptor::TileShape, float, Stride, Int<0>, Int<0>>>; - using ScaleBDescriptor = - cutlass::epilogue::collective::detail::RowBroadcastDescriptor< - EpilogueDescriptor, float>; - using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast< - ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape, - typename ScaleBDescriptor::Element, Stride, Int<1>, Int<0>>>; + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, float, + Stride, Int<1>, Int<0>>>; }; /* @@ -154,12 +148,8 @@ struct ScaledEpilogueBias cutlass::multiply_add, ElementD, float, cutlass::FloatRoundStyle::round_to_nearest>; - using BiasDescriptor = - cutlass::epilogue::collective::detail::RowBroadcastDescriptor< - EpilogueDescriptor, ElementD>; - using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast< - BiasDescriptor::Stages, typename EpilogueDescriptor::TileShape, ElementD, + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, ElementD, Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, false>; public: @@ -251,12 +241,12 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, int64_t ldb = b.stride(1); int64_t ldc = out.stride(0); - using StrideA = Stride, Int<0>>; - using StrideB = Stride, Int<0>>; + using StrideA = Stride, int64_t>; + using StrideB = Stride, int64_t>; using StrideC = typename Gemm::StrideC; - StrideA a_stride{lda, Int<1>{}, Int<0>{}}; - StrideB b_stride{ldb, Int<1>{}, Int<0>{}}; + StrideA a_stride{lda, Int<1>{}, 0}; + StrideB b_stride{ldb, Int<1>{}, 0}; StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; using GemmKernel = typename Gemm::GemmKernel; @@ -282,11 +272,13 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, CUTLASS_CHECK(gemm_op.can_implement(args)); size_t workspace_size = gemm_op.get_workspace_size(args); - cutlass::device_memory::allocation workspace(workspace_size); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); - cutlass::Status status = gemm_op.run(args, workspace.get(), stream); + cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); CUTLASS_CHECK(status); }