Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MatmulParams::cluster_dims parameter #3574

Merged
merged 4 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
code_ << "__global__ void ";
if (kernel_->hasManaged("cluster_dims")) {
auto cluster_dims =
kernel_->getManaged<std::tuple<int64_t, int64_t, int64_t>>(
"cluster_dims");
kernel_->getManaged<std::tuple<int, int, int>>("cluster_dims");
code_ << "__cluster_dims__(" << std::get<0>(cluster_dims) << ", "
<< std::get<1>(cluster_dims) << ", " << std::get<2>(cluster_dims)
<< ") ";
Expand Down
2 changes: 2 additions & 0 deletions csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ void HopperMultipleMatmulScheduler::run() {

inspectPrologues();

setCGADims();

scheduleOperands();

// schedule mma instruction output (mma_result)
Expand Down
8 changes: 8 additions & 0 deletions csrc/scheduler/hopper_multi_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,14 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler {
std::vector<std::vector<MatmulDimRole>> blockTileTensors(
const std::vector<TensorView*>& tvs);

//! Specifies the CGA dimensions by setting "cluster_dims" as fusion-managed
//! data
void setCGADims() const {
if (params_->cluster_dims != std::tuple<int, int, int>{1, 1, 1}) {
fusion_->manage("cluster_dims", params_->cluster_dims);
}
}

//! Schedule the loads of all operands from global memory to shared memory.
//! Starting from the basic tiled schedule, we swizzle the operand memory.
//! Note that the cache op and LoadStoreOpType are already set during
Expand Down
4 changes: 4 additions & 0 deletions csrc/scheduler/matmul_heuristic.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,10 @@ class MatmulParams : public HeuristicParams {
//! axis and perform a grid reduction before the epilogue.
int splitk_factor = 1;

//! This is the CGA size on Hopper+ devices. This parameter is ignored on
//! Ampere and Turing.
std::tuple<int, int, int> cluster_dims = {2, 1, 1};

std::string toString() const override {
std::stringstream ss;
ss << "\n===== Matmul Parameters ========\n"
Expand Down
5 changes: 2 additions & 3 deletions tests/cpp/test_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3663,7 +3663,7 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) {
const int64_t cta_m = 2 * getM(macro);
const int64_t cta_n = 1 * getN(macro);

constexpr std::tuple<int64_t, int64_t, int64_t> cluster_dims{2, 1, 1};
constexpr std::tuple<int, int, int> cluster_dims{2, 1, 1};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nitpick: can we stick with int64_t for consistency?

Suggested change
constexpr std::tuple<int, int, int> cluster_dims{2, 1, 1};
constexpr std::tuple<int64_t, int64_t, int64_t> cluster_dims{2, 1, 1};

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I was mostly doing that because the MatmulParams entries are int, but we should probably just change MatmulParams instead (in another PR).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Much smaller PR now...


auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype);
auto tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype);
Expand All @@ -3680,8 +3680,7 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) {
auto tv3 = castOp(DataType::Half, tv2);
fusion.addOutput(tv3);

if constexpr (
cluster_dims != std::tuple<int64_t, int64_t, int64_t>{1, 1, 1}) {
if constexpr (cluster_dims != std::tuple<int, int, int>{1, 1, 1}) {
fusion.manage("cluster_dims", cluster_dims);
}

Expand Down
Loading