Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limit use of the indexing war for resize #3515

Merged
merged 4 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 38 additions & 11 deletions csrc/bfs.h
Original file line number Diff line number Diff line change
Expand Up @@ -588,30 +588,57 @@ std::vector<typename GetValType<ExprT>::type> getOutputsOfExprPath(
return getInputsOfExprPath(reverse(path), get_inputs, get_outputs);
}

// Given a set of vals, get all reachable ones from another set of vals
// Given a set of exprs and vals, get all reachable ones from another set of
// nodes
template <typename BFSType, typename... AdditionalArgs>
std::vector<typename BFSType::ValType> getReachableValsFrom(
const std::vector<typename BFSType::ValType>& from,
const std::vector<typename BFSType::ValType>& vals,
std::vector<typename BFSType::NodeType> getReachableNodesFrom(
const std::vector<typename BFSType::NodeType>& from,
const std::vector<typename BFSType::NodeType>& nodes,
Direction allowed_direction = Direction::Undefined,
const AdditionalArgs&... additional_args) {
BFSType bfs(
additional_args...,
{from.begin(), from.end()},
{vals.begin(), vals.end()},
from,
nodes,
/*require_all_to_visited=*/false,
allowed_direction);

bfs.traverse();

std::vector<typename BFSType::ValType> reachable_vals;
for (const auto& val : vals) {
if (bfs.isVisited(val) ||
std::find(from.begin(), from.end(), val) != from.end()) {
reachable_vals.push_back(val);
std::vector<typename BFSType::NodeType> reachable_nodes;
for (const auto& node : nodes) {
if (bfs.isVisited(node) ||
std::find(from.begin(), from.end(), node) != from.end()) {
reachable_nodes.push_back(node);
}
}

return reachable_nodes;
}

// Shortcut of getReachableNodesFrom for Vals
template <typename BFSType, typename... AdditionalArgs>
std::vector<typename BFSType::ValType> getReachableValsFrom(
const std::vector<typename BFSType::ValType>& from,
const std::vector<typename BFSType::ValType>& vals,
Direction allowed_direction = Direction::Undefined,
const AdditionalArgs&... additional_args) {
auto reachable_nodes = getReachableNodesFrom<BFSType, AdditionalArgs...>(
{from.begin(), from.end()},
{vals.begin(), vals.end()},
allowed_direction,
additional_args...);

std::vector<typename BFSType::ValType> reachable_vals;
reachable_vals.reserve(reachable_nodes.size());
std::transform(
reachable_nodes.begin(),
reachable_nodes.end(),
std::back_inserter(reachable_vals),
[](const auto& node) {
return std::get<typename BFSType::ValType>(node);
});

return reachable_vals;
}

Expand Down
68 changes: 61 additions & 7 deletions csrc/id_model/indexing_traversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,25 +55,79 @@ std::optional<IndexingTraversal::ExprPath> IndexingTraversal::
auto consumer_tv = ir_utils::getTvOutput(expr);
NVF_ERROR(consumer_tv != nullptr);

// First, try to limit the use of this WAR as much as possible since
// the WAR itself has a limitation that it assumes the loop domain
// is not promoted.

IdModel local_model(
std::vector<Expr*>{consumer_tv->definition()},
/*additional_tvs=*/{},
/*build_graphs=*/false);

// If there's no resize in the producer and consumer tensors of this
// expr, it should not need this WAR.
if (std::none_of(
local_model.idUses().begin(),
local_model.idUses().end(),
[](const auto& kv) {
const VectorOfUniqueEntries<Expr*>& exprs = kv.second;
return !exprs.empty() && exprs.at(0)->isA<Resize>();
})) {
std::vector<Resize*> resize_exprs;
for (const auto& [id, use_exprs] : local_model.idUses()) {
for (const auto& use_expr : use_exprs) {
if (auto resize = dynamic_cast<Resize*>(use_expr)) {
resize_exprs.push_back(resize);
}
}
}

if (resize_exprs.empty()) {
return std::nullopt;
}

// The indexing issue with resize may happen when a single iter
// domain is resized multiple times. In other words, there must be
// at least two connected resize exprs. If not, this WAR is not
// necessary.
//
// Note that the actual indexing is done from the loop IDs, which
// might be promoted to IDs outside of this particular expr. Thus,
// to get the true indexing path, the global IdModel may need to be
// used rather than the local model. Here, since we just need to
// know if there are multiple dependent resize exprs, and loop
// promotion should not further add resize exprs, it is sufficient
// to analyze only the IDs of this expr.

// Shortcut for a common case to avoid building the graph below
if (resize_exprs.size() < 2) {
return std::nullopt;
}

const auto& local_graph = local_model.buildAlmostExactGraph();

// See if these resize expr groups are connected. Note that in the
// current default scheduling method, any tensor ops using resize
// should only show up with a fusion input as its input, so there
// must be no chained resize ops. The default scheduling, this
// function should not move beyond this point. In the case of the
// new resize scheduler that is currently under development will
// have multiple chained resize ops, but the scheduler should
// explicitly set the loop domain such that no promotion would
// happen, thus avoiding hitting the assertion down below.
ExprGroups resize_groups = local_graph.toGroups(resize_exprs);
bool single_id_resized_multiple_times = false;
for (const auto i : c10::irange(resize_groups.size() - 1)) {
const auto resize_i = resize_groups.at(i);
std::vector<ValGraphBFS::NodeType> other_resizes{
resize_groups.begin() + i + 1, resize_groups.end()};
auto reachable_nodes = getReachableNodesFrom<ValGraphBFS>(
{resize_i}, other_resizes, Direction::Undefined, local_graph);
jacobhinkle marked this conversation as resolved.
Show resolved Hide resolved
if (!reachable_nodes.empty()) {
single_id_resized_multiple_times = true;
break;
}
}

// No connection between the resize exprs is found, which they are
// all independent and there's no need to use this WAR
if (!single_id_resized_multiple_times) {
return std::nullopt;
}

// from_ids are loop domains, which are representative
// domains of loop groups and not necessarily domains of any
// of the producer and the consumer. In that case, find an ID out
Expand Down
46 changes: 46 additions & 0 deletions tests/cpp/test_indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5406,6 +5406,52 @@ TEST_F(IndexingTest, ResizeRotation) {
testValidate(&fusion, outputs, inputs, __LINE__, __FILE__);
}

// Repro of issue #3505. The indexing WAR for resize triggered an
// assertion due to loop promotion.
TEST_F(IndexingTest, Issue3505) {
Fusion fusion;
FusionGuard fg(&fusion);

const int64_t i0 = 2;
const int64_t i1 = 4;
const int64_t i2 = 8;
const auto zero = fusion.zeroVal();

EnableOptionsGuard enable_options_guard;
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"});

auto tv0 = makeContigConcreteTensor({i1, i2});
fusion.addInput(tv0);
auto tv1 = makeContigConcreteTensor({i0, i1 / 2, i2 / 2});
fusion.addInput(tv1);

// One slice can reproduce the error but just to trigger the
// reachability check between multiple resize ops
auto tv2 = slice(
tv0,
{{zero, IrBuilder::create<Val>(i1 / 2)},
{zero, IrBuilder::create<Val>(i2 / 2)}});
auto tv3 = broadcast(tv2, {true, false, false});
auto tv4 = add(tv1, tv3);
fusion.addOutput(tv4);

for (auto tv : {tv2, tv3, tv4}) {
tv->flatten();
}
inlineMost();

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn({i1, i2}, options);
auto t1 = at::randn({i0, i1 / 2, i2 / 2}, options);
std::vector<c10::IValue> inputs{t0, t1};

KernelExecutor ke;
ke.compile(&fusion, inputs);
auto outputs = ke.run(inputs);

testValidate(&fusion, outputs, inputs, __LINE__, __FILE__);
}

TEST_F(IndexingTest, AlmostExactIndexingUpdate) {
EnableOptionsGuard enable_options_guard;
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"});
Expand Down
Loading