Skip to content

Commit

Permalink
Contiguous predicate indexing (#2860)
Browse files Browse the repository at this point in the history
This is a follow-up to #2752 and concludes contiguous indexing support
in the new indexing system. The previous only supported tensor indexing.
This one extends that for predicate indexing. The primary differences
with the tensor indexing case are:

- In predicate indexing, all logical domains can be considered
contiguous. Ordering of merge doesn't matter, either. This is also the
same with the current indexer
- However, contiguity analysis needs to take non-divisible splits into
consideration as it's invalid to do contiguous indexing through
non-divisible splits. This is also the same with the current indexer.
  • Loading branch information
naoyam authored Sep 9, 2024
1 parent 29379c5 commit 42587be
Show file tree
Hide file tree
Showing 4 changed files with 398 additions and 49 deletions.
89 changes: 71 additions & 18 deletions csrc/id_model/contiguity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ ContigIDGroups::ContigIDGroups(
const std::vector<IterDomain*>& alloc_domains,
std::vector<bool> contiguity,
const ExprPath<ExprGroup>& path_from_alloc,
const ValGraph& graph)
const ValGraph& graph,
bool is_predicate_pass)
: graph_(graph),
alloc_domains_(alloc_domains),
alloc_contiguity_(std::move(contiguity)),
is_predicate_pass_(is_predicate_pass),
consistent_transform_info_(
std::make_unique<const OrderedIdGroupInformation>(
OrderedIdGroupInformation::get(
Expand Down Expand Up @@ -69,6 +71,14 @@ ContigIDGroups::ContigIDGroups(
}
}

if (std::any_of(inputs.begin(), inputs.end(), [&](const ValGroup& inp) {
return non_divisible_deps_.count(inp) > 0;
})) {
for (const auto& out : outputs) {
non_divisible_deps_.insert(out);
}
}

dispatch(eg, direction);
}
}
Expand All @@ -79,10 +89,14 @@ void ContigIDGroups::handle(Merge* merge, Direction direction) {
return;
}

const bool is_indexing_pass = !is_predicate_pass_;

