Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resize war follow-up fix #3530

Merged
merged 2 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 73 additions & 34 deletions csrc/id_model/indexing_traversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,25 +64,33 @@ std::optional<IndexingTraversal::ExprPath> IndexingTraversal::
/*additional_tvs=*/{},
/*build_graphs=*/false);

// If there's no resize in the producer and consumer tensors of this
// expr, it should not need this WAR.
std::vector<Resize*> resize_exprs;
for (const auto& [id, use_exprs] : local_model.idUses()) {
for (const auto& use_expr : use_exprs) {
if (auto resize = dynamic_cast<Resize*>(use_expr)) {
resize_exprs.push_back(resize);
// Gather all resize exprs for each of the inputs and outputs
std::unordered_map<Val*, std::vector<Resize*>> tv_resize_map;
for (auto inp : ir_utils::filterByType<TensorView>(expr->inputs())) {
for (auto expr : inp->domain()->allExprs()) {
if (auto resize = dynamic_cast<Resize*>(expr)) {
tv_resize_map[inp].push_back(resize);
}
}
}
for (auto out : ir_utils::filterByType<TensorView>(expr->outputs())) {
for (auto expr : out->domain()->allExprs()) {
if (auto resize = dynamic_cast<Resize*>(expr)) {
tv_resize_map[out].push_back(resize);
}
}
}

if (resize_exprs.empty()) {
// If there's no resize in the producer and consumer tensors of this
// expr, it should not need this WAR.
if (tv_resize_map.empty()) {
return std::nullopt;
}

// The indexing issue with resize may happen when a single iter
// domain is resized multiple times. In other words, there must be
// at least two connected resize exprs. If not, this WAR is not
// necessary.
// domain is resized multiple times between a producer and a
// consumer. In other words, there must be at least two connected
// resize exprs. If not, this WAR is not necessary.
//
// Note that the actual indexing is done from the loop IDs, which
// might be promoted to IDs outside of this particular expr. Thus,
Expand All @@ -92,32 +100,63 @@ std::optional<IndexingTraversal::ExprPath> IndexingTraversal::
// promotion should not further add resize exprs, it is sufficient
// to analyze only the IDs of this expr.

// Shortcut for a common case to avoid building the graph below
if (resize_exprs.size() < 2) {
return std::nullopt;
}

const auto& local_graph = local_model.buildAlmostExactGraph();

// See if these resize expr groups are connected. Note that in the
// current default scheduling method, any tensor ops using resize
// should only show up with a fusion input as its input, so there
// must be no chained resize ops. The default scheduling, this
// function should not move beyond this point. In the case of the
// new resize scheduler that is currently under development will
// have multiple chained resize ops, but the scheduler should
// explicitly set the loop domain such that no promotion would
// happen, thus avoiding hitting the assertion down below.
ExprGroups resize_groups = local_graph.toGroups(resize_exprs);
// The below analysis is done for each producer-consumer pair, so it
// can be a rather expensive analysis, but in practice most
// cases should just bail out at the first if condition

auto isSingleIdResizedMultipleTimes = [&](TensorView* inp,
TensorView* out) -> bool {
auto num_resizes = tv_resize_map[inp].size() + tv_resize_map[out].size();
if (num_resizes < 2) {
return false;
}

std::vector<Resize*> resize_exprs;
resize_exprs.reserve(num_resizes);
resize_exprs.insert(
resize_exprs.end(),
tv_resize_map[inp].begin(),
tv_resize_map[inp].end());
resize_exprs.insert(
resize_exprs.end(),
tv_resize_map[out].begin(),
tv_resize_map[out].end());

// See if these resize expr groups are connected. Note that in the
// current default scheduling method, any tensor ops using resize
// should only show up with a fusion input as its input, so there
// must be no chained resize ops. The default scheduling, this
// function should not move beyond this point. In the case of the
// new resize scheduler that is currently under development will
// have multiple chained resize ops, but the scheduler should
// explicitly set the loop domain such that no promotion would
// happen, thus avoiding hitting the assertion down below.
ExprGroups resize_groups = local_graph.toGroups(resize_exprs);
for (const auto i : c10::irange(resize_groups.size() - 1)) {
const auto resize_i = resize_groups.at(i);
std::vector<ValGraphBFS::NodeType> other_resizes{
resize_groups.begin() + i + 1, resize_groups.end()};
auto reachable_nodes = getReachableNodesFrom<ValGraphBFS>(
{resize_i}, other_resizes, Direction::Undefined, local_graph);
if (!reachable_nodes.empty()) {
return true;
}
}

return false;
};

bool single_id_resized_multiple_times = false;
for (const auto i : c10::irange(resize_groups.size() - 1)) {
const auto resize_i = resize_groups.at(i);
std::vector<ValGraphBFS::NodeType> other_resizes{
resize_groups.begin() + i + 1, resize_groups.end()};
auto reachable_nodes = getReachableNodesFrom<ValGraphBFS>(
{resize_i}, other_resizes, Direction::Undefined, local_graph);
if (!reachable_nodes.empty()) {
single_id_resized_multiple_times = true;
for (auto out : ir_utils::filterByType<TensorView>(expr->outputs())) {
for (auto inp : ir_utils::filterByType<TensorView>(expr->inputs())) {
if (isSingleIdResizedMultipleTimes(inp, out)) {
single_id_resized_multiple_times = true;
break;
}
}
if (single_id_resized_multiple_times) {
break;
}
}
Expand Down
50 changes: 49 additions & 1 deletion tests/cpp/test_indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5408,7 +5408,7 @@ TEST_F(IndexingTest, ResizeRotation) {

// Repro of issue #3505. The indexing WAR for resize triggered an
// assertion due to loop promotion.
TEST_F(IndexingTest, Issue3505) {
TEST_F(IndexingTest, Issue3505Repro1) {
Fusion fusion;
FusionGuard fg(&fusion);

Expand Down Expand Up @@ -5452,6 +5452,54 @@ TEST_F(IndexingTest, Issue3505) {
testValidate(&fusion, outputs, inputs, __LINE__, __FILE__);
}

// Another repro of issue #3505
TEST_F(IndexingTest, Issue3505Repro2) {
Fusion fusion;
FusionGuard fg(&fusion);

const int64_t i0 = 8;
const int64_t i1 = 2;
const auto zero = fusion.zeroVal();

EnableOptionsGuard enable_options_guard;
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"});

auto tv0 = makeContigConcreteTensor({i0});
fusion.addInput(tv0);
auto tv1 = makeContigConcreteTensor({i1, i0 / 2});
fusion.addInput(tv1);

// Left half
auto tv2 = slice(tv0, {{zero, IrBuilder::create<Val>(i0 / 2)}});
// Right half
auto tv3 = slice(
tv0, {{IrBuilder::create<Val>(i0 / 2), IrBuilder::create<Val>(i0)}});

// The two inputs of this add expression have a resize of the same
// ID, but this should not mean the resize war path is required.
auto tv4 = add(tv2, tv3);
auto tv5 = broadcast(tv4, {true, false});
auto tv6 = add(tv1, tv5);
fusion.addOutput(tv6);

// Make loop promotion required
for (auto tv : {tv2, tv3, tv4, tv5, tv6}) {
tv->flatten();
}
inlineMost();

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn({i0}, options);
auto t1 = at::randn({i1, i0 / 2}, options);
std::vector<c10::IValue> inputs{t0, t1};

KernelExecutor ke;
ke.compile(&fusion, inputs);
auto outputs = ke.run(inputs);

testValidate(&fusion, outputs, inputs, __LINE__, __FILE__);
}

TEST_F(IndexingTest, AlmostExactIndexingUpdate) {
EnableOptionsGuard enable_options_guard;
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"});
Expand Down
Loading