diff --git a/csrc/executor.cpp b/csrc/executor.cpp index dc0eac24575..2772b78e95e 100644 --- a/csrc/executor.cpp +++ b/csrc/executor.cpp @@ -1861,6 +1861,7 @@ std::vector FusionExecutor::evaluateFusionOutputs( for (const auto& out_val : fusion()->outputs()) { auto out_tensor = expr_eval.evaluate(out_val->as()).as(); + expr_eval.bind(out_val, out_tensor); outputs.emplace_back(out_tensor); } } diff --git a/csrc/logical_domain_map.cpp b/csrc/logical_domain_map.cpp index 0187b347cfa..184baf8aa68 100644 --- a/csrc/logical_domain_map.cpp +++ b/csrc/logical_domain_map.cpp @@ -214,6 +214,7 @@ std::unordered_map PairwiseLogicalDomainMap::map( NVF_ERROR(false, "Producer did not match any LinearOp input.") } + bool k_bcast = op->inA()->as()->axis(-1)->isBroadcast(); // LinearOp: // inputs (0) = {*, in_features} // weight (1) = {out_features, in_features} / {in_features} @@ -221,7 +222,8 @@ std::unordered_map PairwiseLogicalDomainMap::map( // output = {*, out_features} / {*} const std::vector& aligned_producer_ids = - ops::mapLinearOpIterDomains(producer_logical, input_position, out_size); + ops::mapLinearOpIterDomains( + producer_logical, input_position, out_size, k_bcast); pairwiseMapAllIds(aligned_producer_ids, consumer_root); return dom_map; } diff --git a/csrc/ops/composite.cpp b/csrc/ops/composite.cpp index 44902f74fd2..8afd3fe8fba 100644 --- a/csrc/ops/composite.cpp +++ b/csrc/ops/composite.cpp @@ -64,32 +64,43 @@ static TensorView* newForLinear( auto input_domain = TensorDomain::noReductions(input->getLogicalDomain()); auto weight_domain = TensorDomain::noReductions(weight->getLogicalDomain()); + // Output has a reduction axis rK if K is not bcast + NVF_CHECK( + input_domain.back()->isBroadcast() == weight_domain.back()->isBroadcast(), + "K should be broadcast in both inputs and weights, or neither."); + bool k_bcast = input_domain.back()->isBroadcast(); + size_t red_dims = k_bcast ? 0 : 1; + // Linear: a = {*, in_features}, b = {out_features, in_features} / - // {in_features}.The linear output is {*, (out_features), rK}. - // The first out_size -2 dimensions are as the first input, followed by - // out_features (if present) and an additional reduction axis K. - auto ndims_out = input_domain.size() + weight_domain.size() - 1; + // {in_features}.The linear output is {*, (out_features), rK?}. + // Reduction K is present only when K is not bcast. + auto ndims_out = + (input_domain.size() - 1) + (weight_domain.size() - 1) + red_dims; const std::vector& mapping_a = - ops::mapLinearOpIterDomains(input_domain, 0, ndims_out); + ops::mapLinearOpIterDomains(input_domain, 0, ndims_out, k_bcast); const std::vector& mapping_b = - ops::mapLinearOpIterDomains(weight_domain, 1, ndims_out); + ops::mapLinearOpIterDomains(weight_domain, 1, ndims_out, k_bcast); std::vector mapping_bias(ndims_out, nullptr); if (bias != nullptr) { auto bias_domain = TensorDomain::noReductions(bias->getLogicalDomain()); - mapping_bias = ops::mapLinearOpIterDomains(bias_domain, 2, ndims_out); + mapping_bias = + ops::mapLinearOpIterDomains(bias_domain, 2, ndims_out, k_bcast); } std::vector out_domain(ndims_out, nullptr); - for (auto idx : c10::irange(ndims_out - 1)) { + for (auto idx : c10::irange(ndims_out - red_dims)) { out_domain[idx] = ops::newOutputIterDomain( {mapping_a.at(idx), mapping_b.at(idx), mapping_bias.at(idx)}); } - // Specify the iterdomain for K as reduction - out_domain[ndims_out - 1] = ops::newOutputIterDomain( - {mapping_a.back(), mapping_b.back()}, - /*force_iter_type=*/IterType::Reduction); + + if (!k_bcast) { + // Specify the iterdomain for K as reduction + out_domain[ndims_out - 1] = ops::newOutputIterDomain( + {mapping_a.back(), mapping_b.back()}, + /*force_iter_type=*/IterType::Reduction); + } TensorDomain* td = IrBuilder::create( out_domain, TensorDomain::getContiguityFilledWith(out_domain, true)); @@ -335,14 +346,25 @@ static TensorView* newForMatmul(TensorView* tv_a, TensorView* tv_b) { auto ndims_a = orig_domain_a.size(); auto ndims_b = orig_domain_b.size(); + auto b_kpos = orig_domain_b.size() > 1 ? ndims_b - 2 : ndims_b - 1; + NVF_CHECK( + orig_domain_a.back()->isBroadcast() == + orig_domain_b.at(b_kpos)->isBroadcast(), + "K should be broadcast in both A and B, or neither."); + + // Output has a reduction axis rK if K is not bcast + bool k_bcast = orig_domain_a.back()->isBroadcast(); + size_t red_dims = k_bcast ? 0 : 1; + // Matmul output size is same as the higher dimensional input size if both A/B - // > 1D, but with 1 additional IterType::Reduction axis rK. - auto ndims_out = std::max(ndims_a, ndims_b) + 1; + // > 1D, but with 1 additional IterType::Reduction axis rK if K is not + // broadcast. + auto ndims_out = std::max(ndims_a, ndims_b) + red_dims; if (std::min(ndims_a, ndims_b) == 1) { // If one of the inputs is 1D, the output size is the same as the higher // dimensional input size, since we will include a Reduction axis for K in // the output. For example: [iM, iK] x [iK] -> [iM, rK] - ndims_out = std::max(ndims_a, ndims_b); + ndims_out = std::max(ndims_a, ndims_b) - 1 + red_dims; } std::vector out_domain(ndims_out, nullptr); @@ -352,14 +374,15 @@ static TensorView* newForMatmul(TensorView* tv_a, TensorView* tv_b) { const std::vector& mapping_b = ops::mapMatmulOpIterDomains(orig_domain_b, 1, ndims_out); - for (auto idx : c10::irange(ndims_out - 1)) { + for (auto idx : c10::irange(ndims_out - red_dims)) { out_domain[idx] = ops::newOutputIterDomain({mapping_a.at(idx), mapping_b.at(idx)}); } - - out_domain[ndims_out - 1] = ops::newOutputIterDomain( - {mapping_a.back(), mapping_b.back()}, - /*force_iter_type=*/IterType::Reduction); + if (!k_bcast) { + out_domain[ndims_out - 1] = ops::newOutputIterDomain( + {mapping_a.back(), mapping_b.back()}, + /*force_iter_type=*/IterType::Reduction); + } TensorDomain* td = IrBuilder::create( out_domain, TensorDomain::getContiguityFilledWith(out_domain, true)); diff --git a/csrc/ops/utils.cpp b/csrc/ops/utils.cpp index b25dd326f4f..8dc6d04d406 100644 --- a/csrc/ops/utils.cpp +++ b/csrc/ops/utils.cpp @@ -188,20 +188,26 @@ std::vector mapMatmulOpIterDomains( std::vector mapping(out_size, nullptr); auto inp_size = (int64_t)input_domain.size(); - if (inp_size == 1) { - // Only reduction axis {K} - mapping[out_size - 1] = input_domain[0]; - return mapping; - } - // Input A to matmul: {*, M, K} // Input B to matmul: {*, K, N} - auto kpos = input_position == 0 ? inp_size - 1 : inp_size - 2; + auto kpos = inp_size - 1; + if (input_position == 1 && inp_size > 1) { + kpos = inp_size - 2; + } + bool k_bcast = input_domain.at(kpos)->isBroadcast(); + int64_t red_dims = k_bcast ? 0 : 1; - // Last position is a reduction dimension mapping to K - mapping[out_size - 1] = input_domain.at(kpos); + // Last position is a reduction dimension mapping to K if K is not broadcast. + if (!k_bcast) { + mapping[out_size - 1] = input_domain.at(kpos); + ; + } - for (auto out_idx = (int64_t)out_size - 2, inp_idx = inp_size - 1; + if (inp_size == 1) { + return mapping; + } + + for (auto out_idx = (int64_t)out_size - 1 - red_dims, inp_idx = inp_size - 1; inp_idx >= 0; inp_idx--) { if (inp_idx != kpos) { @@ -211,7 +217,7 @@ std::vector mapMatmulOpIterDomains( // Consider [iM, iK] x [iK]: [iM, rK]. Since out_size < inp_size, // input A and output are not right-aligned. In this case, the output index // pointer should not be moved when the reduction axis is encountered. - else if (inp_size <= (int64_t)out_size - 1) { + else if (inp_size <= (int64_t)out_size - red_dims) { out_idx--; } } @@ -222,7 +228,8 @@ std::vector mapMatmulOpIterDomains( std::vector mapLinearOpIterDomains( const std::vector& input_domain, int64_t input_position, - size_t out_size) { + size_t out_size, + bool k_bcast) { std::vector mapping(out_size, nullptr); auto inp_size = input_domain.size(); @@ -231,29 +238,37 @@ std::vector mapLinearOpIterDomains( "Input position must be 0, 1, or 2. Found ", input_position); + auto red_dims = k_bcast ? 0 : 1; + // Input A: {*, M, K} // Input B: {*, N, K} / {K} // Bias: {N} / {} + + // Map K if K is not bcast + if (input_position != 2 && !k_bcast) { + mapping[out_size - 1] = input_domain.back(); + } + switch (input_position) { case 0: { - // Linear output is same as input for all but the last dimension + // Linear output is same as input for inp_size - 1 dimensions. + // K is already mapped above if not broadcast. for (auto inx : c10::irange(inp_size - 1)) { mapping[inx] = input_domain[inx]; } - mapping[out_size - 1] = input_domain.back(); break; } case 1: { - for (auto inx : c10::irange(inp_size)) { - // Map N, K to the last two positions of the output. - mapping[out_size - 1 - inx] = input_domain[inp_size - 1 - inx]; + // Map N / out_features if present + if (inp_size > 1) { + mapping[out_size - 1 - red_dims] = input_domain.front(); } break; } case 2: { if (inp_size > 0) { // Bias is 1D tensor of shape {out_features} - mapping[out_size - 2] = input_domain[0]; + mapping[out_size - 1 - red_dims] = input_domain.front(); } break; } diff --git a/csrc/ops/utils.h b/csrc/ops/utils.h index 86e87f13b99..930f6cebf06 100644 --- a/csrc/ops/utils.h +++ b/csrc/ops/utils.h @@ -73,7 +73,8 @@ std::vector mapMatmulOpIterDomains( std::vector mapLinearOpIterDomains( const std::vector& input_domain, int64_t input_position, - size_t out_size); + size_t out_size, + bool k_bcast); // Takes a vector of aligned input iterdomains to create the output iterdomain. // This is used if the input iterdomains are not trivially mapped to the output diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index c82dcd779ee..2c413be26fe 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1809,13 +1809,14 @@ DimRolesMap MatmulPattern::getDimRoles(IdModel& id_model) const { } else if (output->definition()->isA()) { const std::vector& out_logical = output->getLogicalDomain(); + bool k_bcast = A->getLogicalDomain().back()->isBroadcast(); return matmulOrLinearOpDimRoles( permissive_graph, out_logical, ops::mapLinearOpIterDomains( - A->getLogicalDomain(), 0, out_logical.size()), + A->getLogicalDomain(), 0, out_logical.size(), k_bcast), ops::mapLinearOpIterDomains( - B->getLogicalDomain(), 1, out_logical.size())); + B->getLogicalDomain(), 1, out_logical.size(), k_bcast)); } // The code below handles MmaOp or mul-sum patterns diff --git a/tests/cpp/test_matmul_aten_evaluation.cpp b/tests/cpp/test_matmul_aten_evaluation.cpp index f7bcd7595fd..ddbbce1bdba 100644 --- a/tests/cpp/test_matmul_aten_evaluation.cpp +++ b/tests/cpp/test_matmul_aten_evaluation.cpp @@ -58,21 +58,22 @@ void checkMatmulOpIdMapping( const ValGraph& vg = id_model.idGraph(IdMappingMode::EXACT); vg.validateConsistency(); - // If K is Broadcast then we will not have a reduction dim + // If K is Broadcast then we will not have a reduction dim. bool k_bcast = A->axis(-1)->isBroadcast(); - auto out_ndims = std::max(A->nDims(), B->nDims()) + 1; + int64_t red_dims = k_bcast ? 0 : 1; + auto out_ndims = std::max(A->nDims(), B->nDims()) + red_dims; if (std::min(A->nDims(), B->nDims()) == 1) { - out_ndims = std::max(A->nDims(), B->nDims()); + out_ndims = std::max(A->nDims(), B->nDims()) - 1 + red_dims; } ASSERT_EQ(output->nDims(), out_ndims); if (A->nDims() > 1) { - int out_mpos = B->nDims() > 1 ? -3 : -2; + int out_mpos = B->nDims() > 1 ? -2 - red_dims : -1 - red_dims; EXPECT_TRUE(checkMapped(vg, A->axis(-2), output->axis(out_mpos))); // M } if (B->nDims() > 1) { - EXPECT_TRUE(checkMapped(vg, B->axis(-1), output->axis(-2))); // N + EXPECT_TRUE(checkMapped(vg, B->axis(-1), output->axis(-1 - red_dims))); // N } if (!k_bcast) { @@ -85,7 +86,8 @@ void checkMatmulOpIdMapping( // Note that A and B can have different dimensions, so here we count // backwards from the innermost batch dimension. Then we check that the axis // exists (is not negative) and is not Broadcast before checking mapping. - int batch_ndims = output->nDims() - (B->nDims() > 1) - (A->nDims() > 1) - 1; + int batch_ndims = + output->nDims() - (B->nDims() > 1) - (A->nDims() > 1) - red_dims; for (int64_t i : c10::irange(batch_ndims)) { int64_t i_a = A->nDims() - 3 - i; int64_t i_b = B->nDims() - 3 - i; @@ -115,7 +117,9 @@ void checkLinearOpIdMapping( // bias (optional): [out_features]/[] // output = [*, (out_features), rK] - ASSERT_EQ(output->nDims(), input->nDims() + weight->nDims() - 1); + bool k_bcast = input->axis(-1)->isBroadcast(); + int64_t red_dims = k_bcast ? 0 : 1; + ASSERT_EQ(output->nDims(), input->nDims() + weight->nDims() - 2 + red_dims); // Check that the first input_size - 1 dims are mapped for input for (auto i : c10::irange(input->nDims() - 1)) { @@ -126,10 +130,11 @@ void checkLinearOpIdMapping( // Check out_features dim is mapped in weight & bias if present. if (weight->nDims() > 1) { if (!weight->axis(0)->isBroadcast()) { - EXPECT_TRUE(checkMapped(vg, weight->axis(0), output->axis(-2))); + EXPECT_TRUE( + checkMapped(vg, weight->axis(0), output->axis(-1 - red_dims))); } if (bias != nullptr && bias->nDims() > 0 && !bias->axis(0)->isBroadcast()) { - EXPECT_TRUE(checkMapped(vg, bias->axis(0), output->axis(-2))); + EXPECT_TRUE(checkMapped(vg, bias->axis(0), output->axis(-1 - red_dims))); } } // Check mapping for reduction axis in input and weight diff --git a/tests/python/test_python_frontend.py b/tests/python/test_python_frontend.py index 38e94bae663..e086dd4be22 100644 --- a/tests/python/test_python_frontend.py +++ b/tests/python/test_python_frontend.py @@ -4458,6 +4458,39 @@ def fusion_func(fd: FusionDefinition): for i in range(num_out): self.assertEqual(nvf_out[i].data_ptr(), inputs[0].data_ptr()) + # Tests broadcast reduction axis in matmul: Issue #2532. + def test_repro_issue2532(self): + def fusion_func(fd: FusionDefinition) -> None: + T0 = fd.define_tensor( + shape=[-1, -1, 1], + contiguity=[True, None, True], + dtype=DataType.Float, + is_cpu=False, + stride_order=[2, 0, 1], + ) + T1 = fd.define_tensor( + shape=[-1, 1, -1], + contiguity=[True, None, True], + dtype=DataType.Float, + is_cpu=False, + stride_order=[2, 1, 0], + ) + T2 = fd.ops.sum(T1, dims=[0, 1], keepdim=False, dtype=DataType.Null) + T3 = fd.ops.matmul(T0, T1) + T4 = fd.ops.sum(T3, dims=[0], keepdim=False, dtype=DataType.Null) + fd.add_output(T2) + fd.add_output(T4) + + inputs = [ + torch.randn((262400,), dtype=torch.float32, device="cuda:0").as_strided( + (1025, 256, 1), (256, 1, 256) + ), + torch.randn((1049600,), dtype=torch.float32, device="cuda:0").as_strided( + (1025, 1, 1024), (1024, 1024, 1) + ), + ] + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + if __name__ == "__main__": run_tests()