Skip to content

Commit

Permalink
Grab all IDs and exprs with StmtSort::getAllStmts (and fix replaceSym…
Browse files Browse the repository at this point in the history
…bolicSizes) (#3578)

Stacked on #3585 

`StmtSort::getStmtsTo` may not grab all active iter domains if IDs are
connected in an unconventional way. For example, we can set the loop
domain of a tensor as a producer of its logical domain, but due to the
nature of `IterVisitor`, such ID dependency patterns are not supported,
meaning `StmtSort::getStmtsTo` would fail to grab all valid IDs and
their exprs.

I just recently noticed this issue while working on #3556, specifically
the issue got exposed as an inconsistent replacement of extent vals.
I've been experimenting such patterns of domains, but I hadn't seen this
before, likely because I was using just static shape tensors for
convenience.

To fix the issue, I added a variation of `StmtSort::getStmtsTo`, which
traverses a fusion as usual but stops at TensorView. For each
TensorView, instead of using `IterVisitor`, it uses
`TensorDomain::getAllStatements()`, which combines both
`TensorDomain::allIDs()` and `TensorDomain::allExprs()`, and traverse
the IDs and exprs in the returned order.

It's a bit naive implementation, but I think this is good enough for now
and also I don't have any other immediate idea to try.

I changed `ValReplacementMutator` to use the new interface. That's the
only use for now.

---------

Co-authored-by: Jacob Hinkle <[email protected]>
  • Loading branch information
naoyam and jacobhinkle committed Dec 16, 2024
1 parent 36f49db commit 19524d8
Show file tree
Hide file tree
Showing 9 changed files with 197 additions and 7 deletions.
2 changes: 1 addition & 1 deletion csrc/device_lower/pass/expr_sort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1088,7 +1088,7 @@ bool ExprSegmentationSorter::interIterUpdate() {
NVF_ERROR(
!fallback_mode_enabled_,
"Couldn't succcessfully sort out the fusion expressions. ",
"There are remaining connections of the heirarchical segmentation which should have been ",
"There are remaining connections of the hierarchical segmentation which should have been ",
"flattened to a single ordered group, or disjoint ordered groups.\n",
toString());
// We didn't finish, but we haven't tried the fallback, try again with that.
Expand Down
1 change: 0 additions & 1 deletion csrc/device_lower/pass/replace_size.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,6 @@ void replaceSymbolicSizes(Fusion* fusion) {
}
}

// Run mutation on the fusion with the tensor_dim_map
ir_utils::replaceValue(fusion, extent_simplification_map);
}

