From 7faa6f5f346b30aaa3cf4af99567f56232a8d1aa Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Tue, 16 Jul 2024 11:34:02 -0700 Subject: [PATCH] Rename RootDomainMap -> LogicalDomainMap (#2603) `LogicalDomainMap` is not a great name, but `RootDomainMap` is worse. It is actually a `LogicalToRootDomainMap`, but that is too long. --- CMakeLists.txt | 2 +- csrc/alias_analysis.cpp | 4 +- csrc/compute_at.cpp | 4 +- csrc/compute_at.h | 2 +- csrc/compute_at_map.cpp | 8 +- csrc/device_lower/analysis/index_compute.cpp | 2 +- .../analysis/predicate_elimination.cpp | 7 +- .../analysis/trivial_broadcast.cpp | 6 +- .../device_lower/analysis/trivial_broadcast.h | 6 +- csrc/device_lower/lower2device.h | 2 +- csrc/device_lower/pass/index.cpp | 2 +- csrc/device_lower/pass/index.h | 2 +- .../pass/misaligned_vectorization.cpp | 4 +- csrc/device_lower/pass/replace_size.cpp | 4 +- csrc/device_lower/pass/unroll.h | 2 +- csrc/device_lower/utils.cpp | 5 +- csrc/device_lower/validation.cpp | 6 +- csrc/dynamic_transform.cpp | 4 +- csrc/expr_evaluator.cpp | 4 +- csrc/expr_evaluator.h | 2 +- csrc/grouped_reduction.cpp | 4 +- csrc/id_model/id_model.cpp | 8 +- csrc/index_compute.cpp | 14 +- csrc/index_compute.h | 2 +- csrc/inlining.cpp | 13 +- csrc/ir/nodes.cpp | 2 +- ..._domain_map.cpp => logical_domain_map.cpp} | 233 +++++++++--------- ...root_domain_map.h => logical_domain_map.h} | 105 ++++---- csrc/maxinfo_propagator.cpp | 50 ++-- csrc/maxinfo_propagator.h | 14 +- csrc/multidevice/lower_communication.cpp | 3 +- csrc/multidevice/utils.cpp | 6 +- csrc/predicate_compute.h | 2 +- .../allocation_order_inference.cpp | 2 +- .../exact_mapped_extent_substitution.cpp | 12 +- csrc/preseg_passes/remove_bcast_squeeze.cpp | 2 +- csrc/python_frontend/python_bindings.cpp | 4 +- csrc/scheduler/cache_policy_refiner.cpp | 14 +- csrc/scheduler/mma_utils.cpp | 4 +- csrc/scheduler/normalization_utils.cpp | 2 +- csrc/scheduler/pointwise.cpp | 2 +- csrc/scheduler/reduction.cpp | 2 +- csrc/scheduler/reduction_utils.cpp | 2 +- csrc/scheduler/registry_utils.cpp | 14 +- csrc/scheduler/registry_utils.h | 4 +- csrc/scheduler/transpose.cpp | 8 +- csrc/scheduler/utils.cpp | 16 +- csrc/scheduler/vectorize_helper.cpp | 9 +- csrc/tensor_view.cpp | 2 +- csrc/transform_iter.cpp | 6 +- csrc/transform_iter.h | 6 +- csrc/transform_replay.cpp | 14 +- csrc/transform_replay.h | 10 +- tests/cpp/test_ca_root_domain_map.cpp | 44 ++-- tests/cpp/test_circular_buffering.cpp | 20 +- tests/cpp/test_expr_sort.cpp | 6 +- tests/cpp/test_gather.cpp | 2 +- tests/cpp/test_gpu1.cpp | 24 +- tests/cpp/test_gpu2.cpp | 36 +-- tests/cpp/test_gpu3.cpp | 74 +++--- tests/cpp/test_gpu_compute_with.cpp | 12 +- tests/cpp/test_gpu_fused_reduction.cpp | 58 ++--- tests/cpp/test_gpu_indexing.cpp | 2 +- tests/cpp/test_gpu_outer_reduction.cpp | 4 +- tests/cpp/test_gpu_tensorcore.cpp | 2 +- tests/cpp/test_gpu_transpose.cpp | 6 +- tests/cpp/test_gpu_view.cpp | 22 +- tests/cpp/test_id_model.cpp | 24 +- tests/cpp/test_indexing.cpp | 36 +-- tests/cpp/test_memory.cpp | 4 +- tests/cpp/test_multidevice_pipeline.cpp | 2 +- tests/cpp/test_persistent_buffer.cpp | 2 +- tests/cpp/test_predicate_elimination.cpp | 10 +- tests/cpp/test_resize.cpp | 32 +-- tests/cpp/test_scalar_hoisting.cpp | 2 +- tests/cpp/test_sdpa_node.cpp | 4 +- tests/cpp/test_serial_gridreduce.cpp | 2 +- tests/cpp/test_swizzle.cpp | 4 +- tests/cpp/test_translate_mma.cpp | 2 +- tests/cpp/test_tutorial.cpp | 6 +- 80 files changed, 563 insertions(+), 548 deletions(-) rename csrc/{root_domain_map.cpp => logical_domain_map.cpp} (86%) rename csrc/{root_domain_map.h => logical_domain_map.h} (84%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 7c0a061e30c..ceb31084e18 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -193,7 +193,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/preseg_passes/propagate_shardings.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/remove_empty.cpp ${NVFUSER_SRCS_DIR}/rng.cpp - ${NVFUSER_SRCS_DIR}/root_domain_map.cpp + ${NVFUSER_SRCS_DIR}/logical_domain_map.cpp ${NVFUSER_SRCS_DIR}/scheduler/cache_policy_refiner.cpp ${NVFUSER_SRCS_DIR}/scheduler/heuristic_types.cpp ${NVFUSER_SRCS_DIR}/scheduler/mark_aliases.cpp diff --git a/csrc/alias_analysis.cpp b/csrc/alias_analysis.cpp index 9d02ef12285..f719837a321 100644 --- a/csrc/alias_analysis.cpp +++ b/csrc/alias_analysis.cpp @@ -16,7 +16,7 @@ #include #include #include -#include +#include namespace nvfuser { @@ -144,7 +144,7 @@ std::pair> mergeContiguity( } std::unordered_map in_logical_to_out_root = - PairwiseRootDomainMap(in, out).mapProducerToConsumer(); + PairwiseLogicalDomainMap(in, out).mapProducerToConsumer(); Layout preferred_out_layout; for (const auto i : c10::irange(preferred_in_layout.size())) { diff --git a/csrc/compute_at.cpp b/csrc/compute_at.cpp index e32042220ad..52e70472e97 100644 --- a/csrc/compute_at.cpp +++ b/csrc/compute_at.cpp @@ -11,7 +11,7 @@ #include #include #include -#include +#include #include #include @@ -221,7 +221,7 @@ void ComputeAt::runAt( auto selected = getPropagationSubgraph(producer, consumer); ComputeAtSelector selector(selected); - MaxRootDomainInfoSpanningTree path(consumer, consumer_position, &selector); + MaxLogicalDomainInfoSpanningTree path(consumer, consumer_position, &selector); if (mode == ComputeAtMode::MostInlined) { MostInlinedTransformPropagator propagator; diff --git a/csrc/compute_at.h b/csrc/compute_at.h index 1c1891fd433..45ab268f7ba 100644 --- a/csrc/compute_at.h +++ b/csrc/compute_at.h @@ -9,7 +9,7 @@ #include #include -#include +#include #include #include diff --git a/csrc/compute_at_map.cpp b/csrc/compute_at_map.cpp index 74aa90bb610..c65412226c0 100644 --- a/csrc/compute_at_map.cpp +++ b/csrc/compute_at_map.cpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include #include @@ -436,7 +436,7 @@ void IterDomainGraph::build(Fusion* fusion) { auto tv_inputs = ir_utils::filterByType(expr->inputs()); for (auto p_tv : tv_inputs) { - auto pairwise_map = PairwiseRootDomainMap(p_tv, c_tv); + auto pairwise_map = PairwiseLogicalDomainMap(p_tv, c_tv); // Look for matching ID transformations in producer and consumer, replay // producer as consumer. We use the symmetric API of BestEffortReplay so @@ -459,7 +459,7 @@ void IterDomainGraph::build(Fusion* fusion) { // Note on the boolean flags: swizzles and resizes are skipped // in the permissive-resize map const auto pairwise_resize_map = - PairwiseRootDomainMap(p_tv, c_tv).mapIndexedDomains(true); + PairwiseLogicalDomainMap(p_tv, c_tv).mapIndexedDomains(true); const auto permissive_resize_disjoint_sets = BestEffortReplay::replayPasC( p_tv, c_tv, -1, pairwise_resize_map, true, true, true) @@ -468,7 +468,7 @@ void IterDomainGraph::build(Fusion* fusion) { // For exact mapings do not map any broadcast dimensions to // non-broadcast dimensions. Prevent any broadcasted axes being mapped // to non-broadcasted axes. - auto exact_c2p_root_map = PairwiseRootDomainMap(p_tv, c_tv) + auto exact_c2p_root_map = PairwiseLogicalDomainMap(p_tv, c_tv) .mapBroadcast(false) .mapConsumerToProducer(); diff --git a/csrc/device_lower/analysis/index_compute.cpp b/csrc/device_lower/analysis/index_compute.cpp index b9b2914474b..02f04c64beb 100644 --- a/csrc/device_lower/analysis/index_compute.cpp +++ b/csrc/device_lower/analysis/index_compute.cpp @@ -48,7 +48,7 @@ std::unordered_map mapAllProducerDomainsToConsumer( producer_tv, consumer_tv, -1, - PairwiseRootDomainMap(producer_tv, consumer_tv)); + PairwiseLogicalDomainMap(producer_tv, consumer_tv)); // Grab consumer domain entries and reverse replay map. TODO: Maybe // TransformReplay::replayPasC could return this map diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index 9217f8c14ec..0332f63fe48 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -85,7 +85,7 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { return true; } - auto pairwise_map = PairwiseRootDomainMap(producer, consumer); + auto pairwise_map = PairwiseLogicalDomainMap(producer, consumer); auto c2p = BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map) .getReplay(); @@ -320,8 +320,9 @@ class PredicateChcker : public IterVisitor { expr->toString()); for (auto i : c10::irange(tv_inputs.size())) { - const auto root_p2c = PairwiseRootDomainMap(tv_inputs[i], tv_outputs[i]) - .mapProducerToConsumer(); + const auto root_p2c = + PairwiseLogicalDomainMap(tv_inputs[i], tv_outputs[i]) + .mapProducerToConsumer(); for (auto entry : root_p2c) { auto p_id = entry.first; auto c_id = entry.second; diff --git a/csrc/device_lower/analysis/trivial_broadcast.cpp b/csrc/device_lower/analysis/trivial_broadcast.cpp index 36b267f814c..627f5c2c109 100644 --- a/csrc/device_lower/analysis/trivial_broadcast.cpp +++ b/csrc/device_lower/analysis/trivial_broadcast.cpp @@ -7,14 +7,14 @@ // clang-format on #include #include -#include +#include #include namespace nvfuser { ConcretizedBroadcastDomains::ConcretizedBroadcastDomains(Fusion* fusion) { - exact_map_ = std::make_unique(fusion); + exact_map_ = std::make_unique(fusion); // Initialize the origin map with input broadcast domains auto inputs = fusion->inputsAndCreated(); @@ -107,7 +107,7 @@ void ConcretizedBroadcastDomains::dispatch(Expr* expr) { } for (auto consumer : ir_utils::filterByType(expr->outputs())) { - auto p2c_map = PairwiseRootDomainMap(producer, consumer) + auto p2c_map = PairwiseLogicalDomainMap(producer, consumer) .mapProducerToConsumer(&producer_broadcasts); for (const auto& kv : p2c_map) { auto p_id = kv.first; diff --git a/csrc/device_lower/analysis/trivial_broadcast.h b/csrc/device_lower/analysis/trivial_broadcast.h index ea9a9c6d846..13f4be924ef 100644 --- a/csrc/device_lower/analysis/trivial_broadcast.h +++ b/csrc/device_lower/analysis/trivial_broadcast.h @@ -9,7 +9,7 @@ #include #include -#include +#include #include namespace nvfuser { @@ -20,7 +20,7 @@ namespace nvfuser { //! domains in input tensors. Then, a new entry is added to the origin //! map when a broadcast op is encountered during a forward traversal //! of the given fusion. For non-broadcast ops, mappings are just -//! propagated forward using PairwiseRootDomainMap. +//! propagated forward using PairwiseLogicalDomainMap. //! //! When the mapped consumer domain is not broadcast, it means the //! producer broadcast domain is concretized, and its origin broadcast @@ -69,7 +69,7 @@ class NVF_API ConcretizedBroadcastDomains : private IterVisitor { std::unordered_map> broadcast_to_concrete_map_; - std::unique_ptr exact_map_; + std::unique_ptr exact_map_; }; } // namespace nvfuser diff --git a/csrc/device_lower/lower2device.h b/csrc/device_lower/lower2device.h index d7f1b1095ec..fadfc6337ed 100644 --- a/csrc/device_lower/lower2device.h +++ b/csrc/device_lower/lower2device.h @@ -29,10 +29,10 @@ #include #include #include +#include #include #include #include -#include #include #include diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index 898fb329fa7..0ac753b31ee 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -1834,7 +1834,7 @@ Val* IndexLowering::getIterationIndexForBroadcast( "Expected broadcast ID but found ", broadcast_id->toString()); - auto c2p_root_map = PairwiseRootDomainMap(producer_tv, consumer_tv) + auto c2p_root_map = PairwiseLogicalDomainMap(producer_tv, consumer_tv) .mapBroadcast(false) .mapConsumerToProducer(); diff --git a/csrc/device_lower/pass/index.h b/csrc/device_lower/pass/index.h index 3419cb57a04..448933d3835 100644 --- a/csrc/device_lower/pass/index.h +++ b/csrc/device_lower/pass/index.h @@ -12,7 +12,7 @@ #include #include #include -#include +#include #include #include diff --git a/csrc/device_lower/pass/misaligned_vectorization.cpp b/csrc/device_lower/pass/misaligned_vectorization.cpp index fc7e3e607fa..1985efb0659 100644 --- a/csrc/device_lower/pass/misaligned_vectorization.cpp +++ b/csrc/device_lower/pass/misaligned_vectorization.cpp @@ -470,8 +470,8 @@ class MisalignedVectorizationModifier : public kir::ExprMutator { // Get full extent for the inner-most, merged root domain Val* getVectorizeExtent(TensorView* producer_tv, TensorView* consumer_tv) { - auto p2c = - PairwiseRootDomainMap(producer_tv, consumer_tv).mapProducerToConsumer(); + auto p2c = PairwiseLogicalDomainMap(producer_tv, consumer_tv) + .mapProducerToConsumer(); auto consumer_root_right_of_ca_domains = IterVisitor::getInputsTo( {consumer_tv->getLoopDomain().begin() + diff --git a/csrc/device_lower/pass/replace_size.cpp b/csrc/device_lower/pass/replace_size.cpp index 48bc9b7a535..1dc48319d45 100644 --- a/csrc/device_lower/pass/replace_size.cpp +++ b/csrc/device_lower/pass/replace_size.cpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include @@ -90,7 +90,7 @@ std::unordered_map getSimplificationMap(Fusion* fusion) { for (auto producer_tv : ir_utils::filterByType(fusion_vals)) { auto consumer_tvs = ir_utils::consumerTvsOf(producer_tv); for (auto consumer_tv : consumer_tvs) { - auto pairwise_map = PairwiseRootDomainMap(producer_tv, consumer_tv); + auto pairwise_map = PairwiseLogicalDomainMap(producer_tv, consumer_tv); auto c2p_root_map = pairwise_map.mapConsumerToProducer(); for (auto entry : c2p_root_map) { auto c_id = entry.first; diff --git a/csrc/device_lower/pass/unroll.h b/csrc/device_lower/pass/unroll.h index da04b7d2875..0af5f945e16 100644 --- a/csrc/device_lower/pass/unroll.h +++ b/csrc/device_lower/pass/unroll.h @@ -12,7 +12,7 @@ #include #include #include -#include +#include #include #include diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index 8c5e5b4a0fa..c3fbeb5af70 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -15,8 +15,8 @@ #include #include #include +#include #include -#include #include #include @@ -909,7 +909,8 @@ std::array getMmaLayout(const MmaOp* expr) { continue; } NVF_ERROR(in_tv->getMemoryType() == MemoryType::Shared); - auto out2in = PairwiseRootDomainMap(in_tv, out_tv).mapConsumerToProducer(); + auto out2in = + PairwiseLogicalDomainMap(in_tv, out_tv).mapConsumerToProducer(); auto reduction_id_in = out2in.at(reduction_id); auto inner_id = in_tv->getMaybeAllocationDomain().back(); while (inner_id != reduction_id_in && inner_id->definition() != nullptr) { diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index a49ac06873c..cdd96ca381f 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -237,7 +237,7 @@ void checkContiguity( !consumer->hasAllocation() && !producer->hasAllocation(), "Misaligned vectorization for allocation domain is not supported."); auto alloc_c2p = - PairwiseRootDomainMap(producer, consumer).mapConsumerToProducer(); + PairwiseLogicalDomainMap(producer, consumer).mapConsumerToProducer(); std::unordered_map> producer_domain_contiguity; @@ -515,7 +515,7 @@ class VectorizeValidator : public OptInDispatch { vectorized_set_info.vectorized_consumer_alloc_id = consumer_vectorized_id; // Validate producer - auto pairwise_map = PairwiseRootDomainMap(producer_tv, tv); + auto pairwise_map = PairwiseLogicalDomainMap(producer_tv, tv); auto producer_replayed_as_consumer = TransformReplay::replayPasC( producer_tv, @@ -1097,7 +1097,7 @@ void validateReductions(Fusion* fusion) { for (auto rop : ir_utils::getOpsOfType(fusion)) { auto in = rop->in()->as(); auto out = rop->out()->as(); - PairwiseRootDomainMap c2p_map(in, out); + PairwiseLogicalDomainMap c2p_map(in, out); c2p_map.mapBroadcast(true); auto c2p = c2p_map.mapConsumerToProducer(); for (auto out_id : out->getMaybeRootDomain()) { diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index b44a14684a0..cdddb29337f 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -1214,7 +1214,7 @@ static bool hasTrivialReduction( TensorView* out, std::vector& reduction_axes) { bool has_trivial_reduction = false; - PairwiseRootDomainMap p2c_map(in, out); + PairwiseLogicalDomainMap p2c_map(in, out); // We need to map broadcasts in order to detect reductions of broadcasts p2c_map.mapBroadcast(true); auto p2c = p2c_map.mapProducerToConsumer(); @@ -1303,7 +1303,7 @@ bool DynamicTransformConcretizer::propagateFromProducerToConsumer( std::vector> c2p_maps; bool is_factory_output = true; for (auto producer : ir_utils::filterByType(def->inputs())) { - PairwiseRootDomainMap root_map(producer, consumer); + PairwiseLogicalDomainMap root_map(producer, consumer); // We map symbolic domains here regardless of whether their extents match. // This is safe because we are propagating from a producer which should have // already been concretized. The consumer might have a different extent diff --git a/csrc/expr_evaluator.cpp b/csrc/expr_evaluator.cpp index 7cdead9810c..52e214737fa 100644 --- a/csrc/expr_evaluator.cpp +++ b/csrc/expr_evaluator.cpp @@ -13,7 +13,7 @@ #include #include #include -#include +#include #include #include @@ -293,7 +293,7 @@ void ExpressionEvaluator::propagateBoundValuesThroughExactMaps(Fusion* fusion) { // We map Symbolic IterDomains here only if their extents match. This avoids // mapping between symbolic domains that might concretize to an (Iteration, // Broadcast) pair from a resolved broadcast. - const auto mapped_sets = ExactRootDomainMap(fusion).getMappedSets(); + const auto mapped_sets = ExactLogicalDomainMap(fusion).getMappedSets(); for (const auto& set : mapped_sets.disjointSets()) { int64_t known_size = -1; diff --git a/csrc/expr_evaluator.h b/csrc/expr_evaluator.h index f2a952e95f7..159693e2d84 100644 --- a/csrc/expr_evaluator.h +++ b/csrc/expr_evaluator.h @@ -84,7 +84,7 @@ class ExpressionEvaluator { //! Augment the evaluator with the exact root-domain map such that //! if the extent of a root ID is known, the extents of all other //! root IDs that are exactly mapped also get bound to the same - //! value. This is currently just done with ExactRootDomainMap, but + //! value. This is currently just done with ExactLogicalDomainMap, but //! can be similarly done with the Exact CA map as well. void propagateBoundValuesThroughExactMaps(Fusion* fusion); diff --git a/csrc/grouped_reduction.cpp b/csrc/grouped_reduction.cpp index 99fd65d4e5d..a35122e92e7 100644 --- a/csrc/grouped_reduction.cpp +++ b/csrc/grouped_reduction.cpp @@ -7,7 +7,7 @@ // clang-format on #include #include -#include +#include #include #include @@ -60,7 +60,7 @@ bool validateReductionGrouping( NVF_ERROR( fusion != nullptr, "Grouping of reductions must be done within a Fusion"); - ExactRootDomainMap exact_map(fusion); + ExactLogicalDomainMap exact_map(fusion); // Pick the first output TV as a reference and compare it with the // rest. Do not allow grouping if any mismatch is detected. diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index c2cd23ba639..59da76d59e1 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -16,7 +16,7 @@ #include #include #include -#include +#include #include #include @@ -308,7 +308,7 @@ void IdModel::buildExactGraph() { // For exact mapings do not map any broadcast dimensions to // non-broadcast dimensions. Prevent any broadcasted axes being mapped // to non-broadcasted axes. - auto exact_c2p_root_map = PairwiseRootDomainMap(p_tv, c_tv) + auto exact_c2p_root_map = PairwiseLogicalDomainMap(p_tv, c_tv) .mapBroadcast(false) .mapConsumerToProducer(); @@ -454,7 +454,7 @@ void IdModel::buildPermissiveGraph() { } auto permissive_c2p_root_map = - PairwiseRootDomainMap(p_tv, c_tv).mapBroadcast(true); + PairwiseLogicalDomainMap(p_tv, c_tv).mapBroadcast(true); for (auto entry : permissive_c2p_root_map.mapConsumerToProducer()) { idGraph(IdMappingMode::PERMISSIVE).mapVals(entry.first, entry.second); @@ -472,7 +472,7 @@ namespace { std::vector> resolvedRootBroadcasts( TensorView* producer, TensorView* consumer) { - auto p2c_map = PairwiseRootDomainMap(producer, consumer) + auto p2c_map = PairwiseLogicalDomainMap(producer, consumer) .mapBroadcast(true) .mapProducerToConsumer(); diff --git a/csrc/index_compute.cpp b/csrc/index_compute.cpp index 2a1b22c334b..b0f9a4654db 100644 --- a/csrc/index_compute.cpp +++ b/csrc/index_compute.cpp @@ -23,8 +23,8 @@ #include #include #include +#include #include -#include #include #include #include @@ -1356,7 +1356,7 @@ std::unordered_map mapAllProducerDomainsToConsumer( producer_tv, consumer_tv, -1, - PairwiseRootDomainMap(producer_tv, consumer_tv)); + PairwiseLogicalDomainMap(producer_tv, consumer_tv)); // Grab consumer domain entries and reverse replay map. TODO: Maybe // TransformReplay::replayPasC could return this map @@ -1394,7 +1394,7 @@ std::vector Index::getNonGlobalProducerStridedIndices( const auto gpu_lower = GpuLower::current(); // Replay producer to look like consumer so we can index on producer since our // loop nests look like consumer - auto pairwise_map = PairwiseRootDomainMap(producer_tv, consumer_tv); + auto pairwise_map = PairwiseLogicalDomainMap(producer_tv, consumer_tv); // Resize ops can be and should be replayed. auto producer_replayed_as_consumer = TransformReplay::replayPasC( @@ -1420,7 +1420,7 @@ std::vector Index::getNonGlobalProducerStridedIndices( // Map sent to best effort replay needs to match the exact incantation for // compute_at_mode.cpp with MappingMode::Index - auto c2p_root_map = PairwiseRootDomainMap(producer_tv, consumer_tv) + auto c2p_root_map = PairwiseLogicalDomainMap(producer_tv, consumer_tv) .mapBroadcast(false) .mapConsumerToProducer(); @@ -1720,7 +1720,7 @@ std::vector Index::getProducerAllocationIndices( // Replay producer to look like consumer so we can index on producer since // our loop nests look like consumer auto pairwise_map = - PairwiseRootDomainMap(producer_tv, consumer_tv).mapBroadcast(true); + PairwiseLogicalDomainMap(producer_tv, consumer_tv).mapBroadcast(true); TensorDomain* producerAsC = TransformReplay::replayPasC( producer_tv, @@ -1735,7 +1735,7 @@ std::vector Index::getProducerAllocationIndices( // Map sent to best effort replay needs to match the exact incantation for // compute_at_mode.cpp with MappingMode::Index - auto c2p_root_map = PairwiseRootDomainMap(producer_tv, consumer_tv) + auto c2p_root_map = PairwiseLogicalDomainMap(producer_tv, consumer_tv) .mapBroadcast(false) .mapConsumerToProducer(); @@ -1766,7 +1766,7 @@ std::vector Index::getProducerAllocationIndices( // If we add I1->I6 and I2->I7, the c2p map will no longer be injective, which // is not what we want. const auto p2c_map = invertOneToOneMap(c2p_map); - for (const auto& kv : PairwiseRootDomainMap(producer_tv, consumer_tv) + for (const auto& kv : PairwiseLogicalDomainMap(producer_tv, consumer_tv) .mapBroadcast(false) .mapDifferentExtents(true) .mapConsumerToProducer()) { diff --git a/csrc/index_compute.h b/csrc/index_compute.h index 46a6f9549bd..f5082082e67 100644 --- a/csrc/index_compute.h +++ b/csrc/index_compute.h @@ -9,7 +9,7 @@ #include #include -#include +#include #include #include diff --git a/csrc/inlining.cpp b/csrc/inlining.cpp index 05c7399c267..d875e636848 100644 --- a/csrc/inlining.cpp +++ b/csrc/inlining.cpp @@ -7,7 +7,7 @@ // clang-format on #include #include -#include +#include #include #include @@ -27,15 +27,16 @@ void MaxPosCalculator::buildUnmappableDims(bool compute_at_only) { if (compute_at_only) { return; } - ComputeAtRootDomainMap root_map; + ComputeAtLogicalDomainMap root_map; root_map.build(); auto all_tvs = ir_utils::allTvs(FusionGuard::getCurFusion()); for (auto tv : all_tvs) { auto consumers = ir_utils::consumerTvsOf(tv); for (auto consumer : consumers) { // Grab dimensions in producer and consumer that are mappable to eachother - // based on the computeAtRootDomainMap. This will tell us which dimensions - // can be inlined based on avoiding trying to inline reduction structures. + // based on the computeAtLogicalDomainMap. This will tell us which + // dimensions can be inlined based on avoiding trying to inline reduction + // structures. auto mappable_roots = root_map.getMappableDims(tv->domain(), consumer->domain()); for (auto tv_logical_id : tv->getLogicalDomain()) { @@ -129,7 +130,7 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer( TensorView* producer, TensorView* consumer, bool best_effort) const { - auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer); + auto pairwise_root_map = PairwiseLogicalDomainMap(producer, consumer); auto replay_CasP = BestEffortReplay::replayCasP(consumer, producer, -1, pairwise_root_map); auto p2c_replay_map = replay_CasP.getReplay(); @@ -283,7 +284,7 @@ std::unordered_map getPositionsMappedTo( TensorView* reference_tv, int64_t reference_pos) { std::unordered_map mapped_positions; - MaxRootDomainInfoSpanningTree tree(reference_tv, reference_pos); + MaxLogicalDomainInfoSpanningTree tree(reference_tv, reference_pos); FindMappedPositions propagator(mapped_positions, reference_tv, reference_pos); tree.traverse(&propagator); return mapped_positions; diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 1ac37fe4b22..811fcbc94cc 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -16,8 +16,8 @@ #include #include #include +#include #include -#include #include #include #include diff --git a/csrc/root_domain_map.cpp b/csrc/logical_domain_map.cpp similarity index 86% rename from csrc/root_domain_map.cpp rename to csrc/logical_domain_map.cpp index 47f871619ae..6dc4ade4466 100644 --- a/csrc/root_domain_map.cpp +++ b/csrc/logical_domain_map.cpp @@ -9,48 +9,48 @@ #include #include #include +#include #include -#include #include namespace nvfuser { -std::unordered_map RootDomainMap:: +std::unordered_map LogicalDomainMap:: mapProducerToConsumer( const TensorDomain* producer, const TensorDomain* consumer, - const std::unordered_set& logical_dims_to_map) const { - return map(producer, consumer, logical_dims_to_map, true); + const std::unordered_set& dims_to_map) const { + return map(producer, consumer, dims_to_map, true); } -std::unordered_map RootDomainMap:: +std::unordered_map LogicalDomainMap:: mapProducerToConsumer( const TensorDomain* producer, const TensorDomain* consumer) const { - std::unordered_set logical_dims_to_map( + std::unordered_set dims_to_map( producer->logical().begin(), producer->logical().end()); - return mapProducerToConsumer(producer, consumer, logical_dims_to_map); + return mapProducerToConsumer(producer, consumer, dims_to_map); } -std::unordered_map RootDomainMap:: +std::unordered_map LogicalDomainMap:: mapConsumerToProducer( const TensorDomain* consumer, const TensorDomain* producer, - const std::unordered_set& root_dims_to_map) const { - return map(producer, consumer, root_dims_to_map, false); + const std::unordered_set& dims_to_map) const { + return map(producer, consumer, dims_to_map, false); } -std::unordered_map RootDomainMap:: +std::unordered_map LogicalDomainMap:: mapConsumerToProducer( const TensorDomain* consumer, const TensorDomain* producer) const { - std::unordered_set root_dims_to_map( + std::unordered_set dims_to_map( consumer->maybeRoot().begin(), consumer->maybeRoot().end()); - return mapConsumerToProducer(consumer, producer, root_dims_to_map); + return mapConsumerToProducer(consumer, producer, dims_to_map); } -PairwiseRootDomainMap::PairwiseRootDomainMap( +PairwiseLogicalDomainMap::PairwiseLogicalDomainMap( const TensorView* producer, const TensorView* consumer) : producer_tv_(producer), consumer_tv_(consumer) { @@ -99,10 +99,10 @@ std::pair getIndexedDomainInfo( } // namespace -std::unordered_map PairwiseRootDomainMap::map( +std::unordered_map PairwiseLogicalDomainMap::map( const TensorDomain* producer, const TensorDomain* consumer, - const std::unordered_set& root_dims_to_map, + const std::unordered_set& dims_to_map, bool producer_to_consumer) const { std::vector broadcast_flags; if (BroadcastOp* bop = @@ -127,8 +127,8 @@ std::unordered_map PairwiseRootDomainMap::map( // true. // 2. Do not map Symbolic ID if the extents are not identical unless // map_symbolic_ = true. - auto updatePairwiseRootDomainMap = [&](IterDomain* producer_id, - IterDomain* consumer_id) { + auto updatePairwiseLogicalDomainMap = [&](IterDomain* producer_id, + IterDomain* consumer_id) { if (!map_broadcast_ && producer_id->isBroadcast() != consumer_id->isBroadcast()) { return; @@ -159,7 +159,7 @@ std::unordered_map PairwiseRootDomainMap::map( std::swap(map_key_id, map_value_id); } - if (root_dims_to_map.find(map_key_id) != root_dims_to_map.end()) { + if (dims_to_map.find(map_key_id) != dims_to_map.end()) { dom_map.insert(std::make_pair(map_key_id, map_value_id)); } }; @@ -175,7 +175,7 @@ std::unordered_map PairwiseRootDomainMap::map( if (producer_id == nullptr) { continue; } - updatePairwiseRootDomainMap(producer_id, consumer_id); + updatePairwiseLogicalDomainMap(producer_id, consumer_id); } }; @@ -241,23 +241,24 @@ std::unordered_map PairwiseRootDomainMap::map( // Map N, H from any input (query/key/value) for (auto idx : c10::irange(consumer_root.size())) { if (idx >= num_device_dim && idx < (2 + num_device_dim)) { - updatePairwiseRootDomainMap( + updatePairwiseLogicalDomainMap( producer_logical.at(idx), consumer_root.at(idx)); } // Map L, E from query and value respectively if (idx == (2 + num_device_dim) && producer_tv_->sameAs(op->query())) { - updatePairwiseRootDomainMap( + updatePairwiseLogicalDomainMap( producer_logical.at(idx), consumer_root.at(idx)); } // Map Ev from value to output if (idx == (3 + num_device_dim) && producer_tv_->sameAs(op->value())) { - updatePairwiseRootDomainMap( + updatePairwiseLogicalDomainMap( producer_logical.at(idx), consumer_root.at(idx)); } } // Map D from any input (query/key/value) to output only. if (num_device_dim == 1 && consumer_root.size() > 3) { - updatePairwiseRootDomainMap(producer_logical.at(0), consumer_root.at(0)); + updatePairwiseLogicalDomainMap( + producer_logical.at(0), consumer_root.at(0)); } return dom_map; } @@ -316,7 +317,7 @@ std::unordered_map PairwiseRootDomainMap::map( continue; } - updatePairwiseRootDomainMap(producer_id, consumer_id); + updatePairwiseLogicalDomainMap(producer_id, consumer_id); itc++; itp++; @@ -324,31 +325,31 @@ std::unordered_map PairwiseRootDomainMap::map( return dom_map; } -std::unordered_map PairwiseRootDomainMap:: +std::unordered_map PairwiseLogicalDomainMap:: mapProducerToConsumer( - const std::unordered_set* root_dims_to_map) const { - if (root_dims_to_map == nullptr) { - return RootDomainMap::mapProducerToConsumer( + const std::unordered_set* dims_to_map) const { + if (dims_to_map == nullptr) { + return LogicalDomainMap::mapProducerToConsumer( producerTv()->domain(), consumerTv()->domain()); } else { - return RootDomainMap::mapProducerToConsumer( - producerTv()->domain(), consumerTv()->domain(), *root_dims_to_map); + return LogicalDomainMap::mapProducerToConsumer( + producerTv()->domain(), consumerTv()->domain(), *dims_to_map); } } -std::unordered_map PairwiseRootDomainMap:: +std::unordered_map PairwiseLogicalDomainMap:: mapConsumerToProducer( - const std::unordered_set* root_dims_to_map) const { - if (root_dims_to_map == nullptr) { - return RootDomainMap::mapConsumerToProducer( + const std::unordered_set* dims_to_map) const { + if (dims_to_map == nullptr) { + return LogicalDomainMap::mapConsumerToProducer( consumerTv()->domain(), producerTv()->domain()); } else { - return RootDomainMap::mapConsumerToProducer( - consumerTv()->domain(), producerTv()->domain(), *root_dims_to_map); + return LogicalDomainMap::mapConsumerToProducer( + consumerTv()->domain(), producerTv()->domain(), *dims_to_map); } } -std::string PairwiseRootDomainMap::toString() const { +std::string PairwiseLogicalDomainMap::toString() const { std::stringstream ss; ss << "{producer: " << producerTv() << ", consumer: " << consumerTv(); auto p2c = mapProducerToConsumer(); @@ -444,7 +445,7 @@ class FindInputDomains : BackwardVisitor { } void propagate(TensorView* in_tv, TensorView* out_tv) { - auto c2p = PairwiseRootDomainMap(in_tv, out_tv).mapConsumerToProducer(); + auto c2p = PairwiseLogicalDomainMap(in_tv, out_tv).mapConsumerToProducer(); for (auto root_dom : out_tv->getMaybeRootDomain()) { DomainKey out_key({out_tv->domain(), root_dom}); if (input_keys_.find(out_key) == input_keys_.end()) { @@ -530,7 +531,7 @@ void UnmappableReductionDomains::handle(WelfordOp* op) { bool UnmappableReductionDomains::isReductionOutputMapped( const DomainKeySet& consumer_domains, - const ComputeAtRootDomainMap& root_map) const { + const ComputeAtLogicalDomainMap& logical_map) const { // Check each reduction domain if any of the consumer domains // conflicts with it for (const auto& kv : reduction_domains_) { @@ -565,11 +566,12 @@ bool UnmappableReductionDomains::isReductionOutputMapped( // to be an input to reduction domain and also used by the // consumers, it becomes a persistent tensor. if (input_key.id()->isBroadcast()) { - if (!root_map.isConcretized(input_key.td(), input_key.id())) { + if (!logical_map.isConcretized( + input_key.td(), input_key.id())) { return false; } } - return root_map.canMap( + return logical_map.canMap( consumer_domain.td(), consumer_domain.id(), input_key.td(), @@ -598,7 +600,7 @@ bool UnmappableReductionDomains::isReductionOutputMapped( incompatible_domains.begin(), incompatible_domains.end(), [&](const DomainKey& incompatible_domain) { - return root_map.canMap( + return logical_map.canMap( consumer_domain.td(), consumer_domain.id(), incompatible_domain.td(), @@ -632,16 +634,16 @@ std::string UnmappableReductionDomains::toString() const { return ss.str(); } -void ComputeAtRootDomainMap::build(bool map_through_reduction) { +void ComputeAtLogicalDomainMap::build(bool map_through_reduction) { // Make sure we start from scratch. Throw away previous results. eq_set_.clear(); bcast_map_.clear(); new_broadcast_domains_.clear(); removed_broadcast_domains_.clear(); - ComputeAtRootDomainMapBuilder builder(*this, map_through_reduction); + ComputeAtLogicalDomainMapBuilder builder(*this, map_through_reduction); } -bool ComputeAtRootDomainMap::canMap( +bool ComputeAtLogicalDomainMap::canMap( const TensorDomain* td_a, const IterDomain* id_a, const TensorDomain* td_b, @@ -691,7 +693,7 @@ bool ComputeAtRootDomainMap::canMap( return mappable_pair_found; } -bool ComputeAtRootDomainMap::canMap( +bool ComputeAtLogicalDomainMap::canMap( const DomainKey& key_a, const TensorDomain* td_b, const IterDomain* id_b) const { @@ -728,13 +730,13 @@ bool ComputeAtRootDomainMap::canMap( return mappable_pair_found; } -bool ComputeAtRootDomainMap::canMap( +bool ComputeAtLogicalDomainMap::canMap( const DomainKey& key_a, const DomainKey& key_b) const { return key_a == key_b || eq_set_.permissiveAreMapped(key_a, key_b); } -void ComputeAtRootDomainMap::setAlias( +void ComputeAtLogicalDomainMap::setAlias( const TensorDomain* td, const TensorDomain* td_alias) { auto tmp_bcast_map = bcast_map_; @@ -775,7 +777,7 @@ void ComputeAtRootDomainMap::setAlias( removed_broadcast_domains_ = tmp_removed_broadcast_domains; } -std::vector ComputeAtRootDomainMap::getConcretizedKeys( +std::vector ComputeAtLogicalDomainMap::getConcretizedKeys( const TensorDomain* td, const IterDomain* id) const { DomainKey key(td, id); @@ -792,7 +794,7 @@ std::vector ComputeAtRootDomainMap::getConcretizedKeys( return domains; } -std::unordered_set& ComputeAtRootDomainMap:: +std::unordered_set& ComputeAtLogicalDomainMap:: getConcretizedDomains(const TensorDomain* td, const IterDomain* id) { DomainKey key(td, id); auto it = bcast_map_.find(key); @@ -800,7 +802,7 @@ std::unordered_set& ComputeAtRootDomainMap:: return it->second; } -bool ComputeAtRootDomainMap::isConcretized( +bool ComputeAtLogicalDomainMap::isConcretized( const TensorDomain* td, const IterDomain* id) const { DomainKey key(td, id); @@ -808,15 +810,15 @@ bool ComputeAtRootDomainMap::isConcretized( return it != bcast_map_.end(); } -std::unordered_map ComputeAtRootDomainMap:: +std::unordered_map ComputeAtLogicalDomainMap:: mapBestEffort( const TensorDomain* from_td, - const std::vector& from_root, + const std::vector& from_dom, const TensorDomain* to_td, - const std::vector& to_root) const { + const std::vector& to_dom) const { std::unordered_map id_map; - for (auto& from_id : from_root) { - for (const auto& to_id : to_root) { + for (auto& from_id : from_dom) { + for (const auto& to_id : to_dom) { if (canMap(from_td, from_id, to_td, to_id)) { NVF_ERROR( id_map.insert({from_id, to_id}).second, @@ -828,10 +830,10 @@ std::unordered_map ComputeAtRootDomainMap:: return id_map; } -std::unordered_map ComputeAtRootDomainMap::map( +std::unordered_map ComputeAtLogicalDomainMap::map( const TensorDomain* producer, const TensorDomain* consumer, - const std::unordered_set& root_dims_to_map, + const std::unordered_set& dims_to_map, bool producer_to_consumer) const { const auto& producer_logical = TensorDomain::noReductions(producer->logical()); @@ -844,7 +846,7 @@ std::unordered_map ComputeAtRootDomainMap::map( std::unordered_map id_map = mapBestEffort(from_td, from_ids, to_td, to_ids); for (auto& from_id : from_ids) { - if (root_dims_to_map.find(from_id) == root_dims_to_map.end()) { + if (dims_to_map.find(from_id) == dims_to_map.end()) { // Remove mapping if exists id_map.erase(from_id); continue; @@ -857,7 +859,7 @@ std::unordered_map ComputeAtRootDomainMap::map( // 2. from_id is a removed broadcast of a producer domain; or // 3. from_id is a window axis of a consumer domain; or // 4. from_id is a ViewAsScalar domain - // Note that reduction domains are removed from the producer root domain. + // Note that reduction domains are removed from the producer logical domain. if ((!producer_to_consumer && (new_broadcast_domains_.find(DomainKey(from_td, from_id)) != new_broadcast_domains_.end() || @@ -885,7 +887,7 @@ std::unordered_map ComputeAtRootDomainMap::map( return id_map; } -std::unordered_set ComputeAtRootDomainMap::getMappableDims( +std::unordered_set ComputeAtLogicalDomainMap::getMappableDims( const TensorDomain* producer, const TensorDomain* consumer) const { //! This funciton previously used mapBestEffort but it can fail when @@ -910,15 +912,15 @@ std::unordered_set ComputeAtRootDomainMap::getMappableDims( return mappable_ids; } -std::string ComputeAtRootDomainMap::toString() const { +std::string ComputeAtLogicalDomainMap::toString() const { return eq_set_.toString(); } -ComputeAtRootDomainMapBuilder::ComputeAtRootDomainMapBuilder( - ComputeAtRootDomainMap& root_map, +ComputeAtLogicalDomainMapBuilder::ComputeAtLogicalDomainMapBuilder( + ComputeAtLogicalDomainMap& logical_map, bool map_through_reduction) : BackwardVisitor(false), - root_map_(root_map), + logical_map_(logical_map), map_through_reduction_(map_through_reduction) { Fusion* fusion = FusionGuard::getCurFusion(); NVF_ERROR(fusion != nullptr); @@ -940,13 +942,13 @@ ComputeAtRootDomainMapBuilder::ComputeAtRootDomainMapBuilder( // Set concrete domains for broadcast domains that never get joined // with a concrete domain. Just set its own domain as a concrete // domain, which is not concrete but is sufficient for this analysis. -void ComputeAtRootDomainMapBuilder::initializeBcastMap( +void ComputeAtLogicalDomainMapBuilder::initializeBcastMap( const TensorView* tv, const IterDomain* id) { NVF_ERROR(id->isBroadcast(), "Not a broadcast axis"); auto key = DomainKey(tv->domain(), id); - auto it = root_map_.bcast_map_.find(key); - if (it != root_map_.bcast_map_.end()) { + auto it = logical_map_.bcast_map_.find(key); + if (it != logical_map_.bcast_map_.end()) { // already initialized. return; } @@ -957,17 +959,17 @@ void ComputeAtRootDomainMapBuilder::initializeBcastMap( // pairwise map has no mapping for the broadcast. for (auto consumer : ir_utils::consumerTvsOf(tv)) { const auto p2c = - PairwiseRootDomainMap(tv, consumer).mapProducerToConsumer(); + PairwiseLogicalDomainMap(tv, consumer).mapProducerToConsumer(); // Unfortunately, const_cast is required as our const model is // broken. // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) NVF_ERROR(p2c.find(const_cast(id)) == p2c.end()); } - root_map_.bcast_map_.insert({key, {id}}); + logical_map_.bcast_map_.insert({key, {id}}); } -void ComputeAtRootDomainMapBuilder::addToPendingList( +void ComputeAtLogicalDomainMapBuilder::addToPendingList( const DomainKey& producer, const DomainKey& consumer) { auto it = ensureMapping(pending_map_, producer, {}); @@ -975,28 +977,28 @@ void ComputeAtRootDomainMapBuilder::addToPendingList( consumer_set.insert(consumer); } -void ComputeAtRootDomainMapBuilder::setMapped( +void ComputeAtLogicalDomainMapBuilder::setMapped( const DomainKey& producer, const DomainKey& consumer) { - root_map_.eq_set_.mapEntries(producer, consumer); + logical_map_.eq_set_.mapEntries(producer, consumer); } -void ComputeAtRootDomainMapBuilder::setInvalid( +void ComputeAtLogicalDomainMapBuilder::setInvalid( const DomainKey& key1, const DomainKey& key2) { invalid_mappings_.emplace_back(key1, key2); } -bool ComputeAtRootDomainMapBuilder::isInvalid( +bool ComputeAtLogicalDomainMapBuilder::isInvalid( const DomainKeySet& domains) const { // First, collect all invalid mappings for each of the keys in domains DomainKeyMap invalid_key_map; for (const auto& key : domains) { DomainKeySet invalid_keys; for (const auto& invalid_pair : invalid_mappings_) { - if (root_map_.canMap(key, invalid_pair.first)) { + if (logical_map_.canMap(key, invalid_pair.first)) { invalid_keys.insert(invalid_pair.second); - } else if (root_map_.canMap(key, invalid_pair.second)) { + } else if (logical_map_.canMap(key, invalid_pair.second)) { invalid_keys.insert(invalid_pair.first); } } @@ -1025,7 +1027,7 @@ bool ComputeAtRootDomainMapBuilder::isInvalid( invalid_keys_for_i.begin(), invalid_keys_for_i.end(), [&](const auto& invalid_key_for_i) { - return root_map_.canMap(key_j, invalid_key_for_i); + return logical_map_.canMap(key_j, invalid_key_for_i); })) { return true; } @@ -1034,7 +1036,7 @@ bool ComputeAtRootDomainMapBuilder::isInvalid( return false; } -void ComputeAtRootDomainMapBuilder::setMaybeMapped( +void ComputeAtLogicalDomainMapBuilder::setMaybeMapped( const TensorDomain* producer_td, const IterDomain* producer_id, const TensorDomain* consumer_td, @@ -1043,7 +1045,7 @@ void ComputeAtRootDomainMapBuilder::setMaybeMapped( const DomainKey consumer_key(consumer_td, consumer_id); if (producer_id->isBroadcast()) { - ensureMapping(root_map_.bcast_map_, producer_key, {}); + ensureMapping(logical_map_.bcast_map_, producer_key, {}); } if (consumer_id->isBroadcast()) { @@ -1056,9 +1058,9 @@ void ComputeAtRootDomainMapBuilder::setMaybeMapped( consumer_id->toString()); // Get bcast_map_ entry for consumer_id const auto consumer_bcast_domains = - root_map_.getConcretizedKeys(consumer_td, consumer_id); + logical_map_.getConcretizedKeys(consumer_td, consumer_id); auto& producer_domains = - root_map_.getConcretizedDomains(producer_td, producer_id); + logical_map_.getConcretizedDomains(producer_td, producer_id); // If consumer id is broadcasted, make sure to propagate its concrete_id(s) // to producer @@ -1073,7 +1075,7 @@ void ComputeAtRootDomainMapBuilder::setMaybeMapped( if (producer_id->isBroadcast()) { const auto concrete_id = consumer_id; auto& producer_domains = - root_map_.getConcretizedDomains(producer_td, producer_id); + logical_map_.getConcretizedDomains(producer_td, producer_id); producer_concrete_key = DomainKey(producer_td, producer_id, concrete_id); producer_domains.insert(concrete_id); } @@ -1081,7 +1083,7 @@ void ComputeAtRootDomainMapBuilder::setMaybeMapped( } } -void ComputeAtRootDomainMapBuilder::dispatch(Expr* e) { +void ComputeAtLogicalDomainMapBuilder::dispatch(Expr* e) { // Avoid visiting expressions multiple times if (visited_.find(e) != visited_.end()) { return; @@ -1090,7 +1092,7 @@ void ComputeAtRootDomainMapBuilder::dispatch(Expr* e) { visited_.insert(e); } -void ComputeAtRootDomainMapBuilder::mapPointwiseLikeOp(Expr* expr) { +void ComputeAtLogicalDomainMapBuilder::mapPointwiseLikeOp(Expr* expr) { if (expr->output(0)->getValType() != ValType::TensorView) { return; } @@ -1115,9 +1117,10 @@ void ComputeAtRootDomainMapBuilder::mapPointwiseLikeOp(Expr* expr) { for (auto producer_tv : ir_utils::filterByType(expr->inputs())) { for (auto consumer_tv : ir_utils::filterByType(expr->outputs())) { - for (const auto& mapping : PairwiseRootDomainMap(producer_tv, consumer_tv) - .mapBroadcast(true) - .mapProducerToConsumer()) { + for (const auto& mapping : + PairwiseLogicalDomainMap(producer_tv, consumer_tv) + .mapBroadcast(true) + .mapProducerToConsumer()) { setMaybeMapped( producer_tv->domain(), mapping.first, @@ -1128,7 +1131,7 @@ void ComputeAtRootDomainMapBuilder::mapPointwiseLikeOp(Expr* expr) { } } -void ComputeAtRootDomainMapBuilder::handle(BroadcastOp* op) { +void ComputeAtLogicalDomainMapBuilder::handle(BroadcastOp* op) { const TensorDomain* in_td = op->in()->as()->domain(); const TensorDomain* out_td = op->out()->as()->domain(); const auto in_logical = TensorDomain::noReductions(in_td->logical()); @@ -1146,7 +1149,7 @@ void ComputeAtRootDomainMapBuilder::handle(BroadcastOp* op) { if (bcast_dim_flags.at(std::distance(out_root.begin(), out_it))) { // new broadcast dim. No matching dimension in the input // tensor. - root_map_.new_broadcast_domains_.insert(DomainKey(out_td, *out_it)); + logical_map_.new_broadcast_domains_.insert(DomainKey(out_td, *out_it)); ++out_it; continue; } @@ -1171,11 +1174,11 @@ void ComputeAtRootDomainMapBuilder::handle(BroadcastOp* op) { *out_it, " of ", out_td); - root_map_.new_broadcast_domains_.insert(DomainKey(out_td, *out_it)); + logical_map_.new_broadcast_domains_.insert(DomainKey(out_td, *out_it)); } } -void ComputeAtRootDomainMapBuilder::handle(SqueezeOp* op) { +void ComputeAtLogicalDomainMapBuilder::handle(SqueezeOp* op) { const TensorDomain* in_td = op->in()->as()->domain(); const TensorDomain* out_td = op->out()->as()->domain(); const auto in_logical = TensorDomain::noReductions(in_td->logical()); @@ -1193,7 +1196,7 @@ void ComputeAtRootDomainMapBuilder::handle(SqueezeOp* op) { if (squeeze_dim_flags.at(std::distance(in_logical.begin(), in_it))) { // new broadcast dim. No matching dimension in the input // tensor. - root_map_.removed_broadcast_domains_.insert(DomainKey(in_td, *in_it)); + logical_map_.removed_broadcast_domains_.insert(DomainKey(in_td, *in_it)); ++in_it; continue; } @@ -1218,11 +1221,11 @@ void ComputeAtRootDomainMapBuilder::handle(SqueezeOp* op) { *in_it, " of ", in_td); - root_map_.removed_broadcast_domains_.insert(DomainKey(in_td, *in_it)); + logical_map_.removed_broadcast_domains_.insert(DomainKey(in_td, *in_it)); } } -void ComputeAtRootDomainMapBuilder::handle(ViewAsScalar* op) { +void ComputeAtLogicalDomainMapBuilder::handle(ViewAsScalar* op) { const TensorView* out_tv = op->output(0)->as(); const TensorDomain* out_td = out_tv->domain(); const auto& out_root = out_td->maybeRoot(); @@ -1236,7 +1239,7 @@ void ComputeAtRootDomainMapBuilder::handle(ViewAsScalar* op) { in_logical.size() + 1 == out_root.size(), "\nExpression: ", op, - "\nInput root domain: ", + "\nInput logical domain: ", in_logical, "\nOutput root domain: ", out_root); @@ -1252,7 +1255,7 @@ void ComputeAtRootDomainMapBuilder::handle(ViewAsScalar* op) { "The last dim of ViewDtypeOp's output must be a ViewAsScalar"); } -void ComputeAtRootDomainMapBuilder::mapAllPendingMappings( +void ComputeAtLogicalDomainMapBuilder::mapAllPendingMappings( const DomainKey& key) { auto it = pending_map_.find(key); if (it == pending_map_.end()) { @@ -1273,11 +1276,11 @@ void ComputeAtRootDomainMapBuilder::mapAllPendingMappings( pending_map_.erase(it); } -void ComputeAtRootDomainMapBuilder::mapAllPendingMappings( +void ComputeAtLogicalDomainMapBuilder::mapAllPendingMappings( const TensorDomain* td, IterDomain* id) { if (id->isBroadcast()) { - for (const auto& key : root_map_.getConcretizedKeys(td, id)) { + for (const auto& key : logical_map_.getConcretizedKeys(td, id)) { mapAllPendingMappings(key); } } else { @@ -1285,11 +1288,11 @@ void ComputeAtRootDomainMapBuilder::mapAllPendingMappings( } } -void ComputeAtRootDomainMapBuilder::handle(RNGOp* rop) { +void ComputeAtLogicalDomainMapBuilder::handle(RNGOp* rop) { handle(rop->output(0)->as()); } -void ComputeAtRootDomainMapBuilder::handle(TensorView* tv) { +void ComputeAtLogicalDomainMapBuilder::handle(TensorView* tv) { const TensorDomain* td = tv->domain(); const auto logical = TensorDomain::noReductions(td->logical()); for (auto id : logical) { @@ -1339,13 +1342,13 @@ void ComputeAtRootDomainMapBuilder::handle(TensorView* tv) { // Checks whether all consumers of a producer can be joined without // introducing unsupported mappings, i.e., requiring recomputations. -bool ComputeAtRootDomainMapBuilder::safeToMap(const DomainKeySet& domains) { +bool ComputeAtLogicalDomainMapBuilder::safeToMap(const DomainKeySet& domains) { if (domains.size() <= 1) { return true; } // Can't map if reduction output domains would be mapped - if (incompatible_domains_.isReductionOutputMapped(domains, root_map_) && + if (incompatible_domains_.isReductionOutputMapped(domains, logical_map_) && !map_through_reduction_) { return false; } @@ -1357,9 +1360,9 @@ bool ComputeAtRootDomainMapBuilder::safeToMap(const DomainKeySet& domains) { } namespace { -class ExactRootDomainMapBuilder : private IterVisitor { +class ExactLogicalDomainMapBuilder : private IterVisitor { public: - ExactRootDomainMapBuilder( + ExactLogicalDomainMapBuilder( Fusion* fusion, DisjointSets& eq_sets) : eq_sets_(eq_sets) { @@ -1373,7 +1376,7 @@ class ExactRootDomainMapBuilder : private IterVisitor { for (auto producer : ir_utils::filterByType(expr->inputs())) { for (auto consumer : ir_utils::filterByType(expr->outputs())) { - PairwiseRootDomainMap pwise_map(producer, consumer); + PairwiseLogicalDomainMap pwise_map(producer, consumer); pwise_map.mapBroadcast(false); const auto mappings = pwise_map.mapProducerToConsumer(); for (const auto& mapping : mappings) { @@ -1389,11 +1392,11 @@ class ExactRootDomainMapBuilder : private IterVisitor { } // namespace -ExactRootDomainMap::ExactRootDomainMap(Fusion* fusion) { - ExactRootDomainMapBuilder builder(fusion, eq_sets_); +ExactLogicalDomainMap::ExactLogicalDomainMap(Fusion* fusion) { + ExactLogicalDomainMapBuilder builder(fusion, eq_sets_); } -bool ExactRootDomainMap::areMapped( +bool ExactLogicalDomainMap::areMapped( const IterDomain* id_a, const IterDomain* id_b) const { // With expand going into a view operation there can be an instance where an @@ -1408,10 +1411,10 @@ bool ExactRootDomainMap::areMapped( return eq_sets_.strictAreMapped(id_a, id_b); } -std::unordered_map ExactRootDomainMap::map( +std::unordered_map ExactLogicalDomainMap::map( const TensorDomain* producer, const TensorDomain* consumer, - const std::unordered_set& root_dims_to_map, + const std::unordered_set& dims_to_map, bool producer_to_consumer) const { const auto& producer_logical = TensorDomain::noReductions(producer->logical()); @@ -1423,7 +1426,7 @@ std::unordered_map ExactRootDomainMap::map( std::unordered_map id_map; for (auto& from_id : from_ids) { - if (root_dims_to_map.find(from_id) == root_dims_to_map.end()) { + if (dims_to_map.find(from_id) == dims_to_map.end()) { continue; } for (const auto& to_id : to_ids) { @@ -1439,11 +1442,11 @@ std::unordered_map ExactRootDomainMap::map( return id_map; } -std::string ExactRootDomainMap::toString() const { +std::string ExactLogicalDomainMap::toString() const { return eq_sets_.toString(); } -const DisjointSets& ExactRootDomainMap::getMappedSets() +const DisjointSets& ExactLogicalDomainMap::getMappedSets() const { return eq_sets_; } diff --git a/csrc/root_domain_map.h b/csrc/logical_domain_map.h similarity index 84% rename from csrc/root_domain_map.h rename to csrc/logical_domain_map.h index 1fe9cfb285b..8a5b3883f8a 100644 --- a/csrc/root_domain_map.h +++ b/csrc/logical_domain_map.h @@ -16,19 +16,19 @@ namespace nvfuser { -//! Generic interface for mapping root domains of a producer-consumer pair. -class RootDomainMap : public PolymorphicBase { +//! Generic interface for mapping logical domains of a producer-consumer pair. +class LogicalDomainMap : public PolymorphicBase { public: //! Return a map from a producer TensorDomain to a consumer //! TensorDomain //! //! \param producer A producer TensorDomain //! \param consumer A consumer TensorDomain - //! \param root_dims_to_map Maps only producer root domains in this set + //! \param dims_to_map Maps only producer logical domains in this set std::unordered_map mapProducerToConsumer( const TensorDomain* producer, const TensorDomain* consumer, - const std::unordered_set& root_dims_to_map) const; + const std::unordered_set& dims_to_map) const; //! Return a map from a producer TensorDomain to a consumer //! TensorDomain @@ -44,11 +44,11 @@ class RootDomainMap : public PolymorphicBase { //! //! \param consumer A consumer TensorDomain //! \param producer A producer TensorDomain - //! \param root_dims_to_map Maps only consumer root domains in this set + //! \param dims_to_map Maps only consumer root domains in this set std::unordered_map mapConsumerToProducer( const TensorDomain* consumer, const TensorDomain* producer, - const std::unordered_set& root_dims_to_map) const; + const std::unordered_set& dims_to_map) const; //! Return a map from a consumer TensorDomain to a producer //! TensorDomain @@ -60,27 +60,27 @@ class RootDomainMap : public PolymorphicBase { const TensorDomain* producer) const; protected: - //! Return a map between root IterDomains of a producer-consumer + //! Return a map between logical IterDomains of a producer-consumer //! pair. //! //! \param producer A producer TensorDomain //! \param consumer A consumer TensorDomain - //! \param root_dims_to_map Maps only from IterDomains in this set + //! \param dims_to_map Maps only from IterDomains in this set //! \param producer_to_consumer Maps from producer to consumer if true virtual std::unordered_map map( const TensorDomain* producer, const TensorDomain* consumer, - const std::unordered_set& root_dims_to_map, + const std::unordered_set& dims_to_map, bool producer_to_consumer) const = 0; }; -//! Maps root domains of a producer-consumer pair. This class only +//! Maps logical domains of a producer-consumer pair. This class only //! looks at the given pair of TensorViews and does not take into //! consideration the constraints of the computeAt transformation, //! i.e., unable to compute the same tensors multiple times. This //! should not be used for transformations implementing computeAt, but //! should be valid otherwise. -class NVF_API PairwiseRootDomainMap : public RootDomainMap { +class NVF_API PairwiseLogicalDomainMap : public LogicalDomainMap { public: //! When require_same_extent is false, domains that may have //! different extents are also mapped. For example, IDs of lookup @@ -90,11 +90,11 @@ class NVF_API PairwiseRootDomainMap : public RootDomainMap { //! //! \param producer The producer tensor of a producer-consumer pair. //! \param consumer The consumer tensor of a producer-consumer pair. - explicit PairwiseRootDomainMap( + explicit PairwiseLogicalDomainMap( const TensorView* producer, const TensorView* consumer); - PairwiseRootDomainMap& mapBroadcast(bool b) { + PairwiseLogicalDomainMap& mapBroadcast(bool b) { map_broadcast_ = b; return *this; } @@ -102,17 +102,17 @@ class NVF_API PairwiseRootDomainMap : public RootDomainMap { //! If b is true: map symbolic domains with other IterDomains even if their //! extents don't match. If b is false (default): map symbolic domains with //! other IterDomains only if their extents match. - PairwiseRootDomainMap& mapSymbolic(bool b) { + PairwiseLogicalDomainMap& mapSymbolic(bool b) { map_symbolic_ = b; return *this; } - PairwiseRootDomainMap& mapDifferentExtents(bool b) { + PairwiseLogicalDomainMap& mapDifferentExtents(bool b) { map_different_extents_ = b; return *this; } - PairwiseRootDomainMap& mapIndexedDomains(bool b) { + PairwiseLogicalDomainMap& mapIndexedDomains(bool b) { map_indexed_domains_ = b; return *this; } @@ -127,20 +127,20 @@ class NVF_API PairwiseRootDomainMap : public RootDomainMap { std::string toString() const; - // Helper methods on top of RootDomainMap::mapProducerToConsumer and - // RootDomainMap::mapConsumerToProducer. This way, the caller doesn't have to - // specify the producer domain and the consumer domain, which is redundant and - // error-prone. + // Helper methods on top of LogicalDomainMap::mapProducerToConsumer and + // LogicalDomainMap::mapConsumerToProducer. This way, the caller doesn't have + // to specify the producer domain and the consumer domain, which is redundant + // and error-prone. std::unordered_map mapProducerToConsumer( - const std::unordered_set* root_dims_to_map = nullptr) const; + const std::unordered_set* dims_to_map = nullptr) const; std::unordered_map mapConsumerToProducer( - const std::unordered_set* root_dims_to_map = nullptr) const; + const std::unordered_set* dims_to_map = nullptr) const; protected: std::unordered_map map( const TensorDomain* producer, const TensorDomain* consumer, - const std::unordered_set& root_dims_to_map, + const std::unordered_set& dims_to_map, bool producer_to_consumer) const override; private: @@ -163,13 +163,13 @@ class NVF_API PairwiseRootDomainMap : public RootDomainMap { }; //! Represents an iteration domain of a TensorDomain. Only used for -//! root domain mapping. +//! logical domain mapping. //! //! Note that an IterDomain object may be reused //! across multiple TensorDomains, but an IterDomain in a //! TensorDomain may not be necessarily mappable to the same //! IterDomain used in a different TensorDomain. Thus, for the purpose -//! of root domain mapping, an iteration domain needs to be identified +//! of logical domain mapping, an iteration domain needs to be identified //! with an IterDomain and its TensorDomain. class DomainKey { public: @@ -216,7 +216,7 @@ using DomainKeySet = std::unordered_set; template using DomainKeyMap = std::unordered_map; -class ComputeAtRootDomainMap; +class ComputeAtLogicalDomainMap; //! A helper class to find all DomainKeys that are consumers of //! reduction outputs. Such consumer IterDomains may not be mapped to @@ -231,10 +231,10 @@ class UnmappableReductionDomains : private IterVisitor { //! reduction output domain to be mapped with a consumer domain of //! the redution. It needs to be avoided as computing consumers of //! reduction outputs within the corresponding reduction loop is not - //! possible. This routine is used to build root domain mappings. + //! possible. This routine is used to build logical domain mappings. bool isReductionOutputMapped( const DomainKeySet& consumer_domains, - const ComputeAtRootDomainMap& root_map) const; + const ComputeAtLogicalDomainMap& logical_map) const; std::string toString() const; @@ -254,7 +254,7 @@ class UnmappableReductionDomains : private IterVisitor { DomainKeyMap reduction_domain_inputs_; }; -//! Models root-domain mappings for computeAt +//! Models logical-domain mappings for computeAt //! //! Two iteration domains are mapped when computeAt of one iteration //! domain is possible at another iteration domain. Consider a simple @@ -267,8 +267,8 @@ class UnmappableReductionDomains : private IterVisitor { //! fail. Currently, the only use of this class is getMappableDims, //! which just grabs any domain that is mappable, which works no //! matter view is used or not. -class NVF_API ComputeAtRootDomainMap : public RootDomainMap { - friend class ComputeAtRootDomainMapBuilder; +class NVF_API ComputeAtLogicalDomainMap : public LogicalDomainMap { + friend class ComputeAtLogicalDomainMapBuilder; public: //! Builds a mapping table by analyzing the current @@ -300,7 +300,7 @@ class NVF_API ComputeAtRootDomainMap : public RootDomainMap { //! //! This is for the computeAt transformation, where TensorViews are //! updated with new TensorDomains. Since they keep using the same - //! root doamins, the root mapping remains valid but needs to + //! logical doamins, the logical mapping remains valid but needs to //! reflect the use of new TensorDomains as aliases of the existing //! ones. //! @@ -312,23 +312,22 @@ class NVF_API ComputeAtRootDomainMap : public RootDomainMap { //! //! Unlike the other map functions, two TensorDomains do not need to //! be a producer-consumer pair. Since they may not be a - //! producer-consumer pair, this function requires proper root - //! domains, which may be root or logical domains. Also, no error - //! check is done as we do not assume producer-consumer - //! relationship. + //! producer-consumer pair, this function requires proper domains, which may + //! be root or logical domains. Also, no error check is done as we do not + //! assume producer-consumer relationship. //! //! Note that an exception is thrown when a domain is found to be //! mapped to multiple domains, which can happen with views. //! //! \param from_td A TensorDomain from which a map is created - //! \param from_root A root domain of from_td + //! \param from_dom A root/logical domain of from_td //! \param to_td A TensorDomain to which a map is created - //! \param to_root A root domain of to_td + //! \param to_dom A root/logical domain of to_td std::unordered_map mapBestEffort( const TensorDomain* from_td, - const std::vector& from_root, + const std::vector& from_dom, const TensorDomain* to_td, - const std::vector& to_root) const; + const std::vector& to_dom) const; // Returns an unordered set of all iter domains in producer and consumer that // can map to eachother @@ -370,17 +369,17 @@ class NVF_API ComputeAtRootDomainMap : public RootDomainMap { const TensorDomain* td, const IterDomain* id); - //! Return a map between root IterDomains of a producer-consumer + //! Return a map between logical IterDomains of a producer-consumer //! pair. //! //! \param producer A producer TensorDomain //! \param consumer A consumer TensorDomain - //! \param root_dims_to_map Maps only from IterDomains in this set + //! \param dims_to_map Maps only from IterDomains in this set //! \param producer_to_consumer Maps from producer to consumer if true std::unordered_map map( const TensorDomain* producer, const TensorDomain* consumer, - const std::unordered_set& root_dims_to_map, + const std::unordered_set& dims_to_map, bool producer_to_consumer) const override; private: @@ -402,14 +401,14 @@ class NVF_API ComputeAtRootDomainMap : public RootDomainMap { std::unordered_set window_axes_; }; -//! Create a DisjointSets of root IterDomains by traversing the +//! Create a DisjointSets of logical IterDomains by traversing the //! current fusion entirely. IterDomains that can be mapped each //! other with computeAt are grouped into the same subset in the //! DisjointSets. -class ComputeAtRootDomainMapBuilder : private BackwardVisitor { +class ComputeAtLogicalDomainMapBuilder : private BackwardVisitor { public: - explicit ComputeAtRootDomainMapBuilder( - ComputeAtRootDomainMap& root_map, + explicit ComputeAtLogicalDomainMapBuilder( + ComputeAtLogicalDomainMap& logical_map, bool map_through_reduction = false); private: @@ -426,7 +425,7 @@ class ComputeAtRootDomainMapBuilder : private BackwardVisitor { bool isInvalid(const DomainKeySet& domains) const; //! Track a pair of producer-consumer domains as potentially mappable. Inserts - //! entries into pending_map_, but does not add anything into the root_map_ + //! entries into pending_map_, but does not add anything into the logical_map_ //! (added when handle is called on a TensorView). Maybe mapped will, however, //! immediately propagate broadcast iter domains. void setMaybeMapped( @@ -534,7 +533,7 @@ class ComputeAtRootDomainMapBuilder : private BackwardVisitor { bool safeToMap(const DomainKeySet& domains); private: - ComputeAtRootDomainMap& root_map_; + ComputeAtLogicalDomainMap& logical_map_; //! Keep track of what we want to try and map DomainKeyMap pending_map_; std::unordered_set visited_; @@ -548,11 +547,11 @@ class ComputeAtRootDomainMapBuilder : private BackwardVisitor { bool map_through_reduction_ = false; }; -//! Maps root domains of an entire fusion. Does not map broadcast +//! Maps logical domains of an entire fusion. Does not map broadcast //! domains with non-broadcast domains. -class NVF_API ExactRootDomainMap : public RootDomainMap { +class NVF_API ExactLogicalDomainMap : public LogicalDomainMap { public: - ExactRootDomainMap(Fusion* fusion); + ExactLogicalDomainMap(Fusion* fusion); bool areMapped(const IterDomain* id_a, const IterDomain* id_b) const; @@ -564,7 +563,7 @@ class NVF_API ExactRootDomainMap : public RootDomainMap { std::unordered_map map( const TensorDomain* producer, const TensorDomain* consumer, - const std::unordered_set& root_dims_to_map, + const std::unordered_set& dims_to_map, bool producer_to_consumer) const override; private: diff --git a/csrc/maxinfo_propagator.cpp b/csrc/maxinfo_propagator.cpp index bdfbfaac756..29385e54eff 100644 --- a/csrc/maxinfo_propagator.cpp +++ b/csrc/maxinfo_propagator.cpp @@ -5,8 +5,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on +#include #include -#include namespace nvfuser { @@ -159,11 +159,11 @@ void MaxInfoSpanningTree::traverse(Propagator* propagator) { propagator->tearDown(); } -MaxRootDomainInfoSpanningTree::DomainInfo::operator bool() const { +MaxLogicalDomainInfoSpanningTree::DomainInfo::operator bool() const { return !info.empty(); } -bool MaxRootDomainInfoSpanningTree::DomainInfo::operator<( +bool MaxLogicalDomainInfoSpanningTree::DomainInfo::operator<( const Information& r) const { auto rr = dynamic_cast(r); if (info.size() != rr.info.size()) { @@ -234,11 +234,11 @@ std::unordered_set mapLogicalToRoot( // to first map it to the logical domain of the producer, then we can map it to // the consumer's root domain. The computed info will be represented by root // domain as root domain contains the raw information. -std::shared_ptr MaxRootDomainInfoSpanningTree:: - computeInfoP2C( - TensorView* from, - TensorView* to, - std::shared_ptr from_info) { +std::shared_ptr +MaxLogicalDomainInfoSpanningTree::computeInfoP2C( + TensorView* from, + TensorView* to, + std::shared_ptr from_info) { DomainInfo result; TensorView* producer = from; @@ -246,7 +246,7 @@ std::shared_ptr MaxRootDomainInfoSpanningTree: const auto& producer_root_id_info = std::dynamic_pointer_cast(from_info)->info; - auto pairwise_map = PairwiseRootDomainMap(producer, consumer); + auto pairwise_map = PairwiseLogicalDomainMap(producer, consumer); auto p2c_map = pairwise_map.mapProducerToConsumer(); for (auto& info : producer_root_id_info) { @@ -290,11 +290,11 @@ std::shared_ptr MaxRootDomainInfoSpanningTree: // need to first map it to the root domain of the consumer, then we can map it // to the producer's logical domain. The computed info will be represented by // logical domain as logical domain contains the raw information. -std::shared_ptr MaxRootDomainInfoSpanningTree:: - computeInfoC2P( - TensorView* from, - TensorView* to, - std::shared_ptr from_info) { +std::shared_ptr +MaxLogicalDomainInfoSpanningTree::computeInfoC2P( + TensorView* from, + TensorView* to, + std::shared_ptr from_info) { DomainInfo result; TensorView* producer = to; @@ -302,7 +302,7 @@ std::shared_ptr MaxRootDomainInfoSpanningTree: const auto& consumer_root_id_info = std::dynamic_pointer_cast(from_info)->info; - auto pairwise_map = PairwiseRootDomainMap(producer, consumer); + auto pairwise_map = PairwiseLogicalDomainMap(producer, consumer); auto c2p_map = pairwise_map.mapConsumerToProducer(); for (auto& info : consumer_root_id_info) { @@ -357,8 +357,8 @@ std::shared_ptr MaxRootDomainInfoSpanningTree: return std::make_shared(std::move(result)); } -std::shared_ptr -MaxRootDomainInfoSpanningTree::getReferenceIDInfo(TensorView* tv) { +std::shared_ptr +MaxLogicalDomainInfoSpanningTree::getReferenceIDInfo(TensorView* tv) { DomainInfo result; const auto& root_domain = tv->getMaybeRootDomain(); result.info.reserve(root_domain.size()); @@ -368,8 +368,8 @@ MaxRootDomainInfoSpanningTree::getReferenceIDInfo(TensorView* tv) { return std::make_shared(std::move(result)); } -std::shared_ptr -MaxRootDomainInfoSpanningTree::getReferenceIDInfo( +std::shared_ptr +MaxLogicalDomainInfoSpanningTree::getReferenceIDInfo( TensorView* tv, int64_t loop_pos) { if (loop_pos < 0) { @@ -377,7 +377,7 @@ MaxRootDomainInfoSpanningTree::getReferenceIDInfo( } NVF_CHECK( loop_pos >= 0 && loop_pos <= int64_t(tv->nDims()), - "MaxRootDomainInfoSpanningTree called on an loop_pos outside valid range."); + "MaxLogicalDomainInfoSpanningTree called on an loop_pos outside valid range."); DomainInfo result; const auto& logical_domain = tv->getLogicalDomain(); const auto& loop_domain = tv->getLoopDomain(); @@ -403,11 +403,11 @@ MaxRootDomainInfoSpanningTree::getReferenceIDInfo( // replay state, so sibling info is always identical by definition, except that // we need to replace the IDs stored in the info with the corresponding IDs in // `to`. -std::shared_ptr MaxRootDomainInfoSpanningTree:: - computeInfoSibling( - TensorView* from, - TensorView* to, - std::shared_ptr from_info) { +std::shared_ptr +MaxLogicalDomainInfoSpanningTree::computeInfoSibling( + TensorView* from, + TensorView* to, + std::shared_ptr from_info) { DomainInfo result; const auto& from_root_id_info = diff --git a/csrc/maxinfo_propagator.h b/csrc/maxinfo_propagator.h index 9b3a5ca30ac..47afe1dada4 100644 --- a/csrc/maxinfo_propagator.h +++ b/csrc/maxinfo_propagator.h @@ -155,7 +155,7 @@ class MaxInfoSpanningTree { virtual ~MaxInfoSpanningTree() = default; }; -// MaxRootDomainInfoSpanningTree is a subclass of MaxInfoSpanningTree which +// MaxLogicalDomainInfoSpanningTree is a subclass of MaxInfoSpanningTree which // generates the maximum spanning tree that perserves the most amount of root // domain information from the reference tensor. //* @@ -164,7 +164,7 @@ class MaxInfoSpanningTree { // level. This information is stored as a vector of `IDInfo`, where each // item in the vector corresponds to one ID in the reference tensor's root // domain. -class NVF_API MaxRootDomainInfoSpanningTree : public MaxInfoSpanningTree { +class NVF_API MaxLogicalDomainInfoSpanningTree : public MaxInfoSpanningTree { protected: // This is a struct storing how the information about a root ID in the // starting tensor is preserved during path-finding. If during path-finding, @@ -235,23 +235,23 @@ class NVF_API MaxRootDomainInfoSpanningTree : public MaxInfoSpanningTree { int64_t loop_pos); public: - MaxRootDomainInfoSpanningTree( + MaxLogicalDomainInfoSpanningTree( TensorView* reference, std::shared_ptr reference_info, Selector* selector = nullptr) : MaxInfoSpanningTree(reference, reference_info, selector) {} - MaxRootDomainInfoSpanningTree( + MaxLogicalDomainInfoSpanningTree( TensorView* reference, Selector* selector = nullptr) - : MaxRootDomainInfoSpanningTree( + : MaxLogicalDomainInfoSpanningTree( reference, getReferenceIDInfo(reference), selector) {} - MaxRootDomainInfoSpanningTree( + MaxLogicalDomainInfoSpanningTree( TensorView* reference, int64_t loop_pos, Selector* selector = nullptr) - : MaxRootDomainInfoSpanningTree( + : MaxLogicalDomainInfoSpanningTree( reference, getReferenceIDInfo(reference, loop_pos), selector) {} diff --git a/csrc/multidevice/lower_communication.cpp b/csrc/multidevice/lower_communication.cpp index 11ba1d0a726..957d81e507d 100644 --- a/csrc/multidevice/lower_communication.cpp +++ b/csrc/multidevice/lower_communication.cpp @@ -312,7 +312,8 @@ bool isLowerableToCommunication(Expr* expr) { return false; } // We check whether the reduced axis is sharded on the input - const auto c2p_map = PairwiseRootDomainMap(in, out).mapConsumerToProducer(); + const auto c2p_map = + PairwiseLogicalDomainMap(in, out).mapConsumerToProducer(); auto c2p_map_it = c2p_map.find(reduction_axis.at(0)); return c2p_map_it != c2p_map.end() && c2p_map_it->second->isDeviceDim(); } else { diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 65d1935d62a..855f8ff48e4 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -10,10 +10,10 @@ #include #include #include +#include #include #include #include -#include #include #include @@ -79,7 +79,7 @@ std::pair, std::vector> getShardingChanges std::vector shard_additions; std::vector shard_deletions; - auto rootmap = PairwiseRootDomainMap(input, output).mapBroadcast(false); + auto rootmap = PairwiseLogicalDomainMap(input, output).mapBroadcast(false); const auto c2p_map = rootmap.mapConsumerToProducer(); for (IterDomain* out_root : output->getMaybeRootDomain()) { @@ -154,7 +154,7 @@ bool haveDifferentShardings(TensorView* producer, TensorView* consumer) { // over producer's iterdomain and compare sharding type with consumer's // iterdomain const auto p2c_map = - PairwiseRootDomainMap(producer, consumer).mapProducerToConsumer(); + PairwiseLogicalDomainMap(producer, consumer).mapProducerToConsumer(); for (auto p_id : TensorDomain::noReductions(producer->getLogicalDomain())) { auto p2c_map_it = p2c_map.find(p_id); NVF_ERROR( diff --git a/csrc/predicate_compute.h b/csrc/predicate_compute.h index 8c277d9bca7..2eb25dd9057 100644 --- a/csrc/predicate_compute.h +++ b/csrc/predicate_compute.h @@ -12,7 +12,7 @@ #include #include #include -#include +#include namespace nvfuser { diff --git a/csrc/preseg_passes/allocation_order_inference.cpp b/csrc/preseg_passes/allocation_order_inference.cpp index 6e52b373c2b..888d50faf2c 100644 --- a/csrc/preseg_passes/allocation_order_inference.cpp +++ b/csrc/preseg_passes/allocation_order_inference.cpp @@ -9,8 +9,8 @@ #include #include #include +#include #include -#include namespace nvfuser::preseg_passes { diff --git a/csrc/preseg_passes/exact_mapped_extent_substitution.cpp b/csrc/preseg_passes/exact_mapped_extent_substitution.cpp index 43bd92ecb0a..666076efc90 100644 --- a/csrc/preseg_passes/exact_mapped_extent_substitution.cpp +++ b/csrc/preseg_passes/exact_mapped_extent_substitution.cpp @@ -7,9 +7,9 @@ // clang-format on #include #include +#include #include #include -#include namespace nvfuser::preseg_passes { @@ -30,7 +30,7 @@ void exactMappedExtentSubstitution(Fusion* fusion) { // map non-const extents to const extents std::unordered_map replacement_map; - const auto mapped_sets = ExactRootDomainMap(fusion).getMappedSets(); + const auto mapped_sets = ExactLogicalDomainMap(fusion).getMappedSets(); // Loop over each exact root domain set for (const auto& set_ptr : mapped_sets.disjointSets()) { // (1) pick a const extent @@ -77,16 +77,16 @@ void exactMappedExtentSubstitution(Fusion* fusion) { void ExactMappedExtentSubstitutionPass::runPass(Fusion* fusion) { if (isDebugDumpEnabled(DebugDumpOption::PreSegmenterLogging)) { - debug() << "ExactRootDomainMap before " << name() << ":" << std::endl; - const auto mapped_sets = ExactRootDomainMap(fusion).getMappedSets(); + debug() << "ExactLogicalDomainMap before " << name() << ":" << std::endl; + const auto mapped_sets = ExactLogicalDomainMap(fusion).getMappedSets(); debug() << mapped_sets.toString() << std::endl; } exactMappedExtentSubstitution(fusion); if (isDebugDumpEnabled(DebugDumpOption::PreSegmenterLogging)) { - debug() << "ExactRootDomainMap after " << name() << ":" << std::endl; - const auto mapped_sets = ExactRootDomainMap(fusion).getMappedSets(); + debug() << "ExactLogicalDomainMap after " << name() << ":" << std::endl; + const auto mapped_sets = ExactLogicalDomainMap(fusion).getMappedSets(); debug() << mapped_sets.toString() << std::endl; } } diff --git a/csrc/preseg_passes/remove_bcast_squeeze.cpp b/csrc/preseg_passes/remove_bcast_squeeze.cpp index e894c7d7481..2859816e986 100644 --- a/csrc/preseg_passes/remove_bcast_squeeze.cpp +++ b/csrc/preseg_passes/remove_bcast_squeeze.cpp @@ -7,9 +7,9 @@ // clang-format on #include #include +#include #include #include -#include namespace nvfuser::preseg_passes { diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index b2cbf149002..ea92abde6f4 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -3129,7 +3129,7 @@ void initNvFuserPythonBindings(PyObject* module) { if (selected_tensors.empty()) { // Propagate scheduler transformations on reference TensorView to the // rest of the fusion. - MaxRootDomainInfoSpanningTree(reference_tv).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(reference_tv).traverse(&propagator); } else { // Propagate scheduler transformations on reference TensorView to the // subset of the fusion. @@ -3144,7 +3144,7 @@ void initNvFuserPythonBindings(PyObject* module) { }); SetSelector selector( {selected_tv_set.begin(), selected_tv_set.end()}); - MaxRootDomainInfoSpanningTree(reference_tv, &selector) + MaxLogicalDomainInfoSpanningTree(reference_tv, &selector) .traverse(&propagator); } }, diff --git a/csrc/scheduler/cache_policy_refiner.cpp b/csrc/scheduler/cache_policy_refiner.cpp index b7f87de9417..7e6eab7eb18 100644 --- a/csrc/scheduler/cache_policy_refiner.cpp +++ b/csrc/scheduler/cache_policy_refiner.cpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include #include @@ -37,16 +37,16 @@ bool pointwiseExpands(const Expr* expr, const TensorView* in_tv) { } const auto* out_tv = out->as(); - auto root_domain_map = PairwiseRootDomainMap(in_tv, out_tv) - .mapBroadcast(true) - .mapProducerToConsumer(); + auto logical_domain_map = PairwiseLogicalDomainMap(in_tv, out_tv) + .mapBroadcast(true) + .mapProducerToConsumer(); return std::find_if( - root_domain_map.begin(), - root_domain_map.end(), + logical_domain_map.begin(), + logical_domain_map.end(), [](const auto& mapping) { return mapping.first->isBroadcast() && !mapping.second->isBroadcast(); - }) != root_domain_map.end(); + }) != logical_domain_map.end(); } bool isLoadGlobalToLocal(const Expr* expr) { diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 3a3bfcf2a07..d71afd9ebe9 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -11,9 +11,9 @@ #include #include #include +#include #include #include -#include #include #include #include @@ -764,7 +764,7 @@ bool isLdMatrixTranspose(const LoadStoreOp* ldst) { // This gives us the ID in the consumer root domain. // We'll later map this ID to one in the producer. - const PairwiseRootDomainMap map_across_ldst(producer, consumer); + const PairwiseLogicalDomainMap map_across_ldst(producer, consumer); const auto c2p_map = map_across_ldst.mapConsumerToProducer(); const auto id_in_proc_rfactor = c2p_map.at(corresponding_id_in_consumer_root); diff --git a/csrc/scheduler/normalization_utils.cpp b/csrc/scheduler/normalization_utils.cpp index 31eb497a244..c698907191d 100644 --- a/csrc/scheduler/normalization_utils.cpp +++ b/csrc/scheduler/normalization_utils.cpp @@ -1119,7 +1119,7 @@ bool checkReductionPattern( // Ensure that the reduction operations share the same axes in their root // domains FusionGuard fg(fusion); - ComputeAtRootDomainMap root_map; + ComputeAtLogicalDomainMap root_map; root_map.build(true); // Helper function to check the pattern equivalence for a list of diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index b78d4efe746..fbb9504cbb0 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -851,7 +851,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { } TransformPropagator propagator(reference_tv); - MaxRootDomainInfoSpanningTree spanning_tree(reference_tv); + MaxLogicalDomainInfoSpanningTree spanning_tree(reference_tv); spanning_tree.traverse(&propagator); scheduler_utils::parallelizeAllLike(reference_tv); diff --git a/csrc/scheduler/reduction.cpp b/csrc/scheduler/reduction.cpp index 672f2c77042..a4ad9482f65 100644 --- a/csrc/scheduler/reduction.cpp +++ b/csrc/scheduler/reduction.cpp @@ -953,7 +953,7 @@ bool ReductionScheduler::canScheduleCompileTime(Fusion* fusion) { // Use root domain map to check the reduction ops have the same axes FusionGuard fg(fusion); - ComputeAtRootDomainMap root_map; + ComputeAtLogicalDomainMap root_map; root_map.build(true); // red_ops.size()>1 checked before diff --git a/csrc/scheduler/reduction_utils.cpp b/csrc/scheduler/reduction_utils.cpp index d6df9a2627c..dce70fc2a5f 100644 --- a/csrc/scheduler/reduction_utils.cpp +++ b/csrc/scheduler/reduction_utils.cpp @@ -366,7 +366,7 @@ void propagateTransformation( const std::unordered_set& boundaryNodesSet) { InternalBoundarySelector ibSelector(boundaryNodesSet); TransformPropagator propagator(reference_tv); - MaxRootDomainInfoSpanningTree(reference_tv, &ibSelector) + MaxLogicalDomainInfoSpanningTree(reference_tv, &ibSelector) .traverse(&propagator); } diff --git a/csrc/scheduler/registry_utils.cpp b/csrc/scheduler/registry_utils.cpp index dc11d567062..97b032463c4 100644 --- a/csrc/scheduler/registry_utils.cpp +++ b/csrc/scheduler/registry_utils.cpp @@ -7,7 +7,7 @@ // clang-format on #include #include -#include +#include #include #include #include @@ -19,7 +19,7 @@ namespace registry_utils { bool checkPatternEquivalence( TensorView* out_tv0, TensorView* out_tv1, - const ComputeAtRootDomainMap& root_map) { + const ComputeAtLogicalDomainMap& root_map) { const auto& out_root0 = out_tv0->getMaybeRootDomain(); const auto& out_root1 = out_tv1->getMaybeRootDomain(); const auto domain0 = out_tv0->domain(); @@ -616,7 +616,7 @@ bool SchedulerTopologyChecker::hasNonNormalizePostReductionBCast( continue; } - auto forward_pairwise_root_map = PairwiseRootDomainMap( + auto forward_pairwise_root_map = PairwiseLogicalDomainMap( forward_running_producer, forward_running_consumer); auto forward_p2c_root_map = forward_pairwise_root_map.mapProducerToConsumer(); @@ -671,7 +671,7 @@ bool SchedulerTopologyChecker::hasNonNormalizePostReductionBCast( // see TakeAlongAxisIntermediateTensorNormalization1_CUDA bool at_leat_one_id_mapped = false; auto forward_pairwise_root_map = - PairwiseRootDomainMap(tmp_producer, tmp_consumer); + PairwiseLogicalDomainMap(tmp_producer, tmp_consumer); auto forward_p2c_root_map = forward_pairwise_root_map.mapProducerToConsumer(); for (size_t entry_i = ids_to_resolve.size(); entry_i > 0; entry_i--) { @@ -736,7 +736,7 @@ bool SchedulerTopologyChecker::hasNonNormalizePostReductionBCast( std::vector running_resolved_ids; - auto backward_pairwise_root_map = PairwiseRootDomainMap( + auto backward_pairwise_root_map = PairwiseLogicalDomainMap( backward_running_producer, backward_running_consumer); auto backward_c2p_root_map = @@ -814,7 +814,7 @@ bool SchedulerTopologyChecker::hasPostReductionBCast(Fusion* fusion) { tv_dep_chain.pop_front(); auto pairwise_root_map = - PairwiseRootDomainMap(running_producer, running_consumer); + PairwiseLogicalDomainMap(running_producer, running_consumer); auto p2c_root_map = pairwise_root_map.mapProducerToConsumer(); // Check if any TensorViews have a resolved broadcast @@ -959,7 +959,7 @@ bool SchedulerTopologyChecker::hasGatherToBroadcastBeforeReduction( // If the broadcast IDs are mapped with the reduction TVs, the // reduction scheduler should be able to schedule the gather // output TVs. This mapping can be PERMISSIVE as the broadcast IDs - // may be concretized. ExactRootDomainMap may be enough as + // may be concretized. ExactLogicalDomainMap may be enough as // broadcasts should not be removed by rfactor exprs. // Consider reusing a CA map diff --git a/csrc/scheduler/registry_utils.h b/csrc/scheduler/registry_utils.h index 62b80c39900..85c5ed51509 100644 --- a/csrc/scheduler/registry_utils.h +++ b/csrc/scheduler/registry_utils.h @@ -11,7 +11,7 @@ namespace nvfuser { class TensorView; -class ComputeAtRootDomainMap; +class ComputeAtLogicalDomainMap; class ComputeAtMap; class ExpressionEvaluator; class KernelArgumentHolder; @@ -21,7 +21,7 @@ namespace registry_utils { bool checkPatternEquivalence( TensorView* out_tv0, TensorView* out_tv1, - const ComputeAtRootDomainMap& root_map); + const ComputeAtLogicalDomainMap& root_map); // Reusing some code from lowering specifically in lower_trivial_broadcast.cpp // ConcretizedBroadcastDomains::maybeNonUniquelyConcretized this checks if diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index fd585b60a2d..5c211eb481a 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -800,7 +800,7 @@ std::string getTransposeRuntimeRejectReason( // doing dry-run on the first traverse. Since the following twos are only // used for scheduling tiling, which is not going to cause issue, since we // are only tiling on the merged virtual innermost dimensions. - MaxRootDomainInfoSpanningTree entire_dag(reference1); + MaxLogicalDomainInfoSpanningTree entire_dag(reference1); entire_dag.traverse(&propagator); if (propagator.shouldReject()) { return "transpose scheduler could potentially trigger incoherent transform propagation"; @@ -1240,7 +1240,7 @@ void scheduleTranspose(Fusion* fusion, TransposeParams params) { // Propagate transformations so far to the entire DAG TransformPropagator propagator(reference1); - MaxRootDomainInfoSpanningTree entire_dag(reference1); + MaxLogicalDomainInfoSpanningTree entire_dag(reference1); entire_dag.traverse(&propagator); scheduler_utils::parallelizeAllLike(reference1); @@ -1275,7 +1275,7 @@ void scheduleTranspose(Fusion* fusion, TransposeParams params) { fusion, {grouped_inputs_outputs[0].begin(), grouped_inputs_outputs[0].end()}); SetSelector selector({all_tvs_except1.begin(), all_tvs_except1.end()}); - MaxRootDomainInfoSpanningTree entire_dag_except1(reference2, &selector); + MaxLogicalDomainInfoSpanningTree entire_dag_except1(reference2, &selector); TransformPropagator propagator(reference2); entire_dag_except1.traverse(&propagator); } @@ -1363,7 +1363,7 @@ void scheduleTranspose(Fusion* fusion, TransposeParams params) { auto all_tvs_except2 = ir_utils::allTvsExcept(fusion, group2_and_cached_inputs); SetSelector selector({all_tvs_except2.begin(), all_tvs_except2.end()}); - MaxRootDomainInfoSpanningTree entire_dag_except_outputs( + MaxLogicalDomainInfoSpanningTree entire_dag_except_outputs( reference1, &selector); TransformPropagator propagator(reference1); entire_dag_except_outputs.traverse(&propagator); diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index c6f57873b51..d77555eb7d2 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -13,9 +13,9 @@ #include #include #include +#include #include #include -#include #include #include #include @@ -560,7 +560,7 @@ PersistentBufferInfo persistentBuffers(Fusion* fusion) { FusionGuard fg(fusion); PersistentBufferInfo persistent_buffer_info; - ComputeAtRootDomainMap root_map; + ComputeAtLogicalDomainMap root_map; root_map.build(); auto all_tvs = ir_utils::allTvs(fusion); @@ -1361,7 +1361,7 @@ void FindAllMappedDims::setUp() { void FindAllMappedDims::propagateC2P(TensorView* from, TensorView* to) { auto from_id = mapped_root_ids_.at(from); - PairwiseRootDomainMap root_map(to, from); + PairwiseLogicalDomainMap root_map(to, from); auto c2p_map = root_map.mapConsumerToProducer(); auto p_it = c2p_map.find(from_id); if (p_it != c2p_map.end()) { @@ -1376,7 +1376,7 @@ void FindAllMappedDims::propagateC2P(TensorView* from, TensorView* to) { void FindAllMappedDims::propagateP2C(TensorView* from, TensorView* to) { auto from_id = mapped_logical_ids_.at(from); - PairwiseRootDomainMap root_map(from, to); + PairwiseLogicalDomainMap root_map(from, to); auto p2c_map = root_map.mapProducerToConsumer(); auto c_it = p2c_map.find(from_id); if (c_it != p2c_map.end()) { @@ -1491,7 +1491,7 @@ std::vector getInputsOutputsWithInnerDim( FindAllMappedDims all_mapped_root_dims( reference_tv, inner_most_id, inner_only, vectorize_pass); - MaxRootDomainInfoSpanningTree tree(reference_tv); + MaxLogicalDomainInfoSpanningTree tree(reference_tv); tree.traverse(&all_mapped_root_dims); auto vectorizable_dims = all_mapped_root_dims.get(); @@ -1754,7 +1754,7 @@ BroadcastMultipleInformation getBroadcastMultiples( //! Propagate current transformations on from_tv to all graphs void transformPropagateToAllFrom(TensorView* from_tv, int64_t pos) { TransformPropagator propagator(from_tv, pos); - MaxRootDomainInfoSpanningTree(from_tv, nullptr).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(from_tv, nullptr).traverse(&propagator); } namespace { @@ -1897,7 +1897,7 @@ void BoundedDirectionalTransformPropagator::propagate( // Run transform propagation using the custom selector. SetSelector selector(included_tvs); TransformPropagator propagator(from_tv, pos); - MaxRootDomainInfoSpanningTree(from_tv, &selector).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(from_tv, &selector).traverse(&propagator); // Propagate parallel type if requested by option parameters. if (options.propagate_parallel_type) { @@ -2412,7 +2412,7 @@ void promoteProducerMemoryTypes( auto c2p_exact_map = BestEffortReplay( producer->getLoopDomain(), consumer->getLoopDomain(), - PairwiseRootDomainMap(producer, consumer) + PairwiseLogicalDomainMap(producer, consumer) .mapBroadcast(false) .mapConsumerToProducer()) .getReplay(); diff --git a/csrc/scheduler/vectorize_helper.cpp b/csrc/scheduler/vectorize_helper.cpp index 770197833d8..473432f6e5c 100644 --- a/csrc/scheduler/vectorize_helper.cpp +++ b/csrc/scheduler/vectorize_helper.cpp @@ -463,7 +463,7 @@ ContiguousInnerDimensionsMapper::computeInfoC2P( // resolved broadcast iterdomain is `i2`/`b2`, which would give clear_pos=1. // So we'll skip all from_ids with index < clear_pos. see issue: // https://github.com/NVIDIA/Fuser/issues/1567#issuecomment-1894605385 - PairwiseRootDomainMap root_map(to, from); + PairwiseLogicalDomainMap root_map(to, from); auto c2p_map = root_map.mapConsumerToProducer(); // Id's in consumer to clear from the mapped set due to broadcast @@ -531,7 +531,7 @@ ContiguousInnerDimensionsMapper::computeInfoP2C( // T3[i1, i2] = T2 // Then i1 and i2 are contiguous in both T0 and T3, but due to the sum on T1 // we will have removed i1. - PairwiseRootDomainMap root_map(from, to); + PairwiseLogicalDomainMap root_map(from, to); auto p2c_map = root_map.mapProducerToConsumer(); std::vector consumer_root_ids; @@ -934,8 +934,9 @@ int64_t getVectorizationBreakPointOfReductionProducer( return break_point; } - const auto c2p = PairwiseRootDomainMap(reduction_producer, reduction_consumer) - .mapConsumerToProducer(); + const auto c2p = + PairwiseLogicalDomainMap(reduction_producer, reduction_consumer) + .mapConsumerToProducer(); // Grab all the corresponding producer IDs that are mapped with the // innermost consumer IDs diff --git a/csrc/tensor_view.cpp b/csrc/tensor_view.cpp index 06ce572d1bf..33b3140d1d9 100644 --- a/csrc/tensor_view.cpp +++ b/csrc/tensor_view.cpp @@ -218,7 +218,7 @@ int64_t getConsumerPosAlignedToProducerCA( auto disjoint_sets = BestEffortReplay::replayPasC( - producer, consumer, -1, PairwiseRootDomainMap(producer, consumer)) + producer, consumer, -1, PairwiseLogicalDomainMap(producer, consumer)) .getIterDomainEquivalence(); // Find the innermost position of consumer that has diff --git a/csrc/transform_iter.cpp b/csrc/transform_iter.cpp index 4bfc39d691f..fb542428eb0 100644 --- a/csrc/transform_iter.cpp +++ b/csrc/transform_iter.cpp @@ -7,7 +7,7 @@ // clang-format on #include #include -#include +#include #include #include @@ -954,7 +954,7 @@ BestEffortReplay BestEffortReplay::replayCasP( const TensorView* consumer, const TensorView* producer, int64_t producer_compute_at_axis, - const RootDomainMap& root_map, + const LogicalDomainMap& root_map, bool skip_consumer_swizzle, bool skip_producer_swizzle, bool skip_resize) { @@ -1021,7 +1021,7 @@ BestEffortReplay BestEffortReplay::replayPasC( const TensorView* producer, const TensorView* consumer, int64_t consumer_compute_at_axis, - const RootDomainMap& root_map, + const LogicalDomainMap& root_map, bool skip_producer_swizzle, bool skip_consumer_swizzle, bool skip_resize) { diff --git a/csrc/transform_iter.h b/csrc/transform_iter.h index 242e73a9cc2..60cb0467be0 100644 --- a/csrc/transform_iter.h +++ b/csrc/transform_iter.h @@ -18,7 +18,7 @@ namespace nvfuser { -class RootDomainMap; +class LogicalDomainMap; namespace { @@ -489,7 +489,7 @@ class BestEffortReplay { const TensorView* consumer, const TensorView* producer, int64_t producer_compute_at_axis, - const RootDomainMap& root_map, + const LogicalDomainMap& root_map, bool skip_consumer_swizzle = true, bool skip_producer_swizzle = true, bool skip_resize = true); @@ -502,7 +502,7 @@ class BestEffortReplay { const TensorView* producer, const TensorView* consumer, int64_t consumer_compute_at_axis, - const RootDomainMap& root_map, + const LogicalDomainMap& root_map, bool skip_producer_swizzle = true, bool skip_consumer_swizzle = true, bool skip_resize = true); diff --git a/csrc/transform_replay.cpp b/csrc/transform_replay.cpp index 99ad8eeca3b..95ccfa570c2 100644 --- a/csrc/transform_replay.cpp +++ b/csrc/transform_replay.cpp @@ -16,10 +16,10 @@ #include #include #include +#include #include #include #include -#include #include #include @@ -292,7 +292,7 @@ std::pair TransformReplay::replayPasC( const TensorView* producer, const TensorView* consumer, int64_t consumer_pos, - const RootDomainMap& root_map, + const LogicalDomainMap& root_map, TransformReplayOptions opt) { FUSER_PERF_SCOPE("TransformReplay::replayPasC"); if (producer == consumer) { @@ -525,7 +525,7 @@ std::pair TransformReplay::replayCasP( const TensorView* consumer, const TensorView* producer, int64_t producer_pos, - const RootDomainMap& root_map, + const LogicalDomainMap& root_map, TransformReplayOptions opt) { FUSER_PERF_SCOPE("TransformReplay::replayCasP"); @@ -800,7 +800,7 @@ std::pair TransformReplay::replayPasC( int64_t compute_at_axis, TransformReplayOptions opt) { // Use the pairwise root map as a default mapper - PairwiseRootDomainMap root_map(producer, consumer); + PairwiseLogicalDomainMap root_map(producer, consumer); // Allow replay through indexing exprs root_map.mapIndexedDomains(true); return replayPasC(producer, consumer, compute_at_axis, root_map, opt); @@ -812,7 +812,7 @@ std::pair TransformReplay::replayCasP( int64_t compute_at_axis, TransformReplayOptions opt) { // Use the pairwise root map as a default mapper - PairwiseRootDomainMap root_map(producer, consumer); + PairwiseLogicalDomainMap root_map(producer, consumer); // Allow replay through indexing exprs root_map.mapIndexedDomains(true); return replayCasP(consumer, producer, compute_at_axis, root_map, opt); @@ -831,7 +831,7 @@ int64_t TransformReplay::getMatchedLeafPosWithoutReplayPasC( // Allow replay through indexing exprs const auto pairwise_map = - PairwiseRootDomainMap(producer, consumer).mapIndexedDomains(true); + PairwiseLogicalDomainMap(producer, consumer).mapIndexedDomains(true); id_map c2p_root_map = pairwise_map.mapConsumerToProducer(); // IterDomains in `consumer` root also in `producer` root @@ -903,7 +903,7 @@ int64_t TransformReplay::getMatchedLeafPosWithoutReplayCasP( // Allow replay through indexing exprs const auto pairwise_map = - PairwiseRootDomainMap(producer, consumer).mapIndexedDomains(true); + PairwiseLogicalDomainMap(producer, consumer).mapIndexedDomains(true); id_map p2c_root_map = pairwise_map.mapProducerToConsumer(); // IterDomains in `producer` root that are not reduction diff --git a/csrc/transform_replay.h b/csrc/transform_replay.h index 99b87da8a63..ec30f3a9ec5 100644 --- a/csrc/transform_replay.h +++ b/csrc/transform_replay.h @@ -127,7 +127,7 @@ namespace nvfuser { class TensorDomain; class TensorView; -class RootDomainMap; +class LogicalDomainMap; struct TransformReplayOptions { // In theory, it makes more sense to have skip_target_swizzle = true by @@ -203,7 +203,7 @@ class NVF_API TransformReplay { const TensorView* producer, const TensorView* consumer, int64_t consumer_compute_at_axis, - const RootDomainMap& root_map, + const LogicalDomainMap& root_map, TransformReplayOptions opt = {}); // Replay producer as consumer, returns {replayed_consumer_domain, @@ -219,7 +219,7 @@ class NVF_API TransformReplay { const TensorView* consumer, const TensorView* producer, int64_t producer_compute_at_axis, - const RootDomainMap& root_map, + const LogicalDomainMap& root_map, TransformReplayOptions opt = {}); // Self replay. @@ -269,7 +269,7 @@ class NVF_API TransformReplay { }; class NVF_API TransformPropagator - : public MaxRootDomainInfoSpanningTree::Propagator { + : public MaxLogicalDomainInfoSpanningTree::Propagator { protected: std::unordered_map replayed_pos_; @@ -281,7 +281,7 @@ class NVF_API TransformPropagator }; struct MostInlinedTransformPropagator - : public MaxRootDomainInfoSpanningTree::Propagator { + : public MaxLogicalDomainInfoSpanningTree::Propagator { void propagateC2P(TensorView* from, TensorView* to) override; void propagateP2C(TensorView* from, TensorView* to) override; void propagateSibling(TensorView* from, TensorView* to) override; diff --git a/tests/cpp/test_ca_root_domain_map.cpp b/tests/cpp/test_ca_root_domain_map.cpp index 3861eb780ae..f961d3cc909 100644 --- a/tests/cpp/test_ca_root_domain_map.cpp +++ b/tests/cpp/test_ca_root_domain_map.cpp @@ -8,18 +8,18 @@ #include #include +#include #include -#include #include namespace nvfuser { -using CaRootDomainMapTest = NVFuserTest; +using CaLogicalDomainMapTest = NVFuserTest; namespace { void checkIdMapped( - ComputeAtRootDomainMap& root_map, + ComputeAtLogicalDomainMap& root_map, TensorView* v0, IterDomain* id0, TensorView* v1, @@ -57,7 +57,7 @@ void checkIdMapped( TensorView* v1, const std::vector& root1, const std::vector should_map1) { - ComputeAtRootDomainMap map; + ComputeAtLogicalDomainMap map; map.build(); NVF_ERROR(root0.size() == should_map0.size()); NVF_ERROR(root1.size() == should_map1.size()); @@ -94,7 +94,7 @@ void checkIdMapped( } // namespace -TEST_F(CaRootDomainMapTest, FusionRootMappingBasic_CUDA) { +TEST_F(CaLogicalDomainMapTest, FusionRootMappingBasic_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -148,7 +148,7 @@ TEST_F(CaRootDomainMapTest, FusionRootMappingBasic_CUDA) { checkIdMapped(tv4, tv4->getLogicalDomain(), tv5, tv5->getLogicalDomain()); } -TEST_F(CaRootDomainMapTest, FusionRootMappingRfactor_CUDA) { +TEST_F(CaLogicalDomainMapTest, FusionRootMappingRfactor_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -232,7 +232,7 @@ TEST_F(CaRootDomainMapTest, FusionRootMappingRfactor_CUDA) { {true, true, false}); } -TEST_F(CaRootDomainMapTest, FusionRootMappingReductionDependency1_CUDA) { +TEST_F(CaLogicalDomainMapTest, FusionRootMappingReductionDependency1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -259,7 +259,7 @@ TEST_F(CaRootDomainMapTest, FusionRootMappingReductionDependency1_CUDA) { {true, false}); } -TEST_F(CaRootDomainMapTest, FusionRootMappingReductionDependency2_CUDA) { +TEST_F(CaLogicalDomainMapTest, FusionRootMappingReductionDependency2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -293,7 +293,7 @@ TEST_F(CaRootDomainMapTest, FusionRootMappingReductionDependency2_CUDA) { checkIdMapped(tv2, tv2->getLogicalDomain(), tv3, tv3->getLogicalDomain()); } -TEST_F(CaRootDomainMapTest, FusionRootMappingReductionDependency3_CUDA) { +TEST_F(CaLogicalDomainMapTest, FusionRootMappingReductionDependency3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -322,7 +322,7 @@ TEST_F(CaRootDomainMapTest, FusionRootMappingReductionDependency3_CUDA) { {true, false}); } -TEST_F(CaRootDomainMapTest, FusionRootMappingReductionDependency4_CUDA) { +TEST_F(CaLogicalDomainMapTest, FusionRootMappingReductionDependency4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -367,7 +367,9 @@ TEST_F(CaRootDomainMapTest, FusionRootMappingReductionDependency4_CUDA) { } // Reproducer of issue #749 -TEST_F(CaRootDomainMapTest, FusionRootMappingReductionDependency5_CUDA_CUDA) { +TEST_F( + CaLogicalDomainMapTest, + FusionRootMappingReductionDependency5_CUDA_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -425,7 +427,9 @@ TEST_F(CaRootDomainMapTest, FusionRootMappingReductionDependency5_CUDA_CUDA) { } // Similar to RootMappingReductionDependency5 but with rFactor -TEST_F(CaRootDomainMapTest, FusionRootMappingReductionDependency6_CUDA_CUDA) { +TEST_F( + CaLogicalDomainMapTest, + FusionRootMappingReductionDependency6_CUDA_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -500,7 +504,7 @@ TEST_F(CaRootDomainMapTest, FusionRootMappingReductionDependency6_CUDA_CUDA) { } TEST_F( - CaRootDomainMapTest, + CaLogicalDomainMapTest, FusionRootMappingMultipleBroadcastWithNoCommonConsumer_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -535,7 +539,7 @@ TEST_F( {false, true}); } -TEST_F(CaRootDomainMapTest, FusionRootMappingBroadcastNonUniqueSize_CUDA) { +TEST_F(CaLogicalDomainMapTest, FusionRootMappingBroadcastNonUniqueSize_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -622,7 +626,7 @@ TEST_F(CaRootDomainMapTest, FusionRootMappingBroadcastNonUniqueSize_CUDA) { {true, false}); } -TEST_F(CaRootDomainMapTest, FusionRootMappingBroadcast_CUDA) { +TEST_F(CaLogicalDomainMapTest, FusionRootMappingBroadcast_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -662,7 +666,7 @@ TEST_F(CaRootDomainMapTest, FusionRootMappingBroadcast_CUDA) { } // Repro of issue #1950 -TEST_F(CaRootDomainMapTest, FusionRootMappingRepro1950_CUDA) { +TEST_F(CaLogicalDomainMapTest, FusionRootMappingRepro1950_CUDA) { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeSymbolicTensor(3); @@ -687,7 +691,7 @@ TEST_F(CaRootDomainMapTest, FusionRootMappingRepro1950_CUDA) { fusion.addOutput(tv5); fusion.addOutput(tv4); - ComputeAtRootDomainMap root_map; + ComputeAtLogicalDomainMap root_map; root_map.build(); checkIdMapped(root_map, tv4, tv4->axis(-1), tv9, tv9->axis(-1), false); @@ -701,7 +705,9 @@ TEST_F(CaRootDomainMapTest, FusionRootMappingRepro1950_CUDA) { // After fix, there are two persistent buffers and can be further // reduced to one with a following step-2 to fix the issue in resolution // points detection. -TEST_F(CaRootDomainMapTest, FusionRootMappingConsumerMappedWithReductionInput) { +TEST_F( + CaLogicalDomainMapTest, + FusionRootMappingConsumerMappedWithReductionInput) { std::unique_ptr fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); @@ -732,7 +738,7 @@ TEST_F(CaRootDomainMapTest, FusionRootMappingConsumerMappedWithReductionInput) { // tv8 is a consumer of the reduction output. // If tv9 is mapped with tv2, we can't map tv8 and tv9 because tv9 is in the // pre-reduction set through tv2 and tv8 is in the post-reduction set. - ComputeAtRootDomainMap root_map; + ComputeAtLogicalDomainMap root_map; root_map.build(); checkIdMapped(root_map, tv2, tv2->axis(1), tv9, tv9->axis(1), true); checkIdMapped(root_map, tv7, tv7->axis(1), tv8, tv8->axis(1), false); diff --git a/tests/cpp/test_circular_buffering.cpp b/tests/cpp/test_circular_buffering.cpp index c49cde84c0c..379e8b50a74 100644 --- a/tests/cpp/test_circular_buffering.cpp +++ b/tests/cpp/test_circular_buffering.cpp @@ -33,7 +33,7 @@ TEST_F(CircularBufferingTest, CircularBuffering1) { tv3->split(-1, 128); tv3->split(-1, 32); TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); tv0->computeAt(tv3, 1); @@ -70,7 +70,7 @@ TEST_F(CircularBufferingTest, CircularBuffering2) { tv3->split(-1, 128); tv3->split(-1, 32); TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); tv0->computeAt(tv3, -1); @@ -109,7 +109,7 @@ TEST_F(CircularBufferingTest, CircularBuffering3) { tv3->split(-1, 128); tv3->split(-1, 32); TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); tv0->computeAt(tv3, 1); @@ -157,7 +157,7 @@ TEST_F(CircularBufferingTest, CircularBuffering4) { tv3->split(-1, 32); tv3->split(-1, 8); TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); tv0->computeAt(tv3, 2); tv2->computeAt(tv3, -1); @@ -198,7 +198,7 @@ TEST_F(CircularBufferingTest, CircularBuffering5) { tv2->split(-1, 32); tv2->split(-1, 8); TransformPropagatorWithCheck propagator(tv2); - MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv2).traverse(&propagator); tv0->computeAt(tv2, 2); tv1->computeAt(tv2, -1); @@ -241,7 +241,7 @@ TEST_F(CircularBufferingTest, CircularBuffering6) { tv3->split(-2, 4); tv3->split(-2, 2); TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); tv0->computeAt(tv3, 1); tv2->computeAt(tv3, -1); @@ -278,7 +278,7 @@ TEST_F(CircularBufferingTest, CircularBuffering7) { tv2->split(-1, 128); tv2->split(-1, 4); TransformPropagatorWithCheck propagator(tv2); - MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv2).traverse(&propagator); tv1->computeAt(tv2, 2); @@ -318,7 +318,7 @@ TEST_F(CircularBufferingTest, CircularBuffering8) { tv4->split(0, 32); tv4->split(0, 4); TransformPropagatorWithCheck propagator(tv4); - MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv4).traverse(&propagator); tv0->computeAt(tv4, 1); tv1->computeAt(tv4, 1); @@ -359,7 +359,7 @@ TEST_F(CircularBufferingTest, CircularBuffering9) { out->split(0, 32); out->split(0, 4); TransformPropagatorWithCheck propagator(out); - MaxRootDomainInfoSpanningTree(out).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(out).traverse(&propagator); tv2->setMemoryType(MemoryType::Shared); @@ -427,7 +427,7 @@ TEST_F(CircularBufferingTest, SmemBlockGemmCacheCircularBuffer) { auto tv6_rf = tv6->rFactor({-1}); TransformPropagatorWithCheck propagator(tv6_rf); - MaxRootDomainInfoSpanningTree(tv6_rf).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv6_rf).traverse(&propagator); tv0->computeAt(tv6, 3); tv1->computeAt(tv6, 3); diff --git a/tests/cpp/test_expr_sort.cpp b/tests/cpp/test_expr_sort.cpp index 6bf7dac7451..9280a9b7d00 100644 --- a/tests/cpp/test_expr_sort.cpp +++ b/tests/cpp/test_expr_sort.cpp @@ -63,7 +63,7 @@ TEST_F(ExprSortTest, IndirectNormalizationWithZeroDimTensors) { // // This fusion may appear to have the persistent pattern, but it // isn't the case. The reduction output, tv3, is never used with the - // reduciton input, tv2. So, ComputeAtRootDomainMap detects no + // reduciton input, tv2. So, ComputeAtLogicalDomainMap detects no // domains that should not be inlined, which is correct. // // However, this could turn into a persistent kernel if tv7 and tv8 @@ -101,7 +101,7 @@ TEST_F(ExprSortTest, IndirectNormalizationWithZeroDimTensors) { tv3->split(0, 4); auto tv11 = tv3->rFactor({1}); - MaxRootDomainInfoSpanningTree tree(tv11); + MaxLogicalDomainInfoSpanningTree tree(tv11); TransformPropagator tp(tv11); tree.traverse(&tp); @@ -146,7 +146,7 @@ TEST_F(ExprSortTest, IndirectInnerNormalization) { tv3->split(1, 4); auto tv11 = tv3->rFactor({-1}); - MaxRootDomainInfoSpanningTree tree(tv11); + MaxLogicalDomainInfoSpanningTree tree(tv11); TransformPropagator tp(tv11); tree.traverse(&tp); diff --git a/tests/cpp/test_gather.cpp b/tests/cpp/test_gather.cpp index 1bc606cfeb0..988beb6bc94 100644 --- a/tests/cpp/test_gather.cpp +++ b/tests/cpp/test_gather.cpp @@ -532,7 +532,7 @@ TEST_F(IndexingOpTest, TakeAlongAxisIntermediateTensorPointwise1_CUDA) { tv4->split(1, 10); TransformPropagator propagator(tv4); - MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv4).traverse(&propagator); // All of the tensors should have the split by 2, except for tv1. for (auto tv : ir_utils::allTvsExcept(&fusion, {tv1})) { diff --git a/tests/cpp/test_gpu1.cpp b/tests/cpp/test_gpu1.cpp index f044253dd78..ba7e25abd93 100644 --- a/tests/cpp/test_gpu1.cpp +++ b/tests/cpp/test_gpu1.cpp @@ -29,8 +29,8 @@ #include #include #include +#include #include -#include #include #include #include @@ -94,7 +94,7 @@ TEST_F(NVFuserTest, FusionIrGraphGenerator_CUDA) { tv6->merge(0); tv6->split(0, 4); TransformPropagatorWithCheck propagator(tv6); - MaxRootDomainInfoSpanningTree(tv6).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv6).traverse(&propagator); tv4->axis(2)->parallelize(ParallelType::BIDy); tv6->axis(0)->parallelize(ParallelType::BIDx); @@ -154,7 +154,7 @@ TEST_F(NVFuserTest, FusionClear_CUDA) { tv3->split(0, 4); TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); tv3->axis(0)->parallelize(ParallelType::BIDx); tv2->axis(1)->parallelize(ParallelType::Unroll); @@ -195,7 +195,7 @@ TEST_F(NVFuserTest, FusionClear_CUDA) { tv3->reorder({{2, 0}, {3, 1}, {0, 3}}); // tv3 [i0outer, i0inner{4}, i1, i2] TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); tv3->axis(1)->parallelize(ParallelType::BIDx); @@ -237,7 +237,7 @@ TEST_F(NVFuserTest, FusionCopy_CUDA) { tv3->split(-1, 4); tv3->reorder({{2, 0}, {3, 1}, {0, 3}}); TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); tv3->axis(0)->parallelize(ParallelType::BIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); @@ -312,7 +312,7 @@ TEST_F(NVFuserTest, FusionMove_CUDA) { tv3->split(-1, 4); tv3->reorder({{2, 0}, {3, 1}, {0, 3}}); TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); tv3->axis(0)->parallelize(ParallelType::BIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); @@ -807,7 +807,7 @@ TEST_F(NVFuserTest, FusionOuterSplit_CUDA) { tv2->reorder({{0, 1}, {1, 0}}); // I0*I1*I2o{4}i{2}, [I0*I1*I2o{4}o, I2i] TransformPropagatorWithCheck propagator(tv2); - MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv2).traverse(&propagator); inlineMost(); @@ -849,7 +849,7 @@ TEST_F(NVFuserTest, FusionCodeGen_CUDA) { tv2 = tv2->reorder({{0, 1}, {1, 0}, {3, 2}}); //[I0i{4}*I1, I0o, I2i{2}, I2o] TransformPropagatorWithCheck propagator(tv2); - MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv2).traverse(&propagator); inlineMost(); @@ -887,7 +887,7 @@ TEST_F(NVFuserTest, FusionCodeGen2_CUDA) { tv3->reorder({{2, 0}, {3, 1}, {0, 3}}); // I0o, I0i{4}, I1, I2] TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); tv3->axis(0)->parallelize(ParallelType::BIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); @@ -940,7 +940,7 @@ TEST_F(NVFuserTest, FusionSimplePWise_CUDA) { tv3->split(0, 128); tv3->split(0, 4); TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); // Parallelize TV3 tv3->axis(0)->parallelize(ParallelType::BIDx); @@ -997,7 +997,7 @@ TEST_F(NVFuserTest, FusionSimplePWiseDtypeComplex_CUDA) { tv3->split(0, 128); tv3->split(0, 4); TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); // Parallelize TV3 tv3->axis(0)->parallelize(ParallelType::BIDx); @@ -1047,7 +1047,7 @@ TEST_F(NVFuserTest, FusionExecKernel_CUDA) { tv3->split(0, 128); tv3->split(0, 4); TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); // Parallelize TV3 tv3->axis(0)->parallelize(ParallelType::BIDx); diff --git a/tests/cpp/test_gpu2.cpp b/tests/cpp/test_gpu2.cpp index 972a9bb67b4..f8a9dd0bff1 100644 --- a/tests/cpp/test_gpu2.cpp +++ b/tests/cpp/test_gpu2.cpp @@ -29,8 +29,8 @@ #include #include #include +#include #include -#include #include #include #include @@ -3471,7 +3471,7 @@ TEST_F(NVFuserTest, FusionSimpleVectorizeUnroll_CUDA) { // [bidx, unswitch, vectorize{2}, unroll{2}, tidx] TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); scheduler_utils::parallelizeAllLike(tv3); tv0_cache->axis(2)->parallelize(ParallelType::Vectorize); @@ -3891,7 +3891,7 @@ TEST_F(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeSymbolicPass_CUDA) { tv2->split(-1, kNumElems); tv2->split(-1, kVecSize); TransformPropagatorWithCheck propagator(tv2); - MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv2).traverse(&propagator); c0->computeAt(tv2, -2); c1->computeAt(tv2, -2); @@ -3952,7 +3952,7 @@ TEST_F(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeSymbolicFail_CUDA) { tv2->split(-1, kNumElems); tv2->split(-1, kVecSize); TransformPropagatorWithCheck propagator(tv2); - MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv2).traverse(&propagator); c0->computeAt(tv2, -2); c1->computeAt(tv2, -2); @@ -4616,7 +4616,7 @@ TEST_F(NVFuserTest, FusionValidateParallelize6_CUDA) { tv4->split(0, 2); TransformPropagatorWithCheck propagator(tv4); - MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv4).traverse(&propagator); tv0->computeAt(tv2, 2); tv3->computeAt(tv4, 2); @@ -4690,7 +4690,7 @@ TEST_F(NVFuserTest, FusionValidateParallelize8_CUDA) { tv3->split(2, 16); tv3->axis(-2)->parallelize(ParallelType::TIDx); - MaxRootDomainInfoSpanningTree tree(tv3); + MaxLogicalDomainInfoSpanningTree tree(tv3); TransformPropagator tp(tv3); tree.traverse(&tp); scheduler_utils::parallelizeAllLike(tv3); @@ -4741,7 +4741,7 @@ TEST_F(NVFuserTest, FusionValidateParallelize9_CUDA) { tv4->merge(0)->split(0, 4); - MaxRootDomainInfoSpanningTree tree(tv4); + MaxLogicalDomainInfoSpanningTree tree(tv4); TransformPropagator tp(tv4); tree.traverse(&tp); @@ -4783,7 +4783,7 @@ TEST_F(NVFuserTest, FusionValidateParallelize10_CUDA) { tv5->merge(0)->split(0, 4); TransformPropagatorWithCheck propagator(tv5); - MaxRootDomainInfoSpanningTree(tv5).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv5).traverse(&propagator); // tv2 has no CA tv3->computeAt(tv5, 1); @@ -4833,7 +4833,7 @@ TEST_F(NVFuserTest, FusionValidateParallelize11_CUDA) { tv5->merge(0)->split(0, 4)->split(0, 2); TransformPropagatorWithCheck propagator(tv5); - MaxRootDomainInfoSpanningTree(tv5).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv5).traverse(&propagator); tv2->computeAt(tv5, 1); @@ -6081,7 +6081,7 @@ TEST_F(NVFuserTest, FusionSimpleWarp_CUDA) { tv1->split(1, 32); auto tv1_rf = tv1->rFactor({1}); TransformPropagatorWithCheck propagator(tv1_rf); - MaxRootDomainInfoSpanningTree(tv1_rf).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv1_rf).traverse(&propagator); tv1_rf->axis(-1)->parallelize(ParallelType::TIDx); tv1->axis(0)->parallelize(ParallelType::BIDx); tv1->axis(-1)->parallelize(ParallelType::TIDx); @@ -6128,7 +6128,7 @@ TEST_F(NVFuserTest, FusionSimpleWarpPad_CUDA) { tv1->axis(-1)->parallelize(ParallelType::TIDx); tv1->axis(-1)->padToMultipleOfWarp(32); TransformPropagatorWithCheck propagator(tv1_rf); - MaxRootDomainInfoSpanningTree(tv1_rf).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv1_rf).traverse(&propagator); tv0->axis(-1)->parallelize(ParallelType::TIDx); tv0->axis(-1)->padToMultipleOfWarp(32); tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); @@ -6177,7 +6177,7 @@ TEST_F(NVFuserTest, FusionWarpPadMergeSplit_CUDA) { tv1->axis(-1)->parallelize(ParallelType::TIDx); tv1->axis(-1)->padToMultipleOfWarp(); TransformPropagatorWithCheck propagator(tv1_rf); - MaxRootDomainInfoSpanningTree(tv1_rf).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv1_rf).traverse(&propagator); tv0->axis(-1)->parallelize(ParallelType::TIDx); tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); tv2->axis(-1)->parallelize(ParallelType::TIDx); @@ -6219,7 +6219,7 @@ TEST_F(NVFuserTest, FusionSerialWarpReduction_CUDA) { tv1->axis(-1)->parallelize(ParallelType::TIDx); tv1->axis(-1)->padToMultipleOfWarp(); TransformPropagatorWithCheck propagator(tv1); - MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv1).traverse(&propagator); tv0->axis(-1)->parallelize(ParallelType::TIDx); tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); tv2->axis(-1)->parallelize(ParallelType::TIDx); @@ -6264,7 +6264,7 @@ TEST_F(NVFuserTest, FusionTrivialWarpReduction_CUDA) { tv1->axis(-2)->parallelize(ParallelType::TIDx); tv1->axis(-2)->padToMultipleOfWarp(); TransformPropagatorWithCheck propagator(tv1_rf); - MaxRootDomainInfoSpanningTree(tv1_rf).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv1_rf).traverse(&propagator); tv0->axis(-2)->parallelize(ParallelType::TIDx); tv0_cache->axis(-2)->parallelize(ParallelType::TIDx); tv2->axis(-2)->parallelize(ParallelType::TIDx); @@ -6312,7 +6312,7 @@ TEST_F(NVFuserTest, FusionMultipleDimBinding_CUDA) { tv1->axis(-1)->parallelize(ParallelType::TIDx); tv1->axis(-1)->padToMultipleOfWarp(32); TransformPropagatorWithCheck propagator(tv1_rf); - MaxRootDomainInfoSpanningTree(tv1_rf).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv1_rf).traverse(&propagator); tv0->axis(-1)->parallelize(ParallelType::TIDx); tv0->axis(-1)->padToMultipleOfWarp(32); tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); @@ -6392,7 +6392,7 @@ TEST_F(NVFuserTest, FusionWarpMutipleThreadDim_CUDA) { tv2_rf->axis(-1)->padToMultipleOfWarp(); TransformPropagatorWithCheck propagator(tv2_rf); - MaxRootDomainInfoSpanningTree(tv2_rf).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv2_rf).traverse(&propagator); tv0->axis(-1)->parallelize(ParallelType::TIDx); tv1->axis(-1)->parallelize(ParallelType::TIDx); @@ -6439,7 +6439,7 @@ TEST_F(NVFuserTest, FusionWarpReduceUnrollOuterLoop_CUDA) { tv1->axis(-1)->padToMultipleOfWarp(); tv1->axis(1)->parallelize(ParallelType::Unroll); TransformPropagatorWithCheck propagator(tv1_rf); - MaxRootDomainInfoSpanningTree(tv1_rf).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv1_rf).traverse(&propagator); tv0->axis(-1)->parallelize(ParallelType::TIDx); tv0->axis(1)->parallelize(ParallelType::Unroll); tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); @@ -7921,7 +7921,7 @@ TEST_F(NVFuserTest, FusionFloatPow_CUDA) { tv1->axis(1)->parallelize(ParallelType::TIDx); TransformPropagatorWithCheck propagator(tv1); - MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv1).traverse(&propagator); scheduler_utils::parallelizeAllLike(tv1, {tv2, tv3, tv4, tv5, tv6}); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); diff --git a/tests/cpp/test_gpu3.cpp b/tests/cpp/test_gpu3.cpp index 6214930518b..92c2c36269f 100644 --- a/tests/cpp/test_gpu3.cpp +++ b/tests/cpp/test_gpu3.cpp @@ -31,8 +31,8 @@ #include #include #include +#include #include -#include #include #include #include @@ -384,7 +384,7 @@ TEST_F(NVFuserTest, FusionNonDivisibleSplitVectorize2_CUDA) { tv3->split(0, 8, false); tv3->split(1, 4); TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); tv3->axis(1)->parallelize(ParallelType::TIDx); scheduler_utils::parallelizeAllLike(tv3, {tv1, tv2}); @@ -464,7 +464,7 @@ TEST_F(NVFuserTest, FusionIntermediateTensorVectorize_CUDA) { tv3->split(-1, 4); TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); tv1->computeAt(tv3, -2); @@ -970,7 +970,7 @@ TEST_F(NVFuserTest, FusionTestGridComm2_CUDA) { tv4->split(0, 2); TransformPropagatorWithCheck propagator(tv4); - MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv4).traverse(&propagator); tv3->computeAt(tv4, 1); @@ -1320,7 +1320,7 @@ TEST_F(NVFuserTest, FusionContigIndexingWithBroadcast_CUDA) { tv3->merge(0); TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); tv2->setMemoryType(MemoryType::Local); @@ -1371,7 +1371,7 @@ TEST_F(NVFuserTest, FusionVectorizeContigIndexValidationFail2_CUDA) { tv4->merge(0, 1); tv4->split(0, 4); TransformPropagatorWithCheck propagator(tv4); - MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv4).traverse(&propagator); tv0->computeAt(tv4, -2); tv1->computeAt(tv4, -2); @@ -1417,7 +1417,7 @@ TEST_F(NVFuserTest, FusionVectorizeContigIndexWithBroadcast_CUDA) { // transformations. It would create temporary IterDomains, and the // validation should still be able to detect vectorization by 4 is valid. // TransformPropagatorWithCheck propagator(tv3); - // MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + // MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); tv2->merge(1, 2); tv2->merge(0, 1); @@ -1973,7 +1973,7 @@ TEST_F(NVFuserTest, FusionPropagateParallelTypesToSiblings_CUDA) { tv_avg->split(0, 128); TransformPropagatorWithCheck propagator(tv_avg); - MaxRootDomainInfoSpanningTree(tv_avg).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv_avg).traverse(&propagator); tv_avg->axis(0)->parallelize(ParallelType::BIDx); tv_avg->axis(1)->parallelize(ParallelType::TIDx); @@ -2017,8 +2017,8 @@ TEST_F(NVFuserTest, FusionPropagateParallelTypesToSiblings_CUDA) { testValidate(fe.kernel(), outputs, {t0}, {t0.mean({0})}, __LINE__, __FILE__); } -// Test ExactRootDomainMap -TEST_F(NVFuserTest, FusionExactRootDomainMap_CUDA) { +// Test ExactLogicalDomainMap +TEST_F(NVFuserTest, FusionExactLogicalDomainMap_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -2036,7 +2036,7 @@ TEST_F(NVFuserTest, FusionExactRootDomainMap_CUDA) { fusion.addOutput(tv5); fusion.addOutput(tv6); - const auto exact_map = ExactRootDomainMap(&fusion); + const auto exact_map = ExactLogicalDomainMap(&fusion); // In the exact mapping, the broadcast domain introduced at tv2 is // only mapped with the another one in tv3, which is just transposed @@ -2095,7 +2095,7 @@ TEST_F(NVFuserTest, FusionIncompleteConcreteID_CUDA) { tv6->merge(0); TransformPropagatorWithCheck propagator(tv6); - MaxRootDomainInfoSpanningTree(tv6).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv6).traverse(&propagator); tv0->computeAt(tv6, -1, ComputeAtMode::MostInlined); tv1->computeAt(tv6, -1, ComputeAtMode::MostInlined); @@ -2156,7 +2156,7 @@ TEST_F(NVFuserTest, FusionTestReEntrantGridWelford_CUDA) { // iV25{4}] TransformPropagatorWithCheck propagator(reduction_tv); - MaxRootDomainInfoSpanningTree(reduction_tv).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(reduction_tv).traverse(&propagator); auto rfactor_tv = ir_utils::rFactorHelper(reduction_tv, {4}); scheduler_utils::parallelizeAllLike(rfactor_tv); @@ -2887,7 +2887,7 @@ TEST_F(NVFuserTest, FusionTransformPropagateSibling_CUDA) { auto var_sum_rf = ir_utils::rFactorHelper(tvs.var_sum, {1, 4}); TransformPropagatorWithCheck propagator(var_sum_rf); - MaxRootDomainInfoSpanningTree(var_sum_rf).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(var_sum_rf).traverse(&propagator); auto rf_tvs = ir_utils::producerTvsOf(tvs.var_sum); @@ -2946,8 +2946,8 @@ TEST_F(NVFuserTest, FusionTransformPropagateSelectorSibling_CUDA) { } selector2(tv0); TransformPropagatorWithCheck propagator(var_sum_rf); - MaxRootDomainInfoSpanningTree good_path(var_sum_rf, &selector1); - MaxRootDomainInfoSpanningTree bad_path(var_sum_rf, &selector2); + MaxLogicalDomainInfoSpanningTree good_path(var_sum_rf, &selector1); + MaxLogicalDomainInfoSpanningTree bad_path(var_sum_rf, &selector2); auto rf_tvs = ir_utils::producerTvsOf(tvs.var_sum); @@ -2984,7 +2984,7 @@ TEST_F(NVFuserTest, FusionTransformPropagatePosition_CUDA) { tv0->merge(2); tv0->merge(0); TransformPropagatorWithCheck propagator(tv0); - MaxRootDomainInfoSpanningTree(tv0).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv0).traverse(&propagator); NVF_CHECK(tv1->nDims() == 4); } @@ -3089,7 +3089,7 @@ TEST_F(NVFuserTest, FusionTransformPropagatorSelector_CUDA) { } selector(tv0, tv3); TransformPropagatorWithCheck propagator(tv2); - MaxRootDomainInfoSpanningTree(tv2, &selector).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv2, &selector).traverse(&propagator); NVF_CHECK(tv0->nDims() == 2); NVF_CHECK(tv1->nDims() == 1); @@ -3113,14 +3113,14 @@ TEST_F(NVFuserTest, FusionTransformPropagatorPos_CUDA) { tv1->split(-1, 5); TransformPropagatorWithCheck propagator(tv1, 2); - MaxRootDomainInfoSpanningTree(tv1, 2).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv1, 2).traverse(&propagator); auto expect = makeConcreteTensor({22, 105}); expect->split(0, 2); NVF_CHECK(TransformReplay::fullSelfMatching(expect, tv0)); } -TEST_F(NVFuserTest, FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA) { +TEST_F(NVFuserTest, FusionMaxLogicalDomainInfoSpanningTreePrintTwice_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -3155,7 +3155,7 @@ TEST_F(NVFuserTest, FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA) { printer1.ss << std::endl; printer2.ss << std::endl; - MaxRootDomainInfoSpanningTree path(tv1); + MaxLogicalDomainInfoSpanningTree path(tv1); path.traverse(&printer1); path.traverse(&printer2); @@ -3185,11 +3185,11 @@ TEST_F(NVFuserTest, FusionTransformPropagatorNoOverwrite_CUDA) { tv2->split(1, 2); tv2->split(0, 4); - MaxRootDomainInfoSpanningTree path1(tv2); + MaxLogicalDomainInfoSpanningTree path1(tv2); TransformPropagatorWithCheck propagator1(tv2); path1.traverse(&propagator1); - MaxRootDomainInfoSpanningTree path2(tv0); + MaxLogicalDomainInfoSpanningTree path2(tv0); TransformPropagatorWithCheck propagator2(tv0); path2.traverse(&propagator2); @@ -3270,7 +3270,7 @@ TEST_F(NVFuserTest, FusionSkipReplay_CUDA) { tv3->split(1, 2, false); TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); } { @@ -3287,7 +3287,7 @@ TEST_F(NVFuserTest, FusionSkipReplay_CUDA) { tv0->split(1, 2, false); TransformPropagatorWithCheck propagator(tv0); - MaxRootDomainInfoSpanningTree(tv0).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv0).traverse(&propagator); } } @@ -3414,7 +3414,7 @@ TEST_F(NVFuserTest, FusionInsertMagicZero1_CUDA) { tv2->merge(0); TransformPropagatorWithCheck propagator(tv2); - MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv2).traverse(&propagator); tv0->computeAt(tv2, 1); @@ -4073,7 +4073,7 @@ TEST_F(NVFuserTest, FusionMergeBroadcastingTrivialReduction1_CUDA) { tv0->merge(0); - MaxRootDomainInfoSpanningTree tree(tv0); + MaxLogicalDomainInfoSpanningTree tree(tv0); TransformPropagatorWithCheck tp(tv0); tree.traverse(&tp); @@ -4177,7 +4177,7 @@ TEST_F(NVFuserTest, FusionReplayTrivialReductionAndBroadcast2_CUDA) { tv0->merge(-2, -1)->merge(-2, -1)->split(0, 4); - MaxRootDomainInfoSpanningTree tree(tv0); + MaxLogicalDomainInfoSpanningTree tree(tv0); TransformPropagatorWithCheck tp(tv0); tree.traverse(&tp); @@ -4625,7 +4625,7 @@ TEST_F(NVFuserTest, FusionSqueezeTransformPropagation_CUDA) { tv3->merge(0); tv3->merge(0); - MaxRootDomainInfoSpanningTree tree(tv3); + MaxLogicalDomainInfoSpanningTree tree(tv3); TransformPropagatorWithCheck tp(tv3); tree.traverse(&tp); @@ -4654,7 +4654,7 @@ TEST_F(NVFuserTest, FusionSqueezeInlining_CUDA) { tv0->split(0, 128); { - MaxRootDomainInfoSpanningTree tree(tv0); + MaxLogicalDomainInfoSpanningTree tree(tv0); TransformPropagatorWithCheck tp(tv0); tree.traverse(&tp); NVF_CHECK(tv2->nDims() == 2); @@ -4665,7 +4665,7 @@ TEST_F(NVFuserTest, FusionSqueezeInlining_CUDA) { { // The propagation here should be a no-op, I am adding it here just to test // if transformation propagation works for squeeze on both direction. - MaxRootDomainInfoSpanningTree tree(tv2); + MaxLogicalDomainInfoSpanningTree tree(tv2); TransformPropagatorWithCheck tp(tv2); tree.traverse(&tp); NVF_CHECK(tv2->nDims() == 2); @@ -5000,7 +5000,7 @@ TEST_F(NVFuserTest, FusionPropagateVectorizePredicate_CUDA) { const int vec_factor = 4; tv1->split(-1, vec_factor); - MaxRootDomainInfoSpanningTree tree(tv1); + MaxLogicalDomainInfoSpanningTree tree(tv1); TransformPropagator tp(tv1); tree.traverse(&tp); @@ -5183,7 +5183,7 @@ TEST_F(NVFuserTest, FusionIssue2163ReproInvalidAlias_CUDA) { ref->split(-1, 8); ref->reorder({{0, 1}, {1, 0}, {2, 2}}); TransformPropagator propagator(ref); - MaxRootDomainInfoSpanningTree(ref).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(ref).traverse(&propagator); // Don't inline the innermost axes std::unordered_set uninlinable; @@ -5386,7 +5386,7 @@ TEST_F(NVFuserTest, FusionVectorizeWelford1_CUDA) { tv1->split(1, 4); - MaxRootDomainInfoSpanningTree tree(tv1); + MaxLogicalDomainInfoSpanningTree tree(tv1); TransformPropagator tp(tv1); tree.traverse(&tp); @@ -5456,7 +5456,7 @@ TEST_F(NVFuserTest, FusionVectorizeWelford2_CUDA) { tv1->reorder({{-2, 1}}); - MaxRootDomainInfoSpanningTree tree(tv1); + MaxLogicalDomainInfoSpanningTree tree(tv1); TransformPropagator tp(tv1); tree.traverse(&tp); @@ -5951,7 +5951,7 @@ TEST_F(NVFuserTest, FusionCompileIndexType_CUDA) { tv2->split(0, 256); tv2->split(0, 1024); - MaxRootDomainInfoSpanningTree tree(tv2); + MaxLogicalDomainInfoSpanningTree tree(tv2); TransformPropagator tp(tv2); tree.traverse(&tp); @@ -7846,7 +7846,7 @@ TEST_F(NVFuserTest, PredicateRNGOps) { tv1->axis(1)->parallelize(ParallelType::TIDx); TransformPropagator propagator(tv1); - MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv1).traverse(&propagator); tv9->axis(-1)->parallelize(ParallelType::Vectorize); diff --git a/tests/cpp/test_gpu_compute_with.cpp b/tests/cpp/test_gpu_compute_with.cpp index f6a8c32f136..7abf3e891f8 100644 --- a/tests/cpp/test_gpu_compute_with.cpp +++ b/tests/cpp/test_gpu_compute_with.cpp @@ -29,8 +29,8 @@ #include #include #include +#include #include -#include #include #include #include @@ -197,7 +197,7 @@ TEST_F(NVFuserTest, FusionComputeWith2_CUDA) { input_tv0->split(-1, vec); input_tv0->split(-2, tidx); - MaxRootDomainInfoSpanningTree tree(input_tv0); + MaxLogicalDomainInfoSpanningTree tree(input_tv0); TransformPropagatorWithCheck tp(input_tv0); tree.traverse(&tp); @@ -241,7 +241,7 @@ TEST_F(NVFuserTest, FusionComputeWith3_CUDA) { tv2->split(-1, 4); tv2->split(-2, 3); - MaxRootDomainInfoSpanningTree tree(tv2); + MaxLogicalDomainInfoSpanningTree tree(tv2); TransformPropagatorWithCheck tp(tv2); tree.traverse(&tp); @@ -283,7 +283,7 @@ TEST_F(NVFuserTest, FusionComputeWith4_CUDA) { tv2->split(0, 4); tv2->split(0, 32); - MaxRootDomainInfoSpanningTree tree(tv2); + MaxLogicalDomainInfoSpanningTree tree(tv2); TransformPropagatorWithCheck tp(tv2); tree.traverse(&tp); @@ -329,7 +329,7 @@ TEST_F(NVFuserTest, FusionComputeWith5_CUDA) { tv1->split(-1, 4); - MaxRootDomainInfoSpanningTree tree(tv1); + MaxLogicalDomainInfoSpanningTree tree(tv1); TransformPropagatorWithCheck tp(tv1); tree.traverse(&tp); @@ -423,7 +423,7 @@ TEST_F(NVFuserTest, FusionComputeWith6_CUDA) { auto tv3_rf = ir_utils::rFactorHelper(tv3, {-3, -2}); TransformPropagator propagator(tv3_rf); - MaxRootDomainInfoSpanningTree(tv3_rf).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3_rf).traverse(&propagator); scheduler_utils::parallelizeAllLike(tv3_rf, ir_utils::allTvs(&fusion)); tv1->axis(-1)->parallelize(ParallelType::Vectorize); diff --git a/tests/cpp/test_gpu_fused_reduction.cpp b/tests/cpp/test_gpu_fused_reduction.cpp index e1a8d8a93e7..757a2d8ab81 100644 --- a/tests/cpp/test_gpu_fused_reduction.cpp +++ b/tests/cpp/test_gpu_fused_reduction.cpp @@ -26,8 +26,8 @@ #include #include #include +#include #include -#include #include #include #include @@ -98,7 +98,7 @@ TEST_F(NVFuserTest, FusionGridAllreduce1_CUDA) { tv3->split(0, bidx); tv3->split(0, 1); // unswitch TransformPropagator propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); tv3->axis(0)->parallelize(ParallelType::BIDy); tv3->axis(2)->parallelize(ParallelType::BIDx); @@ -146,7 +146,7 @@ TEST_F(NVFuserTest, FusionGridAllreduce2_CUDA) { tv3->split(0, tidx); TransformPropagator propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); tv3->axis(0)->parallelize(ParallelType::BIDx); tv3->axis(1)->parallelize(ParallelType::TIDx); @@ -198,7 +198,7 @@ TEST_F(NVFuserTest, FusionGridAllreduce3_CUDA) { tv3->split(1, tidx); TransformPropagator propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); tv0->computeAt(tv3, 1); @@ -245,7 +245,7 @@ TEST_F(NVFuserTest, FusionGridAllreduce4_CUDA) { tv4->split(0, tidx); TransformPropagator propagator(tv4); - MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv4).traverse(&propagator); tv4->axis(0)->parallelize(ParallelType::BIDx); tv4->axis(1)->parallelize(ParallelType::TIDx); @@ -303,7 +303,7 @@ TEST_F(NVFuserTest, FusionGridAllreduce5_CUDA) { // Setup the reduction tv4->split(1, tidx); TransformPropagator propagator(tv4); - MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv4).traverse(&propagator); tv4->axis(1)->parallelize(ParallelType::BIDx); tv4->axis(2)->parallelize(ParallelType::TIDx); @@ -357,7 +357,7 @@ TEST_F(NVFuserTest, FusionGridAllreduce6_CUDA) { tv1->split(1, tidx); tv1->split(0, tidy); TransformPropagator propagator(tv1); - MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv1).traverse(&propagator); tv1->axis(0)->parallelize(ParallelType::BIDy); tv1->axis(1)->parallelize(ParallelType::TIDy); @@ -405,7 +405,7 @@ TEST_F(NVFuserTest, FusionGridAllreduceWelford1_CUDA) { tv5->split(0, tidx); TransformPropagator propagator(tv5); - MaxRootDomainInfoSpanningTree(tv5).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv5).traverse(&propagator); tv5->axis(0)->parallelize(ParallelType::BIDx); tv5->axis(1)->parallelize(ParallelType::TIDx); @@ -452,7 +452,7 @@ TEST_F(NVFuserTest, FusionGridAllreduceWelford2_CUDA) { tv3->split(1, tidx); TransformPropagator propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); tv0->computeAt(tv3, 1); @@ -558,7 +558,7 @@ TEST_F(NVFuserTest, FusionFusedReductionBatchnorm_CUDA) { {6, 9}}); TransformPropagator propagator(tv0); - MaxRootDomainInfoSpanningTree(tv0).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv0).traverse(&propagator); ir_utils::rFactorHelper(tvs.avg, {-5, -4, -3, -2, -1}); @@ -683,7 +683,7 @@ TEST_F(NVFuserTest, FusionGroupedReduction2_CUDA) { tv2->split(1, 128); TransformPropagator propagator(tv2); - MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv2).traverse(&propagator); tv0->computeAt(tv4, -1, ComputeAtMode::MostInlined); @@ -727,7 +727,7 @@ TEST_F(NVFuserTest, FusionGroupedReduction3_CUDA) { groupReductions({tv1, tv3}); tv1->split(1, 128); TransformPropagator propagator(tv1); - MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv1).traverse(&propagator); tv0->computeAt(tv5, -1, ComputeAtMode::MostInlined); @@ -1011,7 +1011,7 @@ TEST_F(NVFuserTest, FusionGroupAllreduce1_CUDA) { tv2->split(0, 128); TransformPropagator propagator(tv2); - MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv2).traverse(&propagator); tv2->axis(0)->parallelize(ParallelType::BIDx); tv2->axis(1)->parallelize(ParallelType::TIDx); @@ -1057,7 +1057,7 @@ TEST_F(NVFuserTest, FusionGroupAllreduce2_CUDA) { groupReductions({tv1, tv4}); tv1->split(1, tidx); TransformPropagator propagator(tv1); - MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv1).traverse(&propagator); tv0->computeAt(tv8, -1, ComputeAtMode::MostInlined); @@ -1112,7 +1112,7 @@ TEST_F(NVFuserTest, FusionGroupAllreduce3_CUDA) { tv1->split(0, 128); TransformPropagator propagator(tv1); - MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv1).traverse(&propagator); tv1->axis(0)->parallelize(ParallelType::BIDx); tv1->axis(1)->parallelize(ParallelType::TIDx); @@ -1165,7 +1165,7 @@ TEST_F(NVFuserTest, FusionGroupAllreduce4_CUDA) { reduction_tv->split(0, 128); TransformPropagator propagator(reduction_tv); - MaxRootDomainInfoSpanningTree(reduction_tv).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(reduction_tv).traverse(&propagator); reduction_tv->axis(0)->parallelize(ParallelType::BIDx); reduction_tv->axis(1)->parallelize(ParallelType::TIDx); @@ -1238,7 +1238,7 @@ TEST_F(NVFuserTest, FusionGroupAllreduce5_CUDA) { tv1->split(0, 128); TransformPropagator propagator(tv1); - MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv1).traverse(&propagator); tv1->axis(0)->parallelize(ParallelType::BIDx); tv1->axis(1)->parallelize(ParallelType::TIDx); @@ -1402,7 +1402,7 @@ TEST_F(NVFuserTest, FusionPersistentBNBackwardAllreduce_CUDA) { grad_input->axis(4)->parallelize(ParallelType::TIDx); TransformPropagator propagator(grad_input); - MaxRootDomainInfoSpanningTree(grad_input).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(grad_input).traverse(&propagator); auto rf_tensors = grad_output_sum->rFactor( {-1}, std::vector({grad_output_sum, dot_p})); @@ -1516,7 +1516,7 @@ TEST_F(NVFuserTest, FusionGroupedReductionReEntrant1_CUDA) { tv2->split(0, tidy); TransformPropagator propagator(tv2); - MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv2).traverse(&propagator); tv0_cache->axis(-1)->parallelize(ParallelType::Vectorize); @@ -1617,7 +1617,7 @@ TEST_F(NVFuserTest, FusionGroupedReductionChannelsLastBatchNormLike_CUDA) { ref->reorder({{3, 4}, {4, 3}}); TransformPropagator propagator(ref); - MaxRootDomainInfoSpanningTree(ref).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(ref).traverse(&propagator); auto rf_tvs = tv5->rFactor({-2}, {tv5, tv9}); auto tv5_rf = rf_tvs.at(0); @@ -1746,7 +1746,7 @@ TEST_F( ref->reorder({{3, 4}, {4, 3}}); TransformPropagator propagator(ref); - MaxRootDomainInfoSpanningTree(ref).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(ref).traverse(&propagator); auto rf_tvs = tv5->rFactor({-2}, {tv5, tv9}); auto tv5_rf = rf_tvs.at(0); @@ -1826,7 +1826,7 @@ TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduce1_CUDA) { tv1->split(1, tidx); tv1->split(0, tidy); TransformPropagator propagator(tv1); - MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv1).traverse(&propagator); tv1->axis(0)->parallelize(ParallelType::BIDy); tv1->axis(1)->parallelize(ParallelType::TIDy); @@ -1903,7 +1903,7 @@ TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduce2_CUDA) { tv1->split(1, tidx); tv1->split(0, tidy); TransformPropagator propagator(tv1); - MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv1).traverse(&propagator); tv1->axis(0)->parallelize(ParallelType::BIDy); tv1->axis(1)->parallelize(ParallelType::TIDy); @@ -1988,7 +1988,7 @@ TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduce3_CUDA) { tv1->split(1, tidx); tv1->split(0, tidy); TransformPropagator propagator(tv1); - MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv1).traverse(&propagator); tv1->axis(0)->parallelize(ParallelType::BIDy); tv1->axis(1)->parallelize(ParallelType::TIDy); @@ -2067,7 +2067,7 @@ TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduce4_CUDA) { tv2->split(-1, tidy); TransformPropagator propagator(tv2); - MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv2).traverse(&propagator); tv2->axis(2)->parallelize(ParallelType::Group); @@ -2154,7 +2154,7 @@ TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduceWelford1_CUDA) { tv1->split(1, tidx); tv1->split(0, tidy); TransformPropagator propagator(tv1); - MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv1).traverse(&propagator); tv1->axis(0)->parallelize(ParallelType::BIDy); tv1->axis(1)->parallelize(ParallelType::TIDy); @@ -2218,7 +2218,7 @@ TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduceWelford2_CUDA) { tv1->split(1, tidx); tv1->split(0, tidy); TransformPropagator propagator(tv1); - MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv1).traverse(&propagator); tv1->axis(0)->parallelize(ParallelType::BIDy); tv1->axis(1)->parallelize(ParallelType::TIDy); @@ -2334,7 +2334,7 @@ TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduceWelfordShmoo_CUDA) { reduction_scheduler_utils::sortAndRFactor(transform_ref); TransformPropagator propagator(transform_ref_rf); - MaxRootDomainInfoSpanningTree(transform_ref_rf).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(transform_ref_rf).traverse(&propagator); int vec_id = std::distance( transform_ref_rf->getLoopDomain().begin(), @@ -2537,7 +2537,7 @@ TEST_F(NVFuserTest, FusionCrossEntropyGatherPattern_CUDA) { tv4->split(0, bidx); tv4->split(0, 1); // unswitch TransformPropagator propagator(tv4); - MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv4).traverse(&propagator); tv4->axis(0)->parallelize(ParallelType::BIDy); tv4->axis(2)->parallelize(ParallelType::BIDx); diff --git a/tests/cpp/test_gpu_indexing.cpp b/tests/cpp/test_gpu_indexing.cpp index 6a54182bc22..1ac0ae159d1 100644 --- a/tests/cpp/test_gpu_indexing.cpp +++ b/tests/cpp/test_gpu_indexing.cpp @@ -746,7 +746,7 @@ TEST_F(NVFuserTest, FusionIndexing18_CUDA) { tv4->split(0, 4); auto tv5 = tv4->rFactor({1}); - MaxRootDomainInfoSpanningTree tree(tv5); + MaxLogicalDomainInfoSpanningTree tree(tv5); TransformPropagator tp(tv5); tree.traverse(&tp); diff --git a/tests/cpp/test_gpu_outer_reduction.cpp b/tests/cpp/test_gpu_outer_reduction.cpp index e5411d5aa09..050d4423eec 100644 --- a/tests/cpp/test_gpu_outer_reduction.cpp +++ b/tests/cpp/test_gpu_outer_reduction.cpp @@ -94,7 +94,7 @@ TEST_F(OuterReductionTest, GroupedGridWelfordOuterOpt) { auto ref_rf = ref->rFactor({-3}, {tvs.avg, tvs.var_sum, tvs.n}).at(0); TransformPropagator propagator(ref_rf); - MaxRootDomainInfoSpanningTree(ref_rf).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(ref_rf).traverse(&propagator); ref_rf->axis(1)->parallelize(ParallelType::BIDx); ref_rf->axis(2)->parallelize(ParallelType::TIDx); @@ -529,7 +529,7 @@ void scheduleNormalization(Fusion& fusion, const OuterReductionParams& params) { reduction_tv_rf->reorder(vec_reorder_map); TransformPropagator propagator(reduction_tv_rf); - MaxRootDomainInfoSpanningTree(reduction_tv_rf).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(reduction_tv_rf).traverse(&propagator); // Clear vectorization and unswitch as we want to selectively use // them diff --git a/tests/cpp/test_gpu_tensorcore.cpp b/tests/cpp/test_gpu_tensorcore.cpp index 46d93914079..77763d3887d 100644 --- a/tests/cpp/test_gpu_tensorcore.cpp +++ b/tests/cpp/test_gpu_tensorcore.cpp @@ -28,10 +28,10 @@ #include #include #include +#include #include #include #include -#include #include #include #include diff --git a/tests/cpp/test_gpu_transpose.cpp b/tests/cpp/test_gpu_transpose.cpp index 24fb03763e4..be13957c1c2 100644 --- a/tests/cpp/test_gpu_transpose.cpp +++ b/tests/cpp/test_gpu_transpose.cpp @@ -490,7 +490,7 @@ TEST_F(TransposeTest, FusionManualScheduleTransposeComplexDAG1) { // [BIDx, Unswitch, 32(N), 32(K)] // propagate to the entire DAG - MaxRootDomainInfoSpanningTree entire_dag(tv9); + MaxLogicalDomainInfoSpanningTree entire_dag(tv9); TransformPropagator tp(tv9); entire_dag.traverse(&tp); scheduler_utils::parallelizeAllLike(tv9); @@ -526,7 +526,7 @@ TEST_F(TransposeTest, FusionManualScheduleTransposeComplexDAG1) { auto all_tvs_except_ref1_set = std::unordered_set( all_tvs_except_ref1.begin(), all_tvs_except_ref1.end()); SetSelector selector(all_tvs_except_ref1_set); - MaxRootDomainInfoSpanningTree tree(tv10, &selector); + MaxLogicalDomainInfoSpanningTree tree(tv10, &selector); TransformPropagator tp(tv10); tree.traverse(&tp); scheduler_utils::parallelizeAllLike( @@ -555,7 +555,7 @@ TEST_F(TransposeTest, FusionManualScheduleTransposeComplexDAG1) { auto all_tvs_except2_set = std::unordered_set( all_tvs_except2.begin(), all_tvs_except2.end()); SetSelector selector(all_tvs_except2_set); - MaxRootDomainInfoSpanningTree tree(tv9, &selector); + MaxLogicalDomainInfoSpanningTree tree(tv9, &selector); TransformPropagator tp(tv9); tree.traverse(&tp); scheduler_utils::parallelizeAllLike( diff --git a/tests/cpp/test_gpu_view.cpp b/tests/cpp/test_gpu_view.cpp index 9604296e070..c76a75c961a 100644 --- a/tests/cpp/test_gpu_view.cpp +++ b/tests/cpp/test_gpu_view.cpp @@ -28,8 +28,8 @@ #include #include #include +#include #include -#include #include #include #include @@ -876,7 +876,7 @@ TEST_F(GpuViewTest, FusionFlattenAfterUnsqueezeOutput) { testValidate(&fusion, outputs, aten_inputs, __LINE__, __FILE__); } -TEST_F(GpuViewTest, FusionComputeAtRootDomainMapWithView) { +TEST_F(GpuViewTest, FusionComputeAtLogicalDomainMapWithView) { Fusion fusion; FusionGuard fg(&fusion); @@ -899,17 +899,17 @@ TEST_F(GpuViewTest, FusionComputeAtRootDomainMapWithView) { auto tv5 = add(tv3, tv4); fusion.addOutput(tv5); - ComputeAtRootDomainMap map; + ComputeAtLogicalDomainMap map; map.build(); // It's not possible to compute tv1 at the -1 position of - // t2. ComputeAtRootDomainMap should tell that by not mapping the + // t2. ComputeAtLogicalDomainMap should tell that by not mapping the // second axis. auto tv1_tv2_mappable_dims = map.getMappableDims(tv1->domain(), tv2->domain()); NVF_CHECK( tv1_tv2_mappable_dims.find(tv1->axis(1)) == tv1_tv2_mappable_dims.end(), - "Invalid ComputeAtRootDomainMap. Domain should not be mappable: ", + "Invalid ComputeAtLogicalDomainMap. Domain should not be mappable: ", tv1->axis(1)->toString()); } @@ -1344,7 +1344,7 @@ TEST_F(GpuViewTest, FusionPwiseViewSchedule) { { TransformPropagator propagator(tv4); - MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv4).traverse(&propagator); } for (auto i : c10::irange(tv5->nDims() - 1)) { @@ -1359,7 +1359,7 @@ TEST_F(GpuViewTest, FusionPwiseViewSchedule) { { TransformPropagator propagator(tv5); - MaxRootDomainInfoSpanningTree spanning_tree(tv5); + MaxLogicalDomainInfoSpanningTree spanning_tree(tv5); spanning_tree.traverse(&propagator); scheduler_utils::parallelizeAllLike(tv5); @@ -1407,7 +1407,7 @@ TEST_F(GpuViewTest, FusionSumViewSchedule) { { TransformPropagator propagator(tv4); - MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv4).traverse(&propagator); } tv5->split(1, 128); @@ -1420,7 +1420,7 @@ TEST_F(GpuViewTest, FusionSumViewSchedule) { { TransformPropagator propagator(tv5_rf); - MaxRootDomainInfoSpanningTree spanning_tree(tv5_rf); + MaxLogicalDomainInfoSpanningTree spanning_tree(tv5_rf); spanning_tree.traverse(&propagator); scheduler_utils::parallelizeAllLike(tv5_rf); @@ -1939,7 +1939,7 @@ TEST_F(GpuViewTest, FusionReshapeMapping) { tv6->axis(2)->parallelize(ParallelType::TIDx); TransformPropagator propagator(tv6); - MaxRootDomainInfoSpanningTree spanning_tree(tv6); + MaxLogicalDomainInfoSpanningTree spanning_tree(tv6); spanning_tree.traverse(&propagator); scheduler_utils::parallelizeAllLike(tv6); @@ -1975,7 +1975,7 @@ TEST_F(GpuViewTest, FusionLowerDivisibleSplits) { tv2->merge(0)->merge(0)->merge(0)->split(0, 4)->split(0, 8, false); TransformPropagator propagator(tv2); - MaxRootDomainInfoSpanningTree spanning_tree(tv2); + MaxLogicalDomainInfoSpanningTree spanning_tree(tv2); spanning_tree.traverse(&propagator); scheduler_utils::parallelizeAllLike(tv2); diff --git a/tests/cpp/test_id_model.cpp b/tests/cpp/test_id_model.cpp index f84c7f00a1b..56f3288633f 100644 --- a/tests/cpp/test_id_model.cpp +++ b/tests/cpp/test_id_model.cpp @@ -284,7 +284,7 @@ void checkStep2Results(Fusion* fusion, const IdModelTester& tester) { ASSERT_EQ(consumers.size(), 1) << "Assumed to have one consumer"; TensorView* c_tv = consumers.at(0); const auto p2c = BestEffortReplay::replayCasP( - c_tv, tv, -1, PairwiseRootDomainMap(tv, c_tv)) + c_tv, tv, -1, PairwiseLogicalDomainMap(tv, c_tv)) .getReplay(); for (auto p_id : ir_utils::allIDsOf(tv)) { @@ -473,7 +473,7 @@ std::unique_ptr createFusionWithMultipleResolutionPaths() { // tv10[7*11*13//5//3, 3, 5] TransformPropagatorWithCheck propagator(tv10); - MaxRootDomainInfoSpanningTree(tv10).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv10).traverse(&propagator); std::vector tensors_to_inline{tv1, tv2, tv4, tv6, tv8}; for (auto tensor : tensors_to_inline) { @@ -553,7 +553,7 @@ TEST_F(IdModelTest, ValGraphStmtSort1) { // tensors. tv2->merge(0)->split(0, 4); TransformPropagator propagator(tv2); - MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv2).traverse(&propagator); // The exact graph should just map all IDs of the tensors. Ther // ordering of the exprs should be the merge and then the split. @@ -876,7 +876,7 @@ TEST_F(IdModelTest, LoopPromotion3) { tv3->merge(1); TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); tv2->inlineAt(1); @@ -951,7 +951,7 @@ TEST_F(IdModelTest, LoopPromotion4) { // [4, i0*i1/4] TransformPropagator propagator(tv4); - MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv4).traverse(&propagator); for (auto tv : ir_utils::allTvs(&fusion)) { tv->inlineAt(-2); @@ -1551,7 +1551,7 @@ TEST_F(IdModelTest, LoopPromotion7) { tv4->split(0, 32); TransformPropagatorWithCheck propagator(tv4); - MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv4).traverse(&propagator); tv2->inlineAt(1); tv3->inlineAt(1); @@ -1685,7 +1685,7 @@ TEST_F(IdModelTest, LoopPromotion8) { // [3, 3*5//2] TransformPropagatorWithCheck propagator(tv4); - MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv4).traverse(&propagator); tv1->inlineAt(1); tv2->inlineAt(1); @@ -1886,7 +1886,7 @@ TEST_F(IdModelTest, LoopPromotionPromoteToSameLoopGroup) { tv4->merge(1, 2); TransformPropagatorWithCheck propagator(tv4); - MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv4).traverse(&propagator); for (auto tv : {tv0, tv1, tv2, tv3}) { tv->inlineAt(1); @@ -1990,7 +1990,7 @@ TEST_F(IdModelTest, LoopPromotionTwoStepFailureReproSimple) { t4->merge(-2, -1)->merge(-2, -1)->merge(-2, -1)->merge(-2, -1)->split(0, 4); TransformPropagatorWithCheck propagator(t4); - MaxRootDomainInfoSpanningTree(t4).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(t4).traverse(&propagator); for (auto tv : ir_utils::allTvs(&fusion)) { tv->inlineAt(1); @@ -2403,7 +2403,7 @@ TEST_F(IdModelTest, LoopGraphWithSibling) { avg->merge(0); avg->split(0, 8); TransformPropagatorWithCheck propagator(avg); - MaxRootDomainInfoSpanningTree(avg).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(avg).traverse(&propagator); IdModel id_model(&fusion); const auto& loop_graph = id_model.idGraph(IdMappingMode::LOOP); @@ -2545,7 +2545,7 @@ TEST_F(IdModelTest, LoopPromotionCoverage) { // there is only one loop group. tv10->flatten(); TransformPropagatorWithCheck propagator(tv10); - MaxRootDomainInfoSpanningTree(tv10).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv10).traverse(&propagator); inlineMost(); IdModel id_model(&fusion); @@ -2598,7 +2598,7 @@ TEST_F(IdModelTest, ParallelTypePropagation) { tv2->split(0, 4); TransformPropagatorWithCheck propagator(tv2); - MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv2).traverse(&propagator); inlineMost(); diff --git a/tests/cpp/test_indexing.cpp b/tests/cpp/test_indexing.cpp index 682e1577fa0..1ecfd2ae382 100644 --- a/tests/cpp/test_indexing.cpp +++ b/tests/cpp/test_indexing.cpp @@ -250,7 +250,7 @@ TEST_F(IndexingTest, SimplePointwise1) { tv2->split(0, 4); TransformPropagator propagator(tv2); - MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv2).traverse(&propagator); tv1->inlineAt(1); @@ -340,7 +340,7 @@ TEST_F(IndexingTest, SimplePointwise2) { tv3->split(0, 4); TransformPropagator propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); tv3->axis(0)->parallelize(ParallelType::BIDx); tv3->axis(1)->parallelize(ParallelType::TIDx); @@ -579,7 +579,7 @@ TEST_F(IndexingTest, Reshape) { fusion.addOutput(tv5); TransformPropagator propagator(tv5); - MaxRootDomainInfoSpanningTree(tv5).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv5).traverse(&propagator); inlineMost(); @@ -695,7 +695,7 @@ TEST_F(IndexingTest, SimpleBroadcast2) { tv2->split(0, 4); TransformPropagator propagator(tv2); - MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv2).traverse(&propagator); // The first merge of the logical domains should be a trivial merge, // i.e., a merge with a extent-one domain. Thus, the indexing @@ -756,7 +756,7 @@ TEST_F(IndexingTest, SimpleBroadcast3) { tv3->flatten(); TransformPropagator propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); inlineMost(); @@ -826,7 +826,7 @@ TEST_F(IndexingTest, SimpleBroadcast4) { // [4, i0*i1/4] TransformPropagator propagator(tv4); - MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv4).traverse(&propagator); for (auto tv : ir_utils::allTvs(&fusion)) { tv->inlineAt(-2); @@ -940,7 +940,7 @@ TEST_F(IndexingTest, MultiDevice2D) { tv1->split(0, num_devices, false); TransformPropagator propagator(tv1); - MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv1).traverse(&propagator); tv0->axis(0)->parallelize(ParallelType::DIDx); tv1->axis(0)->parallelize(ParallelType::DIDx); @@ -983,7 +983,7 @@ TEST_F(IndexingTest, MultiDevice2DLeafAllocation) { tv1->split(0, num_devices, false); TransformPropagator propagator(tv1); - MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv1).traverse(&propagator); tv0->axis(0)->parallelize(ParallelType::DIDx); tv1->axis(0)->parallelize(ParallelType::DIDx); @@ -1124,7 +1124,7 @@ TEST_F(IndexingTest, SimpleVectorize) { tv2->axis(2)->parallelize(ParallelType::Vectorize); TransformPropagator propagator(tv2); - MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv2).traverse(&propagator); inlineMost(); @@ -1193,7 +1193,7 @@ TEST_F(IndexingTest, NonInnermostVectorize) { tv3->reorder({{-1, -2}}); TransformPropagator propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); tv3->axis(0)->parallelize(ParallelType::BIDx); tv3->axis(1)->parallelize(ParallelType::TIDx); @@ -1258,7 +1258,7 @@ TEST_F(IndexingTest, AlmostExactTraversalWithNonOneBroadcast) { tv3->merge(1); tv3->split(1, 5); - MaxRootDomainInfoSpanningTree tree(tv3); + MaxLogicalDomainInfoSpanningTree tree(tv3); TransformPropagator tp(tv3); tree.traverse(&tp); @@ -1369,7 +1369,7 @@ TEST_F(IndexingTest, SimpleUnroll) { tv2->split(0, 4); TransformPropagator propagator(tv2); - MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv2).traverse(&propagator); inlineMost(); @@ -1426,7 +1426,7 @@ TEST_F(IndexingTest, InlinedUnroll) { tv4->split(0, 1); TransformPropagator propagator(tv4); - MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv4).traverse(&propagator); inlineMost(); @@ -1675,7 +1675,7 @@ TEST_F(IndexingTest, DoubleBuffering1) { tv3->split(-1, 128); tv3->split(-1, 32); TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); tv1->inlineAt(-2); tv2->inlineAt(-2); @@ -1782,7 +1782,7 @@ TEST_F(IndexingTest, DoubleBuffering4) { tv3->split(-1, 32); tv3->split(-1, 8); TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); tv0->computeAt(tv3, 2); tv2->computeAt(tv3, -1); @@ -1889,7 +1889,7 @@ TEST_F(IndexingTest, DoubleBuffering6) { tv3->split(-2, 4); tv3->split(-2, 2); TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); tv0->computeAt(tv3, 1); tv2->computeAt(tv3, -1); @@ -2036,7 +2036,7 @@ TEST_F(IndexingTest, CircularBuffering1) { tv3->split(-1, 128); tv3->split(-1, 32); TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); tv1->inlineAt(-2); tv2->inlineAt(-2); @@ -2156,7 +2156,7 @@ TEST_F(IndexingTest, CircularBuffering2) { tv3->split(-2, 4); tv3->split(-2, 2); TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); tv0->computeAt(tv3, 1); tv2->computeAt(tv3, -1); diff --git a/tests/cpp/test_memory.cpp b/tests/cpp/test_memory.cpp index 877fd0e51ec..533fcd74d12 100644 --- a/tests/cpp/test_memory.cpp +++ b/tests/cpp/test_memory.cpp @@ -61,7 +61,7 @@ TEST_P(MemoryTest, LoadCache) { tv1->split(0, 4); tv1->split(0, 32); TransformPropagatorWithCheck propagator(tv1); - MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv1).traverse(&propagator); // Parallelize LoadStoreOps. Other TensorViews don't support vectorization. tv1->axis(0)->parallelize(ParallelType::BIDx); @@ -132,7 +132,7 @@ TEST_F(MemoryTest, RefineCachePolicy) { tv_a2->split(0, 4); tv_a2->split(0, 32); TransformPropagatorWithCheck propagator(tv_a2); - MaxRootDomainInfoSpanningTree(tv_a2).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv_a2).traverse(&propagator); tv_a2->axis(0)->parallelize(ParallelType::BIDx); tv_a2->axis(1)->parallelize(ParallelType::TIDx); diff --git a/tests/cpp/test_multidevice_pipeline.cpp b/tests/cpp/test_multidevice_pipeline.cpp index 8c2116d51cf..8c47eaabc1d 100644 --- a/tests/cpp/test_multidevice_pipeline.cpp +++ b/tests/cpp/test_multidevice_pipeline.cpp @@ -23,8 +23,8 @@ #include #include #include +#include #include -#include #include #include #include diff --git a/tests/cpp/test_persistent_buffer.cpp b/tests/cpp/test_persistent_buffer.cpp index 9be416a0a3d..3809117af01 100644 --- a/tests/cpp/test_persistent_buffer.cpp +++ b/tests/cpp/test_persistent_buffer.cpp @@ -8,8 +8,8 @@ #include #include +#include #include -#include #include #include #include diff --git a/tests/cpp/test_predicate_elimination.cpp b/tests/cpp/test_predicate_elimination.cpp index 4d8e60c0cf8..f61c27d9f42 100644 --- a/tests/cpp/test_predicate_elimination.cpp +++ b/tests/cpp/test_predicate_elimination.cpp @@ -102,7 +102,7 @@ TEST_F(PredicateEliminationTest, 3) { tv1->split(0, 10); tv1->split(0, 33); TransformPropagatorWithCheck propagator(tv1); - MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv1).traverse(&propagator); auto tv4 = tv1->rFactor({-1}); auto tv5 = tv1->rFactor({-1}); @@ -157,7 +157,7 @@ TEST_F(PredicateEliminationTest, 4) { tv1->split(0, 11); tv1->reorder({{1, 2}, {2, 1}}); TransformPropagatorWithCheck propagator(tv1); - MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv1).traverse(&propagator); tv1->axis(0)->parallelize(ParallelType::TIDy); tv1->axis(1)->parallelize(ParallelType::TIDx); @@ -207,7 +207,7 @@ TEST_F(PredicateEliminationTest, 5) { tvs2.avg->split(0, 4); TransformPropagatorWithCheck propagator(tvs2.avg); - MaxRootDomainInfoSpanningTree(tvs2.avg).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tvs2.avg).traverse(&propagator); auto avg_rf = ir_utils::rFactorHelper(tvs2.avg, {1}); avg_rf->axis(0)->parallelize(ParallelType::TIDx); @@ -253,7 +253,7 @@ TEST_F(PredicateEliminationTest, 6) { tv4->split(1, 5); TransformPropagatorWithCheck propagator(tv4); - MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv4).traverse(&propagator); tv4->reorder({{0, 1}, {1, 0}}); tv3->computeAt(tv4, 1); @@ -300,7 +300,7 @@ TEST_F(PredicateEliminationTest, 7) { tv3->split(-1, 4); tv3->split(-1, 3); TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); tv0->computeAt(tv3, 1); diff --git a/tests/cpp/test_resize.cpp b/tests/cpp/test_resize.cpp index c36f3901847..3a2e56eba9a 100644 --- a/tests/cpp/test_resize.cpp +++ b/tests/cpp/test_resize.cpp @@ -107,7 +107,7 @@ TEST_F(ResizeTest, Pad3) { tv4->split(0, 32); TransformPropagator propagator(tv4); - MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv4).traverse(&propagator); inlineMost(); @@ -228,7 +228,7 @@ TEST_F(ResizeTest, Pad6) { tv4->split(0, 32); TransformPropagator propagator(tv4); - MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv4).traverse(&propagator); inlineMost(); @@ -269,7 +269,7 @@ TEST_F(ResizeTest, Pad7) { tv3->reorder({{1, 2}}); TransformPropagator propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); inlineMost(); @@ -315,7 +315,7 @@ TEST_F(ResizeTest, Pad8) { tv4->split(0, 128); TransformPropagator propagator(tv4); - MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv4).traverse(&propagator); inlineMost(); @@ -590,7 +590,7 @@ TEST_F(ResizeTest, Cat3) { tv2->split(0, 4); TransformPropagator propagator(tv2); - MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv2).traverse(&propagator); inlineMost(); @@ -631,7 +631,7 @@ TEST_F(ResizeTest, Cat4) { tv2->split(0, 128); TransformPropagator propagator(tv2); - MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv2).traverse(&propagator); tv2->axis(0)->parallelize(ParallelType::BIDx); tv2->axis(1)->parallelize(ParallelType::TIDx); @@ -674,7 +674,7 @@ TEST_F(ResizeTest, Cat5) { tv4->split(0, 128); TransformPropagator propagator(tv4); - MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv4).traverse(&propagator); inlineMost(); @@ -722,7 +722,7 @@ TEST_F(ResizeTest, Cat6) { tv3->merge(0); tv3->split(0, 4); TransformPropagator propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); inlineMost(); @@ -773,7 +773,7 @@ TEST_F(ResizeTest, Cat7) { concat_tv->split(0, 128); TransformPropagator propagator(concat_tv); - MaxRootDomainInfoSpanningTree(concat_tv).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(concat_tv).traverse(&propagator); inlineMost(); @@ -1021,7 +1021,8 @@ TEST_F(ResizeTest, Slice4) { tv5->setMemoryType(MemoryType::Global); SetSelector tv5_rf_selector({tv1, tv3, tv5, tv5_cache}); TransformPropagator tv5_rf_tp(tv5_rf); - MaxRootDomainInfoSpanningTree(tv5_rf, &tv5_rf_selector).traverse(&tv5_rf_tp); + MaxLogicalDomainInfoSpanningTree(tv5_rf, &tv5_rf_selector) + .traverse(&tv5_rf_tp); inlineMost(std::vector{tv1, tv3, tv5_rf}); tv5_rf->axis(0)->parallelize(ParallelType::BIDx); tv5_rf->axis(1)->parallelize(ParallelType::TIDx); @@ -1034,7 +1035,8 @@ TEST_F(ResizeTest, Slice4) { tv6->setMemoryType(MemoryType::Global); SetSelector tv6_rf_selector({tv2, tv4, tv6, tv6_cache}); TransformPropagator tv6_rf_tp(tv6_rf); - MaxRootDomainInfoSpanningTree(tv6_rf, &tv6_rf_selector).traverse(&tv6_rf_tp); + MaxLogicalDomainInfoSpanningTree(tv6_rf, &tv6_rf_selector) + .traverse(&tv6_rf_tp); inlineMost(std::vector{tv2, tv4, tv6_rf}); tv6_rf->axis(0)->parallelize(ParallelType::BIDx); tv6_rf->axis(1)->parallelize(ParallelType::TIDx); @@ -1096,7 +1098,7 @@ TEST_F(ResizeTest, Slice5) { // tv1 to tv3 through tv0, which should work as both tensors are // sliced in the same way. TransformPropagator propagator(tv2); - MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv2).traverse(&propagator); inlineMost(); @@ -2640,7 +2642,7 @@ TEST_F(ResizeTest, Slice1DVectorizeManual2) { tv1->split(0, 128); TransformPropagator propagator(tv1); - MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv1).traverse(&propagator); tv1->axis(0)->parallelize(ParallelType::BIDx); tv1->axis(1)->parallelize(ParallelType::TIDx); @@ -2691,7 +2693,7 @@ TEST_F(ResizeTest, Slice1DVectorizeManual3) { tv1->split(0, 128); TransformPropagator propagator(tv1); - MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv1).traverse(&propagator); tv1->axis(0)->parallelize(ParallelType::BIDx); tv1->axis(1)->parallelize(ParallelType::TIDx); @@ -2940,7 +2942,7 @@ TEST_F(ResizeTest, SliceAndReshapeRepro540Manual) { tv4->reorder({{1, -1}}); TransformPropagator propagator(tv4); - MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv4).traverse(&propagator); tv4->axis(0)->parallelize(ParallelType::BIDx); tv4->axis(1)->parallelize(ParallelType::Unswitch); diff --git a/tests/cpp/test_scalar_hoisting.cpp b/tests/cpp/test_scalar_hoisting.cpp index 5857e33d10f..4993d043919 100644 --- a/tests/cpp/test_scalar_hoisting.cpp +++ b/tests/cpp/test_scalar_hoisting.cpp @@ -242,7 +242,7 @@ TEST_F(ScalarHoistTest, IndexHoist2) { tv5->split(-1, 4); TransformPropagatorWithCheck propagator(tv5); - MaxRootDomainInfoSpanningTree(tv5).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv5).traverse(&propagator); tv4->split(-1, 3); diff --git a/tests/cpp/test_sdpa_node.cpp b/tests/cpp/test_sdpa_node.cpp index 19fe86ce60c..f0bd18ec682 100644 --- a/tests/cpp/test_sdpa_node.cpp +++ b/tests/cpp/test_sdpa_node.cpp @@ -220,7 +220,7 @@ TEST_F(SDPATest, CausalAttn) { validateSdpaFwdOutputs(nvf_out, aten_out); } -TEST_F(SDPATest, PairwiseRootDomainMap) { +TEST_F(SDPATest, PairwiseLogicalDomainMap) { NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); auto fusion = std::make_unique(); @@ -260,7 +260,7 @@ TEST_F(SDPATest, PairwiseRootDomainMap) { for (Val* consumer : fusion->outputs()) { auto consumer_tv = consumer->as(); - auto pairwise_map = PairwiseRootDomainMap(producer_tv, consumer_tv) + auto pairwise_map = PairwiseLogicalDomainMap(producer_tv, consumer_tv) .mapProducerToConsumer(); auto mappingExists = [&pairwise_map]( IterDomain* p_id, IterDomain* c_id) -> bool { diff --git a/tests/cpp/test_serial_gridreduce.cpp b/tests/cpp/test_serial_gridreduce.cpp index 171690d4169..732858e9fb5 100644 --- a/tests/cpp/test_serial_gridreduce.cpp +++ b/tests/cpp/test_serial_gridreduce.cpp @@ -103,7 +103,7 @@ TEST_F(SerialGridReductionTest, Scheduling) { }); TransformPropagator propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv3).traverse(&propagator); scheduler_utils::parallelizeAllLike(tv3); // Here we just transpose A and B in tv2, so that it will be partially diff --git a/tests/cpp/test_swizzle.cpp b/tests/cpp/test_swizzle.cpp index a71ddf82a0c..9cfd5caac97 100644 --- a/tests/cpp/test_swizzle.cpp +++ b/tests/cpp/test_swizzle.cpp @@ -185,7 +185,7 @@ TEST_F(SwizzleTest, SwizzleMapping) { tv1->computeAt(tv2, -1); // Check BestEffortReplay behavior with skip swizzles option on. - PairwiseRootDomainMap root_map(tv1, tv2); + PairwiseLogicalDomainMap root_map(tv1, tv2); // Check producer to consumer map, // i.e. unswizzled tensor to swizzled tensor map @@ -641,7 +641,7 @@ TEST_F(SwizzleTest, TransformPropagatorSkipSwizzleOnTarget) { tv0->merge(0); TransformPropagatorWithCheck propagator(tv0); - MaxRootDomainInfoSpanningTree(tv0).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(tv0).traverse(&propagator); auto exprs = StmtSort::getExprsBetween( {tv1->getLogicalDomain().begin(), tv1->getLogicalDomain().end()}, diff --git a/tests/cpp/test_translate_mma.cpp b/tests/cpp/test_translate_mma.cpp index 20abd530a1b..c297c5f7aa2 100644 --- a/tests/cpp/test_translate_mma.cpp +++ b/tests/cpp/test_translate_mma.cpp @@ -24,12 +24,12 @@ #include #include #include +#include #include #include #include #include #include -#include #include #include #include diff --git a/tests/cpp/test_tutorial.cpp b/tests/cpp/test_tutorial.cpp index 1249d9ebb38..68e6059a20d 100644 --- a/tests/cpp/test_tutorial.cpp +++ b/tests/cpp/test_tutorial.cpp @@ -569,7 +569,7 @@ TEST_F(Tutorial, Reshape) { // Here's how we propagate the transformations of reshape_output // to all other tensors in the fusion TransformPropagatorWithCheck propagator(reshape_output); - MaxRootDomainInfoSpanningTree(reshape_output).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(reshape_output).traverse(&propagator); // Now, all tensors, including those before the reshape op, should // be transformed to 2D tensors with an inner domain of extent @@ -1308,7 +1308,7 @@ TEST_F(Tutorial, VectorizeStorePointwiseTMA) { // Transform Operations between cache operations and output reference TransformPropagator propagator(reference_tv); - MaxRootDomainInfoSpanningTree(reference_tv).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(reference_tv).traverse(&propagator); // Propagate common parallel dimensions reference_tv->axis(1)->parallelize(ParallelType::BIDx); @@ -1409,7 +1409,7 @@ TEST_F(Tutorial, PointwiseBroadcastTMA) { // Transform Operations between cache operations and output reference TransformPropagator propagator(reference_tv); - MaxRootDomainInfoSpanningTree(reference_tv).traverse(&propagator); + MaxLogicalDomainInfoSpanningTree(reference_tv).traverse(&propagator); // Define Parallelization Schema // Intermediate Tensors