Skip to content

Commit

Permalink
Fix #3583 (#3585)
Browse files Browse the repository at this point in the history
  • Loading branch information
naoyam authored and jacobhinkle committed Dec 16, 2024
1 parent 5b37bca commit 5efff49
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
6 changes: 4 additions & 2 deletions csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3763,8 +3763,10 @@ std::vector<IterDomain*> TensorDomain::allIDs() const {
while (!ids_to_be_sorted.empty()) {
auto it = ids_to_be_sorted.begin();
while (it != ids_to_be_sorted.end()) {
auto in_it = out2in.find(*it);
if (in_it == out2in.end() || sorted_ids.has(in_it->second)) {
auto range = out2in.equal_range(*it);
if (std::all_of(range.first, range.second, [&](const auto& kv) {
return sorted_ids.has(kv.second);
})) {
sorted_ids.pushBack(*it);
it = ids_to_be_sorted.erase(it);
} else {
Expand Down
45 changes: 45 additions & 0 deletions tests/cpp/test_gpu3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9225,6 +9225,51 @@ TEST_F(NVFuserTest, ParallelDimensionsInAllocation) {
ASSERT_TRUE(tidx_dim != nullptr);
}

// Check the topological ordering of TensorDomain::allIDs(). Repro of
// issue #3583
TEST_F(NVFuserTest, AllIdsMultipleDependencies) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeConcreteTensor({10, 20});
fusion.addInput(tv0);

auto tv1 = slice(
tv0,
{{fusion.zeroVal(), IrBuilder::create<Val>(2)},
{fusion.zeroVal(), tv0->getLogicalDomain().at(1)->extent()}});

fusion.addOutput(tv1);

tv1->merge(0);
tv1->split(0, 4);
tv1->split(0, 8);

fusion.print();

auto all_ids = tv1->domain()->allIDs();

auto split2 = tv1->axis(0)->definition()->as<Split>();
auto split1 = split2->input(0)->definition()->as<Split>();
auto merge = split1->input(0)->definition()->as<Merge>();
auto resize = merge->input(0)->definition()->as<Resize>();

std::vector<Expr*> exprs{resize, merge, split1, split2};

for (auto expr : exprs) {
for (auto inp : ir_utils::filterByType<IterDomain>(expr->inputs())) {
auto inp_it = std::find(all_ids.begin(), all_ids.end(), inp);
for (auto out : ir_utils::filterByType<IterDomain>(expr->outputs())) {
auto out_it = std::find(all_ids.begin(), all_ids.end(), out);
EXPECT_LT(inp_it, out_it)
<< "Invalid ordering: " << out->toString() << " detected before "
<< inp->toString() << ". All IDs: " << toDelimitedString(all_ids)
<< "\n";
}
}
}
}

// Test file size should be up to 10K LoC. Create a new file for more tests.

} // namespace nvfuser

0 comments on commit 5efff49

Please sign in to comment.