Skip to content

Commit

Permalink
Add more validation for hopper matmul tile sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobhinkle committed Nov 5, 2024
1 parent 40b6908 commit 14cf2ed
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -840,12 +840,24 @@ void HopperMultipleMatmulScheduler::scheduleOperandSmemStores() {
}

void HopperMultipleMatmulScheduler::scheduleMmaResults() {
NVF_ERROR(
params_->tile_sizes.warp_tile == params_->tile_sizes.instruction_tile,
"Warp tile must match instruction tile for Hopper matmul but found ",
toString(params_->tile_sizes));
NVF_ERROR(
params_->tile_sizes.instruction_tile.m == getM(params_->mma_macro) &&
params_->tile_sizes.instruction_tile.n == getN(params_->mma_macro) &&
params_->tile_sizes.instruction_tile.k == getK(params_->mma_macro),
"Instruction tile must match macro matmul but found instruction tile: ",
toString(params_->tile_sizes.instruction_tile),
" and macro: ",
toString(params_->mma_macro));
// If cta_tile is not divisible by instruction tile the mma instruction will
// be predicated.
NVF_ERROR(
params_->tile_sizes.cta_tile.m % getM(params_->mma_macro) == 0 &&
params_->tile_sizes.cta_tile.n % getN(params_->mma_macro) == 0 &&
params_->tile_sizes.cta_tile.k % getK(params_->mma_macro) == 0,
params_->tile_sizes.cta_tile.n % getN(params_->mma_macro) == 0 &&
params_->tile_sizes.cta_tile.k % getK(params_->mma_macro) == 0,
"CTA tile must be divisible by macro size but found cta_tile: ",
toString(params_->tile_sizes.cta_tile),
" and macro: ",
Expand Down Expand Up @@ -930,6 +942,7 @@ void HopperMultipleMatmulScheduler::scheduleMmaResults() {
mma_result->axis(-2)->parallelize(ParallelType::Mma);
mma_result->axis(-3)->parallelize(ParallelType::Mma);
}

/* tv2 below is the consumer of mma_result (cacheAfter)
{
// Split by tile
Expand Down

0 comments on commit 14cf2ed

Please sign in to comment.