Skip to content

Commit

Permalink
addressing reviewer comments
Browse files Browse the repository at this point in the history
  • Loading branch information
protonu committed Nov 19, 2024
1 parent 4022817 commit 6dc6c48
Showing 1 changed file with 21 additions and 29 deletions.
50 changes: 21 additions & 29 deletions tests/cpp/test_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3707,27 +3707,7 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) {
tv2->commitLeafToLogical();

auto tv3 = castOp(DataType::Half, tv2);

TensorView* tv4 = nullptr;
if (use_smem_epilogue) {
// Copy from shared memory to global using TMA
tv4 = set(tv3);

tv3->setMemoryType(MemoryType::Shared);
tv4->definition()->as<LoadStoreOp>()->setOpType(
LoadStoreOpType::CpAsyncBulkTensorTile);

fusion.addOutput(tv4);

// We'll use stmatrix.x4 to store from reg to shared memory
fusion.manage("st_matrix_m_tile", 16);
fusion.manage("st_matrix_n_tile", 16);
fusion.manage("st_matrix_m", getM(macro));
fusion.manage("st_matrix_n", getN(macro));

} else {
fusion.addOutput(tv3);
}
fusion.addOutput(tv3);

auto mma_ops = ir_utils::getOpsOfType<MmaOp>(&fusion);
NVF_CHECK(
Expand All @@ -3743,10 +3723,20 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) {
tv0c->setMemoryType(MemoryType::Shared);
auto tv1c = tv1->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile);
tv1c->setMemoryType(MemoryType::Shared);
auto tv3c = tv3->cacheBefore();

TensorView *tv3c = nullptr, *tv3_shmem = nullptr;
if (use_smem_epilogue) {
tv3->definition()->as<LoadStoreOp>()->setOpType(LoadStoreOpType::StMatrix);
tv3_shmem = tv3->cacheBefore();
tv3c = tv3_shmem->cacheBefore();
tv3_shmem->setMemoryType(MemoryType::Shared);
tv3c->setMemoryType(MemoryType::Local);
tv3_shmem->definition()->as<LoadStoreOp>()->setOpType(
LoadStoreOpType::StMatrix);
tv3->definition()->as<LoadStoreOp>()->setOpType(
LoadStoreOpType::CpAsyncBulkTensorTile);
} else {
tv3c = tv3->cacheBefore();
tv3c->setMemoryType(MemoryType::Local);
}

// gmem [K, M, 1] -TMA-> smem [K, M, 1]
Expand Down Expand Up @@ -3812,15 +3802,17 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) {
tv3c->setLoopDomain(s.as<IterDomain*>());
tv3c->setAllocationDomain(s.as<IterDomain*>(), true);

// We need to split and merge dimensions which have
// been marked parallel/vectorized. So we revert the
// inner-most dimension back to serial.
tv3->axis(-1)->parallelize(ParallelType::Serial);
// We'll use stmatrix.x4 to store from reg to shared memory
fusion.manage("st_matrix_m_tile", 16);
fusion.manage("st_matrix_n_tile", 16);
fusion.manage("st_matrix_m", getM(macro));
fusion.manage("st_matrix_n", getN(macro));

// This internally calls
// mma_utils::MmaSwizzler::scheduleMmaOutputAllocation
mma_utils::scheduleStMatrixForMmaOutput(tv3, 16, 16);
mma_utils::scheduleStMatrixForMmaOutput(tv3_shmem, 16, 16);

mma_utils::scheduleTMAStoreForMmaOutput(tv4, M, N);
mma_utils::scheduleTMAStoreForMmaOutput(tv3, M, N);
}

inlineMost();
Expand Down

0 comments on commit 6dc6c48

Please sign in to comment.