Skip to content

Commit

Permalink
Limit use of the indexing war for resize (#3515)
Browse files Browse the repository at this point in the history
Fixes this error of #3505 

```
Error from segmentation group 9:  INTERNAL ASSERT FAILED at "/Fuser/csrc/id_model/indexing_traversal.cpp":102, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Indexing path for resize not found: iblockIdx.y376{( ceilDiv(1280, blockDim.x) )}
```

The error happens when trying to use the indexing WAR for resize that
was recently added (#3454). The war itself is limited, in particular it
does not work with promoted loop IDs. The limitation should be fine for
the RoPE scheduling I've been working on, but it's a real issue in
general.

This PR avoids the issue by limiting the use of the WAR. Currently, the
WAR is used whenever there's at least a single resize expr in a single
math expr. That is actually overly pessimistic since the indexing issue
only happens when there's multiple resize exprs that result in a cycle
in the exact graph. For example, if there's only one resize, there must
be no cycle, thus the indexing WAR is not necessary.

This PR attempts to limit the use of the WAR by doing a little deeper
analysis. The added check should entirely disable the WAR for the
current default scheduling, where resize is only allowed with fusion
inputs, which means there can be no multiple dependent resize exprs in a
single fusion.

The limitation of the WAR remains, but it does not matter for RoPE, and
with this PR it should also not matter for general cases.
  • Loading branch information
naoyam authored Dec 4, 2024
1 parent 3d1e735 commit 67127c9
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 18 deletions.
49 changes: 38 additions & 11 deletions csrc/bfs.h
Original file line number Diff line number Diff line change
Expand Up @@ -603,30 +603,57 @@ std::vector<typename GetValType<ExprT>::type> getOutputsOfExprPath(
return getInputsOfExprPath(reverse(path), get_inputs, get_outputs);
}

// Given a set of vals, get all reachable ones from another set of vals
// Given a set of exprs and vals, get all reachable ones from another set of
// nodes
template <typename BFSType, typename... AdditionalArgs>
std::vector<typename BFSType::ValType> getReachableValsFrom(
const std::vector<typename BFSType::ValType>& from,
const std::vector<typename BFSType::ValType>& vals,
std::vector<typename BFSType::NodeType> getReachableNodesFrom(
const std::vector<typename BFSType::NodeType>& from,
const std::vector<typename BFSType::NodeType>& nodes,
Direction allowed_direction = Direction::Undefined,
const AdditionalArgs&... additional_args) {
BFSType bfs(
additional_args...,
{from.begin(), from.end()},
{vals.begin(), vals.end()},
from,
nodes,
/*require_all_to_visited=*/false,
allowed_direction);

bfs.traverse();

std::vector<typename BFSType::ValType> reachable_vals;
for (const auto& val : vals) {
if (bfs.isVisited(val) ||
std::find(from.begin(), from.end(), val) != from.end()) {
reachable_vals.push_back(val);
std::vector<typename BFSType::NodeType> reachable_nodes;
for (const auto& node : nodes) {
if (bfs.isVisited(node) ||
std::find(from.begin(), from.end(), node) != from.end()) {
reachable_nodes.push_back(node);
}
}

return reachable_nodes;
}

// Shortcut of getReachableNodesFrom for Vals
template <typename BFSType, typename... AdditionalArgs>
std::vector<typename BFSType::ValType> getReachableValsFrom(
const std::vector<typename BFSType::ValType>& from,
const std::vector<typename BFSType::ValType>& vals,
Direction allowed_direction = Direction::Undefined,
const AdditionalArgs&... additional_args) {
auto reachable_nodes = getReachableNodesFrom<BFSType, AdditionalArgs...>(
{from.begin(), from.end()},
{vals.begin(), vals.end()},
allowed_direction,
additional_args...);

std::vector<typename BFSType::ValType> reachable_vals;
reachable_vals.reserve(reachable_nodes.size());
std::transform(
reachable_nodes.begin(),
reachable_nodes.end(),
std::back_inserter(reachable_vals),
[](const auto& node) {
return std::get<typename BFSType::ValType>(node);
});

return reachable_vals;
}

Expand Down
68 changes: 61 additions & 7 deletions csrc/id_model/indexing_traversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,25 +55,79 @@ std::optional<IndexingTraversal::ExprPath> IndexingTraversal::
auto consumer_tv = ir_utils::getTvOutput(expr);
NVF_ERROR(consumer_tv != nullptr);

// First, try to limit the use of this WAR as much as possible since
// the WAR itself has a limitation that it assumes the loop domain
// is not promoted.

IdModel local_model(
std::vector<Expr*>{consumer_tv->definition()},
/*additional_tvs=*/{},
/*build_graphs=*/false);

// If there's no resize in the producer and consumer tensors of this
// expr, it should not need this WAR.
if (std::none_of(
local_model.idUses().begin(),
local_model.idUses().end(),
[](const auto& kv) {
const VectorOfUniqueEntries<Expr*>& exprs = kv.second;
return !exprs.empty() && exprs.at(0)->isA<Resize>();
})) {
std::vector<Resize*> resize_exprs;
for (const auto& [id, use_exprs] : local_model.idUses()) {
for (const auto& use_expr : use_exprs) {
if (auto resize = dynamic_cast<Resize*>(use_expr)) {
resize_exprs.push_back(resize);
}
}
}

if (resize_exprs.empty()) {
return std::nullopt;
}

// The indexing issue with resize may happen when a single iter
// domain is resized multiple times. In other words, there must be
// at least two connected resize exprs. If not, this WAR is not
// necessary.
//
// Note that the actual indexing is done from the loop IDs, which
// might be promoted to IDs outside of this particular expr. Thus,
// to get the true indexing path, the global IdModel may need to be
// used rather than the local model. Here, since we just need to
// know if there are multiple dependent resize exprs, and loop
// promotion should not further add resize exprs, it is sufficient
// to analyze only the IDs of this expr.

// Shortcut for a common case to avoid building the graph below
if (resize_exprs.size() < 2) {
return std::nullopt;
}

const auto& local_graph = local_model.buildAlmostExactGraph();

// See if these resize expr groups are connected. Note that in the
// current default scheduling method, any tensor ops using resize
// should only show up with a fusion input as its input, so there
// must be no chained resize ops. The default scheduling, this
// function should not move beyond this point. In the case of the
// new resize scheduler that is currently under development will
// have multiple chained resize ops, but the scheduler should
// explicitly set the loop domain such that no promotion would
// happen, thus avoiding hitting the assertion down below.
ExprGroups resize_groups = local_graph.toGroups(resize_exprs);
bool single_id_resized_multiple_times = false;
for (const auto i : c10::irange(resize_groups.size() - 1)) {
const auto resize_i = resize_groups.at(i);
std::vector<ValGraphBFS::NodeType> other_resizes{
resize_groups.begin() + i + 1, resize_groups.end()};
auto reachable_nodes = getReachableNodesFrom<ValGraphBFS>(
{resize_i}, other_resizes, Direction::Undefined, local_graph);
if (!reachable_nodes.empty()) {
single_id_resized_multiple_times = true;
break;
}
}

// No connection between the resize exprs is found, which they are
// all independent and there's no need to use this WAR
if (!single_id_resized_multiple_times) {
return std::nullopt;
}

// from_ids are loop domains, which are representative
// domains of loop groups and not necessarily domains of any
// of the producer and the consumer. In that case, find an ID out
Expand Down
46 changes: 46 additions & 0 deletions tests/cpp/test_indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5406,6 +5406,52 @@ TEST_F(IndexingTest, ResizeRotation) {
testValidate(&fusion, outputs, inputs, __LINE__, __FILE__);
}

// Repro of issue #3505. The indexing WAR for resize triggered an
// assertion due to loop promotion.
TEST_F(IndexingTest, Issue3505) {
Fusion fusion;
FusionGuard fg(&fusion);

const int64_t i0 = 2;
const int64_t i1 = 4;
const int64_t i2 = 8;
const auto zero = fusion.zeroVal();

EnableOptionsGuard enable_options_guard;
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"});

auto tv0 = makeContigConcreteTensor({i1, i2});
fusion.addInput(tv0);
auto tv1 = makeContigConcreteTensor({i0, i1 / 2, i2 / 2});
fusion.addInput(tv1);

// One slice can reproduce the error but just to trigger the
// reachability check between multiple resize ops
auto tv2 = slice(
tv0,
{{zero, IrBuilder::create<Val>(i1 / 2)},
{zero, IrBuilder::create<Val>(i2 / 2)}});
auto tv3 = broadcast(tv2, {true, false, false});
auto tv4 = add(tv1, tv3);
fusion.addOutput(tv4);

for (auto tv : {tv2, tv3, tv4}) {
tv->flatten();
}
inlineMost();

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn({i1, i2}, options);
auto t1 = at::randn({i0, i1 / 2, i2 / 2}, options);
std::vector<c10::IValue> inputs{t0, t1};

KernelExecutor ke;
ke.compile(&fusion, inputs);
auto outputs = ke.run(inputs);

testValidate(&fusion, outputs, inputs, __LINE__, __FILE__);
}

TEST_F(IndexingTest, AlmostExactIndexingUpdate) {
EnableOptionsGuard enable_options_guard;
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"});
Expand Down

0 comments on commit 67127c9

Please sign in to comment.