Skip to content

Commit

Permalink
(Yet another) indexing war for resize (#3454)
Browse files Browse the repository at this point in the history
This is a WAR for #3455. The exact graph-based indexing doesn't work
because of the mapping introduced by the residual path. I think we
should investigate what the right graph should look like for indexing,
but to unblock the scheduler for RoPE, this PR tries to work around the
issue by creating a local graph that only includes the tensors involved
in the expression to index, thus removing the effect by the residual
path.

`IndexngTraversal::getExprsBetweenForResize` is the main addition, which
creates a new IdModel just consisting of the tensors of a given expr. If
a resize is used in any of the producers and consumers of the expr, we
use the path found by the local model. Currently, it it fails to find a
path, it's considered an error.

While this WAR works for the prototype scheduler for RoPE so far
(#3425), it does have some issues as well. For example, since the local
IdModel doesn't have all the information necessary to identify loop
promotions, but the loop domain of the expr may be promoted, so it may
not be able to find the corresponding IDs within the local model. In
other words, if resize is used with inlined broadcast IDs,
`getExprsBetweenForResize` may fail to find a path, which would then
fall back to the existing path, which may not be correct in the case of
#3455. However, this can be avoided by scheduling the loop domains such
that no promotion analysis is required. We can now do this by using
things like `TensorDomain::broadcast()` and
`scheduler_tools::scheduleLoopDomainsLike()`, so I don't think this
issue is a blocker.

The overall changes are also due to the change of the interface of
`IndexingTraversal::getExprsBetween`, which now requires
`std::vector<IterDomain*>` instead of `ValGroups` since for the local
IdModel, the former is required.
  • Loading branch information
naoyam authored Nov 28, 2024
1 parent 2415d90 commit 8546b62
Show file tree
Hide file tree
Showing 8 changed files with 382 additions and 108 deletions.
5 changes: 4 additions & 1 deletion csrc/id_model/id_model_index_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ class IdGraphIndexCompute : public OptOutDispatch {
}

void setIndex(IterDomain* id, Val* idx) {
index_map_.emplace(toGroup(id), idx);
// May overwrite index. When the graph is cyclic due to, e.g.,
// resize, the index obtained by traversing most through the
// indexing path should be used (see also PR #3454)
index_map_[toGroup(id)] = idx;
}

const ValGroup& toGroup(IterDomain* id) const {
Expand Down
50 changes: 25 additions & 25 deletions csrc/id_model/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,9 @@ void TensorIndexer::buildLoopIndexMap() {
}
}

Val* TensorIndexer::getLoopIndex(IterDomain* loop_id) const {
Val* TensorIndexer::getLoopIndex(
IterDomain* loop_id,
const std::vector<ForLoop*>& for_loops) const {
// loop_id must be a loop domain.
const auto& loop_group =
id_model_.idGraph(IdMappingMode::LOOP).toGroup(loop_id);
Expand All @@ -792,6 +794,13 @@ Val* TensorIndexer::getLoopIndex(IterDomain* loop_id) const {
loop_id->toString());

Val* loop_index = loop_index_map_it->second;

// War for circular buffering
if (auto circular_buffer_loop_index =
getLoopIndexOfCircularBufferLoop(loop_id, for_loops, id_model_)) {
loop_index = circular_buffer_loop_index;
}

return loop_index;
}

Expand All @@ -803,16 +812,16 @@ std::unordered_map<ValGroup, Val*> TensorIndexer::getInitialIndexMap(
// For a given list of the loop domains, assign its corresponding
// index Val.
for (IterDomain* loop_id : loop_domains) {
Val* loop_index = getLoopIndex(loop_id);
Val* initial_index = getLoopIndex(loop_id, for_loops);
const auto& almost_exact_group = traversalGraph().toGroup(loop_id);

if (initial_index_map.find(almost_exact_group) != initial_index_map.end()) {
// Initial index already set. This can happen as this is an
// almost exact group. It should be just size-1 domain.
NVF_ERROR(
loop_index->isZeroInt(),
initial_index->isZeroInt(),
"Unexpected initial index: ",
loop_index->toInlineString());
initial_index->toInlineString());
auto existing_index = initial_index_map.at(almost_exact_group);
NVF_ERROR(
existing_index->isZeroInt(),
Expand All @@ -821,13 +830,7 @@ std::unordered_map<ValGroup, Val*> TensorIndexer::getInitialIndexMap(
continue;
}

// War for circular buffering
if (auto circular_buffer_loop_index =
getLoopIndexOfCircularBufferLoop(loop_id, for_loops, id_model_)) {
loop_index = circular_buffer_loop_index;
}

initial_index_map.emplace(almost_exact_group, loop_index);
initial_index_map.emplace(almost_exact_group, initial_index);
}

return initial_index_map;
Expand All @@ -836,12 +839,14 @@ std::unordered_map<ValGroup, Val*> TensorIndexer::getInitialIndexMap(
std::vector<Val*> TensorIndexer::getIndexFor(
const Expr* expr,
bool as_consumer,
const ValGroups& index_groups,
const std::vector<IterDomain*>& index_ids,
const std::vector<ForLoop*>& for_loops) const {
auto info = computeIndex(expr, index_groups, for_loops);
auto info = computeIndex(expr, index_ids, for_loops);
const auto& replacement_map = getIndexReplacementMap(
expr, as_consumer, info.loop_domains, for_loops, info.index_map);

const auto index_groups = traversalGraph().toGroups(index_ids);

std::vector<Val*> result;
result.reserve(index_groups.size());
for (const auto& g : index_groups) {
Expand Down Expand Up @@ -916,13 +921,13 @@ std::vector<IterDomain*> TensorIndexer::getLoopDomains(const Expr* expr) const {

IndexingInfo TensorIndexer::computeIndex(
const Expr* expr,
const ValGroups& index_groups,
const std::vector<IterDomain*>& index_ids,
const std::vector<ForLoop*>& for_loops) const {
const auto loop_domains = getLoopDomains(expr);
const auto loop_domains = getLoopIds(expr, id_model_);

const ValGroups loop_groups = traversalGraph().toGroups(loop_domains);
const ExprPath<ExprGroup> traversal_path = IndexingTraversal::getExprsBetween(
expr, traversalGraph(), loop_groups, index_groups);
expr, traversalGraph(), loop_domains, index_ids);

const std::unordered_map<ValGroup, Val*> initial_index_map =
getInitialIndexMap(loop_domains, for_loops);
Expand Down Expand Up @@ -978,11 +983,7 @@ std::unordered_map<Val*, Val*> TensorIndexer::getIndexReplacementMap(
std::unordered_map<Val*, Val*> replacement_map;

for (const auto loop_id : loop_domains) {
const ValGroup& loop_group = traversalGraph().toGroup(loop_id);
auto index_it = index_map.find(loop_group);
NVF_ERROR(index_it != index_map.end());
Val* cur_index = index_it->second;
NVF_ERROR(cur_index != nullptr);
Val* cur_index = getLoopIndex(loop_id, for_loops);

Val* replacement_index = nullptr;
// Replace the index of a vectorized/bulk domain with zero. Note that
Expand Down Expand Up @@ -1049,8 +1050,8 @@ std::vector<PredicateInfo> TensorIndexer::getPredicates(
const std::vector<IterDomain*>& predicate_domains =
getPredicateDomains(tv, expr);

const IndexingInfo& index_info = computeIndex(
expr, traversalGraph().toGroups(predicate_domains), for_loops);
const IndexingInfo& index_info =
computeIndex(expr, predicate_domains, for_loops);

const auto& index_map = index_info.index_map;

Expand Down Expand Up @@ -1282,8 +1283,7 @@ std::pair<std::vector<Val*>, std::vector<Val*>> TensorIndexer::
bool as_consumer,
const IndexingAllocationInfo& alloc_info,
const std::vector<ForLoop*>& for_loops) const {
const auto& index_groups = traversalGraph().toGroups(alloc_info.domains);
auto index_info = computeIndex(expr, index_groups, for_loops);
auto index_info = computeIndex(expr, alloc_info.domains, for_loops);
const auto& index_map = index_info.index_map;
const auto& replacement_map = getIndexReplacementMap(
expr, as_consumer, index_info.loop_domains, for_loops, index_map);
Expand Down
9 changes: 5 additions & 4 deletions csrc/id_model/indexing.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,15 @@ class TensorIndexer {
const Expr* expr,
const std::vector<ForLoop*>& loops) const;

// Get the index of a loop domain. Intended to be used only for testing.
Val* getLoopIndex(IterDomain* loop_id) const;
// Get the index of a loop domain.
Val* getLoopIndex(IterDomain* loop_id, const std::vector<ForLoop*>& for_loops)
const;

// Get the index of the given ID groups
std::vector<Val*> getIndexFor(
const Expr* expr,
bool as_consumer,
const ValGroups& index_groups,
const std::vector<IterDomain*>& index_ids,
const std::vector<ForLoop*>& loops) const;

// Get the contig indices of the given ID groups with their strides
Expand Down Expand Up @@ -137,7 +138,7 @@ class TensorIndexer {
// getIndexFor.
IndexingInfo computeIndex(
const Expr* expr,
const ValGroups& index_groups,
const std::vector<IterDomain*>& index_ids,
const std::vector<ForLoop*>& for_loops) const;

// Propagate the loop indices of a given list of loop domains to the
Expand Down
129 changes: 127 additions & 2 deletions csrc/id_model/indexing_traversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <id_model/id_model.h>
#include <id_model/indexing_traversal.h>
#include <ir/utils.h>

Expand All @@ -14,8 +15,9 @@ IndexingTraversal::IndexingTraversal(
const Expr* expr,
const ValGraph& graph,
std::vector<NodeType> from_groups,
std::vector<NodeType> to_groups)
: ValGraphBFS(graph, from_groups, to_groups) {
std::vector<NodeType> to_groups,
bool require_all_to_visited)
: ValGraphBFS(graph, from_groups, to_groups, require_all_to_visited) {
auto consumer_tv = ir_utils::getTvOutput(expr);
NVF_ERROR(consumer_tv != nullptr);
// Remember the resize exprs appearing in the consumer
Expand Down Expand Up @@ -44,4 +46,127 @@ IndexingTraversal::IndexingTraversal(
}
}

std::optional<IndexingTraversal::ExprPath> IndexingTraversal::
getExprsBetweenForResize(
const Expr* expr,
const ValGraph& graph,
const std::vector<IterDomain*>& from_ids,
const std::vector<IterDomain*>& to_ids) {
auto consumer_tv = ir_utils::getTvOutput(expr);
NVF_ERROR(consumer_tv != nullptr);

IdModel local_model(
std::vector<Expr*>{consumer_tv->definition()},
/*additional_tvs=*/{},
/*build_graphs=*/false);

// If there's no resize in the producer and consumer tensors of this
// expr, it should not need this WAR.
if (std::none_of(
local_model.idUses().begin(),
local_model.idUses().end(),
[](const auto& kv) {
const VectorOfUniqueEntries<Expr*>& exprs = kv.second;
return !exprs.empty() && exprs.at(0)->isA<Resize>();
})) {
return std::nullopt;
}

const auto& local_graph = local_model.buildAlmostExactGraph();

// from_ids are loop domains, which are representative
// domains of loop groups and not necessarily domains of any
// of the producer and the consumer. In that case, find an ID out
// of the global group that is mapped in the local graph.
ValGroups from_groups;
for (const auto i : c10::irange(from_ids.size())) {
auto from_id = from_ids.at(i);
if (local_graph.hasGroup(from_id)) {
from_groups.pushBack(local_graph.toGroup(from_id));
continue;
}
bool found = false;
const auto& global_group = graph.toGroup(from_id);
for (const auto& vg : local_graph.disjointValSets().disjointSets()) {
if (global_group->has(vg->front())) {
from_groups.pushBack(vg);
found = true;
break;
}
}
// If not found, it should mean it's promoted to some IDs of
// further consumer tensors. This WAR does not work then. We could
// simply fall back to the default ValGraph-based path, but that
// might hit the resize indexing issue (#3455). For now, this is
// considered an error.
NVF_ERROR(
found, "Indexing path for resize not found: ", from_id->toString());
}

// Similarly, to_ids may not be IDs found in any of the producer and
// consumer tensors of this expr. For example, if it's an allocation
// ID, it may be a loop promotion ID.
ValGroups to_groups;
for (auto to_id : to_ids) {
if (local_graph.hasGroup(to_id)) {
to_groups.pushBack(local_graph.toGroup(to_id));
continue;
}
// to_id is not found in the producer or consumer tensors of the
// expr. Look for a mapped ID in the ID group of the global graph.
bool found = false;
const auto& global_group = graph.toGroup(to_id);
for (const auto& vg : local_graph.disjointValSets().disjointSets()) {
if (global_group->has(vg->front())) {
to_groups.pushBack(vg);
found = true;
break;
}
}
NVF_ERROR(found, "Indexing path for resize not found: ", to_id->toString());
}

IndexingTraversal traversal(
expr,
local_graph,
{from_groups.vector().begin(), from_groups.vector().end()},
{to_groups.vector().begin(), to_groups.vector().end()},
/*require_all_to_visited=*/true);
traversal.traverse();
auto [path, all_visited] = traversal.getShortestExprPath();

for (const auto& [g, d] : path) {
if (g->front()->isA<Resize>()) {
return path;
}
}

// If resize doesn't appear, the default path should work fine.
return std::nullopt;
}

IndexingTraversal::ExprPath IndexingTraversal::getExprsBetween(
const Expr* expr,
const ValGraph& graph,
const std::vector<IterDomain*>& from_domains,
const std::vector<IterDomain*>& to_domains) {
// Take the path if found by the war for resize indexing
if (auto path =
getExprsBetweenForResize(expr, graph, from_domains, to_domains);
path.has_value()) {
return *path;
}

auto from_groups = graph.toGroups(from_domains);
auto to_groups = graph.toGroups(to_domains);

IndexingTraversal traversal(
expr,
graph,
{from_groups.vector().begin(), from_groups.vector().end()},
{to_groups.vector().begin(), to_groups.vector().end()});
traversal.traverse();
return traversal.getShortestExprPath().first;
}

} // namespace nvfuser
21 changes: 10 additions & 11 deletions csrc/id_model/indexing_traversal.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,22 @@ class IndexingTraversal : public ValGraphBFS {
const Expr* expr,
const ValGraph& graph,
std::vector<NodeType> from_groups,
std::vector<NodeType> to_groups);
std::vector<NodeType> to_groups,
bool require_all_to_visited = true);

~IndexingTraversal() override = default;

static ExprPath getExprsBetween(
const Expr* expr,
const ValGraph& graph,
const ValGroups& from_groups,
const ValGroups& to_groups) {
IndexingTraversal traversal(
expr,
graph,
{from_groups.vector().begin(), from_groups.vector().end()},
{to_groups.vector().begin(), to_groups.vector().end()});
traversal.traverse();
return traversal.getShortestExprPath().first;
}
const std::vector<IterDomain*>& from_domains,
const std::vector<IterDomain*>& to_domains);

static std::optional<ExprPath> getExprsBetweenForResize(
const Expr* expr,
const ValGraph& graph,
const std::vector<IterDomain*>& from_domains,
const std::vector<IterDomain*>& to_domains);

using ValGraphBFS::isVisited;

Expand Down
22 changes: 22 additions & 0 deletions csrc/id_model/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <expr_simplifier.h>
#include <id_model/id_model.h>
#include <id_model/to_string.h>
#include <ir/utils.h>
#include <options.h>
#include <utils.h>

Expand Down Expand Up @@ -106,6 +107,27 @@ inline IterDomain* getLoopPromotion(
return loop_promotion_map_it->second;
}

// Get the loop domains of a given expr. Currently, they're always
// the loop domains of a consumer tensor, but in the future this
// function may return the loop domains of a producer for
// producer-based indexing.
inline std::vector<IterDomain*> getLoopIds(
const Expr* expr,
const IdModel& id_model) {
// Assume consumer-based indexing. Needs to revisit for ops like
// scatter
NVF_ERROR(!expr->outputs().empty());
auto output_tv = ir_utils::getTvOutput(expr);
NVF_ERROR(output_tv != nullptr);
auto loop_ids = output_tv->getLoopDomain();

for (auto& loop_id : loop_ids) {
loop_id = getLoopPromotion(loop_id, id_model);
}

return loop_ids;
}

inline ParallelType getParallelType(const ValGroup& loop_group) {
ParallelType common_pt = ParallelType::Serial;
for (const auto val : *loop_group) {
Expand Down
Loading

0 comments on commit 8546b62

Please sign in to comment.