diff --git a/csrc/scheduler/hopper_multi_matmul.cpp b/csrc/scheduler/hopper_multi_matmul.cpp index fbb95d46df2..a5b6f4d2bb7 100644 --- a/csrc/scheduler/hopper_multi_matmul.cpp +++ b/csrc/scheduler/hopper_multi_matmul.cpp @@ -53,6 +53,8 @@ void HopperMultipleMatmulScheduler::run() { inspectPrologues(); + setCGADims(); + scheduleOperands(); // schedule mma instruction output (mma_result) diff --git a/csrc/scheduler/hopper_multi_matmul.h b/csrc/scheduler/hopper_multi_matmul.h index bf7bc1df0f5..1d77785cc99 100644 --- a/csrc/scheduler/hopper_multi_matmul.h +++ b/csrc/scheduler/hopper_multi_matmul.h @@ -149,6 +149,14 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler { std::vector> blockTileTensors( const std::vector& tvs); + //! Specifies the CGA dimensions by setting "cluster_dims" as fusion-managed + //! data + void setCGADims() const { + if (params_->cluster_dims != std::tuple{1, 1, 1}) { + fusion_->manage("cluster_dims", params_->cluster_dims); + } + } + //! Schedule the loads of all operands from global memory to shared memory. //! Starting from the basic tiled schedule, we swizzle the operand memory. //! Note that the cache op and LoadStoreOpType are already set during diff --git a/csrc/scheduler/matmul_heuristic.h b/csrc/scheduler/matmul_heuristic.h index 6a92d31fd2c..f66cd12e618 100644 --- a/csrc/scheduler/matmul_heuristic.h +++ b/csrc/scheduler/matmul_heuristic.h @@ -179,6 +179,10 @@ class MatmulParams : public HeuristicParams { //! axis and perform a grid reduction before the epilogue. int splitk_factor = 1; + //! This is the CGA size on Hopper+ devices. This parameter is ignored on + //! Ampere and Turing. + std::tuple cluster_dims = {2, 1, 1}; + std::string toString() const override { std::stringstream ss; ss << "\n===== Matmul Parameters ========\n"