Skip to content

Commit

Permalink
Create TT and NT tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rdspring1 committed Nov 1, 2024
1 parent 1ce800a commit 7c8f375
Showing 1 changed file with 148 additions and 0 deletions.
148 changes: 148 additions & 0 deletions tests/cpp/test_matmul_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3075,6 +3075,80 @@ TEST_F(MatmulSchedulerTest, OperandOrderIssue2434) {
NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001));
}

TEST_F(MatmulSchedulerTest, HSH_TT) {
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 10, 0);
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

const auto dtype = DataType::Half;
constexpr auto layout = MmaLayout::TT;

auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); // A [M, K, b]
auto tv1 = makeContigConcreteTensor({1, -1, -1}, dtype); // B [b, K, N]
fusion->addInput(tv0);
fusion->addInput(tv1);

auto tv2 = fusedMultiplySum(tv0, tv1, {1});

// Reorder the accumulator as [M, N, K]
// [M, rK, N] -> [M, N, K]
tv2->reorder({{-2, -1}, {-1, -2}});
tv2->commitLeafToLogical();

auto tv3 = castOp(DataType::Half, tv2);
fusion->addOutput(tv3);

NVF_CHECK(
1 == ir_utils::getOpsOfType<MmaOp>(fusion.get()).size(),
"matmul fusion must have at least one MmaOp");

// Create custom Matmul Params
MatMulTileOptions gemm_tile;
// TODO cta tile is a multiple of mma macro for hopper.
gemm_tile.cta_tile = GemmTile(128, 128, 32);

// TODO warp tile is (macroM, macroN, macroK) for hopper.
gemm_tile.warp_tile = GemmTile(64, 64, 32);

// TODO instruction tile is not used for hopper.
gemm_tile.instruction_tile = GemmTile(16, 8, 16);

MatmulParams mparams;
mparams.supported_vec_size = {8, 8, 4};

// TODO use hopper macro
// mparams.mma_macro = MmaMacro::Hopper_64_256_16;
mparams.mma_macro = MmaMacro::Ampere_16_8_16;

mparams.tile_sizes = gemm_tile;
mparams.async_gmem_load_operands = true;
mparams.circular_buffer_options.circular_buffer_smem_write = true;
mparams.circular_buffer_options.circular_buffer_smem_read = true;
mparams.circular_buffer_options.smem_circular_buffer_stage = 4;

// TODO Create prefetch parameter
// mparams.circular_buffer_options.smem_circular_buffer_prefetch = 3;

// Schedule matmul fusion using custom parameters
SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
->schedule(fusion.get(), &mparams);

const int M = 32, N = 32, K = 256;
auto inputs =
matmulAtInput3DHopperSS(M, N, K, layout, data_type_to_aten(dtype));

FusionExecutor fe;
fe.compileFusion(
fusion.get(),
{inputs.first, inputs.second},
LaunchParams(),
matmul_cparams);

auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
auto tref = atMatmul(inputs.first.squeeze(), inputs.second.squeeze(), layout);
EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-5, 1e-5));
}

TEST_F(MatmulSchedulerTest, HSH_TN) {
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 10, 0);
auto fusion = std::make_unique<Fusion>();
Expand Down Expand Up @@ -3221,4 +3295,78 @@ TEST_F(MatmulSchedulerTest, HSH_NT) {
EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-5, 1e-5));
}

TEST_F(MatmulSchedulerTest, HSH_NN) {
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 10, 0);
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

const auto dtype = DataType::Half;
constexpr auto layout = MmaLayout::NN;

auto tv0 = makeContigConcreteTensor({1, -1, -1}, dtype); // A [b, K, M]
auto tv1 = makeContigConcreteTensor({-1, -1, 1}, dtype); // B [N, K, 1]
fusion->addInput(tv0);
fusion->addInput(tv1);

auto tv2 = fusedMultiplySum(tv0, tv1, {1});

// Reorder the accumulator as [M, N, K]
// [N, rK, M] -> [M, N, K]
tv2->reorder({{-1, -3}});
tv2->commitLeafToLogical();

auto tv3 = castOp(DataType::Half, tv2);
fusion->addOutput(tv3);

NVF_CHECK(
1 == ir_utils::getOpsOfType<MmaOp>(fusion.get()).size(),
"matmul fusion must have at least one MmaOp");

// Create custom Matmul Params
MatMulTileOptions gemm_tile;
// TODO cta tile is a multiple of mma macro for hopper.
gemm_tile.cta_tile = GemmTile(128, 128, 32);

// TODO warp tile is (macroM, macroN, macroK) for hopper.
gemm_tile.warp_tile = GemmTile(64, 64, 32);

// TODO instruction tile is not used for hopper.
gemm_tile.instruction_tile = GemmTile(16, 8, 16);

MatmulParams mparams;
mparams.supported_vec_size = {8, 8, 4};

// TODO use hopper macro
// mparams.mma_macro = MmaMacro::Hopper_64_256_16;
mparams.mma_macro = MmaMacro::Ampere_16_8_16;

mparams.tile_sizes = gemm_tile;
mparams.async_gmem_load_operands = true;
mparams.circular_buffer_options.circular_buffer_smem_write = true;
mparams.circular_buffer_options.circular_buffer_smem_read = true;
mparams.circular_buffer_options.smem_circular_buffer_stage = 4;

// TODO Create prefetch parameter
// mparams.circular_buffer_options.smem_circular_buffer_prefetch = 3;

// Schedule matmul fusion using custom parameters
SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
->schedule(fusion.get(), &mparams);

const int M = 32, N = 32, K = 256;
auto inputs =
matmulAtInput3DHopperSS(M, N, K, layout, data_type_to_aten(dtype));

FusionExecutor fe;
fe.compileFusion(
fusion.get(),
{inputs.first, inputs.second},
LaunchParams(),
matmul_cparams);

auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
auto tref = atMatmul(inputs.first.squeeze(), inputs.second.squeeze(), layout);
EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-5, 1e-5));
}

} // namespace nvfuser

0 comments on commit 7c8f375

Please sign in to comment.