diff --git a/tests/cpp/test_matmul_scheduler.cpp b/tests/cpp/test_matmul_scheduler.cpp index ef68909601e..a3b72013be6 100644 --- a/tests/cpp/test_matmul_scheduler.cpp +++ b/tests/cpp/test_matmul_scheduler.cpp @@ -3075,6 +3075,80 @@ TEST_F(MatmulSchedulerTest, OperandOrderIssue2434) { NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); } +TEST_F(MatmulSchedulerTest, HSH_TT) { + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 10, 0); + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const auto dtype = DataType::Half; + constexpr auto layout = MmaLayout::TT; + + auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); // A [M, K, b] + auto tv1 = makeContigConcreteTensor({1, -1, -1}, dtype); // B [b, K, N] + fusion->addInput(tv0); + fusion->addInput(tv1); + + auto tv2 = fusedMultiplySum(tv0, tv1, {1}); + + // Reorder the accumulator as [M, N, K] + // [M, rK, N] -> [M, N, K] + tv2->reorder({{-2, -1}, {-1, -2}}); + tv2->commitLeafToLogical(); + + auto tv3 = castOp(DataType::Half, tv2); + fusion->addOutput(tv3); + + NVF_CHECK( + 1 == ir_utils::getOpsOfType(fusion.get()).size(), + "matmul fusion must have at least one MmaOp"); + + // Create custom Matmul Params + MatMulTileOptions gemm_tile; + // TODO cta tile is a multiple of mma macro for hopper. + gemm_tile.cta_tile = GemmTile(128, 128, 32); + + // TODO warp tile is (macroM, macroN, macroK) for hopper. + gemm_tile.warp_tile = GemmTile(64, 64, 32); + + // TODO instruction tile is not used for hopper. + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + MatmulParams mparams; + mparams.supported_vec_size = {8, 8, 4}; + + // TODO use hopper macro + // mparams.mma_macro = MmaMacro::Hopper_64_256_16; + mparams.mma_macro = MmaMacro::Ampere_16_8_16; + + mparams.tile_sizes = gemm_tile; + mparams.async_gmem_load_operands = true; + mparams.circular_buffer_options.circular_buffer_smem_write = true; + mparams.circular_buffer_options.circular_buffer_smem_read = true; + mparams.circular_buffer_options.smem_circular_buffer_stage = 4; + + // TODO Create prefetch parameter + // mparams.circular_buffer_options.smem_circular_buffer_prefetch = 3; + + // Schedule matmul fusion using custom parameters + SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul) + ->schedule(fusion.get(), &mparams); + + const int M = 32, N = 32, K = 256; + auto inputs = + matmulAtInput3DHopperSS(M, N, K, layout, data_type_to_aten(dtype)); + + FusionExecutor fe; + fe.compileFusion( + fusion.get(), + {inputs.first, inputs.second}, + LaunchParams(), + matmul_cparams); + + auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); + auto tref = atMatmul(inputs.first.squeeze(), inputs.second.squeeze(), layout); + EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-5, 1e-5)); +} + TEST_F(MatmulSchedulerTest, HSH_TN) { NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 10, 0); auto fusion = std::make_unique(); @@ -3221,4 +3295,78 @@ TEST_F(MatmulSchedulerTest, HSH_NT) { EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-5, 1e-5)); } +TEST_F(MatmulSchedulerTest, HSH_NN) { + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 10, 0); + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const auto dtype = DataType::Half; + constexpr auto layout = MmaLayout::NN; + + auto tv0 = makeContigConcreteTensor({1, -1, -1}, dtype); // A [b, K, M] + auto tv1 = makeContigConcreteTensor({-1, -1, 1}, dtype); // B [N, K, 1] + fusion->addInput(tv0); + fusion->addInput(tv1); + + auto tv2 = fusedMultiplySum(tv0, tv1, {1}); + + // Reorder the accumulator as [M, N, K] + // [N, rK, M] -> [M, N, K] + tv2->reorder({{-1, -3}}); + tv2->commitLeafToLogical(); + + auto tv3 = castOp(DataType::Half, tv2); + fusion->addOutput(tv3); + + NVF_CHECK( + 1 == ir_utils::getOpsOfType(fusion.get()).size(), + "matmul fusion must have at least one MmaOp"); + + // Create custom Matmul Params + MatMulTileOptions gemm_tile; + // TODO cta tile is a multiple of mma macro for hopper. + gemm_tile.cta_tile = GemmTile(128, 128, 32); + + // TODO warp tile is (macroM, macroN, macroK) for hopper. + gemm_tile.warp_tile = GemmTile(64, 64, 32); + + // TODO instruction tile is not used for hopper. + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + MatmulParams mparams; + mparams.supported_vec_size = {8, 8, 4}; + + // TODO use hopper macro + // mparams.mma_macro = MmaMacro::Hopper_64_256_16; + mparams.mma_macro = MmaMacro::Ampere_16_8_16; + + mparams.tile_sizes = gemm_tile; + mparams.async_gmem_load_operands = true; + mparams.circular_buffer_options.circular_buffer_smem_write = true; + mparams.circular_buffer_options.circular_buffer_smem_read = true; + mparams.circular_buffer_options.smem_circular_buffer_stage = 4; + + // TODO Create prefetch parameter + // mparams.circular_buffer_options.smem_circular_buffer_prefetch = 3; + + // Schedule matmul fusion using custom parameters + SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul) + ->schedule(fusion.get(), &mparams); + + const int M = 32, N = 32, K = 256; + auto inputs = + matmulAtInput3DHopperSS(M, N, K, layout, data_type_to_aten(dtype)); + + FusionExecutor fe; + fe.compileFusion( + fusion.get(), + {inputs.first, inputs.second}, + LaunchParams(), + matmul_cparams); + + auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); + auto tref = atMatmul(inputs.first.squeeze(), inputs.second.squeeze(), layout); + EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-5, 1e-5)); +} + } // namespace nvfuser