diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index 0016cf758e2c4..a98e24c5672d7 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -78,15 +78,15 @@ namespace machete { template -using Kernel = KernelTemplate< +using Kernel = MacheteKernelTemplate< {{DataTypeTag[type_config.element_a]}}, // ElementA {{DataTypeTag[type_config.element_b]}}, // ElementB {{DataTypeTag[type_config.element_d]}}, // ElementD {{DataTypeTag[type_config.accumulator]}}, // Accumulator {{DataTypeTag[type_config.element_b_scale]}}, // Scales {{DataTypeTag[type_config.element_b_zeropoint]}}, // Zeropoints - cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput ->::Speacialization; + cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, + Config, with_C, with_scales, with_zeropoints>; {% for sch in schedules %} {% set schedule_name = gen_sch_name(sch) -%} @@ -138,7 +138,7 @@ {{DataTypeTag[type_config.element_b_scale]}}, // Scales {{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints -using PrepackedLayoutB = PrepackedLayoutBBTemplate< +using PrepackedLayoutB = PrepackedLayoutBTemplate< {{DataTypeTag[type_config.element_a]}}, // ElementA {{DataTypeTag[type_config.element_b]}}, // ElementB {{DataTypeTag[type_config.element_d]}}, // ElementD diff --git a/csrc/quantization/machete/machete_mainloop.cuh b/csrc/quantization/machete/machete_mainloop.cuh index 20121bffec9a8..22beb04f60619 100644 --- a/csrc/quantization/machete/machete_mainloop.cuh +++ b/csrc/quantization/machete/machete_mainloop.cuh @@ -999,8 +999,6 @@ struct MacheteCollectiveMma { CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE - // CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // - // PIPE CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE // diff --git a/csrc/quantization/machete/machete_mm_kernel.cuh b/csrc/quantization/machete/machete_mm_kernel.cuh index 1f4b06725ea27..6d8c734a86f64 100644 --- a/csrc/quantization/machete/machete_mm_kernel.cuh +++ b/csrc/quantization/machete/machete_mm_kernel.cuh @@ -37,179 +37,157 @@ using namespace cute; // we compute the transpose to move it to the left-hand side. template -struct KernelTemplate { + class KernelSchedule, typename ScheduleConfig, bool with_C, + bool with_scales, bool with_zeropoints> +struct MacheteKernelTemplate { using MmaType = ElementA_; using ElementA = ElementA_; using ElementB = ElementB_; using ElementD = ElementD_; - using ElementAccumulator = AccumulatorT; - - using LayoutA_ = cutlass::layout::RowMajor; - using LayoutScale_ = cutlass::layout::RowMajor; + using ElementC = cute::conditional_t; + using ElementZero = ZeroT; + using ElementScale = ScaleT; + using ElementAccumulator = + AccumulatorT; // Element type for internal accumulation + using ElementCompute = AccumulatorT; // For Epilogue + + using BTypeTuple = cute::conditional_t< + with_scales, + cute::conditional_t, + cute::tuple>, + ElementB>; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = LayoutC; + using LayoutScale = cutlass::layout::RowMajor; + // not actually used since B has the prepacked layout, but required by cutlass + using _LayoutB = cutlass::layout::ColumnMajor; using LayoutA_Transpose = - typename cutlass::layout::LayoutTranspose::type; + typename cutlass::layout::LayoutTranspose::type; + using LayoutC_Transpose = + typename cutlass::layout::LayoutTranspose::type; + using LayoutD_Transpose = + typename cutlass::layout::LayoutTranspose::type; using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; - using PrepackedLayoutBB = - PrepackedLayoutBBTemplate; - - // clang-format off - template - // clang-format on - struct Speacialization { - using MmaType = ElementA_; - using ElementA = ElementA_; - using ElementB = ElementB_; - using ElementD = ElementD_; - using ElementC = cute::conditional_t; - using ElementZero = ZeroT; - using ElementScale = ScaleT; - using ElementAccumulator = - AccumulatorT; // Element type for internal accumulation - using ElementCompute = AccumulatorT; // For Epilogue - - using BTypeTuple = cute::conditional_t< - with_scales, - cute::conditional_t, - cute::tuple>, - ElementB>; - - using LayoutA = LayoutA_; - using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::RowMajor; - using LayoutD = LayoutC; - using LayoutScale = cutlass::layout::RowMajor; - - using LayoutB_Transpose = - typename cutlass::layout::LayoutTranspose::type; - using LayoutC_Transpose = - typename cutlass::layout::LayoutTranspose::type; - using LayoutD_Transpose = - typename cutlass::layout::LayoutTranspose::type; - - static int constexpr TileShapeK = - 128 * 8 / cutlass::sizeof_bits::value; - static int constexpr AlignmentA = 128 / cutlass::sizeof_bits_v; - static int constexpr AlignmentB = 128 / cutlass::sizeof_bits_v; - static int constexpr AlignmentC = - (with_C) ? 128 / cutlass::sizeof_bits_v : 0; - static int constexpr AlignmentD = 128 / cutlass::sizeof_bits_v; - - using TileShape = decltype(append(typename ScheduleConfig::TileShapeNM{}, - cute::Int{})); - using ClusterShape = typename ScheduleConfig::ClusterShape; - using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; - using EpilogueTileType = typename ScheduleConfig::EpilogueTileType; - using TileScheduler = typename ScheduleConfig::TileScheduler; - - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, - ElementAccumulator, ElementAccumulator, ElementC, LayoutC_Transpose, - AlignmentC, ElementD, LayoutD_Transpose, AlignmentD, - EpilogueSchedule>::CollectiveOp; - - using CollectiveMainloop = - typename cutlass::gemm::collective::VLLMCollectiveBuilder< - cutlass::gemm::collective::MacheteKernelTag, ArchTag, OperatorClass, - BTypeTuple, PrepackedLayoutBB, AlignmentB, ElementA, - LayoutA_Transpose, AlignmentA, ElementAccumulator, TileShape, - ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout( - sizeof(typename CollectiveEpilogue::SharedStorage))>, - KernelSchedule>::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, // Indicates ProblemShape - CollectiveMainloop, CollectiveEpilogue, TileScheduler>; - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - using StrideA = cutlass::detail::TagToStrideA_t; - using StrideB = cutlass::detail::TagToStrideB_t; - using StrideC = typename GemmKernel::StrideC; - using StrideD = typename GemmKernel::StrideD; - using StrideS = typename CollectiveMainloop::StrideScale; - - using Arguments = typename Gemm::Arguments; - using MainloopArguments = typename GemmKernel::MainloopArguments; - using EpilogueArguments = typename GemmKernel::EpilogueArguments; - - static Arguments create_arguments(cudaStream_t stream, int M, int N, int K, - ElementA const* A, ElementB const* B, - ElementC const* C, ElementD* D, - ElementScale const* scales, - ElementZero const* zeros, - ElementCompute alpha, ElementCompute beta, - std::optional maybe_group_size) { - // if we have zeropoints we need scales - static_assert(!with_zeropoints || with_scales); - // if beta != 0 then we need C - TORCH_CHECK(with_C || (!with_C && beta == 0)); - // if with_scales, we need a scales pointer - TORCH_CHECK(with_scales || !scales); - // if with_zeropoints, we need a zeros pointer - TORCH_CHECK(with_zeropoints || !zeros); - - static int constexpr L = 1; - int const group_size = maybe_group_size.value_or(K); - int const scale_k = (K + group_size - 1) / group_size; - - // not stride_B is unused - auto stride_A = make_cute_stride(StrideA{}, N, K, L); - auto stride_B = make_cute_stride(StrideB{}, M, K, L); - auto stride_C = make_cute_stride(StrideC{}, N, M, L); - auto stride_D = make_cute_stride(StrideD{}, N, M, L); - auto stride_S = make_cute_stride(StrideS{}, N, scale_k, L); - - MainloopArguments mainloop_arguments{}; - EpilogueArguments epilogue_arguments{ - {alpha, beta}, C, stride_C, D, stride_D}; - - if constexpr (with_scales && with_zeropoints) { - mainloop_arguments = MainloopArguments{ - B, stride_B, A, stride_A, scales, stride_S, group_size, zeros}; - } else if constexpr (with_scales) { - mainloop_arguments = MainloopArguments{ - B, stride_B, A, stride_A, scales, stride_S, group_size}; - } else { - mainloop_arguments = MainloopArguments{B, stride_B, A, stride_A}; - } - - return Arguments{cutlass::gemm::GemmUniversalMode::kGemm, - {N, M, K, 1}, - mainloop_arguments, - epilogue_arguments}; - }; - - static size_t get_workspace_size(Arguments const& args) { - return Gemm::get_workspace_size(args); + using PrepackedLayoutB = + PrepackedLayoutBTemplate; + + static int constexpr TileShapeK = + 128 * 8 / cutlass::sizeof_bits::value; + static int constexpr AlignmentA = 128 / cutlass::sizeof_bits_v; + static int constexpr AlignmentB = 128 / cutlass::sizeof_bits_v; + static int constexpr AlignmentC = + (with_C) ? 128 / cutlass::sizeof_bits_v : 0; + static int constexpr AlignmentD = 128 / cutlass::sizeof_bits_v; + + using TileShape = decltype(append(typename ScheduleConfig::TileShapeNM{}, + cute::Int{})); + using ClusterShape = typename ScheduleConfig::ClusterShape; + using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; + using EpilogueTileType = typename ScheduleConfig::EpilogueTileType; + using TileScheduler = typename ScheduleConfig::TileScheduler; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, + ElementAccumulator, ElementAccumulator, ElementC, LayoutC_Transpose, + AlignmentC, ElementD, LayoutD_Transpose, AlignmentD, + EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::VLLMCollectiveBuilder< + cutlass::gemm::collective::MacheteKernelTag, ArchTag, OperatorClass, + BTypeTuple, PrepackedLayoutB, AlignmentB, ElementA, LayoutA_Transpose, + AlignmentA, ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, CollectiveEpilogue, TileScheduler>; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideA = cutlass::detail::TagToStrideA_t; + using StrideC = typename GemmKernel::StrideC; + using StrideD = typename GemmKernel::StrideD; + using StrideS = typename CollectiveMainloop::StrideScale; + + // stride_B is unused (since B is prepacked), but still required by cutlass + using _StrideB = cutlass::detail::TagToStrideB_t<_LayoutB>; + + using Arguments = typename Gemm::Arguments; + using MainloopArguments = typename GemmKernel::MainloopArguments; + using EpilogueArguments = typename GemmKernel::EpilogueArguments; + + static Arguments create_arguments(cudaStream_t stream, int M, int N, int K, + ElementA const* A, ElementB const* B, + ElementC const* C, ElementD* D, + ElementScale const* scales, + ElementZero const* zeros, + ElementCompute alpha, ElementCompute beta, + std::optional maybe_group_size) { + static_assert(!with_zeropoints || with_scales); + TORCH_CHECK(with_C || (!with_C && beta == 0)); + TORCH_CHECK(with_scales || !scales); + TORCH_CHECK(with_zeropoints || !zeros); + + static int constexpr L = 1; + int const group_size = maybe_group_size.value_or(K); + int const scale_k = (K + group_size - 1) / group_size; + + // stride_B is unused (since B is prepacked), but still required by cutlass + auto stride_A = make_cute_stride(StrideA{}, N, K, L); + auto stride_B = make_cute_stride(_StrideB{}, M, K, L); + auto stride_C = make_cute_stride(StrideC{}, N, M, L); + auto stride_D = make_cute_stride(StrideD{}, N, M, L); + auto stride_S = make_cute_stride(StrideS{}, N, scale_k, L); + + MainloopArguments mainloop_arguments{}; + EpilogueArguments epilogue_arguments{ + {alpha, beta}, C, stride_C, D, stride_D}; + + if constexpr (with_scales && with_zeropoints) { + mainloop_arguments = MainloopArguments{ + B, stride_B, A, stride_A, scales, stride_S, group_size, zeros}; + } else if constexpr (with_scales) { + mainloop_arguments = MainloopArguments{ + B, stride_B, A, stride_A, scales, stride_S, group_size}; + } else { + mainloop_arguments = MainloopArguments{B, stride_B, A, stride_A}; } - static bool can_implement(Arguments const& args) { - return Gemm::can_implement(args) == cutlass::Status::kSuccess; - } + return Arguments{cutlass::gemm::GemmUniversalMode::kGemm, + {N, M, K, 1}, + mainloop_arguments, + epilogue_arguments}; + }; - static void run(Arguments const& args, void* workspace, - cudaStream_t stream) { - Gemm gemm_op; + static size_t get_workspace_size(Arguments const& args) { + return Gemm::get_workspace_size(args); + } - cutlass::Status status = gemm_op.initialize(args, workspace, stream); - TORCH_CHECK(status == cutlass::Status::kSuccess, - "Machete kernel failed to initialize workspace"); + static bool can_implement(Arguments const& args) { + return Gemm::can_implement(args) == cutlass::Status::kSuccess; + } - status = gemm_op.run(stream); - TORCH_CHECK(status == cutlass::Status::kSuccess, "Machete kernel failed"); - } - }; + static void run(Arguments const& args, void* workspace, cudaStream_t stream) { + Gemm gemm_op; + + cutlass::Status status = gemm_op.initialize(args, workspace, stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, + "Machete kernel failed to initialize workspace"); + + status = gemm_op.run(stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Machete kernel failed"); + } }; }; // namespace machete diff --git a/csrc/quantization/machete/machete_mm_launcher.cuh b/csrc/quantization/machete/machete_mm_launcher.cuh index 984bd7bce5841..0ad9af656d05c 100644 --- a/csrc/quantization/machete/machete_mm_launcher.cuh +++ b/csrc/quantization/machete/machete_mm_launcher.cuh @@ -20,26 +20,25 @@ struct PyTorchArguments { c10::optional schedule; }; -template +template torch::Tensor run_impl(PyTorchArguments args) { const at::cuda::OptionalCUDAGuard device_guard(device_of(args.A)); auto device = args.A.device(); auto stream = at::cuda::getCurrentCUDAStream(device.index()); - using ElementA = typename KernelSpecialization::ElementA; - using ElementB = typename KernelSpecialization::ElementB; - using ElementC = typename KernelSpecialization::ElementC; - using ElementD = typename KernelSpecialization::ElementD; - using ElementScale = typename KernelSpecialization::ElementScale; - using ElementZero = typename KernelSpecialization::ElementZero; + using ElementA = typename MacheteKernel::ElementA; + using ElementB = typename MacheteKernel::ElementB; + using ElementC = typename MacheteKernel::ElementC; + using ElementD = typename MacheteKernel::ElementD; + using ElementScale = typename MacheteKernel::ElementScale; + using ElementZero = typename MacheteKernel::ElementZero; - using LayoutA = typename KernelSpecialization::LayoutA; - using LayoutB = typename KernelSpecialization::LayoutB; - using LayoutC = typename KernelSpecialization::LayoutC; - using LayoutD = typename KernelSpecialization::LayoutD; - using LayoutScale = typename KernelSpecialization::LayoutScale; - using LayoutZero = typename KernelSpecialization::LayoutScale; + using LayoutA = typename MacheteKernel::LayoutA; + using LayoutC = typename MacheteKernel::LayoutC; + using LayoutD = typename MacheteKernel::LayoutD; + using LayoutScale = typename MacheteKernel::LayoutScale; + using LayoutZero = typename MacheteKernel::LayoutScale; int M = args.A.size(0); int N = args.B.size(1); @@ -52,7 +51,7 @@ torch::Tensor run_impl(PyTorchArguments args) { .device(device)); auto A_ptr = data_ptr(args.A, "A"); - auto B_ptr = data_ptr(args.B, "B"); + auto B_ptr = data_ptr(args.B, "B"); auto D_ptr = data_ptr(D, "D"); auto C_ptr = maybe_data_ptr(args.C, "C"); auto scales_ptr = @@ -60,19 +59,19 @@ torch::Tensor run_impl(PyTorchArguments args) { auto zeros_ptr = maybe_data_ptr(args.zeros, "zeros"); - auto arguments = KernelSpecialization::create_arguments( + auto arguments = MacheteKernel::create_arguments( stream, M, N, K, A_ptr, B_ptr, C_ptr, D_ptr, scales_ptr, zeros_ptr, args.alpha.value_or(1), args.beta.value_or(0), args.group_size.value_or(K)); - TORCH_CHECK(KernelSpecialization::can_implement(arguments), + TORCH_CHECK(MacheteKernel::can_implement(arguments), "Machete kernel cannot be run with these arguments"); - size_t workspace_size = KernelSpecialization::get_workspace_size(arguments); + size_t workspace_size = MacheteKernel::get_workspace_size(arguments); torch::Tensor workspace = torch::empty( workspace_size, torch::TensorOptions().dtype(torch::kU8).device(device)); - KernelSpecialization::run(arguments, workspace.mutable_data_ptr(), stream); + MacheteKernel::run(arguments, workspace.mutable_data_ptr(), stream); return D; }; diff --git a/csrc/quantization/machete/machete_prepacked_layout.cuh b/csrc/quantization/machete/machete_prepacked_layout.cuh index 84372900fb2fa..46a654b34b76e 100644 --- a/csrc/quantization/machete/machete_prepacked_layout.cuh +++ b/csrc/quantization/machete/machete_prepacked_layout.cuh @@ -29,7 +29,7 @@ using namespace cute; template // clang-format on -struct PrepackedLayoutBBTemplate { +struct PrepackedLayoutBTemplate { using MmaType = ElementA_; using ElementA = ElementA_; using ElementB = ElementB_;