Skip to content

Commit

Permalink
Enable compilation in Hopper MMA test without input broadcasts (#3406)
Browse files Browse the repository at this point in the history
Stacked on #3410, #3414, and #3416

This simply enables compilation of the test which uses #3391.
  • Loading branch information
jacobhinkle authored Dec 12, 2024
1 parent 4382f28 commit d5af72f
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions tests/cpp/test_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3819,9 +3819,9 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle_NoBroadcasts) {
Fusion fusion;
FusionGuard fg(&fusion);

// constexpr int64_t M = 2048, N = 2048, K = 8192;
constexpr int64_t M = 2048, N = 2048, K = 8192;
constexpr auto macro = MmaMacro::Hopper_64_256_16;
// constexpr auto layout = MmaLayout::NT; // [K, M] x [K, N] -> [M, N]
constexpr auto layout = MmaLayout::NT; // [K, M] x [K, N] -> [M, N]
constexpr auto swizzle = MmaInputSmemSwizzle::B128;
const auto dtype = DataType::Half;

Expand Down Expand Up @@ -3954,7 +3954,6 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle_NoBroadcasts) {
// of 3
ir_cloner.clone(tv2)->reorder({{2, 1}, {1, 2}});
inlineMost();
tmp_fusion.printMath();
ir_cloner.clone(tv2)->reorder({{2, 1}, {1, 2}});
EXPECT_EQ(ir_cloner.clone(tv0c)->getComputeAtPosition(), 1);
// The outermost loop dim of tv1c is a broadcast Mo axis, so
Expand All @@ -3981,7 +3980,17 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle_NoBroadcasts) {
pred_checker.handle(kernel->topLevelExprs());
ASSERT_TRUE(pred_checker.found_mma);

// TODO: compile and run kernel once inlining is fixed
auto [A3d, B3d] =
matmulAtInput3DHopperSS(M, N, K, layout, data_type_to_aten(dtype));
at::Tensor A = A3d.squeeze();
at::Tensor B = B3d.squeeze();
std::vector<c10::IValue> inputs{A, B};

KernelExecutor ke;
ke.compile(&fusion, inputs, LaunchParams(), matmul_cparams);
auto cg_outputs = ke.run(inputs);
auto tref = atMatmul(A, B, layout);
EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-5, 1e-5));
}

} // namespace nvfuser

0 comments on commit d5af72f

Please sign in to comment.