Skip to content

Commit

Permalink
Fix ComputeAtMap for non-linear ID dependencies (#3577)
Browse files Browse the repository at this point in the history
Just patching ComputeAtMap to exclude dead expressions and vals.
  • Loading branch information
naoyam authored and jacobhinkle committed Dec 16, 2024
1 parent f5dadfb commit a8cab11
Showing 1 changed file with 24 additions and 1 deletion.
25 changes: 24 additions & 1 deletion csrc/compute_at_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,9 @@ void IterDomainGraph::build(Fusion* fusion) {

// Grab all the logical ids.
for (auto consumer_tv : all_consumer_tvs) {
auto exprs = StmtSort::getExprsTo(
auto exprs = StmtSort::getExprsBetween(
{consumer_tv->getMaybeRootDomain().begin(),
consumer_tv->getMaybeRootDomain().end()},
{consumer_tv->getLogicalDomain().begin(),
consumer_tv->getLogicalDomain().end()});
for (auto expr : exprs) {
Expand Down Expand Up @@ -663,6 +665,20 @@ void IterDomainGraph::build(Fusion* fusion) {
continue;
}

// logical_id_uses are guaranteed to be a valid expr, but
// first_logical_id->definition() may not be part of the valid
// exprs
if (!prop_forward) {
if (std::any_of(
first_expr->inputs().begin(),
first_expr->inputs().end(),
[&](Val* id_input) {
return !all_ids_.has(id_input->as<IterDomain>());
})) {
continue;
}
}

if (visited_exprs.find(first_expr) != visited_exprs.end()) {
continue;
}
Expand Down Expand Up @@ -1282,6 +1298,13 @@ void ComputeAtMap::buildUniqueExactExprMaps() {
if (id->definition() != nullptr) {
auto id_inputs =
ir_utils::filterByType<IterDomain>(id->definition()->inputs());
// If any input ID is not included in the map, this definition
// should not be included either.
if (std::any_of(id_inputs.begin(), id_inputs.end(), [&](auto id_input) {
return !idExistsInMap(id_input);
})) {
continue;
}
if (std::any_of(id_inputs.begin(), id_inputs.end(), [&](auto id_input) {
return disjoint_set_shared_ptr->has(id_input);
})) {
Expand Down

0 comments on commit a8cab11

Please sign in to comment.