Skip to content

Commit

Permalink
Some minor changes extracted from #3425 (#3522)
Browse files Browse the repository at this point in the history
  • Loading branch information
naoyam authored Dec 4, 2024
1 parent 285602d commit 3d1e735
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 1 deletion.
15 changes: 15 additions & 0 deletions csrc/bfs.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,21 @@ class BFS {
}
}
}
ss << " (from: ";
for (const auto& from : from_) {
ss << " " << toString(from);
if (const ExprT* e = std::get_if<ExprT>(&from)) {
ss << " " << toString(*e);
}
}
ss << ")";
ss << ", visited: (";
for (const auto& visited : visited_) {
if (const ValT* v = std::get_if<ValT>(&visited)) {
ss << " " << toString(visited);
}
}
ss << ")";
NVF_THROW("BFS traversal could not visit some nodes: ", ss.str());
}
}
Expand Down
16 changes: 16 additions & 0 deletions csrc/disjoint_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,22 @@ class DisjointSets {
return disjoint_sets_;
}

typename DisjointSetMap::iterator find(T entry) {
return disjoint_set_maps_.find(entry);
}

typename DisjointSetMap::iterator end() {
return disjoint_set_maps_.end();
}

typename DisjointSetMap::const_iterator find(T entry) const {
return disjoint_set_maps_.find(entry);
}

typename DisjointSetMap::const_iterator end() const {
return disjoint_set_maps_.end();
}

// Return the entire disjoint set of provided entry
const VectorOfUniqueEntries<T, Hash>& getDisjointSetOf(T entry) const {
auto set_it = disjoint_set_maps_.find(entry);
Expand Down
4 changes: 4 additions & 0 deletions csrc/id_model/id_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ class IdModel : public PolymorphicBase {
return tvs_.empty();
}

const std::vector<TensorView*>& tvs() const {
return tvs_;
}

Fusion* fusion() const {
return fusion_;
}
Expand Down
2 changes: 1 addition & 1 deletion csrc/predicate_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class ParallelizedDomainPredicate {
public:
explicit PredicateInfo(ParallelType pt) : pt_(pt) {}

//! Adds a domain that is parallized by the same paralell type
//! Adds a domain that is parallized by the same parallel type
bool addDomain(IterDomain* id);

const std::vector<IterDomain*>& ids() const {
Expand Down

0 comments on commit 3d1e735

Please sign in to comment.