Skip to content

Commit

Permalink
Return a bool indicating if all nodes are visited (#3452)
Browse files Browse the repository at this point in the history
Just a mechanical change. This makes it a little cumbersome when
`require_all_to_visited` is not used, but sometimes we also need to
check if all nodes were visited.
  • Loading branch information
naoyam authored and Priya2698 committed Nov 20, 2024
1 parent ef01b15 commit 9bc3ecc
Show file tree
Hide file tree
Showing 21 changed files with 123 additions and 97 deletions.
7 changes: 4 additions & 3 deletions csrc/bfs.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,10 @@ class BFS {
}
}

// Find the shortest path from the from_ to to_. This
// Find the shortest path from the from_ to to_. A boolean value
// indicating if all nodes are visited is also returned. This
// must be only used once traversal is completed.
virtual ExprPath getShortestExprPath() {
virtual std::pair<ExprPath, bool> getShortestExprPath() {
NVF_ERROR(
!require_all_to_visited_ || allToNodesVisited(),
"Traveral is either not done or failed");
Expand Down Expand Up @@ -316,7 +317,7 @@ class BFS {
VectorOfUniqueEntries<std::pair<ExprT, Direction>> unique_path(
path.rbegin(), path.rend());

return unique_path.vector();
return std::make_pair(unique_path.vector(), allToNodesVisited());
}

// Check if a node is ready to visit. If yes, return the direction
Expand Down
7 changes: 4 additions & 3 deletions csrc/device_lower/analysis/tma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1025,9 +1025,10 @@ TMAInfo getTMAInfo(LoadStoreOp* ldst) {
std::list<std::pair<ExprGroup, Direction>> exprs = [&]() {
ValGraph& id_graph = GpuLower::current()->tensorIndexer().traversalGraph();
auto exprs_vec = ValGraphBFS::getExprsBetween(
id_graph,
id_graph.toGroups(consumer_tv->getLoopDomain()),
id_graph.toGroups(gmem_tv->getMaybeAllocationDomain()));
id_graph,
id_graph.toGroups(consumer_tv->getLoopDomain()),
id_graph.toGroups(gmem_tv->getMaybeAllocationDomain()))
.first;
return std::list(exprs_vec.begin(), exprs_vec.end());
}();

Expand Down
2 changes: 1 addition & 1 deletion csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1832,7 +1832,7 @@ ValGroup getInnerMmaLoopGroup(TensorView* tv, const MmaOp* mma) {
ValGroup inner = alloc_domain.back();

auto exprs =
ValGraphBFS::getExprsBetween(id_graph, loop_domain, alloc_domain);
ValGraphBFS::getExprsBetween(id_graph, loop_domain, alloc_domain).first;
while (!exprs.empty()) {
auto [expr, direction] = exprs.back();
exprs.pop_back();
Expand Down
20 changes: 1 addition & 19 deletions csrc/device_lower/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,24 +280,6 @@ std::vector<TensorView*> getTvs(const std::vector<Val*>& vals) {
return tvs;
}

TensorView* getTvOutput(const Expr* expr) {
for (auto out : expr->outputs()) {
if (auto tv = getTv(out)) {
return tv;
}
}
return nullptr;
}

TensorView* getTvInput(const Expr* expr) {
for (auto inp : expr->inputs()) {
if (auto tv = getTv(inp)) {
return tv;
}
}
return nullptr;
}

bool isScalarOp(const Expr* expr) {
for (auto out : expr->outputs()) {
if (!out->isScalar()) {
Expand Down Expand Up @@ -1923,7 +1905,7 @@ Val* proveLinearAndGetStride(
// Propagate from linear_g to domain. Use frontier to keep track of the
// how linear_g lives in the current propagation front.
Projection frontier = linear_g;
auto path = ValGraphBFS::getExprsBetween(id_graph, domain, {linear_g});
auto path = ValGraphBFS::getExprsBetween(id_graph, domain, {linear_g}).first;
while (!path.empty()) {
const auto& [eg, direction] = path.back();
path.pop_back();
Expand Down
6 changes: 0 additions & 6 deletions csrc/device_lower/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,6 @@ bool isTV(const Val* const);
// Returns if Expr is a TensorView or TensorIndex Expr.
NVF_API bool isTvOp(const Expr*);

// Returns the first output of Expr that is a TensorView
NVF_API TensorView* getTvOutput(const Expr*);

// Returns the first input of Expr that is a TensorView
TensorView* getTvInput(const Expr*);

//! Returns the iterdomain that maps to the thread dimension grouped
//! to warps. Returns nullopt if the reduction is not to be lowered to
//! a warp reduction.
Expand Down
7 changes: 4 additions & 3 deletions csrc/device_lower/validation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -415,9 +415,10 @@ class VectorizeValidator : public OptInDispatch {
const auto& graph = id_model.idGraph(IdMappingMode::EXACT);

auto expr_path = ValGraphBFS::getExprsBetween(
graph,
graph.toGroups(tv->getMaybeAllocationDomain()),
graph.toGroups(std::vector<Val*>{v_id}));
graph,
graph.toGroups(tv->getMaybeAllocationDomain()),
graph.toGroups(std::vector<Val*>{v_id}))
.first;
expr_path = reverse(expr_path);

ValGroup cur_group = graph.toGroup(v_id);
Expand Down
1 change: 1 addition & 0 deletions csrc/id_model/circular_buffer_indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
// clang-format on
#include <id_model/circular_buffer_indexing.h>
#include <id_model/indexing_utils.h>
#include <ir/utils.h>

namespace nvfuser {

Expand Down
1 change: 1 addition & 0 deletions csrc/id_model/indexing_traversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
*/
// clang-format on
#include <id_model/indexing_traversal.h>
#include <ir/utils.h>

namespace nvfuser {

Expand Down
2 changes: 1 addition & 1 deletion csrc/id_model/indexing_traversal.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class IndexingTraversal : public ValGraphBFS {
{from_groups.vector().begin(), from_groups.vector().end()},
{to_groups.vector().begin(), to_groups.vector().end()});
traversal.traverse();
return traversal.getShortestExprPath();
return traversal.getShortestExprPath().first;
}

using ValGraphBFS::isVisited;
Expand Down
7 changes: 4 additions & 3 deletions csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3722,9 +3722,10 @@ std::vector<IterDomain*> TensorDomain::allIDs() const {
continue;
}
auto path = IRBFS::getExprsBetween(
{all_domains[i]->begin(), all_domains[i]->end()},
{all_domains[j]->begin(), all_domains[j]->end()},
false);
{all_domains[i]->begin(), all_domains[i]->end()},
{all_domains[j]->begin(), all_domains[j]->end()},
false)
.first;
for (auto [expr, _] : path) {
discovered_ids.pushBack(
ir_utils::filterByType<IterDomain>(expr->outputs()));
Expand Down
24 changes: 22 additions & 2 deletions csrc/ir/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -826,8 +826,10 @@ CompareDomainResult compareDomains(
toDelimitedString(dom1));

dom0.insert(dom0.end(), additional_ids.begin(), additional_ids.end());
auto exprs = IRBFS::getExprsBetween(
{dom0.begin(), dom0.end()}, {dom1.begin(), dom1.end()}, false);
auto exprs =
IRBFS::getExprsBetween(
{dom0.begin(), dom0.end()}, {dom1.begin(), dom1.end()}, false)
.first;

std::unordered_set<Val*> frontier(dom0.begin(), dom0.end());

Expand Down Expand Up @@ -1285,4 +1287,22 @@ ForLoop* createRangeLoop(int64_t size) {
return loop;
}

TensorView* getTvOutput(const Expr* expr) {
for (auto out : expr->outputs()) {
if (auto tv = getTv(out)) {
return tv;
}
}
return nullptr;
}

TensorView* getTvInput(const Expr* expr) {
for (auto inp : expr->inputs()) {
if (auto tv = getTv(inp)) {
return tv;
}
}
return nullptr;
}

} // namespace nvfuser::ir_utils
6 changes: 6 additions & 0 deletions csrc/ir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -727,4 +727,10 @@ int64_t getOperationCount(Val* val);
// for (int i = 0; i < size; i++)
ForLoop* createRangeLoop(int64_t size);

// Returns the first output of Expr that is a TensorView
TensorView* getTvOutput(const Expr*);

// Returns the first input of Expr that is a TensorView
TensorView* getTvInput(const Expr*);

} // namespace nvfuser::ir_utils
5 changes: 3 additions & 2 deletions csrc/iter_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1175,7 +1175,7 @@ std::vector<Val*> IRBFS::getValsBetween(
const std::vector<Val*>& from,
const std::vector<Val*>& to) {
auto path =
IRBFS::getExprsBetween(from, to, /*require_all_to_visited=*/false);
IRBFS::getExprsBetween(from, to, /*require_all_to_visited=*/false).first;

VectorOfUniqueEntries<Val*> unique_vals;
for (auto [expr, _] : path) {
Expand All @@ -1197,7 +1197,8 @@ std::vector<Val*> IRBFS::getValsBetween(
std::vector<Val*> IRBFS::getDependenciesTo(
const std::vector<Val*>& vals,
const std::vector<Val*>& to) {
auto path = IRBFS::getExprsBetween(vals, to, /*require_all_to_visited=*/true);
auto path =
IRBFS::getExprsBetween(vals, to, /*require_all_to_visited=*/true).first;

VectorOfUniqueEntries<Val*> unique_vals;

Expand Down
2 changes: 1 addition & 1 deletion csrc/iter_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ class IRBFS
// Find the shortest path from the from_groups_ to to_groups_ on a
// given graph. Dependency between vals and exprs must be satisfied.
// It is an error if no valid path is found.
static ExprPath getExprsBetween(
static std::pair<ExprPath, bool> getExprsBetween(
const std::vector<Val*>& from,
const std::vector<Val*>& to,
bool require_all_to_visited = true) {
Expand Down
1 change: 1 addition & 0 deletions csrc/kernel_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <ir/builder.h>
#include <ir/cloner.h>
#include <ir/iostream.h>
#include <ir/utils.h>
#include <kernel.h>
#include <kernel_ir.h>
#include <type.h>
Expand Down
33 changes: 18 additions & 15 deletions csrc/scheduler/tools/loop_domain_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,30 +327,33 @@ ValGraphBFS::ExprPath LoopDomainScheduler::getReplayPath(TensorView* tv) const {
return all_ancestors_of_ref_.has(tv_target_domain);
})) {
return ValGraphBFS::getExprsBetween(
graph(),
ref_id_groups_,
tv_target_domains,
/*require_all_to_visited=*/true,
Direction::Backward);
graph(),
ref_id_groups_,
tv_target_domains,
/*require_all_to_visited=*/true,
Direction::Backward)
.first;
}

// Find the forward path from the ancestors to the target tensor
auto forward_path = ValGraphBFS::getExprsBetween(
graph(),
all_ancestors_of_ref_,
tv_target_domains,
/*require_all_to_visited=*/true,
Direction::Forward);
graph(),
all_ancestors_of_ref_,
tv_target_domains,
/*require_all_to_visited=*/true,
Direction::Forward)
.first;

// Find the path from the ref to the forward path.
auto inputs_of_forward_path = getInputsOfExprPath(graph(), forward_path);

auto backward_path = ValGraphBFS::getExprsBetween(
graph(),
ref_id_groups_,
inputs_of_forward_path,
/*require_all_to_visited=*/true,
Direction::Backward);
graph(),
ref_id_groups_,
inputs_of_forward_path,
/*require_all_to_visited=*/true,
Direction::Backward)
.first;

// Overall replay path = backward_path + forward_path
ValGraphBFS::ExprPath replay_path;
Expand Down
11 changes: 6 additions & 5 deletions csrc/val_graph_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,12 @@ std::unordered_set<ValGroup> ValGraphBFS::projectTo(
std::unordered_set<ValGroup> projection{from};
// Reverse order
auto exprs = ValGraphBFS::getExprsBetween(
id_graph,
to,
{from},
/*require_all_to_visited=*/false,
allowed_direction);
id_graph,
to,
{from},
/*require_all_to_visited=*/false,
allowed_direction)
.first;
while (!exprs.empty()) {
const auto [expr, direction] = exprs.back();
exprs.pop_back();
Expand Down
4 changes: 2 additions & 2 deletions csrc/val_graph_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ class ValGraphBFS : public BFS<
// Find the shortest path from the from_groups_ to to_groups_ on a
// given graph. Dependency between vals and exprs must be satisfied.
// It is an error if no valid path is found.
static ExprPath getExprsBetween(
static std::pair<ExprPath, bool> getExprsBetween(
const ValGraph& graph,
std::vector<NodeType> from,
std::vector<NodeType> to,
Expand All @@ -214,7 +214,7 @@ class ValGraphBFS : public BFS<
bfs.traverse();
return bfs.getShortestExprPath();
}
static ExprPath getExprsBetween(
static std::pair<ExprPath, bool> getExprsBetween(
const ValGraph& graph,
const ValGroups& from,
const ValGroups& to,
Expand Down
Loading

0 comments on commit 9bc3ecc

Please sign in to comment.