Skip to content

Commit

Permalink
Rename RootDomainMap -> LogicalDomainMap (NVIDIA#2603)
Browse files Browse the repository at this point in the history
`LogicalDomainMap` is not a great name, but `RootDomainMap` is worse. It
is actually a `LogicalToRootDomainMap`, but that is too long.
  • Loading branch information
zasdfgbnm authored Jul 16, 2024
1 parent ab79e81 commit 7faa6f5
Show file tree
Hide file tree
Showing 80 changed files with 563 additions and 548 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/preseg_passes/propagate_shardings.cpp
${NVFUSER_SRCS_DIR}/preseg_passes/remove_empty.cpp
${NVFUSER_SRCS_DIR}/rng.cpp
${NVFUSER_SRCS_DIR}/root_domain_map.cpp
${NVFUSER_SRCS_DIR}/logical_domain_map.cpp
${NVFUSER_SRCS_DIR}/scheduler/cache_policy_refiner.cpp
${NVFUSER_SRCS_DIR}/scheduler/heuristic_types.cpp
${NVFUSER_SRCS_DIR}/scheduler/mark_aliases.cpp
Expand Down
4 changes: 2 additions & 2 deletions csrc/alias_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include <ir/iostream.h>
#include <ir/utils.h>
#include <linked_hash_map.h>
#include <root_domain_map.h>
#include <logical_domain_map.h>

namespace nvfuser {

Expand Down Expand Up @@ -144,7 +144,7 @@ std::pair<bool, std::optional<bool>> mergeContiguity(
}

std::unordered_map<IterDomain*, IterDomain*> in_logical_to_out_root =
PairwiseRootDomainMap(in, out).mapProducerToConsumer();
PairwiseLogicalDomainMap(in, out).mapProducerToConsumer();

Layout preferred_out_layout;
for (const auto i : c10::irange(preferred_in_layout.size())) {
Expand Down
4 changes: 2 additions & 2 deletions csrc/compute_at.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include <ir/all_nodes.h>
#include <ir/iostream.h>
#include <ir/utils.h>
#include <root_domain_map.h>
#include <logical_domain_map.h>
#include <transform_iter.h>

#include <c10/util/irange.h>
Expand Down Expand Up @@ -221,7 +221,7 @@ void ComputeAt::runAt(
auto selected = getPropagationSubgraph(producer, consumer);
ComputeAtSelector selector(selected);

MaxRootDomainInfoSpanningTree path(consumer, consumer_position, &selector);
MaxLogicalDomainInfoSpanningTree path(consumer, consumer_position, &selector);

if (mode == ComputeAtMode::MostInlined) {
MostInlinedTransformPropagator propagator;
Expand Down
2 changes: 1 addition & 1 deletion csrc/compute_at.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

#include <exceptions.h>
#include <inlining.h>
#include <root_domain_map.h>
#include <logical_domain_map.h>
#include <transform_replay.h>

#include <deque>
Expand Down
8 changes: 4 additions & 4 deletions csrc/compute_at_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include <device_lower/lower2device.h>
#include <disjoint_set.h>
#include <ir/utils.h>
#include <root_domain_map.h>
#include <logical_domain_map.h>
#include <transform_iter.h>

#include <tuple>
Expand Down Expand Up @@ -436,7 +436,7 @@ void IterDomainGraph::build(Fusion* fusion) {
auto tv_inputs = ir_utils::filterByType<TensorView>(expr->inputs());

for (auto p_tv : tv_inputs) {
auto pairwise_map = PairwiseRootDomainMap(p_tv, c_tv);
auto pairwise_map = PairwiseLogicalDomainMap(p_tv, c_tv);

// Look for matching ID transformations in producer and consumer, replay
// producer as consumer. We use the symmetric API of BestEffortReplay so
Expand All @@ -459,7 +459,7 @@ void IterDomainGraph::build(Fusion* fusion) {
// Note on the boolean flags: swizzles and resizes are skipped
// in the permissive-resize map
const auto pairwise_resize_map =
PairwiseRootDomainMap(p_tv, c_tv).mapIndexedDomains(true);
PairwiseLogicalDomainMap(p_tv, c_tv).mapIndexedDomains(true);
const auto permissive_resize_disjoint_sets =
BestEffortReplay::replayPasC(
p_tv, c_tv, -1, pairwise_resize_map, true, true, true)
Expand All @@ -468,7 +468,7 @@ void IterDomainGraph::build(Fusion* fusion) {
// For exact mapings do not map any broadcast dimensions to
// non-broadcast dimensions. Prevent any broadcasted axes being mapped
// to non-broadcasted axes.
auto exact_c2p_root_map = PairwiseRootDomainMap(p_tv, c_tv)
auto exact_c2p_root_map = PairwiseLogicalDomainMap(p_tv, c_tv)
.mapBroadcast(false)
.mapConsumerToProducer();

Expand Down
2 changes: 1 addition & 1 deletion csrc/device_lower/analysis/index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ std::unordered_map<IterDomain*, IterDomain*> mapAllProducerDomainsToConsumer(
producer_tv,
consumer_tv,
-1,
PairwiseRootDomainMap(producer_tv, consumer_tv));
PairwiseLogicalDomainMap(producer_tv, consumer_tv));

// Grab consumer domain entries and reverse replay map. TODO: Maybe
// TransformReplay::replayPasC could return this map
Expand Down
7 changes: 4 additions & 3 deletions csrc/device_lower/analysis/predicate_elimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch {
return true;
}

auto pairwise_map = PairwiseRootDomainMap(producer, consumer);
auto pairwise_map = PairwiseLogicalDomainMap(producer, consumer);
auto c2p =
BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map)
.getReplay();
Expand Down Expand Up @@ -320,8 +320,9 @@ class PredicateChcker : public IterVisitor {
expr->toString());

for (auto i : c10::irange(tv_inputs.size())) {
const auto root_p2c = PairwiseRootDomainMap(tv_inputs[i], tv_outputs[i])
.mapProducerToConsumer();
const auto root_p2c =
PairwiseLogicalDomainMap(tv_inputs[i], tv_outputs[i])
.mapProducerToConsumer();
for (auto entry : root_p2c) {
auto p_id = entry.first;
auto c_id = entry.second;
Expand Down
6 changes: 3 additions & 3 deletions csrc/device_lower/analysis/trivial_broadcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
// clang-format on
#include <ir/utils.h>
#include <iter_visitor.h>
#include <root_domain_map.h>
#include <logical_domain_map.h>

#include <device_lower/analysis/trivial_broadcast.h>

namespace nvfuser {

ConcretizedBroadcastDomains::ConcretizedBroadcastDomains(Fusion* fusion) {
exact_map_ = std::make_unique<ExactRootDomainMap>(fusion);
exact_map_ = std::make_unique<ExactLogicalDomainMap>(fusion);

// Initialize the origin map with input broadcast domains
auto inputs = fusion->inputsAndCreated();
Expand Down Expand Up @@ -107,7 +107,7 @@ void ConcretizedBroadcastDomains::dispatch(Expr* expr) {
}

for (auto consumer : ir_utils::filterByType<TensorView>(expr->outputs())) {
auto p2c_map = PairwiseRootDomainMap(producer, consumer)
auto p2c_map = PairwiseLogicalDomainMap(producer, consumer)
.mapProducerToConsumer(&producer_broadcasts);
for (const auto& kv : p2c_map) {
auto p_id = kv.first;
Expand Down
6 changes: 3 additions & 3 deletions csrc/device_lower/analysis/trivial_broadcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

#include <exceptions.h>
#include <ir/all_nodes.h>
#include <root_domain_map.h>
#include <logical_domain_map.h>
#include <visibility.h>

namespace nvfuser {
Expand All @@ -20,7 +20,7 @@ namespace nvfuser {
//! domains in input tensors. Then, a new entry is added to the origin
//! map when a broadcast op is encountered during a forward traversal
//! of the given fusion. For non-broadcast ops, mappings are just
//! propagated forward using PairwiseRootDomainMap.
//! propagated forward using PairwiseLogicalDomainMap.
//!
//! When the mapped consumer domain is not broadcast, it means the
//! producer broadcast domain is concretized, and its origin broadcast
Expand Down Expand Up @@ -69,7 +69,7 @@ class NVF_API ConcretizedBroadcastDomains : private IterVisitor {
std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>>
broadcast_to_concrete_map_;

std::unique_ptr<ExactRootDomainMap> exact_map_;
std::unique_ptr<ExactLogicalDomainMap> exact_map_;
};

} // namespace nvfuser
2 changes: 1 addition & 1 deletion csrc/device_lower/lower2device.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
#include <ir/all_nodes.h>
#include <kernel.h>
#include <kernel_ir.h>
#include <logical_domain_map.h>
#include <non_divisible_split.h>
#include <options.h>
#include <parallel_dimension_map.h>
#include <root_domain_map.h>
#include <vectorization_info.h>
#include <visibility.h>

Expand Down
2 changes: 1 addition & 1 deletion csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1834,7 +1834,7 @@ Val* IndexLowering::getIterationIndexForBroadcast(
"Expected broadcast ID but found ",
broadcast_id->toString());

auto c2p_root_map = PairwiseRootDomainMap(producer_tv, consumer_tv)
auto c2p_root_map = PairwiseLogicalDomainMap(producer_tv, consumer_tv)
.mapBroadcast(false)
.mapConsumerToProducer();

Expand Down
2 changes: 1 addition & 1 deletion csrc/device_lower/pass/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include <instrumentation.h>
#include <kernel_ir.h>
#include <kernel_ir_dispatch.h>
#include <root_domain_map.h>
#include <logical_domain_map.h>

#include <unordered_set>
#include <vector>
Expand Down
4 changes: 2 additions & 2 deletions csrc/device_lower/pass/misaligned_vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -470,8 +470,8 @@ class MisalignedVectorizationModifier : public kir::ExprMutator {

// Get full extent for the inner-most, merged root domain
Val* getVectorizeExtent(TensorView* producer_tv, TensorView* consumer_tv) {
auto p2c =
PairwiseRootDomainMap(producer_tv, consumer_tv).mapProducerToConsumer();
auto p2c = PairwiseLogicalDomainMap(producer_tv, consumer_tv)
.mapProducerToConsumer();

auto consumer_root_right_of_ca_domains = IterVisitor::getInputsTo(
{consumer_tv->getLoopDomain().begin() +
Expand Down
4 changes: 2 additions & 2 deletions csrc/device_lower/pass/replace_size.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include <ir/builder.h>
#include <ir/iostream.h>
#include <ir/utils.h>
#include <root_domain_map.h>
#include <logical_domain_map.h>

#include <device_lower/pass/replace_size.h>

Expand Down Expand Up @@ -90,7 +90,7 @@ std::unordered_map<Val*, Val*> getSimplificationMap(Fusion* fusion) {
for (auto producer_tv : ir_utils::filterByType<TensorView>(fusion_vals)) {
auto consumer_tvs = ir_utils::consumerTvsOf(producer_tv);
for (auto consumer_tv : consumer_tvs) {
auto pairwise_map = PairwiseRootDomainMap(producer_tv, consumer_tv);
auto pairwise_map = PairwiseLogicalDomainMap(producer_tv, consumer_tv);
auto c2p_root_map = pairwise_map.mapConsumerToProducer();
for (auto entry : c2p_root_map) {
auto c_id = entry.first;
Expand Down
2 changes: 1 addition & 1 deletion csrc/device_lower/pass/unroll.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include <device_lower/utils.h>
#include <kernel_ir.h>
#include <kernel_ir_dispatch.h>
#include <root_domain_map.h>
#include <logical_domain_map.h>

#include <bitset>
#include <unordered_map>
Expand Down
5 changes: 3 additions & 2 deletions csrc/device_lower/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
#include <ir/utils.h>
#include <iter_visitor.h>
#include <kernel_ir_dispatch.h>
#include <logical_domain_map.h>
#include <ops/arith.h>
#include <root_domain_map.h>

#include <expr_simplifier.h>
#include <algorithm>
Expand Down Expand Up @@ -909,7 +909,8 @@ std::array<UnitDim, 2> getMmaLayout(const MmaOp* expr) {
continue;
}
NVF_ERROR(in_tv->getMemoryType() == MemoryType::Shared);
auto out2in = PairwiseRootDomainMap(in_tv, out_tv).mapConsumerToProducer();
auto out2in =
PairwiseLogicalDomainMap(in_tv, out_tv).mapConsumerToProducer();
auto reduction_id_in = out2in.at(reduction_id);
auto inner_id = in_tv->getMaybeAllocationDomain().back();
while (inner_id != reduction_id_in && inner_id->definition() != nullptr) {
Expand Down
6 changes: 3 additions & 3 deletions csrc/device_lower/validation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ void checkContiguity(
!consumer->hasAllocation() && !producer->hasAllocation(),
"Misaligned vectorization for allocation domain is not supported.");
auto alloc_c2p =
PairwiseRootDomainMap(producer, consumer).mapConsumerToProducer();
PairwiseLogicalDomainMap(producer, consumer).mapConsumerToProducer();

std::unordered_map<IterDomain*, std::optional<bool>>
producer_domain_contiguity;
Expand Down Expand Up @@ -515,7 +515,7 @@ class VectorizeValidator : public OptInDispatch {
vectorized_set_info.vectorized_consumer_alloc_id = consumer_vectorized_id;

// Validate producer
auto pairwise_map = PairwiseRootDomainMap(producer_tv, tv);
auto pairwise_map = PairwiseLogicalDomainMap(producer_tv, tv);
auto producer_replayed_as_consumer =
TransformReplay::replayPasC(
producer_tv,
Expand Down Expand Up @@ -1097,7 +1097,7 @@ void validateReductions(Fusion* fusion) {
for (auto rop : ir_utils::getOpsOfType<ReductionOp>(fusion)) {
auto in = rop->in()->as<TensorView>();
auto out = rop->out()->as<TensorView>();
PairwiseRootDomainMap c2p_map(in, out);
PairwiseLogicalDomainMap c2p_map(in, out);
c2p_map.mapBroadcast(true);
auto c2p = c2p_map.mapConsumerToProducer();
for (auto out_id : out->getMaybeRootDomain()) {
Expand Down
4 changes: 2 additions & 2 deletions csrc/dynamic_transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1214,7 +1214,7 @@ static bool hasTrivialReduction(
TensorView* out,
std::vector<int64_t>& reduction_axes) {
bool has_trivial_reduction = false;
PairwiseRootDomainMap p2c_map(in, out);
PairwiseLogicalDomainMap p2c_map(in, out);
// We need to map broadcasts in order to detect reductions of broadcasts
p2c_map.mapBroadcast(true);
auto p2c = p2c_map.mapProducerToConsumer();
Expand Down Expand Up @@ -1303,7 +1303,7 @@ bool DynamicTransformConcretizer::propagateFromProducerToConsumer(
std::vector<std::unordered_map<IterDomain*, IterDomain*>> c2p_maps;
bool is_factory_output = true;
for (auto producer : ir_utils::filterByType<TensorView>(def->inputs())) {
PairwiseRootDomainMap root_map(producer, consumer);
PairwiseLogicalDomainMap root_map(producer, consumer);
// We map symbolic domains here regardless of whether their extents match.
// This is safe because we are propagating from a producer which should have
// already been concretized. The consumer might have a different extent
Expand Down
4 changes: 2 additions & 2 deletions csrc/expr_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include <ir/all_nodes.h>
#include <ir/iostream.h>
#include <ir/utils.h>
#include <root_domain_map.h>
#include <logical_domain_map.h>

#include <functional>
#include <iostream>
Expand Down Expand Up @@ -293,7 +293,7 @@ void ExpressionEvaluator::propagateBoundValuesThroughExactMaps(Fusion* fusion) {
// We map Symbolic IterDomains here only if their extents match. This avoids
// mapping between symbolic domains that might concretize to an (Iteration,
// Broadcast) pair from a resolved broadcast.
const auto mapped_sets = ExactRootDomainMap(fusion).getMappedSets();
const auto mapped_sets = ExactLogicalDomainMap(fusion).getMappedSets();

for (const auto& set : mapped_sets.disjointSets()) {
int64_t known_size = -1;
Expand Down
2 changes: 1 addition & 1 deletion csrc/expr_evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class ExpressionEvaluator {
//! Augment the evaluator with the exact root-domain map such that
//! if the extent of a root ID is known, the extents of all other
//! root IDs that are exactly mapped also get bound to the same
//! value. This is currently just done with ExactRootDomainMap, but
//! value. This is currently just done with ExactLogicalDomainMap, but
//! can be similarly done with the Exact CA map as well.
void propagateBoundValuesThroughExactMaps(Fusion* fusion);

Expand Down
4 changes: 2 additions & 2 deletions csrc/grouped_reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
// clang-format on
#include <ir/builder.h>
#include <ir/utils.h>
#include <root_domain_map.h>
#include <logical_domain_map.h>
#include <transform_iter.h>

#include <grouped_reduction.h>
Expand Down Expand Up @@ -60,7 +60,7 @@ bool validateReductionGrouping(
NVF_ERROR(
fusion != nullptr, "Grouping of reductions must be done within a Fusion");

ExactRootDomainMap exact_map(fusion);
ExactLogicalDomainMap exact_map(fusion);

// Pick the first output TV as a reference and compare it with the
// rest. Do not allow grouping if any mismatch is detected.
Expand Down
8 changes: 4 additions & 4 deletions csrc/id_model/id_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include <device_lower/utils.h>
#include <disjoint_set.h>
#include <ir/utils.h>
#include <root_domain_map.h>
#include <logical_domain_map.h>
#include <transform_iter.h>
#include <val_graph_visitor.h>

Expand Down Expand Up @@ -308,7 +308,7 @@ void IdModel::buildExactGraph() {
// For exact mapings do not map any broadcast dimensions to
// non-broadcast dimensions. Prevent any broadcasted axes being mapped
// to non-broadcasted axes.
auto exact_c2p_root_map = PairwiseRootDomainMap(p_tv, c_tv)
auto exact_c2p_root_map = PairwiseLogicalDomainMap(p_tv, c_tv)
.mapBroadcast(false)
.mapConsumerToProducer();

Expand Down Expand Up @@ -454,7 +454,7 @@ void IdModel::buildPermissiveGraph() {
}

auto permissive_c2p_root_map =
PairwiseRootDomainMap(p_tv, c_tv).mapBroadcast(true);
PairwiseLogicalDomainMap(p_tv, c_tv).mapBroadcast(true);

for (auto entry : permissive_c2p_root_map.mapConsumerToProducer()) {
idGraph(IdMappingMode::PERMISSIVE).mapVals(entry.first, entry.second);
Expand All @@ -472,7 +472,7 @@ namespace {
std::vector<std::pair<IterDomain*, IterDomain*>> resolvedRootBroadcasts(
TensorView* producer,
TensorView* consumer) {
auto p2c_map = PairwiseRootDomainMap(producer, consumer)
auto p2c_map = PairwiseLogicalDomainMap(producer, consumer)
.mapBroadcast(true)
.mapProducerToConsumer();

Expand Down
Loading

0 comments on commit 7faa6f5

Please sign in to comment.