Skip to content

Commit

Permalink
Fix multiple-broadcasts test.
Browse files Browse the repository at this point in the history
This fixes #2273
  • Loading branch information
jacobhinkle committed May 20, 2024
1 parent a583d40 commit f5ec534
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 12 deletions.
2 changes: 1 addition & 1 deletion csrc/ir/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1290,7 +1290,7 @@ MmaOpDetails getMmaOpDetails(
const auto validateOutputDetails = [](const TensorViewDetails& details,
const std::string& desc) {
// TODO: revise rules when add support for batch gemms
NVF_ERROR(details.bcasts.empty(), desc, ": has broadcast domains.");
// NVF_ERROR(details.bcasts.empty(), desc, ": has broadcast domains.");
NVF_ERROR(!details.rdomains.empty(), desc, ": has no reduction domains.");
NVF_ERROR(
(details.cdomains.size() >= expected_gemm_cdomains),
Expand Down
38 changes: 36 additions & 2 deletions csrc/scheduler/mma_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1259,10 +1259,14 @@ RolesMapOpt getTensorsRoles(
bool has_m = false, has_n = false, has_k = false, has_unmapped = false;
for (IterDomain* id :
TensorDomain::noReductions(tv->getMaybeRFactorDomain())) {
if (id->isBroadcast()) {
// Ignore broadcasts in output
continue;
}
const ValGroup& g = exact_graph.toGroup(id);
auto it = group_to_domain.find(g);
if (it == group_to_domain.end()) {
// tv has an unmapped dimension
// output tv has an unmapped non-broadcast dimension
has_unmapped = true;
continue;
}
Expand Down Expand Up @@ -1459,6 +1463,19 @@ class MatmulPatternMatcher : IterVisitor {
if (bop->getBinaryOpType() != BinaryOpType::Mul) {
return;
}
// TODO: Allow multiple K dimensions
// Check that there's a single K dimension
bool has_k = false;
for (IterDomain* id :
rop->out()->as<TensorView>()->getMaybeRFactorDomain()) {
if (id->isReduction()) {
if (has_k) {
return;
}
has_k = true;
}
}

// Remember that we are just gathering the immediate inputs to the
// matmul, so there should be no prologue between a, b and the mul/sum.

Expand All @@ -1481,8 +1498,16 @@ class MatmulPatternMatcher : IterVisitor {
bool has_m = false, has_n = false;
for (size_t i : c10::irange(lrf.size())) {
if (lrf[i]->isBroadcast() && !rrf[i]->isBroadcast()) {
if (has_m) {
// TODO: Handle multiple M dimensions
return;
}
has_m = true;
} else if (!lrf[i]->isBroadcast() && rrf[i]->isBroadcast()) {
if (has_n) {
// TODO: Handle multiple N dimensions
return;
}
has_n = true;
}
if (red_root[i]->isReduction()) {
Expand All @@ -1493,7 +1518,7 @@ class MatmulPatternMatcher : IterVisitor {
}
}
if (!has_m || !has_n) {
// This is an ordinary reduction, not a matmul
// This is an ordinary reduction or mat-vec, not a matmul
return;
}

Expand All @@ -1514,6 +1539,15 @@ std::vector<MatmulPattern> findMatmulPatterns(Fusion* fusion) {
return MatmulPatternMatcher::run(fusion);
}

std::string MatmulPattern::toString() const {
std::stringstream ss;
ss << "MatmulPattern{";
ss << "\n A=" << A->toString();
ss << "\n B=" << B->toString();
ss << "\n output=" << output->toString() << "\n}";
return ss.str();
}

MmaOp* MatmulPattern::translateToMmaOp() {
if (auto mma_op = dynamic_cast<MmaOp*>(output->definition())) {
// No translation needed
Expand Down
2 changes: 2 additions & 0 deletions csrc/scheduler/mma_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ struct MatmulPattern {
//! object can safely outlive id_model.
std::unordered_map<ValGroup, MatmulDomain> getDimRoles(
IdModel& id_model) const;

std::string toString() const;
};

//! Traverse the fusion to find supported matmul patterns
Expand Down
32 changes: 23 additions & 9 deletions tests/cpp/test_combine_mul_sum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,18 +127,17 @@ TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_Fail1) {
performSubstitution(&fusion, /*should_not_find=*/true);
}

// This fusion has more than one broadcasted dimension for each operand, so it
// is currently rejected isMatmulFusionDefinitionSupported. Still, it is a valid
// MatmulPattern so we check that it is found.
// This fusion has Broadcast batch axes in each operand.
TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_MultipleBroadcasts) {
// Assumes layout is kAllSupportedMmaLayout::NT;
Fusion fusion;
FusionGuard fg(&fusion);
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
Fusion* fusion = fusion_ptr.get();
FusionGuard fg(fusion);
auto tv0 = makeContigTensor(2, DataType::Half);
auto tv1 = makeContigTensor(2, DataType::Half);

fusion.addInput(tv0);
fusion.addInput(tv1);
fusion->addInput(tv0);
fusion->addInput(tv1);

auto tv0t = transpose(tv0, 0, 1);
auto tv1t = transpose(tv1, 0, 1);
Expand All @@ -155,9 +154,24 @@ TEST_F(CombineMulSumAsMmaTest, AmpereMulSumToMatmul_MultipleBroadcasts) {
auto tv1b = broadcast(tv1t, bcast_dims);
auto tv2 = mul(tv0b, tv1b);
auto tv3 = sum(tv2, {-1});
fusion.addOutput(tv3);
fusion->addOutput(tv3);

performSubstitution(fusion, /*should_not_find=*/false);

// We test running this fusion also to verify that the broadcast batch
// dimension does not cause unforeseen issues

int64_t M = 256, N = 128, K = 64;
auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
auto t0 = at::randn({K, M}, options);
auto t1 = at::randn({K, N}, options);
auto tref = at::linear(t0.t(), t1.t()).unsqueeze(1);

FusionExecutorCache executor_cache(std::move(fusion_ptr));
auto outputs = executor_cache.runFusionWithInputs({t0, t1});

performSubstitution(&fusion, /*should_not_find=*/false);
testValidate(
executor_cache.fusion(), outputs, {t0, t1}, {tref}, __LINE__, __FILE__);
}

// As a sanity check we test that after replacing a mul-sum
Expand Down

0 comments on commit f5ec534

Please sign in to comment.