Skip to content

Commit

Permalink
Resize war follow-up fix (#3530)
Browse files Browse the repository at this point in the history
This is a second attempt to fix #3505. The first attempt is #3515. As
mentioned
[here](#3505 (comment)),
the first fix isn't sufficient when an expr has multiple resized inputs,
like concat. The actual condition we need to check is between each
producer and consumer pair, not between producers, so this second
attempt is just changing how we check the condition.
  • Loading branch information
naoyam authored Dec 9, 2024
1 parent 9346c8f commit 8694a34
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 35 deletions.
107 changes: 73 additions & 34 deletions csrc/id_model/indexing_traversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,25 +64,33 @@ std::optional<IndexingTraversal::ExprPath> IndexingTraversal::
/*additional_tvs=*/{},
/*build_graphs=*/false);

// If there's no resize in the producer and consumer tensors of this
// expr, it should not need this WAR.
std::vector<Resize*> resize_exprs;
for (const auto& [id, use_exprs] : local_model.idUses()) {
for (const auto& use_expr : use_exprs) {
if (auto resize = dynamic_cast<Resize*>(use_expr)) {
resize_exprs.push_back(resize);
// Gather all resize exprs for each of the inputs and outputs
std::unordered_map<Val*, std::vector<Resize*>> tv_resize_map;
for (auto inp : ir_utils::filterByType<TensorView>(expr->inputs())) {
for (auto expr : inp->domain()->allExprs()) {
if (auto resize = dynamic_cast<Resize*>(expr)) {
tv_resize_map[inp].push_back(resize);
}
}
}
for (auto out : ir_utils::filterByType<TensorView>(expr->outputs())) {
for (auto expr : out->domain()->allExprs()) {
if (auto resize = dynamic_cast<Resize*>(expr)) {
tv_resize_map[out].push_back(resize);
}
}
}

if (resize_exprs.empty()) {
// If there's no resize in the producer and consumer tensors of this
// expr, it should not need this WAR.
if (tv_resize_map.empty()) {
return std::nullopt;
}

// The indexing issue with resize may happen when a single iter
// domain is resized multiple times. In other words, there must be
// at least two connected resize exprs. If not, this WAR is not
// necessary.
// domain is resized multiple times between a producer and a
// consumer. In other words, there must be at least two connected
// resize exprs. If not, this WAR is not necessary.
//
// Note that the actual indexing is done from the loop IDs, which
// might be promoted to IDs outside of this particular expr. Thus,
Expand All @@ -92,32 +100,63 @@ std::optional<IndexingTraversal::ExprPath> IndexingTraversal::
// promotion should not further add resize exprs, it is sufficient
// to analyze only the IDs of this expr.

// Shortcut for a common case to avoid building the graph below
if (resize_exprs.size() < 2) {
return std::nullopt;
}

const auto& local_graph = local_model.buildAlmostExactGraph();

// See if these resize expr groups are connected. Note that in the
// current default scheduling method, any tensor ops using resize
// should only show up with a fusion input as its input, so there
// must be no chained resize ops. The default scheduling, this
// function should not move beyond this point. In the case of the
// new resize scheduler that is currently under development will
// have multiple chained resize ops, but the scheduler should
// explicitly set the loop domain such that no promotion would
// happen, thus avoiding hitting the assertion down below.
ExprGroups resize_groups = local_graph.toGroups(resize_exprs);
// The below analysis is done for each producer-consumer pair, so it
// can be a rather expensive analysis, but in practice most
// cases should just bail out at the first if condition

auto isSingleIdResizedMultipleTimes = [&](TensorView* inp,
TensorView* out) -> bool {
auto num_resizes = tv_resize_map[inp].size() + tv_resize_map[out].size();
if (num_resizes < 2) {
return false;
}

std::vector<Resize*> resize_exprs;
resize_exprs.reserve(num_resizes);
resize_exprs.insert(
resize_exprs.end(),
tv_resize_map[inp].begin(),
tv_resize_map[inp].end());
resize_exprs.insert(
resize_exprs.end(),
tv_resize_map[out].begin(),
tv_resize_map[out].end());

// See if these resize expr groups are connected. Note that in the
// current default scheduling method, any tensor ops using resize
// should only show up with a fusion input as its input, so there
// must be no chained resize ops. The default scheduling, this
// function should not move beyond this point. In the case of the
// new resize scheduler that is currently under development will
// have multiple chained resize ops, but the scheduler should
// explicitly set the loop domain such that no promotion would
// happen, thus avoiding hitting the assertion down below.
ExprGroups resize_groups = local_graph.toGroups(resize_exprs);
for (const auto i : c10::irange(resize_groups.size() - 1)) {
const auto resize_i = resize_groups.at(i);
std::vector<ValGraphBFS::NodeType> other_resizes{
resize_groups.begin() + i + 1, resize_groups.end()};
auto reachable_nodes = getReachableNodesFrom<ValGraphBFS>(
{resize_i}, other_resizes, Direction::Undefined, local_graph);
if (!reachable_nodes.empty()) {
return true;
}
}

return false;
};

bool single_id_resized_multiple_times = false;
for (const auto i : c10::irange(resize_groups.size() - 1)) {
const auto resize_i = resize_groups.at(i);
std::vector<ValGraphBFS::NodeType> other_resizes{
resize_groups.begin() + i + 1, resize_groups.end()};
auto reachable_nodes = getReachableNodesFrom<ValGraphBFS>(
{resize_i}, other_resizes, Direction::Undefined, local_graph);
if (!reachable_nodes.empty()) {
single_id_resized_multiple_times = true;
for (auto out : ir_utils::filterByType<TensorView>(expr->outputs())) {
for (auto inp : ir_utils::filterByType<TensorView>(expr->inputs())) {
if (isSingleIdResizedMultipleTimes(inp, out)) {
single_id_resized_multiple_times = true;
break;
}
}
if (single_id_resized_multiple_times) {
break;
}
}
Expand Down
50 changes: 49 additions & 1 deletion tests/cpp/test_indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5392,7 +5392,7 @@ TEST_F(IndexingTest, ResizeRotation) {

// Repro of issue #3505. The indexing WAR for resize triggered an
// assertion due to loop promotion.
TEST_F(IndexingTest, Issue3505) {
TEST_F(IndexingTest, Issue3505Repro1) {
Fusion fusion;
FusionGuard fg(&fusion);

Expand Down Expand Up @@ -5436,6 +5436,54 @@ TEST_F(IndexingTest, Issue3505) {
testValidate(&fusion, outputs, inputs, __LINE__, __FILE__);
}

// Another repro of issue #3505
TEST_F(IndexingTest, Issue3505Repro2) {
Fusion fusion;
FusionGuard fg(&fusion);

const int64_t i0 = 8;
const int64_t i1 = 2;
const auto zero = fusion.zeroVal();

EnableOptionsGuard enable_options_guard;
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"});

auto tv0 = makeContigConcreteTensor({i0});
fusion.addInput(tv0);
auto tv1 = makeContigConcreteTensor({i1, i0 / 2});
fusion.addInput(tv1);

// Left half
auto tv2 = slice(tv0, {{zero, IrBuilder::create<Val>(i0 / 2)}});
// Right half
auto tv3 = slice(
tv0, {{IrBuilder::create<Val>(i0 / 2), IrBuilder::create<Val>(i0)}});

// The two inputs of this add expression have a resize of the same
// ID, but this should not mean the resize war path is required.
auto tv4 = add(tv2, tv3);
auto tv5 = broadcast(tv4, {true, false});
auto tv6 = add(tv1, tv5);
fusion.addOutput(tv6);

// Make loop promotion required
for (auto tv : {tv2, tv3, tv4, tv5, tv6}) {
tv->flatten();
}
inlineMost();

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn({i0}, options);
auto t1 = at::randn({i1, i0 / 2}, options);
std::vector<c10::IValue> inputs{t0, t1};

KernelExecutor ke;
ke.compile(&fusion, inputs);
auto outputs = ke.run(inputs);

testValidate(&fusion, outputs, inputs, __LINE__, __FILE__);
}

TEST_F(IndexingTest, AlmostExactIndexingUpdate) {
EnableOptionsGuard enable_options_guard;
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"});
Expand Down

0 comments on commit 8694a34

Please sign in to comment.