-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Split Hopper MMA by warp-tile before instruction tile #3642
base: main
Are you sure you want to change the base?
Changes from all commits
851669a
8b42cd6
7c6d417
521d5cc
dce16ad
f5e084c
be705bf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,20 +34,53 @@ void HopperMultipleMatmulScheduler::transformLikeMmaOutput( | |
bool is_mma_result) { | ||
// TODO Add constraints | ||
|
||
auto apply_k_dim_offset = [is_mma_result](int64_t idx) constexpr { | ||
return (is_mma_result) ? idx - 1 : idx; | ||
}; | ||
|
||
// Original: [..., Mo, No, Mi, Ni] | ||
tv->split(apply_k_dim_offset(-2), getM(params_->mma_macro)); | ||
tv->split(apply_k_dim_offset(-1), getN(params_->mma_macro)); | ||
// After Split: [..., Mo, No, Mio, Mii, Nio, Nii] | ||
tv->reorder({{apply_k_dim_offset(-3), apply_k_dim_offset(-2)}}); | ||
// After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii] | ||
tv->merge(apply_k_dim_offset(-4)); | ||
// After Merge: [..., Mo, No, Mio * Nio, Mii, Nii] | ||
tv->axis(apply_k_dim_offset(-3))->parallelize(ParallelType::TIDy); | ||
// After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii] | ||
// The input is originally block tiled so that the inner dims are the CTA tile | ||
// size | ||
// Original: [..., M, N(, K)] | ||
// We split this into warp tiles then instruction tiles | ||
if (is_mma_result) { | ||
// Original: [..., M, N, K] | ||
tv->split(-3, params_->tile_sizes.warp_tile.m); | ||
tv->split(-3, getM(params_->mma_macro)); | ||
tv->split(-2, params_->tile_sizes.warp_tile.n); | ||
tv->split(-2, getN(params_->mma_macro)); | ||
// K dimension is present for mma_result | ||
tv->split(-1, params_->tile_sizes.warp_tile.k); | ||
tv->split(-1, getK(params_->mma_macro)); | ||
Comment on lines
+47
to
+49
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @rdspring1 is this enough or is #3616 still needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is all that is required for scheduler changes. |
||
// After Split: [..., Mo, Mw, Mi, No, Nw, Ni, Ko, Kw, Ki] | ||
tv->reorder({ | ||
{-9, -9}, // Mo | ||
{-8, -6}, // Mw | ||
{-7, -3}, // Mi | ||
{-6, -8}, // No | ||
{-5, -5}, // Nw | ||
{-4, -2}, // Ni | ||
{-3, -7}, // Ko | ||
{-2, -4}, // Kw | ||
{-1, -1}, // Ki | ||
}); | ||
// After Reorder: [..., Mo, No, Ko, Mw, Nw, Kw, Mi, Ni, Ki] | ||
tv->merge(-9); | ||
// After Merge: [..., Mo * No, Ko, Mw, Nw, Kw, Mi, Ni] | ||
tv->axis(-8)->parallelize(ParallelType::TIDy); | ||
// After Parallelize: [..., Mo * No (TIDy), Ko, Mw, Nw, Kw, Mi, Ni, Ki] | ||
} else { | ||
// Original: [..., M, N] | ||
tv->split(-2, params_->tile_sizes.warp_tile.m); | ||
tv->split(-2, getM(params_->mma_macro)); | ||
tv->split(-1, params_->tile_sizes.warp_tile.n); | ||
tv->split(-1, getN(params_->mma_macro)); | ||
// After Split: [..., Mo, Mw, Mi, No, Nw, Ni] | ||
tv->reorder({ | ||
{-3, -5}, | ||
{-2, -3}, | ||
}); | ||
// After Reorder: [..., Mo, No, Mw, Nw, Mi, Ni] | ||
tv->merge(-6); | ||
// After Merge: [..., Mo * No, Mw, Nw, Mi, Ni] | ||
tv->axis(-5)->parallelize(ParallelType::TIDy); | ||
// After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Mi, Ni] | ||
} | ||
} | ||
|
||
MatmulDimRole HopperMultipleMatmulScheduler::findMatmulDimRole(IterDomain* id) { | ||
|
@@ -490,8 +523,8 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() { | |
// tile is a multiple of the macro size because stmatrix stores results from | ||
// wgmma to shared memory. For maximum inlining and to reduce shared memory | ||
// usage, the tma tile is mma_macro size. | ||
const int64_t tma_m = getM(params_->mma_macro); | ||
const int64_t tma_n = getN(params_->mma_macro); | ||
const int64_t tma_m = params_->tile_sizes.warp_tile.m; | ||
const int64_t tma_n = params_->tile_sizes.warp_tile.n; | ||
|
||
fusion_->manage("st_matrix_m_tile", stmatrix_tile_m); | ||
fusion_->manage("st_matrix_n_tile", stmatrix_tile_n); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4243,4 +4243,140 @@ TEST_F(HopperMatmulTest, HSH_TT_UseScheduler) { | |
EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K)); | ||
} | ||
|
||
// This tests that we can use a small instruction tile with a medium size | ||
// warpgroup tile and a large CTA tile. | ||
TEST_F(HopperMatmulTest, HSH_NT_UseScheduler_MultipleInstructionsPerWarpTile) { | ||
Fusion fusion; | ||
FusionGuard fg(&fusion); | ||
|
||
constexpr int64_t M = 2048, N = 2048, K = 8192; | ||
const auto dtype = DataType::Half; | ||
|
||
auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); // K, M | ||
auto tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype); // K, N | ||
fusion.addInput(tv0); | ||
fusion.addInput(tv1); | ||
|
||
auto tv2 = fusedMultiplySum(tv0, tv1, {0}); | ||
|
||
// Reorder the accumulator as [M, N, K] | ||
// [K, M, N] -> [M, N, K] | ||
tv2->reorder({{-3, -1}}); | ||
tv2->commitLeafToLogical(); | ||
|
||
auto tv3 = castOp(DataType::Half, tv2); | ||
fusion.addOutput(tv3); | ||
|
||
auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); | ||
auto a_ref = at::randn({K, M, 1}, options); | ||
auto b_ref = at::randn({K, 1, N}, options); | ||
auto out_ref = at::matmul(a_ref.squeeze().t(), b_ref.squeeze()).to(at::kHalf); | ||
|
||
MatMulTileOptions gemm_tile; | ||
// Regardless of the instruction, this should result in 2 warp groups i.e. 256 | ||
// threads | ||
gemm_tile.cta_tile = GemmTile(256, 256, 32); | ||
gemm_tile.warp_tile = GemmTile(128, 128, 32); | ||
|
||
MatmulParams mparams; | ||
mparams.supported_vec_size = {8, 8, 8}; | ||
mparams.mma_macro = MmaMacro::Hopper_64_64_16; | ||
mparams.tile_sizes = gemm_tile; | ||
mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor; | ||
mparams.async_gmem_load_operands = true; | ||
mparams.circular_buffer_options.circular_buffer_smem_write = true; | ||
mparams.circular_buffer_options.circular_buffer_smem_read = false; | ||
mparams.circular_buffer_options.smem_circular_buffer_stage = 4; | ||
mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1; | ||
mparams.splitk_factor = 1; | ||
// NOTE: disabling smem use for this test since we currrently hit a bank | ||
// conflict. | ||
// TODO: enable smem epilogue once stmatrix is updated | ||
mparams.use_smem_epilogue = false; | ||
mparams.cluster_dims = {2, 1, 1}; | ||
mparams.promote_prologue_smem_reuse = false; | ||
|
||
SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul) | ||
->schedule(&fusion, &mparams); | ||
|
||
std::vector<c10::IValue> inputs = {a_ref, b_ref}; | ||
|
||
KernelExecutor ke; | ||
ke.compile(&fusion, inputs); | ||
EXPECT_TRUE(getBankConflictInfo(ke.kernel()).empty()); | ||
EXPECT_FALSE( | ||
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(ke.kernel())); | ||
|
||
auto cg_outputs = ke.run(inputs); | ||
|
||
// Check number of launched threads matches what we expect | ||
EXPECT_EQ(ke.lastLaunchParams().bdimx(), 128); | ||
EXPECT_EQ(ke.lastLaunchParams().bdimy(), 4) | ||
<< " expected 4 warp groups (BIDy==4) but found BIDy==" | ||
<< ke.lastLaunchParams().bdimy(); | ||
|
||
// Relax tolerance for larger sum due to large K | ||
EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K)); | ||
} | ||
|
||
TEST_F(HopperMatmulTest, ScheduleWithTranslation) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test is pretty much identical to the previous one, but it uses a
|
||
Fusion fusion; | ||
FusionGuard fg(&fusion); | ||
|
||
constexpr int64_t M = 2048, N = 2048, K = 8192; | ||
const auto dtype = DataType::Half; | ||
|
||
auto tv0 = makeContigConcreteTensor({-1, -1}, dtype); // M, K | ||
auto tv1 = makeContigConcreteTensor({-1, -1}, dtype); // K, N | ||
// Note tv1 has allocation domain | ||
// tv1->setAllocationDomain({tv1->axis(1), tv1->axis(0)}, true); | ||
fusion.addInput(tv0); | ||
fusion.addInput(tv1); | ||
|
||
auto tv2 = matmul(tv0, tv1); | ||
|
||
fusion.addOutput(tv2); | ||
|
||
auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); | ||
auto a_ref = at::randn({M, K}, options); | ||
// auto b_ref = at::randn({N, K}, options).t(); | ||
auto b_ref = at::randn({K, N}, options); | ||
auto out_ref = at::matmul(a_ref, b_ref); | ||
|
||
MatMulTileOptions gemm_tile; | ||
gemm_tile.cta_tile = GemmTile(128, 256, 16); | ||
gemm_tile.warp_tile = GemmTile(64, 64, 16); | ||
|
||
MatmulParams mparams; | ||
mparams.supported_vec_size = {8, 8, 8}; | ||
mparams.mma_macro = MmaMacro::Hopper_64_64_16; | ||
mparams.tile_sizes = gemm_tile; | ||
mparams.cta_order = MatmulParams::TileRasterizationOrder::RowMajor; | ||
mparams.async_gmem_load_operands = true; | ||
mparams.circular_buffer_options.circular_buffer_smem_write = true; | ||
mparams.circular_buffer_options.circular_buffer_smem_read = false; | ||
mparams.circular_buffer_options.smem_circular_buffer_stage = 3; | ||
mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1; | ||
mparams.splitk_factor = 1; | ||
mparams.use_smem_epilogue = true; | ||
mparams.cluster_dims = {1, 1, 1}; | ||
mparams.promote_prologue_smem_reuse = true; | ||
|
||
SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul) | ||
->schedule(&fusion, &mparams); | ||
|
||
std::vector<c10::IValue> inputs = {a_ref, b_ref}; | ||
|
||
KernelExecutor ke; | ||
ke.compile(&fusion, inputs); | ||
EXPECT_TRUE(getBankConflictInfo(ke.kernel()).empty()); | ||
EXPECT_FALSE( | ||
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(ke.kernel())); | ||
|
||
auto cg_outputs = ke.run(inputs); | ||
|
||
// Relax tolerance for larger sum due to large K | ||
EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K)); | ||
} | ||
|
||
} // namespace nvfuser |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: since there is no code in common between these branches, we should split this into two separate functions.