Expand Down
3 changes: 3 additions & 0 deletions csrc/ir/internal_base_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,9 @@ class TensorDomain : public Val {
// Similar to allIDs but returns all ID expressions.
std::vector<Expr*> allExprs() const;

// Combine allIDs and allExprs
std::vector<Statement*> allStatements() const;

const std::vector<IterDomain*>& maybeAllocation() const {
return hasAllocation() ? allocation_domain_ : logical();
};
Expand Down
28 changes: 28 additions & 0 deletions csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3803,6 +3803,34 @@ std::vector<Expr*> TensorDomain::allExprs() const {
return exprs.vector();
}

std::vector<Statement*> TensorDomain::allStatements() const {
auto all_ids = allIDs();
std::unordered_set<Val*> all_id_set{all_ids.begin(), all_ids.end()};

VectorOfUniqueEntries<Statement*> stmts;
for (auto id : all_ids) {
// Visit definition if available and all inputs are already visited
auto def = id->definition();
if (def != nullptr) {
if (std::all_of(
def->inputs().begin(), def->inputs().end(), [&](Val* inp) {
return all_id_set.find(inp) != all_id_set.end();
})) {
stmts.pushBack(def);
} else {
NVF_ERROR(std::none_of(
def->inputs().begin(), def->inputs().end(), [&](Val* inp) {
return all_id_set.find(inp) != all_id_set.end();
}));
}
}

stmts.pushBack(id);
}

return stmts.vector();
}

Split::Split(
IrBuilderPasskey passkey,
IterDomain* outer,
Expand Down
20 changes: 19 additions & 1 deletion csrc/ir/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ class ValReplacementMutator : private OptOutMutator {
// typically not used by anything else. If we don't grab that count, then it
// would be a tensorview that doesn't get updated extents. Therefore, first
// grab all leaves towards outputs and grab stmts from there.
auto stmts = StmtSort::getStmtsTo(allLeafOuts(fusion), true, true);
auto stmts = StmtSort::getAllStmtsTo(allLeafOuts(fusion), true, true);

// Some fusions, such as standalone rand_like, can have disconnected DAG, so
// we need some mechanism to make sure our replacement set is as complete as
Expand Down Expand Up @@ -501,6 +501,24 @@ class ValReplacementMutator : private OptOutMutator {
std::unordered_set<Val*> outputs;
std::vector<Val*> ordered_outputs;
for (auto expr : exprs) {
// Iter domains and their exprs are taken care by traversing
// from TensorDomain with TensorDomain::allStatements, so they
// don't need to be included here
if (std::any_of(
expr->outputs().begin(), expr->outputs().end(), [](Val* output) {
return output->isA<IterDomain>();
})) {
NVF_ERROR(std::all_of(
expr->outputs().begin(), expr->outputs().end(), [](Val* output) {
return output->isA<IterDomain>();
}));
NVF_ERROR(std::all_of(
expr->inputs().begin(), expr->inputs().end(), [](Val* input) {
return input->isA<IterDomain>();
}));
continue;
}

inputs.insert(expr->inputs().begin(), expr->inputs().end());
outputs.insert(expr->outputs().begin(), expr->outputs().end());
ordered_outputs.insert(
Expand Down
92 changes: 92 additions & 0 deletions csrc/iter_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,98 @@ std::vector<Statement*> StmtSort::getStmtsTo(
return es.stmts;
}

std::vector<Statement*> StmtSort::getAllStmts(
Fusion* fusion,
bool traverse_members,
bool traverse_attributes,
bool traverse_siblings) {
return getAllStmtsTo(
fusion->getTerminatingOutputs(),
traverse_members,
traverse_attributes,
traverse_siblings);
}

std::vector<Statement*> StmtSort::getAllStmtsTo(
const std::vector<Val*>& to,
bool traverse_members,
bool traverse_attributes,
bool traverse_siblings) {
// If members are not traversed, this can just be handled by getStmts
if (!traverse_members) {
return getStmtsTo(
to, traverse_members, traverse_attributes, traverse_siblings);
}

// to is assumed to include only scalar or TensorView
NVF_ERROR(std::all_of(to.begin(), to.end(), [](Val* to_val) {
return to_val->vtype() == ValType::TensorView ||
to_val->vtype() == ValType::Others;
}));

// First, grab all statements without traversing tensor members
auto stmts = getStmtsTo(to, false, traverse_attributes, traverse_siblings);

VectorOfUniqueEntries<Statement*> all_stmts;

// For TensorView, further traversing into its members. Note that
// traverse_members is always true here
for (auto stmt : stmts) {
auto tv = dynamic_cast<TensorView*>(stmt);
if (tv == nullptr) {
all_stmts.pushBack(stmt);
continue;
}

// Instead of using MemberStatements, grab the iter domains and
// their exprs with TensorDomain::allStatements(), which
// internally uses IRBFS.
auto all_id_stmts = tv->domain()->allStatements();

// For iter domains, traverse further their members and then visit
// themselves. For ID exprs, traverse attributes then the expr
// themselves.
for (auto id_stmt : all_id_stmts) {
if (auto id = dynamic_cast<IterDomain*>(id_stmt)) {
auto id_members = MemberStatements::next(id);
// Note that traverse_members is always true at this point
for (auto id_member : id_members) {
for (auto stmt_dep : StmtSort::getStmtsTo(
{id_member->as<Val>()},
/*traverse_members=*/true,
traverse_attributes,
traverse_siblings)) {
all_stmts.pushBack(stmt_dep);
}
}
all_stmts.pushBack(id);
} else {
auto expr = dynamic_cast<Expr*>(id_stmt);
NVF_ERROR(expr != nullptr);
if (traverse_attributes) {
for (auto attr : expr->attributes()) {
for (auto stmt_dep : StmtSort::getStmtsTo(
{attr->as<Val>()},
/*traverse_members=*/true,
traverse_attributes,
traverse_siblings)) {
all_stmts.pushBack(stmt_dep);
}
}
}
all_stmts.pushBack(expr);
}
}

// All depednent vals and exprs for this TensorDomain are in
// all_stmts. Append TensorDomain and then TensorView
all_stmts.pushBack(tv->domain());
all_stmts.pushBack(tv);
}

return all_stmts.vector();
}

std::vector<Statement*> StmtSort::getStmtsBetween(
const std::vector<Val*>& from,
const std::vector<Val*>& to,
Expand Down
22 changes: 22 additions & 0 deletions csrc/iter_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,28 @@ class StmtSort : public IterVisitor {
bool traverse_attributes = false,
bool traverse_siblings = false);

// Returns all ordered Statements of a given fusion. Unlike
// getStmts, for TensorDomain, all of its iter domains and exprs are
// grabbed and returned in a topological order.
NVF_API static std::vector<Statement*> getAllStmts(
Fusion* fusion,
bool traverse_members = false,
bool traverse_attributes = false,
bool traverse_siblings = false);

// Returns ordered Statements required to produce 'to', including
// 'to'. Unlike getStmtsTo, for TensorDomain, all of its iter domains and
// exprs are grabbed and returned in a topological order, regardless of
// `traverse_members`.
//
// The to vals are assumed to be either TensorView or scalar
// Val. This assumption could be removed if desired.
NVF_API static std::vector<Statement*> getAllStmtsTo(
const std::vector<Val*>& to,
bool traverse_members = false,
bool traverse_attributes = false,
bool traverse_siblings = false);

// Returns ordered Statements required to produce from, including from.
// Stops traversal once hiting any Statements in to. Includes Statements in
// to.
Expand Down
32 changes: 32 additions & 0 deletions tests/cpp/test_loop_domain_scheduling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <gmock/gmock-matchers.h>
#include <gtest/gtest.h>

#include <iter_visitor.h>
#include <ops/all_ops.h>
#include <scheduler/tools/inlining.h>
#include <scheduler/tools/loop_domain_scheduler.h>
Expand All @@ -20,6 +21,25 @@

namespace nvfuser {

namespace {

void checkGetAllStmts(Fusion* fusion) {
// Check if StmtSort can grab all IDS, including those that are
// producers of root IDs
auto all_stmts = StmtSort::getAllStmts(fusion, /*traverse_members=*/true);
std::unordered_set<Statement*> all_stmt_set{
all_stmts.begin(), all_stmts.end()};
for (auto tv : fusion->allTvs()) {
for (auto id_or_expr : tv->domain()->allStatements()) {
EXPECT_TRUE(all_stmt_set.count(id_or_expr))
<< "Not found: " << id_or_expr->toString() << " of "
<< tv->toString();
}
}
}

} // namespace

class LoopDomainSchedulingTest : public NVFuserTest {
protected:
void SetUp() override {
Expand Down Expand Up @@ -82,6 +102,8 @@ TEST_F(LoopDomainSchedulingTest, ReshapeSplitThenMerge) {
}
}

checkGetAllStmts(&fusion);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn({10}, options);
std::vector<c10::IValue> inputs({t0});
Expand Down Expand Up @@ -143,6 +165,8 @@ TEST_F(LoopDomainSchedulingTest, Slice) {
tv->axis(1)->parallelize(ParallelType::TIDx);
}

checkGetAllStmts(&fusion);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn(shape, options);
std::vector<c10::IValue> aten_inputs({t0});
Expand Down Expand Up @@ -245,6 +269,8 @@ TEST_F(LoopDomainSchedulingTest, ReshapeTraversalDirection) {
tv5_loop_to_logical.at(3).first,
tv4->getLogicalDomain().at(0)->definition()) &&
tv5_loop_to_logical.at(3).second == Direction::Forward);

checkGetAllStmts(&fusion);
}

// Using the same fusion as ReshapeTraversalDirection, try each one of
Expand Down Expand Up @@ -309,6 +335,8 @@ TEST_F(LoopDomainSchedulingTest, ManyReshape) {
EXPECT_EQ(tv->getComputeAtPosition(), tv->getLoopDomain().size());
}

checkGetAllStmts(&fusion_copy);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn({12}, options);
std::vector<c10::IValue> aten_inputs({t0});
Expand Down Expand Up @@ -383,6 +411,8 @@ TEST_F(LoopDomainSchedulingTest, ScheduleLoopDomainsBy1) {

EXPECT_EQ(tv1->getLoopDomain(), tv1_loop_domain);
EXPECT_EQ(tv2->getLoopDomain(), tv2_loop_domain);

checkGetAllStmts(&fusion);
}

// Testing scheduleLoopDomainBy on its insertion position of new IDs
Expand Down Expand Up @@ -414,6 +444,8 @@ TEST_F(LoopDomainSchedulingTest, ScheduleLoopDomainsBy2) {
EXPECT_EQ(
exact_graph.toGroups(tv1->getLoopDomain()),
exact_graph.toGroups(tv2->getLoopDomain()));

checkGetAllStmts(&fusion);
}

} // namespace nvfuser
4 changes: 0 additions & 4 deletions tests/cpp/test_resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4090,8 +4090,6 @@ TEST_F(ResizeTest, PropagateSliceToInputsWithReshape1) {
// Fusion should have a uniform loop domain
checkLoopDomainEquivalence(ref_tv);

fusion.print();

// Schedule the reference
ref_tv->flatten();
// For TIDx
Expand Down Expand Up @@ -4241,8 +4239,6 @@ TEST_F(ResizeTest, PropagateMultipleSlicesToInputs) {
// Fusion should have a uniform loop domain
checkLoopDomainEquivalence(ref_tv);

fusion.print();

// Schedule the reference
ref_tv->flatten();
// For TIDx
Expand Down

0 comments on commit 19524d8

Please sign in to comment.