// If output is not consistently ordered or doesn't solely consume all
// allocation domains in its dependencies, then it can't be a contiguously
// indexable iterdomain.
if (!consistent_transform_info_->isConsistentlyOrdered(merge->out())) {
// indexable iterdomain. If it's a predicate pass, the ordering
// doesn't matter since it does not index any actual memory.
if (!(is_predicate_pass_ ||
consistent_transform_info_->isConsistentlyOrdered(merge->out()))) {
return;
}

Expand All @@ -92,22 +106,36 @@ void ContigIDGroups::handle(Merge* merge, Direction direction) {

// Check allocation domains for contiguity
auto alloc_ids_it = consistent_transform_info_->findAllocIDs(merge->out());
VectorOfUniqueEntries<IterDomain*> alloc_ids = alloc_ids_it->second;
for (auto alloc_id_i : c10::irange(alloc_domains_.size())) {
auto alloc_id = alloc_domains_[alloc_id_i];
if (alloc_ids.erase(alloc_id) == 0) {
continue;
}
auto alloc_contiguity = alloc_contiguity_.at(alloc_id_i);
// If we're indexing:
// we could still potentially consider this ID linearly indexable, as we
// could multiple the index by the last allocation's stride. See
// ContigIndexingTest.NonContigInnermost for a concrete example.
if (!alloc_contiguity && !alloc_ids.empty()) {
return;

// Contiguity doesn't matter for predicates
if (is_indexing_pass) {
VectorOfUniqueEntries<IterDomain*> alloc_ids = alloc_ids_it->second;
for (auto alloc_id_i : c10::irange(alloc_domains_.size())) {
auto alloc_id = alloc_domains_[alloc_id_i];
if (alloc_ids.erase(alloc_id) == 0) {
continue;
}
auto alloc_contiguity = alloc_contiguity_.at(alloc_id_i);
// If we're indexing:
// we could still potentially consider this ID linearly indexable, as we
// could multiple the index by the last allocation's stride. See
// ContigIndexingTest.NonContigInnermost for a concrete example.
if (!alloc_contiguity && !alloc_ids.empty()) {
return;
}
}
}

// If there's a non-divisible
// split in the history of merge->out then the extents of the inputs
// and also the outputs may be expanded due to ceilDiv. Predicate
// indexing needs to avoid contiguous indexing. Non-predicate
// indexing should have no such constraint.
if (is_predicate_pass_ &&
non_divisible_deps_.count(graph_.toGroup(merge->out()))) {
return;
}

// Don't allow contig indexing after resize as we need traverse back
// at least to direct outputs of resize ops
if (resize_deps_.count(graph_.toGroup(merge->out()))) {
Expand All @@ -123,6 +151,26 @@ void ContigIDGroups::handle(Merge* merge, Direction direction) {
contig_ids_.emplace(graph_.toGroup(merge->out()));
}

// Avoid contiguous indexing if going through non-divisible
// splits. Not all non-divisible splits need specific predicates, so
// this condition could be relaxed.
void ContigIDGroups::handle(Split* split, Direction direction) {
if (direction == Direction::Forward) {
const auto& divisible_splits = GpuLower::current()->divisibleSplitSet();
const ExprGroup& split_group = graph_.toGroup(split);
bool divisible = std::any_of(
divisible_splits.begin(),
divisible_splits.end(),
[&](Split* divisible_split) -> bool {
return split_group->has(divisible_split);
});
if (!divisible) {
non_divisible_deps_.emplace(graph_.toGroup(split->outer()));
non_divisible_deps_.emplace(graph_.toGroup(split->inner()));
}
}
}

void ContigIDGroups::handle(Resize* resize, Direction direction) {
if (direction == Direction::Forward) {
resize_deps_.emplace(graph_.toGroup(resize->out()));
Expand All @@ -135,9 +183,14 @@ std::unordered_map<IterDomain*, ValGroup> getContigDomains(
const std::vector<IterDomain*>& alloc_domains,
const std::vector<bool>& alloc_contiguity,
const ExprPath<ExprGroup>& path_from_alloc,
const ValGraph& graph) {
const ValGraph& graph,
bool is_predicate_pass) {
ContigIDGroups contig_finder(
alloc_domains, alloc_contiguity, path_from_alloc, graph);
alloc_domains,
alloc_contiguity,
path_from_alloc,
graph,
is_predicate_pass);

return contig_finder.allocToContigIDs();
}
Expand Down
13 changes: 11 additions & 2 deletions csrc/id_model/contiguity.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ class ContigIDGroups {
const std::vector<IterDomain*>& alloc_domains,
std::vector<bool> contiguity,
const ExprPath<ExprGroup>& path_from_alloc,
const ValGraph& graph);
const ValGraph& graph,
bool is_predicate_pass);

void dispatch(const ExprGroup& eg, Direction direction) {
NVF_ERROR(!eg->empty());
Expand All @@ -110,13 +111,17 @@ class ContigIDGroups {

if (auto merge = dynamic_cast<Merge*>(expr)) {
handle(merge, direction);
} else if (auto split = dynamic_cast<Split*>(expr)) {
handle(split, direction);
} else if (auto resize = dynamic_cast<Resize*>(expr)) {
handle(resize, direction);
}
}

void handle(Merge* merge, Direction direction);

void handle(Split* split, Direction direction);

void handle(Resize* resize, Direction direction);

const std::unordered_set<ValGroup>& contigIDs() const {
Expand All @@ -136,6 +141,7 @@ class ContigIDGroups {
const std::vector<IterDomain*> alloc_domains_;
// Contiguity of alloc_domains_
const std::vector<bool> alloc_contiguity_;
const bool is_predicate_pass_;
std::unique_ptr<const OrderedIdGroupInformation> consistent_transform_info_;

// Contig domain groups
Expand All @@ -144,6 +150,8 @@ class ContigIDGroups {
std::unordered_map<IterDomain*, ValGroup> alloc_to_contig_ids_;
// All domains that have dependencies with resize ops
std::unordered_set<ValGroup> resize_deps_;
// All domains that have dependencies with non-divisible split ops
std::unordered_set<ValGroup> non_divisible_deps_;
};

// Get a contiguous indexing domain for a given allocation domain. If
Expand All @@ -153,6 +161,7 @@ std::unordered_map<IterDomain*, ValGroup> getContigDomains(
const std::vector<IterDomain*>& alloc_domains,
const std::vector<bool>& alloc_contiguity,
const ExprPath<ExprGroup>& path_from_alloc,
const ValGraph& graph);
const ValGraph& graph,
bool is_predicate_pass);

} // namespace nvfuser
83 changes: 70 additions & 13 deletions csrc/id_model/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,8 @@ class AllocationDomainSetup : private kir::IrVisitor {
}
}

NVF_ERROR(allocation_domains.size() == contiguity.size());

return {allocation_domains, contiguity};
}

Expand Down Expand Up @@ -402,6 +404,9 @@ class AllocationDomainSetup : private kir::IrVisitor {
actual_contiguity.push_back(contig.value());
}

NVF_ERROR(actual_allocation_domains.size() == actual_strides.size());
NVF_ERROR(actual_allocation_domains.size() == actual_contiguity.size());

return IndexingAllocationInfo{
actual_allocation_domains, actual_strides, actual_contiguity};
}
Expand Down Expand Up @@ -756,8 +761,7 @@ TensorIndexer::TensorIndexer(IdModel& id_model) : id_model_(id_model) {

if (isDebugDumpEnabled(DebugDumpOption::IndexingVerbose)) {
std::ofstream ofs("indexing_traversal_graph.dot", std::ofstream::trunc);
auto dot_string =
id_model_.idGraph(IdMappingMode::ALMOSTEXACT).toGraphvizDotGraph();
auto dot_string = traversalGraph().toGraphvizDotGraph();
ofs << dot_string;
ofs.close();
}
Expand Down Expand Up @@ -937,22 +941,24 @@ Val* TensorIndexer::getLinearIndex(
getContigIndexFor(expr, as_consumer, alloc_info, for_loops);

// Linearize the indices with strides.
Val* index = tv->fusion()->zeroVal();
Val* linear_index = tv->fusion()->zeroVal();
for (const auto i : c10::irange(contig_indices.size())) {
Val* stride = contig_strides.at(i);
index = SimplifyingIrBuilder::addExpr(
index, SimplifyingIrBuilder::mulExpr(contig_indices.at(i), stride));
linear_index = SimplifyingIrBuilder::addExpr(
linear_index,
SimplifyingIrBuilder::mulExpr(contig_indices.at(i), stride));
}

// If a tensor is circular buffered, it also requires indexing of
// the circular buffer itself
if (tv->isCircularBuffered()) {
auto circular_buffer_offset =
getOffsetForCircularBufferTensor(tv, as_consumer, for_loops);
index = SimplifyingIrBuilder::addExpr(index, circular_buffer_offset);
linear_index =
SimplifyingIrBuilder::addExpr(linear_index, circular_buffer_offset);
}

return index;
return linear_index;
}

// Get the loop domains of a given expr, which are (potentially
Expand Down Expand Up @@ -1131,21 +1137,71 @@ std::vector<PredicateInfo> TensorIndexer::getPredicates(
/*is_start_predicate=*/false,
/*unswitched_loop=*/unswitched_loop);

const std::unordered_map<IterDomain*, ValGroup> contig_domains =
isContigIndexingEnabled()
? getContigDomains(
predicate_domains,
std::vector<bool>(predicate_domains.size(), true),
reverse(index_info.traversal_path),
traversalGraph(),
/*is_predicate_pass=*/true)
: std::unordered_map<IterDomain*, ValGroup>{};

auto getCoveredPredicatedDomains =
[&predicate_domains, &contig_domains](const ValGroup& contig_group) {
std::unordered_set<IterDomain*> covered_domains;
for (const auto& predicate_domain : predicate_domains) {
auto contig_domains_it = contig_domains.find(predicate_domain);
NVF_ERROR(contig_domains_it != contig_domains.end());
if (contig_group == contig_domains_it->second) {
covered_domains.emplace(predicate_domain);
}
}
return covered_domains;
};

const CircularBufferLoopStage loop_stage = getCircularBufferLoopStage(
tv, for_loops, id_model_.idGraph(IdMappingMode::LOOP));

std::vector<PredicateInfo> info_vec;
info_vec.reserve(predicate_domains.size());

std::unordered_set<ValGroup> already_indexed_domains;

// Follow the same approach as Index::getReferenceRootPredicates.
for (const auto& predicate_domain : predicate_domains) {
const auto& predicate_domain_group =
traversalGraph().toGroup(predicate_domain);
auto idx_it = index_map.find(predicate_domain_group);
IterDomain* actual_predicate_domain = predicate_domain;
ValGroup actual_predicate_domain_group = predicate_domain_group;
std::unordered_set<IterDomain*> actual_predicate_domains = {
predicate_domain};

if (isContigIndexingEnabled()) {
auto contig_domains_it = contig_domains.find(predicate_domain);
NVF_ERROR(
contig_domains_it != contig_domains.end(),
"No contig domain mapping found for ",
predicate_domain->toString());
const ValGroup& contig_domain_group = contig_domains_it->second;
if (already_indexed_domains.find(contig_domain_group) !=
already_indexed_domains.end()) {
continue;
}
already_indexed_domains.emplace(contig_domain_group);

actual_predicate_domain_group = contig_domain_group;
actual_predicate_domain =
actual_predicate_domain_group->front()->as<IterDomain>();
actual_predicate_domains =
getCoveredPredicatedDomains(contig_domain_group);
}

auto idx_it = index_map.find(actual_predicate_domain_group);
NVF_ERROR(
idx_it != index_map.end(),
"Index not found for ",
predicate_domain->toString());
nvfuser::toString(actual_predicate_domain_group));

Val* idx = idx_it->second;
Val* start_idx =
Expand All @@ -1170,13 +1226,13 @@ std::vector<PredicateInfo> TensorIndexer::getPredicates(

info.stop_predicate_ = SimplifyingIrBuilder::ltExpr(
SimplifyingIrBuilder::addExpr(stop_idx, info.stop_offset_),
predicate_domain->extent());
actual_predicate_domain->extent());

info.predicated_domains_ = {predicate_domain};
info.predicated_domains_ = actual_predicate_domains;

// Set the used loop ID groups for this predicated domain
const ValGroups& loop_deps =
index_info.loop_group_dependencies.at(predicate_domain_group);
index_info.loop_group_dependencies.at(actual_predicate_domain_group);
for (const auto& loop_dep : loop_deps) {
info.loop_domains_.insert(loop_dep->front()->as<IterDomain>());
}
Expand Down Expand Up @@ -1247,7 +1303,8 @@ std::pair<std::vector<ValGroup>, std::vector<Val*>> TensorIndexer::
alloc_info.domains,
alloc_info.contiguity,
reverse(traversal_path),
traversalGraph());
traversalGraph(),
/*is_predicate_pass=*/false);

// Find contiguous domains to index
std::unordered_set<ValGroup> already_indexed_domains;
Expand Down
Loading

0 comments on commit 42587be

Please sign in to comment.