diff --git a/csrc/cutlass_extensions/cute_utils.cuh b/csrc/cutlass_extensions/cute_utils.cuh index 114a14cd61b88..1842fab8b2cac 100644 --- a/csrc/cutlass_extensions/cute_utils.cuh +++ b/csrc/cutlass_extensions/cute_utils.cuh @@ -25,8 +25,9 @@ CUTE_HOST_DEVICE static constexpr bool is_identity_layout() { else { constexpr auto coalesced_layout = coalesce(Layout{}); if constexpr (rank(coalesced_layout) == 1 && - stride<0>(coalesced_layout) == 1) + stride<0>(coalesced_layout) == 1) { return true; + } return false; } } @@ -51,16 +52,17 @@ static constexpr auto get_logical_ptr(PointerType* ptr) { template CUTE_HOST_DEVICE static constexpr auto create_auto_vectorizing_copy() { constexpr auto bits = sizeof_bits_v * Elements{}; - if constexpr (bits % 128 == 0) + if constexpr (bits % 128 == 0) { return AutoVectorizingCopyWithAssumedAlignment<128>{}; - else if constexpr (bits % 64 == 0) + } else if constexpr (bits % 64 == 0) { return AutoVectorizingCopyWithAssumedAlignment<64>{}; - else if constexpr (bits % 32 == 0) + } else if constexpr (bits % 32 == 0) { return AutoVectorizingCopyWithAssumedAlignment<32>{}; - else if constexpr (bits % 16 == 0) + } else if constexpr (bits % 16 == 0) { return AutoVectorizingCopyWithAssumedAlignment<16>{}; - else + } else { return AutoVectorizingCopyWithAssumedAlignment<8>{}; + } } }; // namespace cute diff --git a/csrc/cutlass_extensions/torch_utils.hpp b/csrc/cutlass_extensions/torch_utils.hpp index ec8b21a62f894..1618a340ce10e 100644 --- a/csrc/cutlass_extensions/torch_utils.hpp +++ b/csrc/cutlass_extensions/torch_utils.hpp @@ -17,7 +17,7 @@ namespace detail { template CUTE_HOST_DEVICE constexpr auto tapply_with_idx(T&& t, F&& f, G&& g, seq) { - return g(f(get(static_cast(t)), I)...); + return g(f(cute::get(static_cast(t)), I)...); } template @@ -29,7 +29,7 @@ CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f, seq) { template CUTE_HOST_DEVICE constexpr auto transform_with_idx(T const& t, F&& f) { - if constexpr (is_tuple::value) { + if constexpr (cute::is_tuple::value) { return detail::tapply_with_idx( t, f, [](auto const&... a) { return cute::make_tuple(a...); }, tuple_seq{}); @@ -72,8 +72,9 @@ static inline auto make_cute_layout(torch::Tensor const& tensor, } } else { // Extra strides are assumed to be 0 or 1 - if constexpr (cute::is_static_v) + if constexpr (cute::is_static_v) { static_assert(StrideEle::value == 0 || StrideEle::value == 1); + } return StrideEle{}; } }); diff --git a/csrc/cutlass_extensions/vllm_numeric_conversion.cuh b/csrc/cutlass_extensions/vllm_numeric_conversion.cuh index 7561b2505b10e..2ad914f8e9868 100644 --- a/csrc/cutlass_extensions/vllm_numeric_conversion.cuh +++ b/csrc/cutlass_extensions/vllm_numeric_conversion.cuh @@ -524,7 +524,7 @@ struct NumericArrayConverter { // Below constructs the following temporary: uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3}; static_assert(RegArray::kElements <= 4, - "Too many inputs for BF16 -> I4 vector converter"); + "Too many inputs for uint4b8_t -> BF16 vector converter"); CUTLASS_PRAGMA_UNROLL for (int ii = 0; ii < RegArray::kElements; ++ii) { asm volatile( diff --git a/csrc/ops.h b/csrc/ops.h index b7f07f5da7d1e..6bf0cff232528 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -88,7 +88,7 @@ namespace machete { std::vector supported_schedules( vllm::ScalarTypeTorchPtr const& btype); -torch::Tensor gemm(torch::Tensor const A, torch::Tensor const B, +torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B, vllm::ScalarTypeTorchPtr const& btype, c10::optional const& scales, c10::optional const& zeros, @@ -97,7 +97,7 @@ torch::Tensor gemm(torch::Tensor const A, torch::Tensor const B, c10::optional alpha, c10::optional beta, c10::optional schedule); -torch::Tensor prepack_B(torch::Tensor const B, +torch::Tensor prepack_B(torch::Tensor const& B, vllm::ScalarTypeTorchPtr const& btype); }; // namespace machete diff --git a/csrc/quantization/machete/machete_mainloop.cuh b/csrc/quantization/machete/machete_mainloop.cuh index c711748ec37b9..3d574ad99efda 100644 --- a/csrc/quantization/machete/machete_mainloop.cuh +++ b/csrc/quantization/machete/machete_mainloop.cuh @@ -1,5 +1,27 @@ +// // Based off of: -// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp +// cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp +// Specifically: +// https://github.com/NVIDIA/cutlass/tree/06b21349bcf6ddf6a1686a47a137ad1446579db9/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp +// Referred to as upstream from in the comments +// +// The main optimization machete implements compared to upstream is to prepack +// the weight matrix to more closely match the shape of the wgmma instructions +// allowing for wider (ideally 128bit) shared memory loads. For subbyte types +// this is done by packing values from multiple wgmma loads (for a single +// thread) into a single 128bit load. This is very similar to layout used in +// Marlin, although specific to the wgmma instructions. +// +// Since the wgmma instructions only support sourcing from registers for the A +// operand, and we want to upconvert/decompress the weight values/elements +// before feeding them into the tensor cores in registers, we need the weight +// matrix to be A. To achieve this we compute the transpose of Y = XW^t as +// Y^t = W^tX^t. This is mostly done outside of this file in +// csrc/quantization/machete/machete_mm_kernel.cuh, but this why A is the +// quantized/narrow type and has the prepacked layout despite the API being: +// B_prepacked = machete_prepack_B(B) +// Y = machete_mm(A, B_prepacked) +// #pragma once // clang-format off @@ -87,6 +109,9 @@ struct MacheteCollectiveMma { gmma_rs_tag_to_major_A(); static constexpr cute::GMMA::Major GmmaMajorB = gmma_rs_tag_to_major_B(); + + // For coop schedules we have two warp groups cooperatively issuing wgmma + // instructions so we use 2 atoms along the M dim (one for each warpgroup) using AtomLayoutMNK = cute::conditional_t< cute::is_same_v, @@ -98,6 +123,23 @@ struct MacheteCollectiveMma { AtomLayoutMNK{})); private: + // + // the setup section (until "section setup end") contains a combination of + // modified code from (used as a starting point): + // `cutlass/gemm/collective/builders/sm90_gmma_builder.inl` + // `cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp` + // (upstream) + // + // however in-order to simplify the code we combine a lot of the logic from + // `CollectiveMma` and `CollectiveBuilder` into this class, this also makes + // sense given that we have flexibility on layouts here. We also simplify the + // code by only supporting scales and zeros for A (in the transposed problem, + // B from an API perspective), also since we force A to be the narrow type + // (i.e. the type to be upconverted) we can remove all the `SwapAB` logic in + // the upstream also simplifying the code. This section includes new logic + // (compared ustream) for handling the prepacked-A layouts (in the transposed + // problem, B from an API perspective) + // using ElementScale = deduce_mixed_width_dtype_t<1, ElementATuple_>; using ElementZero = deduce_mixed_width_dtype_t<2, ElementATuple_>; @@ -321,6 +363,7 @@ struct MacheteCollectiveMma { KernelConversionMode == ConversionMode::ConvertAndScale || KernelConversionMode == ConversionMode::ConvertAndScaleWithZero; + // Same as upstream, should be kept the same when possible static constexpr auto elements_per_smem_scale() { if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { return 0; @@ -332,6 +375,7 @@ struct MacheteCollectiveMma { } } + // Same as upstream, should be kept the same when possible static constexpr auto elements_per_smem_zero() { if constexpr (KernelConversionMode == ConversionMode::DirectConvert || KernelConversionMode == ConversionMode::ConvertAndScale) { @@ -345,49 +389,43 @@ struct MacheteCollectiveMma { } } - // These methods use some the public members of the class. For that reason, we - // define them after the public section. - static constexpr uint32_t compute_tma_transaction_bytes_mk() { - constexpr uint32_t baseline_bytes = cutlass::bits_to_bytes( - size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * - static_cast(cute::sizeof_bits_v)); + // Same as upstream, should be kept the same when possible, not formatte for + // easier comparison + // clang-format off + // These methods use some the public members of the class. For that reason, we define them after the public section. + static constexpr uint32_t + compute_tma_transaction_bytes_mk() { + constexpr uint32_t baseline_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(cute::sizeof_bits_v)); if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { return baseline_bytes; - } else if constexpr (ModeHasScales) { - constexpr uint32_t scale_tx_bytes = - (size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * - static_cast(cute::sizeof_bits_v) / 8); - static_assert( - scale_tx_bytes % 128 == 0, - "Each scale stage must be 128B aligned."); // required by TMA + } + else if constexpr (ModeHasScales) { + constexpr uint32_t scale_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast(cute::sizeof_bits_v)); + static_assert(scale_tx_bytes % 128 == 0, "Each scale stage must be 128B aligned."); // required by TMA if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { return baseline_bytes + scale_tx_bytes; - } else if constexpr (KernelConversionMode == - ConversionMode::ConvertAndScaleWithZero) { + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { // Scale and zero share smem layout - constexpr uint32_t zero_tx_bytes = - (size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * - static_cast(cute::sizeof_bits_v) / 8); - static_assert( - zero_tx_bytes % 128 == 0, - "Each zero stage must be 128B aligned."); // required by TMA + constexpr uint32_t zero_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast(cute::sizeof_bits_v)); + static_assert(zero_tx_bytes % 128 == 0, "Each zero stage must be 128B aligned."); // required by TMA return baseline_bytes + scale_tx_bytes + zero_tx_bytes; - } else { - static_assert(cutlass::detail::dependent_false, - "Type not handled in tma transaction bytes computation."); } - } else { - static_assert(cutlass::detail::dependent_false, - "Type not handled in tma transaction bytes computation."); + else { + static_assert(cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); } } - static constexpr uint32_t compute_tma_transaction_bytes_nk() { - return cutlass::bits_to_bytes( - size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * - static_cast(cute::sizeof_bits_v)); + static constexpr uint32_t + compute_tma_transaction_bytes_nk() { + return cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(cute::sizeof_bits_v)); } + // clang-format on // ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx) using PrepackedStrideA = decltype(stride(GmemLayoutA::TVbNbKL_to_offset( @@ -441,29 +479,26 @@ struct MacheteCollectiveMma { } public: - static constexpr size_t SmemAlignmentA = - cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // with `RealInternalElementA` -> `ElementA` since we support `SwapAB` logic + // clang-format off + static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); - static constexpr size_t SmemAlignmentB = - cutlass::detail::alignment_for_swizzle(SmemLayoutB{}); + static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{}); - // Just pick the max alignment of A and B since it is required to be at least - // 128B - static constexpr size_t SmemAlignmentScale = - cute::max(SmemAlignmentA, SmemAlignmentB); + // Just pick the max alignment of A and B since it is required to be at least 128B + static constexpr size_t SmemAlignmentScale = cute::max(SmemAlignmentA, SmemAlignmentB); - static_assert(SmemAlignmentA >= 128 and SmemAlignmentB >= 128, - "Require at least 128B alignment"); + static_assert(SmemAlignmentA >= 128 and SmemAlignmentB >= 128, "Require at least 128B alignment"); - struct SharedStorage { + struct SharedStorage + { static constexpr int scale_elements = elements_per_smem_scale(); static constexpr int zero_elements = elements_per_smem_zero(); - struct TensorStorage - : cute::aligned_struct { + struct TensorStorage : cute::aligned_struct { cute::ArrayEngine> smem_A; - cute::ArrayEngine> - smem_B; + cute::ArrayEngine> smem_B; cute::ArrayEngine smem_scale; cute::ArrayEngine smem_zero; } tensors; @@ -471,7 +506,6 @@ struct MacheteCollectiveMma { using PipelineStorage = typename MainloopPipeline::SharedStorage; PipelineStorage pipeline; }; - using TensorStorage = typename SharedStorage::TensorStorage; using PipelineStorage = typename SharedStorage::PipelineStorage; @@ -487,7 +521,16 @@ struct MacheteCollectiveMma { ElementZero const* ptr_Z = nullptr; uint32_t mma_promotion_interval = 4; }; + // clang-format on + + // + // section setup end + // + // Similar (but not idendtical) to upstream, should be kept the same when + // possible + // compared to upstream we use `make_tma_copy_A`, `make_tma_copy_B` etc. to + // define the TMA types // Device side kernel params struct Params { public: @@ -497,6 +540,8 @@ struct MacheteCollectiveMma { using TMA_Zero = decltype(make_tma_copy_zero()); using TMA_B = decltype(make_tma_copy_B()); + // required by outer loop: i.e. + // cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp TMA_A tma_load_a; TMA_B tma_load_b; TMA_Scale tma_load_scale; @@ -512,6 +557,10 @@ struct MacheteCollectiveMma { // Methods // + // Similar (but not idendtical) to upstream, should be kept the same when + // possible + // compared to upstream we use `make_tma_copy_A` and `TVbNbKL_to_offset` here + // to handle the prepacked layout template static constexpr Params to_underlying_arguments( ProblemShape const& problem_shape, Arguments const& args, @@ -566,109 +615,91 @@ struct MacheteCollectiveMma { } } - template - CUTLASS_HOST_DEVICE static bool can_implement( + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // with `SwapAB ? N : M -> M` since we dont support SwapAB + // clang-format off + template + static bool + can_implement( ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) { constexpr int tma_alignment_bits = 128; auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_MNKL; - + auto [M,N,K,L] = problem_shape_MNKL; + bool implementable = true; - constexpr int min_tma_aligned_elements_A = - tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = - implementable && - cutlass::detail::check_alignment( - cute::make_shape(M, K, L), StrideA{}); - constexpr int min_tma_aligned_elements_B = - tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = - implementable && - cutlass::detail::check_alignment( - cute::make_shape(N, K, L), StrideB{}); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { implementable = implementable && (args.ptr_S == nullptr); implementable = implementable && (args.ptr_Z == nullptr); - } else if constexpr (ModeHasScales) { + } + else if constexpr (ModeHasScales) { const int scale_mn = M; const int scale_k = (K + args.group_size - 1) / args.group_size; - constexpr int min_tma_aligned_elements_scale = - tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = - implementable && - cutlass::detail::check_alignment( - cute::make_shape(scale_mn, scale_k, L), StrideScale{}); - implementable = - implementable && (args.group_size == K || - ((args.group_size % size<2>(TileShape{})) == 0)); + constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(scale_mn,scale_k,L), StrideScale{}); + implementable = implementable && (args.group_size == K || ((args.group_size % size<2>(TileShape{})) == 0)); implementable = implementable && args.group_size != 0; implementable = implementable && (args.ptr_S != nullptr); if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { implementable = implementable && (args.ptr_Z == nullptr); - } else if constexpr (KernelConversionMode == - ConversionMode::ConvertAndScaleWithZero) { - constexpr int min_tma_aligned_elements_zero = - tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = - implementable && - cutlass::detail::check_alignment( - cute::make_shape(scale_mn, scale_k, L), StrideScale{}); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + constexpr int min_tma_aligned_elements_zero = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(scale_mn,scale_k,L), StrideScale{}); implementable = implementable && (args.ptr_Z != nullptr); - } else { - static_assert(cutlass::detail::dependent_false, - "Conversion mode not handled in can_implement."); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); } - } else { - static_assert(cutlass::detail::dependent_false, - "Conversion mode not handled in can_implement."); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); } if (!implementable) { - CUTLASS_TRACE_HOST( - " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment " - "requirements for TMA.\n"); + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); } return implementable; } static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; - static constexpr uint32_t TmaTransactionBytesMK = - compute_tma_transaction_bytes_mk(); - static constexpr uint32_t TmaTransactionBytesNK = - compute_tma_transaction_bytes_nk(); - static constexpr uint32_t TmaTransactionBytes = - TmaTransactionBytesMK + TmaTransactionBytesNK; - - // Issue Tma Descriptor Prefetch -- ideally from a single thread for best - // performance + static constexpr uint32_t TmaTransactionBytesMK = compute_tma_transaction_bytes_mk(); + static constexpr uint32_t TmaTransactionBytesNK = compute_tma_transaction_bytes_nk(); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance CUTLASS_DEVICE static void prefetch_tma_descriptors(Params const& mainloop_params) { - cute::prefetch_tma_descriptor( - mainloop_params.tma_load_a.get_tma_descriptor()); - cute::prefetch_tma_descriptor( - mainloop_params.tma_load_b.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { // Nothing extra to do - } else if constexpr (KernelConversionMode == - ConversionMode::ConvertAndScale) { - cute::prefetch_tma_descriptor( - mainloop_params.tma_load_scale.get_tma_descriptor()); - } else if constexpr (KernelConversionMode == - ConversionMode::ConvertAndScaleWithZero) { - cute::prefetch_tma_descriptor( - mainloop_params.tma_load_scale.get_tma_descriptor()); - cute::prefetch_tma_descriptor( - mainloop_params.tma_load_zero.get_tma_descriptor()); - } else { - static_assert(cutlass::detail::dependent_false, - "Conversion mode not handled in TMA prefetch."); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor()); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_zero.get_tma_descriptor()); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in TMA prefetch."); } + } + // clang-format off + // Modified from upstream, should be kept close to that when possible + // the main difference is special handling for the prepacked A layout + // // Set up the data needed by this collective for load and mma. // Returns a tuple of tensors. The collective and the kernel layer have the // contract Returned tuple must contain at least two elements, with the first @@ -734,69 +765,72 @@ struct MacheteCollectiveMma { } } - // Perform a collective-scoped matrix multiply-accumulate - // Producer Perspective - // This overload gets triggered when we have scales. - template - CUTLASS_DEVICE void load(Params const& mainloop_params, - MainloopPipeline pipeline, - PipelineState smem_pipe_write, - cute::tuple const& load_inputs, - BlockCoord const& blk_coord, - KTileIterator k_tile_iter, int k_tile_count, - int thread_idx, uint32_t block_rank_in_cluster, - TensorStorage& shared_tensors) { + // Similar to upstream, should be kept close to that when possible + // the main difference is in the layout comments + // clang-format off + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + /// This overload gets triggered when we have scales. + template < + class... Ts, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { - static_assert(sizeof...(Ts) == 2, "Direct convert needs two inputs"); - } else if constexpr (KernelConversionMode == - ConversionMode::ConvertAndScale) { - static_assert(sizeof...(Ts) == 3, "Scaled convert needs three inputs"); - } else if constexpr (KernelConversionMode == - ConversionMode::ConvertAndScaleWithZero) { - static_assert(sizeof...(Ts) == 4, - "Scaled and zero convert needs four inputs"); - } else { - static_assert(cutlass::detail::dependent_false, - "Conversion mode not handled in TMA load."); + static_assert(sizeof... (Ts) == 2, "Direct convert needs two inputs"); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + static_assert(sizeof... (Ts) == 3, "Scaled convert needs three inputs"); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + static_assert(sizeof... (Ts) == 4, "Scaled and zero convert needs four inputs"); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in TMA load."); } int lane_predicate = cute::elect_one_sync(); if (lane_predicate) { - Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), - SmemLayoutACopy{}); // (TILE_V,TILE_B,PIPE) - Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), - SmemLayoutB{}); // (TILE_N,TILE_K,PIPE) - Tensor sB = - as_position_independent_swizzle_tensor(sB_); // (TILE_N,TILE_K,PIPE) + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) + Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_N,BLK_K,PIPE) // // Prepare the TMA loads for A, B and Scales // - + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); - uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, - block_rank_in_cluster / cluster_shape_x}; + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; Tensor gA_mkl = get<0>(load_inputs); Tensor gB_nkl = get<1>(load_inputs); - auto block_tma_a = - mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); - auto block_tma_b = - mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); // Partition the inputs based on the current block coordinates. auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; - Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (TILE_V,TILE_B,k) - Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (TILE_N,TILE_K,k) + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (TILE_V,TILE_B,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (TILE_N,TILE_K,k) // Applies the mapping from block_tma_a - Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) - Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) - Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) - Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) uint16_t mcast_mask_a = 0; uint16_t mcast_mask_b = 0; @@ -805,32 +839,24 @@ struct MacheteCollectiveMma { // Issue TmaLoads // Maps the tile -> block, value if constexpr (cute::is_same_v) { - auto block_layout = - Layout{}; // (m,n) -> - // block_id + auto block_layout = Layout{}; // (m,n) -> block_id for (int n = 0; n < size<1>(block_layout); ++n) { - mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, - n, Int<0>{})); + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); } } if constexpr (cute::is_same_v) { - auto block_layout = - Layout{}; // (m,n) -> - // block_id + auto block_layout = Layout{}; // (m,n) -> block_id for (int m = 0; m < size<0>(block_layout); ++m) { - mcast_mask_b |= (uint16_t(1) << block_layout( - m, cluster_local_block_id.y, Int<0>{})); + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); } } - auto extra_input_partitions = partition_extra_tma_inputs( - mainloop_params, load_inputs, shared_tensors, cluster_local_block_id, - m_coord, l_coord); + auto extra_input_partitions = partition_extra_tma_inputs(mainloop_params, load_inputs, shared_tensors, cluster_local_block_id, m_coord, l_coord); // Mainloop CUTLASS_PRAGMA_NO_UNROLL - for (; k_tile_count > 0; --k_tile_count) { + for ( ; k_tile_count > 0; --k_tile_count) { // LOCK smem_pipe_write for _writing_ pipeline.producer_acquire(smem_pipe_write); @@ -839,51 +865,41 @@ struct MacheteCollectiveMma { // using BarrierType = typename MainloopPipeline::ProducerBarrierType; - BarrierType* tma_barrier = - pipeline.producer_get_barrier(smem_pipe_write); + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); int write_stage = smem_pipe_write.index(); - copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), - tAgA(_, _, _, *k_tile_iter), tAsA(_, _, _, write_stage)); - copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), - tBgB(_, _, _, *k_tile_iter), tBsB(_, _, _, write_stage)); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { // Nothing extra to do. - } else if constexpr (ModeHasScales) { + } + else if constexpr (ModeHasScales) { auto tSgS = get<0>(extra_input_partitions); auto tSsS = get<1>(extra_input_partitions); - // Temporary factor which will determine which k tile to reload from - // gmem. Needed so we don't modify tma transaction bytes on the fly. - // We must do a ceiling divide here to correctly handle with - // group_size == K. In that case, we don't require that K is a - // multiple of the threadblock tile K - const int ReloadFactor = - (mainloop_params.group_size + size<2>(TileShape{}) - 1) / - size<2>(TileShape{}); - const int scale_load_k = - *k_tile_iter / - ReloadFactor; // This will always be 0 when group_size == K. - copy(mainloop_params.tma_load_scale.with(*tma_barrier, mcast_mask_s), - tSgS(_, _, _, scale_load_k), tSsS(_, _, _, write_stage)); - - if constexpr (KernelConversionMode == - ConversionMode::ConvertAndScale) { + // Temporary factor which will determine which k tile to reload from gmem. Needed so we don't modify tma transaction bytes + // on the fly. + // We must do a ceiling divide here to correctly handle with group_size == K. In that case, we don't require that K + // is a multiple of the threadblock tile K + const int ReloadFactor = (mainloop_params.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}); + const int scale_load_k = *k_tile_iter / ReloadFactor; // This will always be 0 when group_size == K. + copy(mainloop_params.tma_load_scale.with(*tma_barrier, mcast_mask_s), tSgS(_,_,_,scale_load_k), tSsS(_,_,_,write_stage)); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { // Nothing extra to do - } else if constexpr (KernelConversionMode == - ConversionMode::ConvertAndScaleWithZero) { + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { auto tZgZ = get<2>(extra_input_partitions); auto tZsZ = get<3>(extra_input_partitions); - copy(mainloop_params.tma_load_zero.with(*tma_barrier, mcast_mask_s), - tZgZ(_, _, _, scale_load_k), tZsZ(_, _, _, write_stage)); - } else { - static_assert(cutlass::detail::dependent_false, - "Conversion mode not handled for TMA copy op."); + copy(mainloop_params.tma_load_zero.with(*tma_barrier, mcast_mask_s), tZgZ(_,_,_,scale_load_k), tZsZ(_,_,_,write_stage)); } - } else { - static_assert(cutlass::detail::dependent_false, - "Conversion mode not handled for TMA copy op."); + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); } ++k_tile_iter; @@ -893,24 +909,33 @@ struct MacheteCollectiveMma { } } } + // clang-format off + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // clang-format off // Perform a Producer Epilogue to prevent early exit of blocks in a Cluster - CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, - PipelineState smem_pipe_write) { + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { int lane_predicate = cute::elect_one_sync(); // Issue the epilogue waits if (lane_predicate) { /* This helps avoid early exit of blocks in Cluster - * Waits for all stages to either be released (all + * Waits for all stages to either be released (all * Consumer UNLOCKs), or if the stage was never used - * then would just be acquired since the phase was + * then would just be acquired since the phase was * still inverted from make_producer_start_state */ pipeline.producer_tail(smem_pipe_write); } } + // clang-format on + // Modified from upstream, should be kept close to that when possible + // the main differences are handling the prepacked A layout, and separating + // the loading of A from upcoverting A + // // Perform a collective-scoped matrix multiply-accumulate // Consumer Perspective template @@ -1168,118 +1193,134 @@ struct MacheteCollectiveMma { } private: - // Utilities for any additional inputs inside of the TMA load + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // clang-format off + /// Utilities for any additional inputs inside of the TMA load template - CUTLASS_DEVICE auto partition_extra_tma_inputs( - Params const& mainloop_params, cute::tuple const& load_inputs, - TensorStorage& shared_tensors, uint2 const& cluster_local_block_id, - int const m_coord, int const l_coord) { + CUTLASS_DEVICE + auto partition_extra_tma_inputs( + Params const& mainloop_params, + cute::tuple const& load_inputs, + TensorStorage& shared_tensors, + uint2 const& cluster_local_block_id, + int const m_coord, + int const l_coord) { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { - return cute::tuple{}; - } else if constexpr (ModeHasScales) { - Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), - SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) + return cute::make_tuple(); + } + else if constexpr (ModeHasScales) { + Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) Tensor gS_mkl = get<2>(load_inputs); - auto block_tma_s = - mainloop_params.tma_load_scale.get_slice(cluster_local_block_id.y); - Tensor gS = gS_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) + auto block_tma_s = mainloop_params.tma_load_scale.get_slice(cluster_local_block_id.y); + Tensor gS = gS_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - Tensor tSgS = block_tma_s.partition_S(gS); // (TMA,TMA_M,TMA_K,k) - Tensor tSsS = block_tma_s.partition_D(sS); // (TMA,TMA_M,TMA_K,PIPE) + Tensor tSgS = block_tma_s.partition_S(gS); // (TMA,TMA_M,TMA_K,k) + Tensor tSsS = block_tma_s.partition_D(sS); // (TMA,TMA_M,TMA_K,PIPE) if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { return cute::make_tuple(tSgS, tSsS); - } else if constexpr (KernelConversionMode == - ConversionMode::ConvertAndScaleWithZero) { - Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), - SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) Tensor gZ_mkl = get<3>(load_inputs); - auto block_tma_z = - mainloop_params.tma_load_zero.get_slice(cluster_local_block_id.y); - Tensor gZ = gZ_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) + auto block_tma_z = mainloop_params.tma_load_zero.get_slice(cluster_local_block_id.y); + Tensor gZ = gZ_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - Tensor tZgZ = block_tma_z.partition_S(gZ); // (TMA,TMA_M,TMA_K,k) - Tensor tZsZ = block_tma_z.partition_D(sZ); // (TMA,TMA_M,TMA_K,PIPE) - return cute::make_tuple(tSgS, tSsS, tZgZ, tZsZ); - } else { - static_assert(cutlass::detail::dependent_false, - "Conversion mode not handled for input partitioning."); + Tensor tZgZ = block_tma_z.partition_S(gZ); // (TMA,TMA_M,TMA_K,k) + Tensor tZsZ = block_tma_z.partition_D(sZ); // (TMA,TMA_M,TMA_K,PIPE) + return cute::make_tuple(tSgS, tSsS, tZgZ, tZsZ); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); } - } else { - static_assert(cutlass::detail::dependent_false, - "Conversion mode not handled for input partitioning."); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); } } + // clang-format off - // Utilities for partitioning extra inputs for loading from smem in the - // mainloop. + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // clang-format off + /// Utilities for partitioning extra inputs for loading from smem in the mainloop. template - CUTLASS_DEVICE auto partition_extra_mma_info(ThreadMma const& thread_mma, - TensorStorage& shared_tensors) { + CUTLASS_DEVICE + auto partition_extra_mma_info( + ThreadMma const& mma_thread_slice, + TensorStorage& shared_tensors) { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { - // noting to do - return cute::tuple{}; - } else if constexpr (ModeHasScales) { - Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), - SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE) - Tensor tCsS = thread_mma.partition_A(sS); - Tensor tCrS = make_tensor( - thread_mma.partition_fragment_A(sS(_, _, Int<0>{})).shape()); + // nothing to do + return cute::make_tuple(); + } + else if constexpr (ModeHasScales) { + Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsS = mma_thread_slice.partition_A(sS); + Tensor tCrS = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).shape()); if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { return cute::make_tuple(tCsS, tCrS); - } else if constexpr (KernelConversionMode == - ConversionMode::ConvertAndScaleWithZero) { - Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), - SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE) - Tensor tCsZ = thread_mma.partition_A(sZ); - Tensor tCrZ = make_tensor( - thread_mma.partition_fragment_A(sZ(_, _, Int<0>{})).shape()); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsZ = mma_thread_slice.partition_A(sZ); + Tensor tCrZ = make_tensor(mma_thread_slice.partition_fragment_A(sZ(_,_,Int<0>{})).shape()); return cute::make_tuple(tCsS, tCrS, tCsZ, tCrZ); - } else { - static_assert(cutlass::detail::dependent_false, - "Conversion mode not handled in A -> RF path."); } - } else { - static_assert(cutlass::detail::dependent_false, - "Conversion mode not handled in A -> RF path."); + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); } } + // clang-format on - // Returns the tiled copy and copy views for the extra inputs. + // Same as upstream, should be kept the same when possible, not formatted for + // easier comparison + // clang-format off + /// Returns the tiled copy and copy views for the extra inputs. template - CUTLASS_DEVICE auto retile_extra_mma_info( - TiledMma const& tiled_mma, cute::tuple& partitioned_extra_info, - int const warp_group_thread_idx) { - if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { - // noting to do - return cute::tuple{}; - } else if constexpr (ModeHasScales) { - auto smem_tiled_copy_S = - make_tiled_copy_A(SmemCopyAtomScale{}, tiled_mma); - auto smem_thr_copy_S = - smem_tiled_copy_S.get_thread_slice(warp_group_thread_idx); - Tensor tCrS_copy_view = smem_thr_copy_S.retile_D( - cute::get<1>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) + CUTLASS_DEVICE + auto retile_extra_mma_info( + TiledMma const& tiled_mma, + cute::tuple& partitioned_extra_info, + int const warp_group_thread_idx) { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + return cute::make_tuple(); + } + else if constexpr (ModeHasScales) { + auto smem_tiled_copy_S = make_tiled_copy_A(SmemCopyAtomScale{}, tiled_mma); + auto smem_thr_copy_S = smem_tiled_copy_S.get_thread_slice(warp_group_thread_idx); + Tensor tCrS_copy_view = smem_thr_copy_S.retile_D(cute::get<1>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view); - } else if constexpr (KernelConversionMode == - ConversionMode::ConvertAndScaleWithZero) { - Tensor tCrZ_copy_view = smem_thr_copy_S.retile_D( - cute::get<3>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) - return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view, - tCrZ_copy_view); - } else { - static_assert(cutlass::detail::dependent_false, - "Conversion mode not handled in A -> RF path."); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor tCrZ_copy_view = smem_thr_copy_S.retile_D(cute::get<3>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) + return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view, tCrZ_copy_view); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); } - } else { - static_assert(cutlass::detail::dependent_false, - "Conversion mode not handled in A -> RF path."); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); } } + // clang-format on - // Utilities to copy A and extra inputs from smem to RF + // Similar to `copy_A_and_extra_info` upstream, should be kept the same when + // possible + // the main differences this only loads the extra info into registers and + // not A (since we now preload more of A in the main pipeline) + // Load scales and zeros into registers if required template CUTLASS_DEVICE void load_extra_info_to_registers( cute::tuple const& partitioned_mma_extra_info, @@ -1314,6 +1355,10 @@ struct MacheteCollectiveMma { } } + // Similar to upstream, should be kept the same when possible. + // the main differences are that `convert_tensor` supports interleaved + // layouts and bfloat16 has been optimized. `transform_internal_A` has also + // been inlined for code simplicity. // Utilities to transform A. template CUTLASS_DEVICE void transform_A_kblock( @@ -1373,6 +1418,8 @@ struct MacheteCollectiveMma { } } + // Modified from upstream, should be kept the same when possible + // the main differences is that this version supports interleaved converts // Utilities for transforming the A operand prior to issuing tensorcore math. template @@ -57,6 +68,12 @@ struct PrepackedLayoutBTemplate { // TODO (LucasWilkinson): compare the performance for other sizes // Prepacked block shape, smallest layout atom for loading into registers // (can contain multiple wgmma instructions worth of data in one block) + // We ideally want this to be configured such that a thread can perform 128bit + // loads, i.e. we amount of data associated with each thread within a + // prepacked block is a multiple of 128bits, when using a cooperative sechdule + // we have 256 threads working a single block at a time, this means each + // thread works on `sizeof_bits_v * (128*64) / 256` bits of data, + // for a 4bit type this would be 128bits using PPBlockShape_NK = Shape<_128, _64>; // Create the shape of the tile anticipated to be used by the GEMM kernel, @@ -70,6 +87,9 @@ struct PrepackedLayoutBTemplate { static constexpr cute::GMMA::Major GmmaMajorB = gmma_rs_tag_to_major_B(); + + // For coop schedules we have two warp groups cooperatively issuing wgmma + // instructions so we use 2 atoms along the M dim (one for each warpgroup) using AtomLayoutMNK = cute::conditional_t< cute::is_same_v, diff --git a/csrc/quantization/machete/machete_pytorch.cu b/csrc/quantization/machete/machete_pytorch.cu index 0f68dfdcd0528..ef36a490c3c50 100644 --- a/csrc/quantization/machete/machete_pytorch.cu +++ b/csrc/quantization/machete/machete_pytorch.cu @@ -42,7 +42,7 @@ std::vector supported_schedules(ScalarTypeTorchPtr const& btype) { }); } -torch::Tensor gemm(torch::Tensor const A, torch::Tensor const B, +torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B, ScalarTypeTorchPtr const& btype, c10::optional const& scales, c10::optional const& zeros, @@ -69,7 +69,7 @@ torch::Tensor gemm(torch::Tensor const A, torch::Tensor const B, }); } -torch::Tensor prepack_B(torch::Tensor const B, +torch::Tensor prepack_B(torch::Tensor const& B, ScalarTypeTorchPtr const& btype) { return scalar_type_dispatch(*btype, [&](auto BType) { return PrepackBDispatcher::dispatch(B); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index f4c8d406c671b..6d1f53b75f4e2 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -135,11 +135,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Machete (Dense) Optimized Mixed Precision GEMM for Hopper. ops.def("machete_supported_schedules", &machete::supported_schedules); - ops.impl("machete_supported_schedules", torch::kCPU, - &machete::supported_schedules); - ops.def("machete_gemm", &machete::gemm); + ops.def( + "machete_gemm(Tensor A, Tensor B," + " __torch__.torch.classes._core_C.ScalarType btype," + " Tensor? scales, Tensor? zeros, int? group_size," + " Tensor? C, float? alpha, float? beta, str? schedule)" + "-> Tensor"); ops.impl("machete_gemm", torch::kCUDA, &machete::gemm); - ops.def("machete_prepack_B", &machete::prepack_B); + ops.def( + "machete_prepack_B(Tensor B," + " __torch__.torch.classes._core_C.ScalarType btype)" + "-> Tensor"); ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B); // gptq_marlin Optimized Quantized GEMM for GPTQ. diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 1c285cfbfb20f..b89a90ef0f70c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -336,7 +336,7 @@ def machete_supported_schedules(b_type: ScalarType) -> List[str]: def machete_gemm( a: torch.Tensor, - b_q: torch.Tensor, + b_q: torch.Tensor, # Should be the tensor returned by machete_prepack_B b_type: ScalarType, b_scales: Optional[torch.Tensor] = None, b_zeros: Optional[torch.Tensor] = None,