Skip to content

Commit

Permalink
second attempt
Browse files Browse the repository at this point in the history
  • Loading branch information
naoyam committed Dec 13, 2024
1 parent d8a2a24 commit d19dc1e
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 72 deletions.
72 changes: 1 addition & 71 deletions csrc/device_lower/pass/replace_size.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,75 +137,6 @@ std::unordered_map<Val*, Val*> getSimplificationMap(Fusion* fusion) {
return simplification_map;
}

class IterDomainReplacementMutator : private OptOutMutator {
public:
IterDomainReplacementMutator(
Fusion* fusion,
const std::unordered_map<Val*, Val*>& replacement_map)
: replacement_map_(replacement_map) {
FusionGuard fg(fusion);

std::unordered_map<TensorView*, std::vector<Statement*>> tv_stmts;
for (auto tv : fusion->allTvs()) {
auto all_stmts = tv->domain()->allStatements();

std::vector<Statement*> stmts_to_visit;
for (auto stmt : all_stmts) {
if (auto id = dynamic_cast<IterDomain*>(stmt)) {
auto id_members = MemberStatements::get(id);
for (auto id_member : id_members) {
for (auto stmt_dep :
StmtSort::getStmtsTo({id_member->as<Val>()}, true, true)) {
stmts_to_visit.push_back(stmt_dep);
}
}
stmts_to_visit.push_back(id);
} else {
auto expr = dynamic_cast<Expr*>(stmt);
NVF_ERROR(expr != nullptr);
for (auto attr : expr->attributes()) {
for (auto stmt_dep :
StmtSort::getStmtsTo({attr->as<Val>()}, true, true)) {
stmts_to_visit.push_back(stmt_dep);
}
}
stmts_to_visit.push_back(expr);
}
}

tv_stmts.emplace(tv, stmts_to_visit);
}

std::unordered_set<Statement*> visited;
for (auto tv : fusion->allTvs()) {
const auto& stmts_to_visit = tv_stmts.at(tv);
for (auto stmt : stmts_to_visit) {
if (visited.count(stmt)) {
continue;
}
dispatchMutate(stmt);
visited.insert(stmt);
}

dispatchMutate(tv->domain());
dispatchMutate(tv);
}
}

private:
using OptOutMutator::dispatchMutate;

void dispatchMutate(Val* val) final {
if (replacement_map_.find(val) == replacement_map_.end()) {
return OptOutMutator::dispatchMutate(val);
}
auto replaced_val = replacement_map_.at(val);
registerMutation(val, replaced_val);
}

const std::unordered_map<Val*, Val*>& replacement_map_;
};

} // namespace

void replaceSymbolicSizes(Fusion* fusion) {
Expand Down Expand Up @@ -335,8 +266,7 @@ void replaceSymbolicSizes(Fusion* fusion) {
}
}

