Skip to content

Commit

Permalink
Add support for stmatrix in the unit test HopperMatmulTest/HSH_NT_128…
Browse files Browse the repository at this point in the history
…BSwizzle (#3411)

This demonstrates the use of stmatrix in a multi-tile hopper matmul.
  • Loading branch information
protonu authored Nov 23, 2024
1 parent ceed503 commit caa7f07
Showing 1 changed file with 39 additions and 6 deletions.
45 changes: 39 additions & 6 deletions tests/cpp/test_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3687,6 +3687,8 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) {
constexpr auto swizzle = MmaInputSmemSwizzle::B128;
const auto dtype = DataType::Half;

constexpr bool use_smem_epilogue = false;

constexpr int64_t stages = 4;
constexpr int64_t prefetch = 3;
const int64_t cta_m = 2 * getM(macro);
Expand All @@ -3705,7 +3707,6 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) {
tv2->commitLeafToLogical();

auto tv3 = castOp(DataType::Half, tv2);

fusion.addOutput(tv3);

auto mma_ops = ir_utils::getOpsOfType<MmaOp>(&fusion);
Expand All @@ -3722,7 +3723,21 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) {
tv0c->setMemoryType(MemoryType::Shared);
auto tv1c = tv1->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile);
tv1c->setMemoryType(MemoryType::Shared);
auto tv3c = tv3->cacheBefore();

TensorView *tv3c = nullptr, *tv3_shmem = nullptr;
if (use_smem_epilogue) {
tv3_shmem = tv3->cacheBefore();
tv3c = tv3_shmem->cacheBefore();
tv3_shmem->setMemoryType(MemoryType::Shared);
tv3c->setMemoryType(MemoryType::Local);
tv3_shmem->definition()->as<LoadStoreOp>()->setOpType(
LoadStoreOpType::StMatrix);
tv3->definition()->as<LoadStoreOp>()->setOpType(
LoadStoreOpType::CpAsyncBulkTensorTile);
} else {
tv3c = tv3->cacheBefore();
tv3c->setMemoryType(MemoryType::Local);
}

// gmem [K, M, 1] -TMA-> smem [K, M, 1]
// gmem [K, 1, N] -TMA-> smem [K, 1, N]
Expand Down Expand Up @@ -3775,12 +3790,30 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) {
tv2->axis(-3)->parallelize(ParallelType::Mma);
}

for (auto tv : {tv3c, tv3}) {
if (!use_smem_epilogue) {
for (auto tv : {tv3c, tv3}) {
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv->getLoopDomain());
tv->setLoopDomain(s.as<IterDomain*>());
}
} else {
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv->getLoopDomain());
tv->setLoopDomain(s.as<IterDomain*>());
tv3c->getLoopDomain());
tv3c->setLoopDomain(s.as<IterDomain*>());
tv3c->setAllocationDomain(s.as<IterDomain*>(), true);

// We'll use stmatrix.x4 to store from reg to shared memory
fusion.manage("st_matrix_m_tile", 16);
fusion.manage("st_matrix_n_tile", 16);
fusion.manage("st_matrix_m", getM(macro));
fusion.manage("st_matrix_n", getN(macro));

// This internally calls
// mma_utils::MmaSwizzler::scheduleMmaOutputAllocation
mma_utils::scheduleStMatrixForMmaOutput(tv3_shmem, 16, 16);

mma_utils::scheduleTMAStoreForMmaOutput(tv3, M, N);
}
tv3->axis(-1)->parallelize(ParallelType::Vectorize);

inlineMost();

Expand Down

0 comments on commit caa7f07

Please sign in to comment.