diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index 415a28829c3..70a6ef151ea 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -800,9 +800,10 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { // scheduler. // 6. Check if the fusion is resharding. + const auto device_prop = at::cuda::getCurrentDeviceProperties(); + // #0 { - const auto device_prop = at::cuda::getCurrentDeviceProperties(); // Use a dummy problem shape to determine whether this is a supported // device. const auto mma_op = getMmaOp( @@ -824,6 +825,16 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) { { for (const mma_utils::MatmulPattern& pattern : patterns) { Expr* op = pattern.output->definition(); + if (device_prop->major >= 9 && op->isA()) { + bool found_reduction = false; + for (size_t dim : c10::irange((size_t)pattern.output->nDims())) { + if (found_reduction && + !pattern.output->axis((int64_t)dim)->isReduction()) { + return "Mul+Sum patterns can only be translated to MmaOp " + "on Hopper if the reduction dim is innermost"; + } + } + } if (op->isA() || op->isA()) { if (!isOptionEnabled(EnableOption::FuseMatmul)) { // Check for MatmulOp or LinearOp. If found, then only fuse if option diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index e9a24851a5a..08d568d6c63 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1777,7 +1777,7 @@ std::string MatmulPattern::toString() const { return ss.str(); } -MmaOp* MatmulPattern::translateToMmaOp() { +MmaOp* MatmulPattern::translateToMmaOp(bool avoid_intermediates) { if (auto mma_op = dynamic_cast(output->definition())) { // No translation needed return mma_op; @@ -1804,6 +1804,13 @@ MmaOp* MatmulPattern::translateToMmaOp() { // - bias, if present, can be zero or one dimensional. Bias can only be // present if weight is 2D // + // When A has dimension greater than two, all the preceding dimensions + // are essentially also M dimensions. The output is shaped like + // + // A [ ... iS0{M} iS1{K} ] + // B [ iS2{N} iS3{K} ] + // out [ ... iS3{M} iS3{N} rS3{K} ] + // // We translate by broadcasting input, weight, and bias such that the // contracted dimension K is in the last position (this is true of the // logical domains in input and weight already). Then we form an MmaOp and @@ -1812,15 +1819,51 @@ MmaOp* MatmulPattern::translateToMmaOp() { NVF_ERROR( A->nDims() > 1 && B->nDims() > 1, "Cannot translate LinearOp with 1D input"); - std::vector bcast_dim((size_t)A->nDims() + 1, false); - bcast_dim[bcast_dim.size() - 2] = true; // N - A = broadcast(A, bcast_dim); + NVF_ERROR( + B->nDims() == 2, "Cannot translate LinearOp without 2D weight tensor"); + if (avoid_intermediates) { + MmaOp::AxisMapping axis_mapping; + int64_t out_dim = A->nDims() + 1L; + axis_mapping.a_axes.reserve(out_dim); + for (int64_t d : c10::irange(out_dim - 2L)) { + axis_mapping.a_axes.push_back(d); + } + axis_mapping.a_axes.reserve(out_dim); + for (size_t d : c10::irange(out_dim - 2)) { + axis_mapping.a_axes.push_back((int64_t)d); + } + axis_mapping.a_axes.push_back(-1); // missing N dimension + axis_mapping.a_axes.push_back(A->nDims() - 1); // K dimension + + axis_mapping.b_axes.reserve(out_dim); + axis_mapping.b_axes.resize(out_dim, -1); + axis_mapping.b_axes[out_dim - 2] = 0; // N + axis_mapping.b_axes[out_dim - 1] = 1; // K + + int64_t num_M_dims = 1 + A->nDims() - B->nDims(); + + // Add loop broadcasts to A and B to mimic logical broadcasts for simpler + // scheduling + A->broadcast(-2); // There's always a single N dimension + + for ([[maybe_unused]] size_t i : c10::irange((size_t)num_M_dims)) { + // Broadcast B for every M dimension in A + B->broadcast(0); + } - bcast_dim[bcast_dim.size() - 2] = false; // reset N - std::fill(bcast_dim.begin(), bcast_dim.end() - 2, true); - B = broadcast(B, bcast_dim); + fms = fusedMultiplySum(A, B, {-1}, /*init=*/nullptr, axis_mapping); + } else { + std::vector bcast_dim(A->nDims() + 1, false); + bcast_dim[bcast_dim.size() - 2] = true; // N + A = broadcast(A, bcast_dim); + + bcast_dim[bcast_dim.size() - 2] = false; // reset N + std::fill(bcast_dim.begin(), bcast_dim.end() - 2, true); + B = broadcast(B, bcast_dim); + + fms = fusedMultiplySum(A, B, {-1}); + } - fms = fusedMultiplySum(A, B, {-1}); mma_op = fms->definition()->as(); auto* bias = dynamic_cast(lop->bias()); @@ -1835,19 +1878,82 @@ MmaOp* MatmulPattern::translateToMmaOp() { // Also note that the output of MatmulOp is a tensor of shape [..., M, N] // whose dtype matches that of the inputs. We will most commonly then also // need to cast the output of the MmaOp to produce the output TensorView. + // + // There are two possibilities: + // + // Case 1: A->nDims() > B->nDims(): + // + // A [ ..., B1, ..., Bn, M, K ] + // B [ B1, ..., Bn, K, N ] + // + // All the preceding dimensions in A are additional M dimensions. There + // are batch dimensions in between those and "M". + // + // Case 2: A->nDims() <= B->nDims(): + // + // A [ B1, ..., Bn, M, K ] + // B [ ..., B1, ..., Bn, K, N ] + // + // All the preceding dimensions in B are additional N dimensions. There + // are batch dimensions in between those and "N". + // + // In either case, to form the output we transpose B in the last two dims, + // and prepend broadcasts to the lower dimensional input as needed. NVF_ERROR( A->nDims() > 1 && B->nDims() > 1, "Cannot translate MatmulOp with 1D input"); - TensorView* Btrans = transpose(B, -2, -1); - A = unsqueeze(A, -2); - B = unsqueeze(Btrans, -3); - // A and B might have different dimensions. If so, broadcast the smaller one - // up to the size of the larger. - int64_t out_dims = std::max(A->nDims(), B->nDims()); - // Add new outer broadcast dimensions if necessary - A = ops::maybe_broadcast_inner_to_rank(A, out_dims); - B = ops::maybe_broadcast_inner_to_rank(B, out_dims); - fms = fusedMultiplySum(A, B, {-1}); + if (avoid_intermediates) { + MmaOp::AxisMapping axis_mapping; + int64_t out_dims = std::max(A->nDims(), B->nDims()) + 1; + + axis_mapping.a_axes.resize((size_t)out_dims, -1); + axis_mapping.b_axes.resize((size_t)out_dims, -1); + + for (size_t a_axis : c10::irange((size_t)A->nDims() - 1)) { + // Output is [ ... M, N, K ] + // This loop maps everything but N and K to A + int64_t out_axis = (int64_t)a_axis + (out_dims - 1 - A->nDims()); + axis_mapping.a_axes.at((size_t)out_axis) = (int64_t)a_axis; + } + // Map the K dim, skipping one position + axis_mapping.a_axes.at((size_t)out_dims - 1) = A->nDims() - 1; + + for (size_t b_axis : c10::irange((size_t)B->nDims() - 2)) { + // Output is [ ... M, N, K ] + // This loop maps everything before M to B, skipping the output M dim + int64_t out_axis = (int64_t)b_axis + (out_dims - B->nDims()) - 1; + axis_mapping.b_axes.at((size_t)out_axis) = (int64_t)b_axis; + } + // Skip the K dim and map N and K + axis_mapping.b_axes.at((size_t)out_dims - 2) = B->nDims() - 1; + axis_mapping.b_axes.at((size_t)out_dims - 1) = B->nDims() - 2; + + fms = fusedMultiplySum(A, B, {-1}, /*init=*/nullptr, axis_mapping); + + int64_t num_M_dims = std::max(1 + A->nDims() - B->nDims(), (int64_t)1); + + // Reorder to BMNK. + // Add loop broadcasts to A and B to mimick logical broadcasts for simpler + // scheduling + A->broadcast(-2); + + B->reorder({{-2, -1}}); + for ([[maybe_unused]] size_t i : c10::irange((size_t)num_M_dims)) { + // Broadcast B for every M dimension in A + B->broadcast(-3); + } + } else { + TensorView* Btrans = transpose(B, -2, -1); + A = unsqueeze(A, -2); + B = unsqueeze(Btrans, -3); + // A and B might have different dimensions. If so, broadcast the smaller + // one up to the size of the larger. + int64_t out_dims = std::max(A->nDims(), B->nDims()); + // Add new outer broadcast dimensions if necessary + A = ops::maybe_broadcast_inner_to_rank(A, out_dims); + B = ops::maybe_broadcast_inner_to_rank(B, out_dims); + fms = fusedMultiplySum(A, B, {-1}); + } mma_op = fms->definition()->as(); } else { NVF_THROW( @@ -1986,17 +2092,18 @@ DimRolesMap MatmulPattern::getDimRoles(IdModel& id_model) const { // for each valgroup, store a pair of flags. The first records whether the // group is present at all in the tv. The second records whether the value is // concrete (i.e. not reduction, broadcast, or device). - std::unordered_map> flags; + std::unordered_map flags; const auto recordPresence = [&graph, &flags]( TensorView* tv, size_t tensor_num) { for (IterDomain* id : tv->getLogicalDomain()) { const ValGroup& g = graph.toGroup(id); - auto& [present_flags, concrete_flags] = flags[g]; - present_flags.set(tensor_num); + DimPresence& group_flags = flags[g]; + // Note: broadcast or device dims will be initialized to have all false + // flags above if (id->isReduction() || id->isBroadcast() || id->isDeviceDim()) { continue; } - concrete_flags.set(tensor_num); + group_flags.set(tensor_num); } }; recordPresence(A, 0); @@ -2005,8 +2112,7 @@ DimRolesMap MatmulPattern::getDimRoles(IdModel& id_model) const { DimRolesMap dim_roles; - for (const auto& [g, f] : flags) { - const auto& [present_flags, concrete_flags] = f; + for (const auto& [g, concrete_flags] : flags) { if (concrete_flags.all() || concrete_flags.none()) { // Batch dimensions are any of those that are not concretized or reduced. // These could be all Iteration or all Broadcast @@ -2019,9 +2125,25 @@ DimRolesMap MatmulPattern::getDimRoles(IdModel& id_model) const { dim_roles[g] = MatmulDimRole::N; } else { NVF_THROW( - "IterDomain ValGroup should be present in at least two of A, B, output.", - " present_flags: ", - present_flags); + "IterDomain ValGroup should be concrete in at least two of A, B, output.", + " concrete_flags: ", + concrete_flags); + } + } + + // NOTE: For Hopper, we create loop broadcasts to mimic logical broadcasts + // when translating MatmulOp and LinearOp. Here we detect these and map them + // appropriately. + for (IterDomain* id : A->getLoopDomain()) { + const ValGroup& g = graph.toGroup(id); + if (dim_roles.count(g) == 0) { + dim_roles[g] = MatmulDimRole::N; + } + } + for (IterDomain* id : B->getLoopDomain()) { + const ValGroup& g = graph.toGroup(id); + if (dim_roles.count(g) == 0) { + dim_roles[g] = MatmulDimRole::M; } } diff --git a/csrc/scheduler/mma_utils.h b/csrc/scheduler/mma_utils.h index c88fe4926e3..6bef370240e 100644 --- a/csrc/scheduler/mma_utils.h +++ b/csrc/scheduler/mma_utils.h @@ -327,7 +327,11 @@ struct MatmulPattern { //! there is a MatmulOp instead, this function modifies the fusion to insert //! an MmaOp. TensorViews A and B are unchanged, but this->output might be //! updated to reflect the replacement tensor. - MmaOp* translateToMmaOp(); + //! + //! If avoid_intermediates is true, this function will use an + //! MmaOp::AxisMapping instead of broadcasting and permuting axes, in order to + //! avoid introducing unnecessary copies on Hopper and above. + MmaOp* translateToMmaOp(bool avoid_intermediates = false); //! Given an IdModel, map groups of IterDomains to dimension roles //! (MatmulDimRole). Note that ValGroup is a shared_ptr to a diff --git a/csrc/scheduler/multi_matmul.cpp b/csrc/scheduler/multi_matmul.cpp index 915e08ab8e8..33b350fd467 100644 --- a/csrc/scheduler/multi_matmul.cpp +++ b/csrc/scheduler/multi_matmul.cpp @@ -7,6 +7,7 @@ // clang-format on #include +#include #include #include #include @@ -21,7 +22,23 @@ void MultipleMatmulScheduler::findPatterns() { void MultipleMatmulScheduler::translatePatterns() { mma_results_.reserve(patterns_.size()); for (mma_utils::MatmulPattern& pattern : patterns_) { - MmaOp* mma = pattern.translateToMmaOp(); + // TODO: properly handle all mul+sum patterns for Hopper. For now, these + // should work fine as long as the inner dimensions are the ones being + // reduced. + if (!isAmpere(params_->mma_macro) && !isTuring(params_->mma_macro) && + pattern.output->definition()->isA()) { + bool found_reduction = false; + for (size_t dim : c10::irange((size_t)pattern.output->nDims())) { + NVF_ERROR( + !found_reduction || + !pattern.output->axis((int64_t)dim)->isReduction(), + "Mul+Sum patterns can only be translated on Hopper if the reduction dim is innermost"); + } + } + + MmaOp* mma = pattern.translateToMmaOp( + /*avoid_intermediates=*/!isAmpere(params_->mma_macro) && + !isTuring(params_->mma_macro)); mma_results_.push_back(mma->out()->as()); } diff --git a/tests/cpp/test_translate_mma.cpp b/tests/cpp/test_translate_mma.cpp index f290d30b576..ab6eb658cf6 100644 --- a/tests/cpp/test_translate_mma.cpp +++ b/tests/cpp/test_translate_mma.cpp @@ -43,10 +43,10 @@ namespace nvfuser { class CombineMulSumAsMmaTest : public NVFuserTest { void SetUp() override { // These test are enable for Turing and newer. Temporarily - // we are skipping Hopper since the matmul for it is under development. + // we are skipping Blackwell since the matmul for it is under development. auto lower_major = 8; auto lower_minor = 0; - auto upper_major = 9; + auto upper_major = 10; auto upper_minor = 0; if (cudaArchGuardShouldSkip( lower_major, lower_minor, upper_major, upper_minor)) { @@ -55,8 +55,14 @@ class CombineMulSumAsMmaTest : public NVFuserTest { << lower_minor << "and " << upper_major << "." << upper_minor << " to run.\n"; } + + pre_hopper = at::cuda::getCurrentDeviceProperties()->major < 9; + NVFuserTest::SetUp(); } + + protected: + bool pre_hopper; }; class CombineMulSumAsMmaTestWithLayout @@ -66,24 +72,30 @@ class CombineMulSumAsMmaTestWithLayout MmaLayout layout; void SetUp() override { layout = GetParam(); - // These test are enable for Turing and newer. Temporarily - // we are skipping Hopper since the matmul for it is under development. + // These test are enable for Turing and newer. + // we are skipping Blackwell since the matmul for it is under development. auto lower_major = 8; auto lower_minor = 0; - auto upper_major = 9; + auto upper_major = 10; auto upper_minor = 0; if (cudaArchGuardShouldSkip( lower_major, lower_minor, upper_major, upper_minor)) { - GTEST_SKIP() << "CombineMulSumAsMmaTest skipped " + GTEST_SKIP() << "CombineMulSumAsMmaTestWithLayout skipped " << "Requires GPU capability between " << lower_major << "." << lower_minor << "and " << upper_major << "." << upper_minor << " to run.\n"; } + pre_hopper = at::cuda::getCurrentDeviceProperties()->major < 9; NVFuserTest::SetUp(); } + + bool pre_hopper; }; -void performSubstitution(Fusion* fusion, bool should_not_find = false) { +void performSubstitution( + Fusion* fusion, + bool avoid_intermediates, + bool should_not_find = false) { EXPECT_TRUE(ir_utils::getOpsOfType(fusion).empty()); std::vector patterns = @@ -96,14 +108,14 @@ void performSubstitution(Fusion* fusion, bool should_not_find = false) { ASSERT_FALSE(patterns.empty()); EXPECT_EQ(patterns.size(), 1); - patterns.front().translateToMmaOp(); + patterns.front().translateToMmaOp(avoid_intermediates); ASSERT_FALSE(ir_utils::getOpsOfType(fusion).empty()); } // Test checks to see that the combiner can correctly replace // the mul-sum pair with a mma op. -TEST_P(CombineMulSumAsMmaTestWithLayout, AmpereMulSumToMatmul_Pass) { +TEST_P(CombineMulSumAsMmaTestWithLayout, MulSumToMatmul_Pass) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeContigTensor(2, DataType::Half); @@ -119,13 +131,13 @@ TEST_P(CombineMulSumAsMmaTestWithLayout, AmpereMulSumToMatmul_Pass) { fusion.addOutput(tv3); - performSubstitution(&fusion); + performSubstitution(&fusion, /*avoid_intermediates=*/!pre_hopper); } // This test checks that the pattern matcher does not incorrectly identify // this mul-sum pair, as the mul is not fed by broadcasts ops; i.e. it is // not a matmul. -TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_Fail1) { +TEST_F(CombineMulSumAsMmaTest, MulSumToMatmul_Fail1) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeContigTensor(3, DataType::Half); @@ -138,11 +150,15 @@ TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_Fail1) { auto tv3 = sum(tv2, {-1}); fusion.addOutput(tv3); - performSubstitution(&fusion, /*should_not_find=*/true); + performSubstitution( + &fusion, /*avoid_intermediates=*/!pre_hopper, /*should_not_find=*/true); } // This fusion has Broadcast batch axes in each operand. -TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_MultipleBroadcasts) { +TEST_F(CombineMulSumAsMmaTest, MulSumToMatmul_MultipleBroadcasts) { + // This test expicitly broadcasts and transposes, so we cannot avoid + // intermediates on Hopper (yet). + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); // Assumes layout is kAllSupportedMmaLayout::NT; std::unique_ptr fusion_ptr = std::make_unique(); Fusion* fusion = fusion_ptr.get(); @@ -170,7 +186,8 @@ TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_MultipleBroadcasts) { auto tv3 = sum(tv2, {-1}); fusion->addOutput(tv3); - performSubstitution(fusion, /*should_not_find=*/false); + performSubstitution( + fusion, /*avoid_intermediates=*/!pre_hopper, /*should_not_find=*/false); // We test running this fusion also to verify that the broadcast batch // dimension does not cause unforeseen issues @@ -192,6 +209,7 @@ TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_MultipleBroadcasts) { // pair with a mma op, we are able to schedule it as we did with // a fusion that had a mma op to begin with. TEST_P(CombineMulSumAsMmaTestWithLayout, AmpereMulSumToMatmul_Schedule) { + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); // Keep multiples of 8 to keep vectorizable. int M = 504, N = 136, K = 248; @@ -209,7 +227,7 @@ TEST_P(CombineMulSumAsMmaTestWithLayout, AmpereMulSumToMatmul_Schedule) { fusion.addOutput(tv2); - performSubstitution(&fusion); + performSubstitution(&fusion, /*avoid_intermediates=*/!pre_hopper); MatMulTileOptions gemm_tile; gemm_tile.cta_tile = GemmTile(128, 128, 32); @@ -239,6 +257,7 @@ TEST_P(CombineMulSumAsMmaTestWithLayout, AmpereMulSumToMatmul_Schedule) { } TEST_P(CombineMulSumAsMmaTestWithLayout, UseMatmulScheduler) { + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); // Keep multiples of 8 to keep vectorizable. int M = 504, N = 136, K = 248; auto fusion = std::make_unique();