diff --git a/csrc/scheduler/hopper_multi_matmul.cpp b/csrc/scheduler/hopper_multi_matmul.cpp index 1c7a9b0407c..568d3552ff4 100644 --- a/csrc/scheduler/hopper_multi_matmul.cpp +++ b/csrc/scheduler/hopper_multi_matmul.cpp @@ -840,12 +840,24 @@ void HopperMultipleMatmulScheduler::scheduleOperandSmemStores() { } void HopperMultipleMatmulScheduler::scheduleMmaResults() { + NVF_ERROR( + params_->tile_sizes.warp_tile == params_->tile_sizes.instruction_tile, + "Warp tile must match instruction tile for Hopper matmul but found ", + toString(params_->tile_sizes)); + NVF_ERROR( + params_->tile_sizes.instruction_tile.m == getM(params_->mma_macro) && + params_->tile_sizes.instruction_tile.n == getN(params_->mma_macro) && + params_->tile_sizes.instruction_tile.k == getK(params_->mma_macro), + "Instruction tile must match macro matmul but found instruction tile: ", + toString(params_->tile_sizes.instruction_tile), + " and macro: ", + toString(params_->mma_macro)); // If cta_tile is not divisible by instruction tile the mma instruction will // be predicated. NVF_ERROR( params_->tile_sizes.cta_tile.m % getM(params_->mma_macro) == 0 && - params_->tile_sizes.cta_tile.n % getN(params_->mma_macro) == 0 && - params_->tile_sizes.cta_tile.k % getK(params_->mma_macro) == 0, + params_->tile_sizes.cta_tile.n % getN(params_->mma_macro) == 0 && + params_->tile_sizes.cta_tile.k % getK(params_->mma_macro) == 0, "CTA tile must be divisible by macro size but found cta_tile: ", toString(params_->tile_sizes.cta_tile), " and macro: ", @@ -930,6 +942,7 @@ void HopperMultipleMatmulScheduler::scheduleMmaResults() { mma_result->axis(-2)->parallelize(ParallelType::Mma); mma_result->axis(-3)->parallelize(ParallelType::Mma); } + /* tv2 below is the consumer of mma_result (cacheAfter) { // Split by tile