From f669fcf78b5c5dee6c08715aeb6f2e36a6af964b Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Tue, 3 Sep 2024 00:51:14 -0400 Subject: [PATCH] Cleanup all uses of ir_utils::allTvs (#2884) Follow up to https://github.com/NVIDIA/Fuser/pull/2873 removes all uses of the ir_utils variant in favor of the Fusion variant. --------- Co-authored-by: Jacob Hinkle --- csrc/compute_at_map.cpp | 6 ++-- .../device_lower/analysis/divisible_split.cpp | 2 +- .../analysis/thread_predicate.cpp | 2 +- csrc/device_lower/lower2device.cpp | 2 +- csrc/device_lower/pass/expr_sort.cpp | 2 +- csrc/device_lower/pass/loops.cpp | 4 +-- csrc/device_lower/validation.cpp | 4 +-- csrc/evaluator_common.cpp | 2 +- csrc/fusion.cpp | 27 +++++++++++++++-- csrc/fusion_segmenter.cpp | 10 +++---- csrc/id_model/validation_utils.cpp | 2 +- csrc/inlining.cpp | 4 +-- csrc/ir/graphviz.cpp | 2 +- csrc/ir/utils.cpp | 23 ++------------ csrc/ir/utils.h | 3 -- csrc/kernel_cache.cpp | 4 +-- csrc/multidevice/utils.cpp | 4 +-- csrc/preseg_passes/mark_aliases_prepare.cpp | 4 +-- csrc/preseg_passes/propagate_shardings.cpp | 2 +- csrc/python_frontend/fusion_definition.cpp | 4 +-- csrc/scheduler/normalization_inner_outer.cpp | 2 +- csrc/scheduler/pointwise.cpp | 2 +- csrc/scheduler/registry.cpp | 2 +- csrc/scheduler/registry_utils.cpp | 2 +- csrc/scheduler/utils.cpp | 18 +++++------ tests/cpp/test_gpu3.cpp | 12 ++++---- tests/cpp/test_gpu_compute_with.cpp | 4 +-- tests/cpp/test_gpu_fused_reduction.cpp | 5 ++-- tests/cpp/test_gpu_outer_reduction.cpp | 5 ++-- tests/cpp/test_gpu_utils.cpp | 2 +- tests/cpp/test_id_model.cpp | 30 +++++++++---------- tests/cpp/test_indexing.cpp | 16 +++++----- tests/cpp/test_matmul.cpp | 6 ++-- tests/cpp/test_scatter_gather.cpp | 2 +- tests/cpp/utils.h | 2 +- 35 files changed, 110 insertions(+), 113 deletions(-) diff --git a/csrc/compute_at_map.cpp b/csrc/compute_at_map.cpp index 0176ec7563a..3b159e696c5 100644 --- a/csrc/compute_at_map.cpp +++ b/csrc/compute_at_map.cpp @@ -292,7 +292,7 @@ std::optional> detectMappablePair( // matter in practice. std::optional> findFirstSelfMapping(Fusion* fusion, const IterDomainGraph& id_graph) { - for (auto tv : ir_utils::allTvs(fusion)) { + for (auto tv : fusion->allTvs()) { // For each tensor, make sure root, logical and loop domains // should not include domains that are mapped with another domain // in the same set of domains. This may be overly conservative, @@ -342,7 +342,7 @@ void IterDomainGraph::build(Fusion* fusion) { FusionGuard fg(fusion); // Initialize a node for every iteration domain - for (auto tv : ir_utils::allTvs(fusion)) { + for (auto tv : fusion->allTvs()) { const auto& domain = tv->getLoopDomain(); auto all_ids = tv->domain()->allIDs(); @@ -586,7 +586,7 @@ void IterDomainGraph::build(Fusion* fusion) { // transformations makes it easy to check if different view operations are // consistent with eachother. - auto all_tvs = ir_utils::allTvs(fusion); + auto all_tvs = fusion->allTvs(); std::vector all_consumer_tvs; std::copy_if( all_tvs.begin(), diff --git a/csrc/device_lower/analysis/divisible_split.cpp b/csrc/device_lower/analysis/divisible_split.cpp index cbf251e5d35..a4844d8c388 100644 --- a/csrc/device_lower/analysis/divisible_split.cpp +++ b/csrc/device_lower/analysis/divisible_split.cpp @@ -25,7 +25,7 @@ std::unordered_set getAllDivisibleSplits( const ComputeAtMap* ca_map) { std::unordered_set all_divisible_splits; - auto all_tvs = ir_utils::allTvs(fusion); + auto all_tvs = fusion->allTvs(); // Find all tensor views with a view like rfactor. Splits used in view // transformations must be divisible by definition. for (auto tv : all_tvs) { diff --git a/csrc/device_lower/analysis/thread_predicate.cpp b/csrc/device_lower/analysis/thread_predicate.cpp index 408140590b1..0c9b4413c9b 100644 --- a/csrc/device_lower/analysis/thread_predicate.cpp +++ b/csrc/device_lower/analysis/thread_predicate.cpp @@ -734,7 +734,7 @@ void ThreadPredicateMap::build(Fusion* fusion) { updateBitSet(expr); } - for (auto tv : ir_utils::allTvs(fusion)) { + for (auto tv : fusion->allTvs()) { if (tv->getMemoryType() == MemoryType::Global) { avoidConcretizedBroadcastRedundantWrite(tv); } diff --git a/csrc/device_lower/lower2device.cpp b/csrc/device_lower/lower2device.cpp index 03d5834c0e8..3aa025b4f8e 100644 --- a/csrc/device_lower/lower2device.cpp +++ b/csrc/device_lower/lower2device.cpp @@ -345,7 +345,7 @@ bool requiresIdModel(Fusion* fusion) { } // If a tensor does not have a nice root->logical/allocation->loop // linear transformation history, use IdModel. - for (auto tv : ir_utils::allTvs(fusion)) { + for (auto tv : fusion->allTvs()) { if (!lower_utils::hasRootToLoopLinearTransformations(tv)) { return true; } diff --git a/csrc/device_lower/pass/expr_sort.cpp b/csrc/device_lower/pass/expr_sort.cpp index bb15739af23..b7144d3577d 100644 --- a/csrc/device_lower/pass/expr_sort.cpp +++ b/csrc/device_lower/pass/expr_sort.cpp @@ -1142,7 +1142,7 @@ void ExprSegmentationSorter::initializeForLoopDependencies() { concrete_id_dependencies_.empty(), "For loop dependencies have already been initialized."); - for (auto tv : ir_utils::allTvs(fusion_)) { + for (auto tv : fusion_->allTvs()) { std::unordered_set dependencies; for (int64_t tv_id_i = std::max( tv->getMaxProducerPosition(), diff --git a/csrc/device_lower/pass/loops.cpp b/csrc/device_lower/pass/loops.cpp index bd8c1a60271..30e30e063bb 100644 --- a/csrc/device_lower/pass/loops.cpp +++ b/csrc/device_lower/pass/loops.cpp @@ -145,7 +145,7 @@ void LoopNestGenerator::generate(const std::vector& exprs) { std::unordered_map> concrete_id_dependencies; - for (auto tv : ir_utils::allTvs(FusionGuard::getCurFusion())) { + for (auto tv : FusionGuard::getCurFusion()->allTvs()) { std::unordered_set dependencies; for (auto tv_id : tv->getLoopDomain()) { @@ -212,7 +212,7 @@ void LoopNestGenerator::generate(const std::vector& exprs) { } // Generate loop structure for each tensor view - for (auto tv : ir_utils::allTvs(FusionGuard::getCurFusion())) { + for (auto tv : FusionGuard::getCurFusion()->allTvs()) { // Zero dim tensor support if (tv->nDims() == 0) { loop_structures_[tv] = std::vector(); diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index 5091e655556..2b465405406 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -189,7 +189,7 @@ void validateIr(Fusion* fusion) { "Tensor with dynamic transform must be concretized before lowering: ", toDelimitedString(dynamic_tvs.begin(), dynamic_tvs.end())); - auto all_tvs = ir_utils::allTvs(fusion); + auto all_tvs = fusion->allTvs(); validateCpAsyncBulk(all_tvs); } @@ -912,7 +912,7 @@ void validateSwizzle(Fusion* fusion) { } void validateAndConvertIterDomainGrouping(Fusion* fusion) { - for (auto tv : ir_utils::allTvs(fusion)) { + for (auto tv : fusion->allTvs()) { bool is_grouped = false; for (const auto id_idx : c10::irange(tv->nDims())) { const auto id = tv->axis(id_idx); diff --git a/csrc/evaluator_common.cpp b/csrc/evaluator_common.cpp index c7a45c79359..56d9c868f90 100644 --- a/csrc/evaluator_common.cpp +++ b/csrc/evaluator_common.cpp @@ -98,7 +98,7 @@ void collectBufferSizes( std::vector collectRuntimeUsedValues(Fusion* fusion) { std::vector ret; - auto all_tvs = ir_utils::allTvs(fusion); + auto all_tvs = fusion->allTvs(); // Collect extent and inputs for (auto tv : all_tvs) { for (auto id : tv->getLoopDomain()) { diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index bd864c9f881..222a3b1afb6 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -871,10 +871,33 @@ bool isExpressionEvaluated(Fusion* fusion) { }); } +namespace { +std::vector findAllTvs(Fusion* fusion) { + auto used_vals = fusion->usedMathVals(); + auto used_tvs = ir_utils::filterByType(used_vals); + + // This shouldn't be necessary but FusionSegmentIoAlias_CUDA due to aliasing + // is having an input disconnected from outputs, and these iter domains are + // being checked in compute at maps in scheduling logic. This shouldn't hurt + // AFAICT. + auto tv_inputs = ir_utils::filterByType(fusion->inputs()); + + std::vector all_tvs({used_tvs.begin(), used_tvs.end()}); + // Sometimes inputs are not connected to outputs, however, we still include + // them when returning allTvs because they are registered as an input. + all_tvs.insert(all_tvs.end(), tv_inputs.begin(), tv_inputs.end()); + + VectorOfUniqueEntries unique_vector( + all_tvs.begin(), all_tvs.end()); + + // all_tvs has duplicates, to deduplicate it and return + return unique_vector.vector(); +} +} // namespace + std::vector Fusion::allTvs() { if (all_tvs_ptr_ == nullptr) { - all_tvs_ptr_ = - std::make_unique>(ir_utils::allTvs(this)); + all_tvs_ptr_ = std::make_unique>(findAllTvs(this)); } return std::vector(*all_tvs_ptr_); } diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index ec10307bbe3..8ef0319b12c 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -2370,7 +2370,7 @@ class FusionSegmentGuard : public NonCopyable { NVF_ERROR(fusion_ != nullptr); #ifndef NDEBUG num_original_exprs_ = fusion_->exprs().size(); - original_tvs_ = ir_utils::allTvs(fusion_); + original_tvs_ = fusion_->allTvs(); #endif // NDEBUG narrowToNewSegment(inputs, outputs); } @@ -2382,7 +2382,7 @@ class FusionSegmentGuard : public NonCopyable { FUSER_PERF_SCOPE("Segmenter::FusionSegmentGuard"); #ifndef NDEBUG num_original_exprs_ = fusion_->exprs().size(); - original_tvs_ = ir_utils::allTvs(fusion_); + original_tvs_ = fusion_->allTvs(); #endif // NDEBUG lowered_edges_ = segmented_fusion_->castInputOutputToLowerPrecision( segmented_fusion_->edges()); @@ -2398,7 +2398,7 @@ class FusionSegmentGuard : public NonCopyable { FUSER_PERF_SCOPE("Segmenter::FusionSegmentGuard"); #ifndef NDEBUG num_original_exprs_ = fusion_->exprs().size(); - original_tvs_ = ir_utils::allTvs(fusion_); + original_tvs_ = fusion_->allTvs(); #endif // NDEBUG // Cast inputs and outputs of a merged group consisting of a and @@ -2427,7 +2427,7 @@ class FusionSegmentGuard : public NonCopyable { FUSER_PERF_SCOPE("Segmenter::FusionSegmentGuard"); #ifndef NDEBUG num_original_exprs_ = fusion_->exprs().size(); - original_tvs_ = ir_utils::allTvs(fusion_); + original_tvs_ = fusion_->allTvs(); #endif // NDEBUG // Cast inputs and outputs of a merged group consisting of @@ -2468,7 +2468,7 @@ class FusionSegmentGuard : public NonCopyable { num_original_exprs_, ", actual: ", num_current_exprs); - auto current_tvs = ir_utils::allTvs(fusion_); + auto current_tvs = fusion_->allTvs(); NVF_ERROR( original_tvs_ == current_tvs, "Failed to revert temporary changes."); #endif diff --git a/csrc/id_model/validation_utils.cpp b/csrc/id_model/validation_utils.cpp index c8c116df8e8..6dd6e520f7c 100644 --- a/csrc/id_model/validation_utils.cpp +++ b/csrc/id_model/validation_utils.cpp @@ -118,7 +118,7 @@ bool exprsMap( IdModelValidator::IdModelValidator(Fusion* fusion, bool allow_self_mapping) : ca_map_(fusion, allow_self_mapping) { - for (auto tv : ir_utils::allTvs(fusion)) { + for (auto tv : fusion->allTvs()) { for (auto id : tv->domain()->allIDs()) { if (id->definition() && id->definition()->isA()) { has_swizzle_ = true; diff --git a/csrc/inlining.cpp b/csrc/inlining.cpp index e308183cc10..d71fc059846 100644 --- a/csrc/inlining.cpp +++ b/csrc/inlining.cpp @@ -29,7 +29,7 @@ void MaxPosCalculator::buildUnmappableDims(bool compute_at_only) { } ComputeAtLogicalDomainMap logical_map; logical_map.build(); - auto all_tvs = ir_utils::allTvs(FusionGuard::getCurFusion()); + auto all_tvs = FusionGuard::getCurFusion()->allTvs(); for (auto tv : all_tvs) { auto consumers = ir_utils::consumerTvsOf(tv); for (auto consumer : consumers) { @@ -173,7 +173,7 @@ size_t MaxPosCalculator::getMaxPosAll( } void inlineMost(const std::unordered_set& uninlinable_ids) { - inlineMost(ir_utils::allTvs(FusionGuard::getCurFusion()), uninlinable_ids); + inlineMost(FusionGuard::getCurFusion()->allTvs(), uninlinable_ids); } void inlineMost( diff --git a/csrc/ir/graphviz.cpp b/csrc/ir/graphviz.cpp index 4e7413eb148..7cbd23f7dd3 100644 --- a/csrc/ir/graphviz.cpp +++ b/csrc/ir/graphviz.cpp @@ -426,7 +426,7 @@ void TransformToDot::handle(Fusion* fusion) { // Make sure the loop domains are ordered correctly indent() << "graph [ordering=\"out\"];\n"; - for (const auto tv : ir_utils::allTvs(fusion)) { + for (const auto tv : fusion->allTvs()) { handle(tv); } diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index 5d52a898e84..ebdcf699f33 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -403,25 +403,6 @@ std::vector outputTvsOf(std::vector tvs) { return uniqueEntries(out_tvs); } -std::vector allTvs(Fusion* fusion) { - auto used_vals = fusion->usedMathVals(); - auto used_tvs = ir_utils::filterByType(used_vals); - - // This shouldn't be necessary but FusionSegmentIoAlias_CUDA due to aliasing - // is having an input disconnected from outputs, and these iter domains are - // being checked in compute at maps in scheduling logic. This shouldn't hurt - // AFAICT. - auto tv_inputs = ir_utils::filterByType(fusion->inputs()); - - std::vector all_tvs({used_tvs.begin(), used_tvs.end()}); - // Sometimes inputs are not connected to outputs, however, we still include - // them when returning allTvs because they are registered as an input. - all_tvs.insert(all_tvs.end(), tv_inputs.begin(), tv_inputs.end()); - - // all_tvs has duplicates, to deduplicate it and return - return uniqueEntries(all_tvs); -} - VectorOfUniqueEntries allTvsOfExprs( const std::vector& exprs) { VectorOfUniqueEntries all_tvs; @@ -438,7 +419,7 @@ VectorOfUniqueEntries allTvsOfExprs( std::vector allTvsExcept( Fusion* fusion, const std::unordered_set& except) { - auto all_tvs = allTvs(fusion); + auto all_tvs = fusion->allTvs(); std::vector result; for (auto tv : all_tvs) { if (except.count(tv) == 0) { @@ -803,7 +784,7 @@ bool hasResizedRfactor(const TensorView* tv) { } std::vector getTVsWithDynamicTransform(Fusion* fusion) { - const auto all_tvs = ir_utils::allTvs(fusion); + const auto all_tvs = fusion->allTvs(); std::vector dynamic_tvs; std::copy_if( all_tvs.begin(), diff --git a/csrc/ir/utils.h b/csrc/ir/utils.h index 913df3773f4..46225feb240 100644 --- a/csrc/ir/utils.h +++ b/csrc/ir/utils.h @@ -383,9 +383,6 @@ std::vector inputTvsOf(std::vector tvs); // Returns consumers of tvs that are outputs of fusion std::vector outputTvsOf(std::vector tvs); -// returns all tensor views in fusion that are used between outputs and inputs. -NVF_API std::vector allTvs(Fusion* fusion); - // returns all tensor views used in the provided expressions VectorOfUniqueEntries allTvsOfExprs( const std::vector& exprs); diff --git a/csrc/kernel_cache.cpp b/csrc/kernel_cache.cpp index 3c83ed5642a..90d498103e2 100644 --- a/csrc/kernel_cache.cpp +++ b/csrc/kernel_cache.cpp @@ -1024,8 +1024,7 @@ FusionKernelRuntime::FusionKernelRuntime( // SchedulerRuntimeInfo modifies the fusion, so it is required for both // compile paths. - std::vector all_tvs = - fusion->allTvs(); // ir_utils::allTvs(fusion.get()); + std::vector all_tvs = fusion->allTvs(); SchedulerRuntimeInfo runtime_info( fusion.get(), args, nullptr, all_tvs, forced_index_type); @@ -1491,7 +1490,6 @@ std::optional FusionKernelRuntime:: // Get all tensorviews for segmented fusion std::vector all_tvs_for_fusion_to_run = fusion_to_run->allTvs(); - // ir_utils::allTvs(fusion_to_run); SchedulerRuntimeInfo fusion_to_run_info( fusion_to_run, diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index b5b8c7f1725..9b32ba9c690 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -233,7 +233,7 @@ void shardAllLike(TensorView* ref, std::vector tvs) { int64_t requestedNumberOfDevices(Fusion* fusion) { DeviceIdxType max_index = 0; - for (auto tv : ir_utils::allTvs(fusion)) { + for (auto tv : fusion->allTvs()) { if (tv->hasDeviceMesh()) { for (auto d_id : tv->getDeviceMesh().vector()) { max_index = std::max(max_index, d_id); @@ -253,7 +253,7 @@ void unshard(TensorView* tv) { } void unshard(Fusion* fusion) { - for (auto tv : ir_utils::allTvs(fusion)) { + for (auto tv : fusion->allTvs()) { unshard(tv); } } diff --git a/csrc/preseg_passes/mark_aliases_prepare.cpp b/csrc/preseg_passes/mark_aliases_prepare.cpp index 2478105e33e..6afdbca299c 100644 --- a/csrc/preseg_passes/mark_aliases_prepare.cpp +++ b/csrc/preseg_passes/mark_aliases_prepare.cpp @@ -56,7 +56,7 @@ std::unordered_set exprsDependedByNonAliases( const AliasAnalysisResult& analysis, Fusion* fusion) { std::vector non_aliases; - for (TensorView* tv : ir_utils::allTvs(fusion)) { + for (TensorView* tv : fusion->allTvs()) { if (analysis.getRoot(tv) == nullptr) { non_aliases.push_back(tv); } @@ -129,7 +129,7 @@ void MarkAliasesPreparePass::runPass(Fusion* fusion) { } // Materialize the alias-enabling allocation domain. - for (TensorView* tv : ir_utils::allTvs(fusion)) { + for (TensorView* tv : fusion->allTvs()) { if (analysis.getRoot(tv) == nullptr) { continue; } diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index e3e4f39f8d9..997566cfaf1 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -42,7 +42,7 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // Validate that meshes are assigned to all TensorViews or none. TensorView* tv_with_mesh = nullptr; TensorView* tv_without_mesh = nullptr; - for (TensorView* tv : ir_utils::allTvs(fusion)) { + for (TensorView* tv : fusion->allTvs()) { auto update_if_null = [](TensorView*& lhs, TensorView* rhs) { if (lhs == nullptr) { lhs = rhs; diff --git a/csrc/python_frontend/fusion_definition.cpp b/csrc/python_frontend/fusion_definition.cpp index 5f71d7a4604..d6926065d67 100644 --- a/csrc/python_frontend/fusion_definition.cpp +++ b/csrc/python_frontend/fusion_definition.cpp @@ -230,7 +230,7 @@ void FusionDefinition::setupSchedule(const at::ArrayRef& inputs) { user_schedule_fusion, args, /*precomuted_values=*/nullptr, - ir_utils::allTvs(user_schedule_fusion)); + user_schedule_fusion->allTvs()); // Manually setting the fusion guard as there is not a good way of using a // guard in a local scope across the schedule function @@ -243,7 +243,7 @@ void FusionDefinition::finalizeSchedule( FUSER_PERF_SCOPE("FusionDefinition::finalizeSchedule"); // TODO: remove when multidevice executor integration is done natively Fusion* fusion = user_sched_->schedule.get(); - std::vector tvs = ir_utils::allTvs(fusion); + std::vector tvs = fusion->allTvs(); if (std::any_of(tvs.begin(), tvs.end(), [](Val* v) { return v->isA() && v->as()->hasDeviceMesh(); })) { diff --git a/csrc/scheduler/normalization_inner_outer.cpp b/csrc/scheduler/normalization_inner_outer.cpp index a29cb3af26a..51562346e72 100644 --- a/csrc/scheduler/normalization_inner_outer.cpp +++ b/csrc/scheduler/normalization_inner_outer.cpp @@ -238,7 +238,7 @@ std::vector getOuterBroadcastTvs( // find the broadcast tensor whose broadcast mask is same to the reference std::vector outer_broadcast_tvs; - for (auto tv : ir_utils::allTvs(fusion)) { + for (auto tv : fusion->allTvs()) { if (std::any_of( tv->getLoopDomain().begin(), tv->getLoopDomain().end(), diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index df6ba64499e..2445ab79afa 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -901,7 +901,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { // unrolling manually. inlineAllAt(reference_tv, unswitch_pos, true); - auto all_tvs = ir_utils::allTvs(fusion); + auto all_tvs = fusion->allTvs(); // Inline at the inner most position. The CA position of all tensors except // inputs, cached inputs and outputs will be updated. diff --git a/csrc/scheduler/registry.cpp b/csrc/scheduler/registry.cpp index 9800576758d..047d9e479b1 100644 --- a/csrc/scheduler/registry.cpp +++ b/csrc/scheduler/registry.cpp @@ -39,7 +39,7 @@ SchedulerRuntimeInfo::SchedulerRuntimeInfo( } else { index_type_ = registry_utils::getIndexTypeOfKernel( complete_fusion_, - all_tvs.empty() ? ir_utils::allTvs(complete_fusion_) : all_tvs, + all_tvs.empty() ? complete_fusion_->allTvs() : all_tvs, args, *expression_evaluator_); } diff --git a/csrc/scheduler/registry_utils.cpp b/csrc/scheduler/registry_utils.cpp index b12475b208e..2af092c08f2 100644 --- a/csrc/scheduler/registry_utils.cpp +++ b/csrc/scheduler/registry_utils.cpp @@ -62,7 +62,7 @@ bool checkPatternEquivalence( bool hasNonUniqueBcast(Fusion* fusion) { ConcretizedBroadcastDomains concretize_info(fusion); - for (auto tv : ir_utils::allTvs(fusion)) { + for (auto tv : fusion->allTvs()) { for (auto id : tv->getMaybeRootDomain()) { if (concretize_info.maybeNonUniquelyConcretized(id)) { return true; diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index db40a638068..6013f2838e6 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -311,7 +311,7 @@ void parallelizeAllLike( } if (selected_tvs.empty()) { - selected_tvs = ir_utils::allTvs(reference_tv->fusion()); + selected_tvs = reference_tv->fusion()->allTvs(); } for (auto tv : selected_tvs) { if (tv->isFusionInput()) { @@ -564,7 +564,7 @@ PersistentBufferInfo persistentBuffers(Fusion* fusion) { ComputeAtLogicalDomainMap logical_map; logical_map.build(); - auto all_tvs = ir_utils::allTvs(fusion); + auto all_tvs = fusion->allTvs(); for (auto producer : all_tvs) { // Are all producer ids mappable to all consumers @@ -1063,7 +1063,7 @@ std::pair canonicalDimReduction( } std::vector getReductionTvs(Fusion* fusion) { - auto all_tvs = ir_utils::allTvs(fusion); + auto all_tvs = fusion->allTvs(); std::vector reduction_tvs; for (auto tv : all_tvs) { if (!tv->isFusionInput() && @@ -1130,7 +1130,7 @@ std::vector getTVsWithNonReductionRFactor(Fusion* fusion) { // Reset inputs and outputs to global memory, everything else to local. void clearMemorySpace(Fusion* fusion) { - for (auto tv : ir_utils::allTvs(fusion)) { + for (auto tv : fusion->allTvs()) { if (tv->isFusionInput() || tv->isFusionOutput()) { tv->setMemoryType(MemoryType::Global); } else { @@ -1986,7 +1986,7 @@ DisjointSets disjointLogicalSets(Fusion* fusion) { // If iter domains are involved in any transformation from root domains to // logical domains they should be considered "contaminated". - for (auto tv : ir_utils::allTvs(fusion)) { + for (auto tv : fusion->allTvs()) { for (auto expr : StmtSort::getExprsTo( {tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()})) { if (expr->isA()) { @@ -2146,7 +2146,7 @@ void propagateReshapeTransforms(Fusion* fusion, const ComputeAtMap& ca_map) { // If iter domains are involved in any transformation from root domains to // logical domains they should be considered "contaminated". - for (auto tv : ir_utils::allTvs(fusion)) { + for (auto tv : fusion->allTvs()) { for (auto expr : StmtSort::getExprsBetween( {tv->getMaybeRootDomain().begin(), tv->getMaybeRootDomain().end()}, {tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()})) { @@ -2183,7 +2183,7 @@ void propagateReshapeTransforms(Fusion* fusion, const ComputeAtMap& ca_map) { // If iter domains are involved in any transformation from root domains to // logical domains they should be considered "contaminated". - for (auto tv : ir_utils::allTvs(fusion)) { + for (auto tv : fusion->allTvs()) { if (!tv->hasRoot()) { continue; } @@ -2249,7 +2249,7 @@ std::vector> getNonPointwiseProducerConsumerPairs(Fusion* fusion) { std::vector> tvs; - for (auto consumer : ir_utils::allTvs(fusion)) { + for (auto consumer : fusion->allTvs()) { if (consumer->isFusionInput()) { continue; } @@ -2570,7 +2570,7 @@ void moveNonConcretizedBroadcastInnermost( } } - for (auto tv : ir_utils::allTvs(fusion)) { + for (auto tv : fusion->allTvs()) { std::vector broadcast_to_move; for (const auto i : c10::irange(tv->getLoopDomain().size())) { auto loop_id = tv->getLoopDomain().at(i); diff --git a/tests/cpp/test_gpu3.cpp b/tests/cpp/test_gpu3.cpp index 4cc1ac113f4..eecc11cf03f 100644 --- a/tests/cpp/test_gpu3.cpp +++ b/tests/cpp/test_gpu3.cpp @@ -772,7 +772,7 @@ TEST_F(NVFuserTest, FusionIssue1430_CUDA) { scheduler_utils::parallelizeAllLike(rfactor); - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { if (tv != tv1 || tv != tv3) { for (auto i : c10::irange(tv->nDims())) { if (isParallelTypeVectorize(tv->axis(i)->getParallelType())) { @@ -2054,7 +2054,7 @@ TEST_F(NVFuserTest, FusionExactLogicalDomainMap_CUDA) { exact_map.toString()); // They must not be mapped with anything else. - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { for (auto logical_id : tv->getLogicalDomain()) { if (logical_id == tv2_bc || logical_id == tv3_bc) { continue; @@ -2167,7 +2167,7 @@ TEST_F(NVFuserTest, FusionTestReEntrantGridWelford_CUDA) { cached_input->computeAt(rfactor_tv, 4, ComputeAtMode::BestEffort); - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { if (tv == cached_input || tv == tv_avg || tv == tv_M2) { continue; } @@ -8535,7 +8535,7 @@ TEST_F(NVFuserTest, MoveNonConcretizedBroadcastInNormalization) { auto ref_outermost = tv7->getLoopDomain().at(0); IdModel id_model(&fusion); const auto& exact_graph = id_model.idGraph(IdMappingMode::EXACT); - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { if (tv->isFusionInput()) { continue; } @@ -8603,7 +8603,7 @@ TEST_F(NVFuserTest, MoveNonConcretizedBroadcastInPointwise) { auto ref_outermost = tv5->getLoopDomain().at(0); IdModel id_model(&fusion); const auto& exact_graph = id_model.idGraph(IdMappingMode::EXACT); - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { if (tv->isFusionInput()) { continue; } @@ -8670,7 +8670,7 @@ TEST_F(NVFuserTest, MoveNonConcretizedBroadcastInReduction) { auto ref_outermost = tv6->getLoopDomain().at(0); IdModel id_model(&fusion); const auto& exact_graph = id_model.idGraph(IdMappingMode::EXACT); - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { if (tv->isFusionInput()) { continue; } diff --git a/tests/cpp/test_gpu_compute_with.cpp b/tests/cpp/test_gpu_compute_with.cpp index 7abf3e891f8..b2d872308bb 100644 --- a/tests/cpp/test_gpu_compute_with.cpp +++ b/tests/cpp/test_gpu_compute_with.cpp @@ -130,7 +130,7 @@ TEST_F(NVFuserTest, FusionComputeWith1_CUDA) { // Set the global inlining only with the outer axis std::unordered_set uninlinable; - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { if (tv->nDims() == 2) { uninlinable.insert(tv->axis(1)); } @@ -424,7 +424,7 @@ TEST_F(NVFuserTest, FusionComputeWith6_CUDA) { TransformPropagator propagator(tv3_rf); MaxLogicalDomainInfoSpanningTree(tv3_rf).traverse(&propagator); - scheduler_utils::parallelizeAllLike(tv3_rf, ir_utils::allTvs(&fusion)); + scheduler_utils::parallelizeAllLike(tv3_rf, fusion.allTvs()); tv1->axis(-1)->parallelize(ParallelType::Vectorize); tv7->axis(-1)->parallelize(ParallelType::Vectorize); diff --git a/tests/cpp/test_gpu_fused_reduction.cpp b/tests/cpp/test_gpu_fused_reduction.cpp index bf0bb8ec877..e67875f4a1a 100644 --- a/tests/cpp/test_gpu_fused_reduction.cpp +++ b/tests/cpp/test_gpu_fused_reduction.cpp @@ -2085,7 +2085,7 @@ TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduce4_CUDA) { tv4->axis(0)->parallelize(ParallelType::BIDx); tv4->axis(1)->parallelize(ParallelType::TIDx); - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { tv->axis(-2)->parallelize(ParallelType::BIDy); tv->axis(-1)->parallelize(ParallelType::TIDy); } @@ -2355,8 +2355,7 @@ TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduceWelfordShmoo_CUDA) { })); transform_ref_rf->axis(unswitch_id)->parallelize(ParallelType::Serial); - scheduler_utils::parallelizeAllLike( - transform_ref_rf, ir_utils::allTvs(&fusion)); + scheduler_utils::parallelizeAllLike(transform_ref_rf, fusion.allTvs()); ParallelType vec_pt = ParallelType::Vectorize; tv1->axis(vec_id)->parallelize(vec_pt); diff --git a/tests/cpp/test_gpu_outer_reduction.cpp b/tests/cpp/test_gpu_outer_reduction.cpp index f6c120c5aba..afbe2eb5d5d 100644 --- a/tests/cpp/test_gpu_outer_reduction.cpp +++ b/tests/cpp/test_gpu_outer_reduction.cpp @@ -101,7 +101,7 @@ TEST_F(OuterReductionTest, GroupedGridWelfordOuterOpt) { ref_rf->axis(3)->parallelize(ParallelType::BIDy); ref_rf->axis(5)->parallelize(ParallelType::TIDy); - scheduler_utils::parallelizeAllLike(ref_rf, ir_utils::allTvs(&fusion)); + scheduler_utils::parallelizeAllLike(ref_rf, fusion.allTvs()); tv1->axis(-1)->parallelize(ParallelType::Vectorize); tv3->axis(-1)->parallelize(ParallelType::Group); @@ -552,8 +552,7 @@ void scheduleNormalization(Fusion& fusion, const OuterReductionParams& params) { unswitch_id->parallelize(ParallelType::Serial); } - scheduler_utils::parallelizeAllLike( - reduction_tv_rf, ir_utils::allTvs(&fusion)); + scheduler_utils::parallelizeAllLike(reduction_tv_rf, fusion.allTvs()); // Vectorize inputs for (auto input_cache : input_caches) { diff --git a/tests/cpp/test_gpu_utils.cpp b/tests/cpp/test_gpu_utils.cpp index 908272aa1e2..f7bc304e47a 100644 --- a/tests/cpp/test_gpu_utils.cpp +++ b/tests/cpp/test_gpu_utils.cpp @@ -1058,7 +1058,7 @@ TEST_F(VectorizeHelperTest, SpanningTree_CUDA) { auto mapper = vectorize_helper::ContiguousInnerDimensionsMapper::map( out, {out->axis(0), out->axis(1)}); - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { if (tv->name() == 0 || tv->name() == 1) { continue; } diff --git a/tests/cpp/test_id_model.cpp b/tests/cpp/test_id_model.cpp index 6f6f1d72c6f..899fc657a88 100644 --- a/tests/cpp/test_id_model.cpp +++ b/tests/cpp/test_id_model.cpp @@ -261,7 +261,7 @@ void checkStep2Results(Fusion* fusion, const IdModelTester& tester) { } }; - for (auto tv : ir_utils::allTvs(fusion)) { + for (auto tv : fusion->allTvs()) { // If there's no broadcast or it isn't inlined, there's no // promotion if (std::none_of( @@ -591,7 +591,7 @@ TEST_F(IdModelTest, ValGraphStmtSort2) { // Note that the two groups of tensors, {tv0, tv1} and {tv2, tv3}, // are not connected - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { tv->merge(0)->split(0, 4); } @@ -674,7 +674,7 @@ TEST_F(IdModelTest, ValGraphStmtSort3) { TEST_F(IdModelTest, ValGraphStmtSort4) { auto fusion = createFusionWithMultipleResolutionPaths(); FusionGuard fg(fusion.get()); - auto all_tvs = ir_utils::allTvs(fusion.get()); + auto all_tvs = fusion->allTvs(); // Since this fusion is not supported by ComputeAtMap, the // validation flag must be false @@ -953,14 +953,14 @@ TEST_F(IdModelTest, LoopPromotion4) { TransformPropagator propagator(tv4); MaxLogicalDomainInfoSpanningTree(tv4).traverse(&propagator); - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { tv->inlineAt(-2); } IdModelTester tester(&fusion); // Verify all tensors with root broadcast have correct resolutions - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { // Skip tensors with no broadcast or non-inlined if (std::none_of( tv->getLogicalDomain().begin(), @@ -1078,7 +1078,7 @@ TEST_F(IdModelTest, LoopPromotion5) { tv2->axis(1)->parallelize(ParallelType::Unroll); tv2->axis(2)->parallelize(ParallelType::TIDx); - auto all_tvs = ir_utils::allTvs(&fusion); + auto all_tvs = fusion.allTvs(); IdModelTester tester(&fusion); @@ -1225,7 +1225,7 @@ TEST_F(IdModelTest, LoopPromotion5) { TEST_F(IdModelTest, LoopPromotion6) { auto fusion = createFusionWithMultipleResolutionPaths(); FusionGuard fg(fusion.get()); - auto all_tvs = ir_utils::allTvs(fusion.get()); + auto all_tvs = fusion->allTvs(); IdModelTester tester(fusion.get()); @@ -1558,7 +1558,7 @@ TEST_F(IdModelTest, LoopPromotion7) { tv2->split(-1, 8); - auto all_tvs = ir_utils::allTvs(&fusion); + auto all_tvs = fusion.allTvs(); IdModelTester tester(&fusion); @@ -1698,7 +1698,7 @@ TEST_F(IdModelTest, LoopPromotion8) { // [2, 4, (3*5//2)*7//4] tv5->inlineAt(2); - auto all_tvs = ir_utils::allTvs(&fusion); + auto all_tvs = fusion.allTvs(); IdModelTester tester(&fusion); @@ -1992,7 +1992,7 @@ TEST_F(IdModelTest, LoopPromotionTwoStepFailureReproSimple) { TransformPropagatorWithCheck propagator(t4); MaxLogicalDomainInfoSpanningTree(t4).traverse(&propagator); - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { tv->inlineAt(1); } @@ -2044,7 +2044,7 @@ TEST_F(IdModelTest, ComplimentMappingCausingLoopSelfMapping) { fusion.addOutput(tv11); // Merge all domains except for tv10 and tv11 - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { if (tv == tv10 || tv == tv11) { continue; } @@ -2054,7 +2054,7 @@ TEST_F(IdModelTest, ComplimentMappingCausingLoopSelfMapping) { } // Fully inline all tensors up until tv10 - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { if (tv == tv9 || tv == tv10 || tv == tv11) { continue; } @@ -2446,7 +2446,7 @@ TEST_F(IdModelTest, LoopPromotionWithViewRFactor1) { // All of the inlined tensors (i.e., all tensors except for the // inputs) should be grouped together. - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { if (tv->isFusionInput()) { continue; } @@ -2496,7 +2496,7 @@ TEST_F(IdModelTest, LoopPromotionWithLogicalDomains2) { // All of the inlined tensors (i.e., all tensors except for the // inputs) should be grouped together. - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { if (tv->isFusionInput()) { continue; } @@ -2560,7 +2560,7 @@ TEST_F(IdModelTest, LoopPromotionCoverage) { // All tvs except for inptus should be just a 1D tensor and be // promoted to a domain that is exactly mappd with the loop domain // of tv10. - for (const auto tv : ir_utils::allTvs(&fusion)) { + for (const auto tv : fusion.allTvs()) { if (tv->isFusionInput()) { continue; } diff --git a/tests/cpp/test_indexing.cpp b/tests/cpp/test_indexing.cpp index b1eb4ea80e1..9f9d78e26b0 100644 --- a/tests/cpp/test_indexing.cpp +++ b/tests/cpp/test_indexing.cpp @@ -561,7 +561,7 @@ TEST_F(IndexingTest, SimplePointwise2) { tv3->axis(0)->parallelize(ParallelType::BIDx); tv3->axis(1)->parallelize(ParallelType::TIDx); - scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion)); + scheduler_utils::parallelizeAllLike(tv3, fusion.allTvs()); // Test shared memory indexing tv2->setMemoryType(MemoryType::Shared); @@ -1044,7 +1044,7 @@ TEST_F(IndexingTest, SimpleBroadcast4) { TransformPropagator propagator(tv4); MaxLogicalDomainInfoSpanningTree(tv4).traverse(&propagator); - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { tv->inlineAt(-2); } @@ -1344,7 +1344,7 @@ TEST_F(IndexingTest, SimpleVectorize) { inlineMost(); - scheduler_utils::parallelizeAllLike(tv2, ir_utils::allTvs(&fusion)); + scheduler_utils::parallelizeAllLike(tv2, fusion.allTvs()); struct GetReference : AbstractGetReference { GetReference(const TensorIndexer& indexer, const IdModel& id_model) @@ -1413,7 +1413,7 @@ TEST_F(IndexingTest, NonInnermostVectorize) { tv3->axis(0)->parallelize(ParallelType::BIDx); tv3->axis(1)->parallelize(ParallelType::TIDx); - scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion)); + scheduler_utils::parallelizeAllLike(tv3, fusion.allTvs()); tv1->axis(2)->parallelize(ParallelType::Vectorize); tv3->axis(2)->parallelize(ParallelType::Vectorize); @@ -1648,7 +1648,7 @@ TEST_F(IndexingTest, InlinedUnroll) { tv4->axis(1)->parallelize(ParallelType::Unroll); - scheduler_utils::parallelizeAllLike(tv4, ir_utils::allTvs(&fusion)); + scheduler_utils::parallelizeAllLike(tv4, fusion.allTvs()); // The CA position of tv2 is 1 as shown below: // @@ -1704,7 +1704,7 @@ TEST_F(IndexingTest, SmemAllocationDomainForTranspose) { } // [I0, I1] -> [(I0/32 * I1/32), (32 * 32) / 4, 4] - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { tv->split(0, 32); tv->split(2, 32); tv->reorder({{1, 2}}); @@ -2834,7 +2834,7 @@ TEST_F(PredicateIndexingTest, SimpleVectorize) { inlineMost(); - scheduler_utils::parallelizeAllLike(tv2, ir_utils::allTvs(&fusion)); + scheduler_utils::parallelizeAllLike(tv2, fusion.allTvs()); // T1_l[ iblockIdx.x9{( ceilDiv(( ceilDiv(i0, 4) ), 128) )}, // ithreadIdx.x10{128}, iV8{4} ] ca_pos( 2 ) T2_g[ iblockIdx.x5{( ceilDiv(( @@ -2904,7 +2904,7 @@ TEST_F(PredicateIndexingTest, NonInnermostVectorize) { tv3->axis(0)->parallelize(ParallelType::BIDx); tv3->axis(1)->parallelize(ParallelType::TIDx); - scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion)); + scheduler_utils::parallelizeAllLike(tv3, fusion.allTvs()); tv1->axis(2)->parallelize(ParallelType::Vectorize); tv3->axis(2)->parallelize(ParallelType::Vectorize); diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index 862c1f99fc7..c6da1d7d849 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -2406,7 +2406,7 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulSmemEpilogue) { // for prologue and 1 for epilogue. int num_shared_mem_tensors = 0; int expected_num_shared_mem_tensors = params.use_smem_epilogue ? 3 : 2; - for (const auto& tv : ir_utils::allTvs(&fusion)) { + for (const auto& tv : fusion.allTvs()) { if (tv->getMemoryType() == MemoryType::Shared) { num_shared_mem_tensors++; } @@ -2640,7 +2640,7 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulSmemEpilogueCast) { // for prologue and 1 for epilogue. int num_shared_mem_tensors = 0; int expected_num_shared_mem_tensors = params.use_smem_epilogue ? 3 : 2; - for (const auto& tv : ir_utils::allTvs(&fusion)) { + for (const auto& tv : fusion.allTvs()) { if (tv->getMemoryType() == MemoryType::Shared) { num_shared_mem_tensors++; } @@ -2733,7 +2733,7 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulSmemEpilogueRelu) { // for prologue and 1 for epilogue. int num_shared_mem_tensors = 0; int expected_num_shared_mem_tensors = params.use_smem_epilogue ? 3 : 2; - for (const auto& tv : ir_utils::allTvs(&fusion)) { + for (const auto& tv : fusion.allTvs()) { if (tv->getMemoryType() == MemoryType::Shared) { num_shared_mem_tensors++; } diff --git a/tests/cpp/test_scatter_gather.cpp b/tests/cpp/test_scatter_gather.cpp index c8a39e88b01..fbac505ff0d 100644 --- a/tests/cpp/test_scatter_gather.cpp +++ b/tests/cpp/test_scatter_gather.cpp @@ -561,7 +561,7 @@ TEST_F(ScatterGatherTest, TakeAlongAxisIntermediateTensorPointwise1) { tv4->axis(-1)->parallelize(ParallelType::TIDx); scheduler_utils::parallelizeAllLike(tv4); - for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto tv : fusion.allTvs()) { if (tv->isFusionInput()) { continue; } diff --git a/tests/cpp/utils.h b/tests/cpp/utils.h index a0ee6764d06..4ae88f4064c 100644 --- a/tests/cpp/utils.h +++ b/tests/cpp/utils.h @@ -112,7 +112,7 @@ inline void clearL2Cache() { }; inline TensorView* loweredTv(TensorView* tv, kir::Kernel* kernel) { - auto used_tvs = ir_utils::allTvs(kernel); + auto used_tvs = kernel->allTvs(); TensorView* matching_tv = nullptr; for (auto lowered_tv : used_tvs) { if (lowered_tv->name() == tv->name()) {