diff --git a/csrc/id_model/indexing.cpp b/csrc/id_model/indexing.cpp index 46edcfe8367..8cc0da959e2 100644 --- a/csrc/id_model/indexing.cpp +++ b/csrc/id_model/indexing.cpp @@ -855,14 +855,18 @@ std::vector TensorIndexer::getIndexFor( const auto& replacement_map = getIndexReplacementMap( expr, as_consumer, info.loop_domains, for_loops, info.index_map); - const auto index_groups = traversalGraph().toGroups(index_ids); + // Note that IDs of index_ids may be mapped as the traversal graph + // is the AlmostExact graph. std::vector result; - result.reserve(index_groups.size()); - for (const auto& g : index_groups) { - auto it = info.index_map.find(g); + result.reserve(index_ids.size()); + for (IterDomain* index_id : index_ids) { + const auto& index_group = traversalGraph().toGroup(index_id); + auto it = info.index_map.find(index_group); NVF_ERROR( - it != info.index_map.end(), "Index not found for ", g->toString()); + it != info.index_map.end(), + "Index not found for ", + index_id->toString()); result.push_back( ir_utils::replaceValRecursively(it->second, replacement_map)); }