// NOLINTNEXTLINE(bugprone-unused-raii)
IterDomainReplacementMutator(fusion, extent_simplification_map);
ir_utils::replaceValue(fusion, extent_simplification_map);
}

} // namespace nvfuser
19 changes: 18 additions & 1 deletion csrc/ir/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,9 @@ class ValReplacementMutator : private OptOutMutator {
// typically not used by anything else. If we don't grab that count, then it
// would be a tensorview that doesn't get updated extents. Therefore, first
// grab all leaves towards outputs and grab stmts from there.
auto stmts = StmtSort::getStmtsTo(allLeafOuts(fusion), true, true);
// auto stmts = StmtSort::getStmtsTo(allLeafOuts(fusion), true,
// true);
auto stmts = StmtSort::getAllStmtsTo(allLeafOuts(fusion), true, true);

// Some fusions, such as standalone rand_like, can have disconnected DAG, so
// we need some mechanism to make sure our replacement set is as complete as
Expand Down Expand Up @@ -501,6 +503,21 @@ class ValReplacementMutator : private OptOutMutator {
std::unordered_set<Val*> outputs;
std::vector<Val*> ordered_outputs;
for (auto expr : exprs) {
if (std::any_of(
expr->outputs().begin(), expr->outputs().end(), [](Val* output) {
return output->isA<IterDomain>();
})) {
NVF_ERROR(std::all_of(
expr->outputs().begin(), expr->outputs().end(), [](Val* output) {
return output->isA<IterDomain>();
}));
NVF_ERROR(std::all_of(
expr->inputs().begin(), expr->inputs().end(), [](Val* input) {
return input->isA<IterDomain>();
}));
continue;
}

inputs.insert(expr->inputs().begin(), expr->inputs().end());
outputs.insert(expr->outputs().begin(), expr->outputs().end());
ordered_outputs.insert(
Expand Down
81 changes: 81 additions & 0 deletions csrc/iter_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,87 @@ std::vector<Statement*> StmtSort::getStmtsTo(
return es.stmts;
}

std::vector<Statement*> StmtSort::getAllStmts(
Fusion* fusion,
bool traverse_members,
bool traverse_attributes,
bool traverse_siblings) {
return getAllStmtsTo(
fusion->getTerminatingOutputs(),
traverse_members,
traverse_attributes,
traverse_siblings);
}

std::vector<Statement*> StmtSort::getAllStmtsTo(
const std::vector<Val*>& to,
bool traverse_members,
bool traverse_attributes,
bool traverse_siblings) {
// If members are not traversed, this can just be handled by getStmts
if (!traverse_members) {
return getStmtsTo(
to, traverse_members, traverse_attributes, traverse_siblings);
}

// to is assumed to include only scalar or TensorView
NVF_ERROR(std::all_of(to.begin(), to.end(), [](Val* to_val) {
return to_val->vtype() == ValType::TensorView ||
to_val->vtype() == ValType::Others;
}));

// First, grab all statements without traversing tensor members
auto stmts = getStmtsTo(to, false, traverse_attributes, traverse_siblings);

VectorOfUniqueEntries<Statement*> all_stmts;

for (auto stmt : stmts) {
auto tv = dynamic_cast<TensorView*>(stmt);
if (tv == nullptr) {
all_stmts.pushBack(stmt);
continue;
}

auto all_id_stmts = tv->domain()->allStatements();
for (auto id_stmt : all_id_stmts) {
if (auto id = dynamic_cast<IterDomain*>(id_stmt)) {
auto id_members = MemberStatements::get(id);
// Note that traverse_members is always true at this point
for (auto id_member : id_members) {
for (auto stmt_dep : StmtSort::getStmtsTo(
{id_member->as<Val>()},
/*traverse_members=*/true,
traverse_attributes,
traverse_siblings)) {
all_stmts.pushBack(stmt_dep);
}
}
all_stmts.pushBack(id);
} else {
auto expr = dynamic_cast<Expr*>(id_stmt);
NVF_ERROR(expr != nullptr);
if (traverse_attributes) {
for (auto attr : expr->attributes()) {
for (auto stmt_dep : StmtSort::getStmtsTo(
{attr->as<Val>()},
/*traverse_members=*/true,
traverse_attributes,
traverse_siblings)) {
all_stmts.pushBack(stmt_dep);
}
}
}
all_stmts.pushBack(expr);
}
}

all_stmts.pushBack(tv->domain());
all_stmts.pushBack(tv);
}

return all_stmts.vector();
}

std::vector<Statement*> StmtSort::getStmtsBetween(
const std::vector<Val*>& from,
const std::vector<Val*>& to,
Expand Down
12 changes: 12 additions & 0 deletions csrc/iter_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,18 @@ class StmtSort : public IterVisitor {
bool traverse_attributes = false,
bool traverse_siblings = false);

NVF_API static std::vector<Statement*> getAllStmts(
Fusion* fusion,
bool traverse_members = false,
bool traverse_attributes = false,
bool traverse_siblings = false);

NVF_API static std::vector<Statement*> getAllStmtsTo(
const std::vector<Val*>& to,
bool traverse_members = false,
bool traverse_attributes = false,
bool traverse_siblings = false);

// Returns ordered Statements required to produce from, including from.
// Stops traversal once hiting any Statements in to. Includes Statements in
// to.
Expand Down

0 comments on commit d19dc1e

Please sign in to comment.