Skip to content

Commit

Permalink
create TN and NT tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rdspring1 committed Nov 1, 2024
1 parent 3ae7baa commit 1ce800a
Showing 1 changed file with 146 additions and 0 deletions.
146 changes: 146 additions & 0 deletions tests/cpp/test_matmul_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3075,4 +3075,150 @@ TEST_F(MatmulSchedulerTest, OperandOrderIssue2434) {
NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001));
}

TEST_F(MatmulSchedulerTest, HSH_TN) {
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 10, 0);
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

const auto dtype = DataType::Half;
constexpr auto layout = MmaLayout::TN;

auto tv0 = makeContigConcreteTensor({-1, 1, -1}, dtype);
auto tv1 = makeContigConcreteTensor({1, -1, -1}, dtype);
fusion->addInput(tv0);
fusion->addInput(tv1);

// [M, b, K] x [b, N, K] -> [M, N, rK]
auto tv2 = fusedMultiplySum(tv0, tv1, {-1});

// [M, N]
auto tv3 = castOp(DataType::Half, tv2);
fusion->addOutput(tv3);

NVF_CHECK(
1 == ir_utils::getOpsOfType<MmaOp>(fusion.get()).size(),
"matmul fusion must have at least one MmaOp");

// Create custom Matmul Params
MatMulTileOptions gemm_tile;
// TODO cta tile is a multiple of mma macro for hopper.
gemm_tile.cta_tile = GemmTile(128, 128, 32);

// TODO warp tile is (macroM, macroN, macroK) for hopper.
gemm_tile.warp_tile = GemmTile(64, 64, 32);

// TODO instruction tile is not used for hopper.
gemm_tile.instruction_tile = GemmTile(16, 8, 16);

MatmulParams mparams;
mparams.supported_vec_size = {8, 8, 4};

// TODO use hopper macro
// mparams.mma_macro = MmaMacro::Hopper_64_256_16;
mparams.mma_macro = MmaMacro::Ampere_16_8_16;

mparams.tile_sizes = gemm_tile;
mparams.async_gmem_load_operands = true;
mparams.circular_buffer_options.circular_buffer_smem_write = true;
mparams.circular_buffer_options.circular_buffer_smem_read = true;
mparams.circular_buffer_options.smem_circular_buffer_stage = 4;

// TODO Create prefetch parameter
// mparams.circular_buffer_options.smem_circular_buffer_prefetch = 3;

// Schedule matmul fusion using custom parameters
SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
->schedule(fusion.get(), &mparams);

const int M = 32, N = 32, K = 256;
auto inputs =
matmulAtInput3DHopperSS(M, N, K, layout, data_type_to_aten(dtype));

FusionExecutor fe;
fe.compileFusion(
fusion.get(),
{inputs.first, inputs.second},
LaunchParams(),
matmul_cparams);

auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
auto tref = atMatmul(inputs.first.squeeze(), inputs.second.squeeze(), layout);
EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-5, 1e-5));
}

TEST_F(MatmulSchedulerTest, HSH_NT) {
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 10, 0);
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

const auto dtype = DataType::Half;
constexpr auto layout = MmaLayout::NT; // [K, M] x [K, N] -> [M, N]

auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype);
auto tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype);
fusion->addInput(tv0);
fusion->addInput(tv1);

auto tv2 = fusedMultiplySum(tv0, tv1, {0});

// Reorder the accumulator as [M, N, K]
// [K, M, N] -> [M, N, K]
tv2->reorder({{-3, -1}});
tv2->commitLeafToLogical();

auto tv3 = castOp(DataType::Half, tv2);

fusion->addOutput(tv3);

NVF_CHECK(
1 == ir_utils::getOpsOfType<MmaOp>(fusion.get()).size(),
"matmul fusion must have at least one MmaOp");

// Create custom Matmul Params
MatMulTileOptions gemm_tile;
// TODO cta tile is a multiple of mma macro for hopper.
gemm_tile.cta_tile = GemmTile(128, 128, 32);

// TODO warp tile is (macroM, macroN, macroK) for hopper.
gemm_tile.warp_tile = GemmTile(64, 64, 32);

// TODO instruction tile is not used for hopper.
gemm_tile.instruction_tile = GemmTile(16, 8, 16);

MatmulParams mparams;
mparams.supported_vec_size = {8, 8, 4};

// TODO use hopper macro
// mparams.mma_macro = MmaMacro::Hopper_64_256_16;
mparams.mma_macro = MmaMacro::Ampere_16_8_16;

mparams.tile_sizes = gemm_tile;
mparams.async_gmem_load_operands = true;
mparams.circular_buffer_options.circular_buffer_smem_write = true;
mparams.circular_buffer_options.circular_buffer_smem_read = true;
mparams.circular_buffer_options.smem_circular_buffer_stage = 4;

// TODO Create prefetch parameter
// mparams.circular_buffer_options.smem_circular_buffer_prefetch = 3;

// Schedule matmul fusion using custom parameters
SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
->schedule(fusion.get(), &mparams);

const int M = 32, N = 32, K = 256;
auto inputs =
matmulAtInput3DHopperSS(M, N, K, layout, data_type_to_aten(dtype));

FusionExecutor fe;
fe.compileFusion(
fusion.get(),
{inputs.first, inputs.second},
LaunchParams(),
matmul_cparams);

auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
auto tref = atMatmul(inputs.first.squeeze(), inputs.second.squeeze(), layout);
EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-5, 1e-5));
}

} // namespace nvfuser

0 comments on commit 1ce800a

Please sign in to comment.