Skip to content

Commit

Permalink
Cleanup all uses of ir_utils::allTvs (#2884)
Browse files Browse the repository at this point in the history
Follow up to #2873 removes all uses
of the ir_utils variant in favor of the Fusion variant.

---------

Co-authored-by: Jacob Hinkle <[email protected]>
  • Loading branch information
csarofeen and jacobhinkle authored Sep 3, 2024
1 parent dca416d commit f669fcf
Show file tree
Hide file tree
Showing 35 changed files with 110 additions and 113 deletions.
6 changes: 3 additions & 3 deletions csrc/compute_at_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ std::optional<std::pair<IterDomain*, IterDomain*>> detectMappablePair(
// matter in practice.
std::optional<std::tuple<TensorView*, IterDomain*, IterDomain*, std::string>>
findFirstSelfMapping(Fusion* fusion, const IterDomainGraph& id_graph) {
for (auto tv : ir_utils::allTvs(fusion)) {
for (auto tv : fusion->allTvs()) {
// For each tensor, make sure root, logical and loop domains
// should not include domains that are mapped with another domain
// in the same set of domains. This may be overly conservative,
Expand Down Expand Up @@ -342,7 +342,7 @@ void IterDomainGraph::build(Fusion* fusion) {
FusionGuard fg(fusion);

// Initialize a node for every iteration domain
for (auto tv : ir_utils::allTvs(fusion)) {
for (auto tv : fusion->allTvs()) {
const auto& domain = tv->getLoopDomain();
auto all_ids = tv->domain()->allIDs();

Expand Down Expand Up @@ -586,7 +586,7 @@ void IterDomainGraph::build(Fusion* fusion) {
// transformations makes it easy to check if different view operations are
// consistent with eachother.

auto all_tvs = ir_utils::allTvs(fusion);
auto all_tvs = fusion->allTvs();
std::vector<TensorView*> all_consumer_tvs;
std::copy_if(
all_tvs.begin(),
Expand Down
2 changes: 1 addition & 1 deletion csrc/device_lower/analysis/divisible_split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ std::unordered_set<Split*> getAllDivisibleSplits(
const ComputeAtMap* ca_map) {
std::unordered_set<Split*> all_divisible_splits;

auto all_tvs = ir_utils::allTvs(fusion);
auto all_tvs = fusion->allTvs();
// Find all tensor views with a view like rfactor. Splits used in view
// transformations must be divisible by definition.
for (auto tv : all_tvs) {
Expand Down
2 changes: 1 addition & 1 deletion csrc/device_lower/analysis/thread_predicate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,7 @@ void ThreadPredicateMap::build(Fusion* fusion) {
updateBitSet(expr);
}

for (auto tv : ir_utils::allTvs(fusion)) {
for (auto tv : fusion->allTvs()) {
if (tv->getMemoryType() == MemoryType::Global) {
avoidConcretizedBroadcastRedundantWrite(tv);
}
Expand Down
2 changes: 1 addition & 1 deletion csrc/device_lower/lower2device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ bool requiresIdModel(Fusion* fusion) {
}
// If a tensor does not have a nice root->logical/allocation->loop
// linear transformation history, use IdModel.
for (auto tv : ir_utils::allTvs(fusion)) {
for (auto tv : fusion->allTvs()) {
if (!lower_utils::hasRootToLoopLinearTransformations(tv)) {
return true;
}
Expand Down
2 changes: 1 addition & 1 deletion csrc/device_lower/pass/expr_sort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1142,7 +1142,7 @@ void ExprSegmentationSorter::initializeForLoopDependencies() {
concrete_id_dependencies_.empty(),
"For loop dependencies have already been initialized.");

for (auto tv : ir_utils::allTvs(fusion_)) {
for (auto tv : fusion_->allTvs()) {
std::unordered_set<IterDomain*> dependencies;
for (int64_t tv_id_i = std::max(
tv->getMaxProducerPosition(),
Expand Down
4 changes: 2 additions & 2 deletions csrc/device_lower/pass/loops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ void LoopNestGenerator::generate(const std::vector<Expr*>& exprs) {

std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>>
concrete_id_dependencies;
for (auto tv : ir_utils::allTvs(FusionGuard::getCurFusion())) {
for (auto tv : FusionGuard::getCurFusion()->allTvs()) {
std::unordered_set<IterDomain*> dependencies;

for (auto tv_id : tv->getLoopDomain()) {
Expand Down Expand Up @@ -212,7 +212,7 @@ void LoopNestGenerator::generate(const std::vector<Expr*>& exprs) {
}

// Generate loop structure for each tensor view
for (auto tv : ir_utils::allTvs(FusionGuard::getCurFusion())) {
for (auto tv : FusionGuard::getCurFusion()->allTvs()) {
// Zero dim tensor support
if (tv->nDims() == 0) {
loop_structures_[tv] = std::vector<IterDomain*>();
Expand Down
4 changes: 2 additions & 2 deletions csrc/device_lower/validation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ void validateIr(Fusion* fusion) {
"Tensor with dynamic transform must be concretized before lowering: ",
toDelimitedString(dynamic_tvs.begin(), dynamic_tvs.end()));

auto all_tvs = ir_utils::allTvs(fusion);
auto all_tvs = fusion->allTvs();
validateCpAsyncBulk(all_tvs);
}

Expand Down Expand Up @@ -912,7 +912,7 @@ void validateSwizzle(Fusion* fusion) {
}

void validateAndConvertIterDomainGrouping(Fusion* fusion) {
for (auto tv : ir_utils::allTvs(fusion)) {
for (auto tv : fusion->allTvs()) {
bool is_grouped = false;
for (const auto id_idx : c10::irange(tv->nDims())) {
const auto id = tv->axis(id_idx);
Expand Down
2 changes: 1 addition & 1 deletion csrc/evaluator_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ void collectBufferSizes(

std::vector<Val*> collectRuntimeUsedValues(Fusion* fusion) {
std::vector<Val*> ret;
auto all_tvs = ir_utils::allTvs(fusion);
auto all_tvs = fusion->allTvs();
// Collect extent and inputs
for (auto tv : all_tvs) {
for (auto id : tv->getLoopDomain()) {
Expand Down
27 changes: 25 additions & 2 deletions csrc/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -871,10 +871,33 @@ bool isExpressionEvaluated(Fusion* fusion) {
});
}

namespace {
std::vector<TensorView*> findAllTvs(Fusion* fusion) {
auto used_vals = fusion->usedMathVals();
auto used_tvs = ir_utils::filterByType<TensorView>(used_vals);

// This shouldn't be necessary but FusionSegmentIoAlias_CUDA due to aliasing
// is having an input disconnected from outputs, and these iter domains are
// being checked in compute at maps in scheduling logic. This shouldn't hurt
// AFAICT.
auto tv_inputs = ir_utils::filterByType<TensorView>(fusion->inputs());

std::vector<TensorView*> all_tvs({used_tvs.begin(), used_tvs.end()});
// Sometimes inputs are not connected to outputs, however, we still include
// them when returning allTvs because they are registered as an input.
all_tvs.insert(all_tvs.end(), tv_inputs.begin(), tv_inputs.end());

VectorOfUniqueEntries<TensorView*> unique_vector(
all_tvs.begin(), all_tvs.end());

// all_tvs has duplicates, to deduplicate it and return
return unique_vector.vector();
}
} // namespace

std::vector<TensorView*> Fusion::allTvs() {
if (all_tvs_ptr_ == nullptr) {
all_tvs_ptr_ =
std::make_unique<std::vector<TensorView*>>(ir_utils::allTvs(this));
all_tvs_ptr_ = std::make_unique<std::vector<TensorView*>>(findAllTvs(this));
}
return std::vector<TensorView*>(*all_tvs_ptr_);
}
Expand Down
10 changes: 5 additions & 5 deletions csrc/fusion_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2370,7 +2370,7 @@ class FusionSegmentGuard : public NonCopyable {
NVF_ERROR(fusion_ != nullptr);
#ifndef NDEBUG
num_original_exprs_ = fusion_->exprs().size();
original_tvs_ = ir_utils::allTvs(fusion_);
original_tvs_ = fusion_->allTvs();
#endif // NDEBUG
narrowToNewSegment(inputs, outputs);
}
Expand All @@ -2382,7 +2382,7 @@ class FusionSegmentGuard : public NonCopyable {
FUSER_PERF_SCOPE("Segmenter::FusionSegmentGuard");
#ifndef NDEBUG
num_original_exprs_ = fusion_->exprs().size();
original_tvs_ = ir_utils::allTvs(fusion_);
original_tvs_ = fusion_->allTvs();
#endif // NDEBUG
lowered_edges_ = segmented_fusion_->castInputOutputToLowerPrecision(
segmented_fusion_->edges());
Expand All @@ -2398,7 +2398,7 @@ class FusionSegmentGuard : public NonCopyable {
FUSER_PERF_SCOPE("Segmenter::FusionSegmentGuard");
#ifndef NDEBUG
num_original_exprs_ = fusion_->exprs().size();
original_tvs_ = ir_utils::allTvs(fusion_);
original_tvs_ = fusion_->allTvs();
#endif // NDEBUG

// Cast inputs and outputs of a merged group consisting of a and
Expand Down Expand Up @@ -2427,7 +2427,7 @@ class FusionSegmentGuard : public NonCopyable {
FUSER_PERF_SCOPE("Segmenter::FusionSegmentGuard");
#ifndef NDEBUG
num_original_exprs_ = fusion_->exprs().size();
original_tvs_ = ir_utils::allTvs(fusion_);
original_tvs_ = fusion_->allTvs();
#endif // NDEBUG

// Cast inputs and outputs of a merged group consisting of
Expand Down Expand Up @@ -2468,7 +2468,7 @@ class FusionSegmentGuard : public NonCopyable {
num_original_exprs_,
", actual: ",
num_current_exprs);
auto current_tvs = ir_utils::allTvs(fusion_);
auto current_tvs = fusion_->allTvs();
NVF_ERROR(
original_tvs_ == current_tvs, "Failed to revert temporary changes.");
#endif
Expand Down
2 changes: 1 addition & 1 deletion csrc/id_model/validation_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ bool exprsMap(

IdModelValidator::IdModelValidator(Fusion* fusion, bool allow_self_mapping)
: ca_map_(fusion, allow_self_mapping) {
for (auto tv : ir_utils::allTvs(fusion)) {
for (auto tv : fusion->allTvs()) {
for (auto id : tv->domain()->allIDs()) {
if (id->definition() && id->definition()->isA<Swizzle2D>()) {
has_swizzle_ = true;
Expand Down
4 changes: 2 additions & 2 deletions csrc/inlining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ void MaxPosCalculator::buildUnmappableDims(bool compute_at_only) {
}
ComputeAtLogicalDomainMap logical_map;
logical_map.build();
auto all_tvs = ir_utils::allTvs(FusionGuard::getCurFusion());
auto all_tvs = FusionGuard::getCurFusion()->allTvs();
for (auto tv : all_tvs) {
auto consumers = ir_utils::consumerTvsOf(tv);
for (auto consumer : consumers) {
Expand Down Expand Up @@ -173,7 +173,7 @@ size_t MaxPosCalculator::getMaxPosAll(
}

void inlineMost(const std::unordered_set<IterDomain*>& uninlinable_ids) {
inlineMost(ir_utils::allTvs(FusionGuard::getCurFusion()), uninlinable_ids);
inlineMost(FusionGuard::getCurFusion()->allTvs(), uninlinable_ids);
}

void inlineMost(
Expand Down
2 changes: 1 addition & 1 deletion csrc/ir/graphviz.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ void TransformToDot::handle(Fusion* fusion) {
// Make sure the loop domains are ordered correctly
indent() << "graph [ordering=\"out\"];\n";

for (const auto tv : ir_utils::allTvs(fusion)) {
for (const auto tv : fusion->allTvs()) {
handle(tv);
}

Expand Down
23 changes: 2 additions & 21 deletions csrc/ir/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,25 +403,6 @@ std::vector<TensorView*> outputTvsOf(std::vector<TensorView*> tvs) {
return uniqueEntries<TensorView>(out_tvs);
}

std::vector<TensorView*> allTvs(Fusion* fusion) {
auto used_vals = fusion->usedMathVals();
auto used_tvs = ir_utils::filterByType<TensorView>(used_vals);

// This shouldn't be necessary but FusionSegmentIoAlias_CUDA due to aliasing
// is having an input disconnected from outputs, and these iter domains are
// being checked in compute at maps in scheduling logic. This shouldn't hurt
// AFAICT.
auto tv_inputs = ir_utils::filterByType<TensorView>(fusion->inputs());

std::vector<TensorView*> all_tvs({used_tvs.begin(), used_tvs.end()});
// Sometimes inputs are not connected to outputs, however, we still include
// them when returning allTvs because they are registered as an input.
all_tvs.insert(all_tvs.end(), tv_inputs.begin(), tv_inputs.end());

// all_tvs has duplicates, to deduplicate it and return
return uniqueEntries<TensorView>(all_tvs);
}

VectorOfUniqueEntries<TensorView*> allTvsOfExprs(
const std::vector<Expr*>& exprs) {
VectorOfUniqueEntries<TensorView*> all_tvs;
Expand All @@ -438,7 +419,7 @@ VectorOfUniqueEntries<TensorView*> allTvsOfExprs(
std::vector<TensorView*> allTvsExcept(
Fusion* fusion,
const std::unordered_set<TensorView*>& except) {
auto all_tvs = allTvs(fusion);
auto all_tvs = fusion->allTvs();
std::vector<TensorView*> result;
for (auto tv : all_tvs) {
if (except.count(tv) == 0) {
Expand Down Expand Up @@ -803,7 +784,7 @@ bool hasResizedRfactor(const TensorView* tv) {
}

std::vector<TensorView*> getTVsWithDynamicTransform(Fusion* fusion) {
const auto all_tvs = ir_utils::allTvs(fusion);
const auto all_tvs = fusion->allTvs();
std::vector<TensorView*> dynamic_tvs;
std::copy_if(
all_tvs.begin(),
Expand Down
3 changes: 0 additions & 3 deletions csrc/ir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -383,9 +383,6 @@ std::vector<TensorView*> inputTvsOf(std::vector<TensorView*> tvs);
// Returns consumers of tvs that are outputs of fusion
std::vector<TensorView*> outputTvsOf(std::vector<TensorView*> tvs);

// returns all tensor views in fusion that are used between outputs and inputs.
NVF_API std::vector<TensorView*> allTvs(Fusion* fusion);

// returns all tensor views used in the provided expressions
VectorOfUniqueEntries<TensorView*> allTvsOfExprs(
const std::vector<Expr*>& exprs);
Expand Down
4 changes: 1 addition & 3 deletions csrc/kernel_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1024,8 +1024,7 @@ FusionKernelRuntime::FusionKernelRuntime(

// SchedulerRuntimeInfo modifies the fusion, so it is required for both
// compile paths.
std::vector<TensorView*> all_tvs =
fusion->allTvs(); // ir_utils::allTvs(fusion.get());
std::vector<TensorView*> all_tvs = fusion->allTvs();
SchedulerRuntimeInfo runtime_info(
fusion.get(), args, nullptr, all_tvs, forced_index_type);

Expand Down Expand Up @@ -1491,7 +1490,6 @@ std::optional<FusionKernelRuntime::HeuristicsPtr> FusionKernelRuntime::
// Get all tensorviews for segmented fusion
std::vector<TensorView*> all_tvs_for_fusion_to_run =
fusion_to_run->allTvs();
// ir_utils::allTvs(fusion_to_run);

SchedulerRuntimeInfo fusion_to_run_info(
fusion_to_run,
Expand Down
4 changes: 2 additions & 2 deletions csrc/multidevice/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ void shardAllLike(TensorView* ref, std::vector<TensorView*> tvs) {

int64_t requestedNumberOfDevices(Fusion* fusion) {
DeviceIdxType max_index = 0;
for (auto tv : ir_utils::allTvs(fusion)) {
for (auto tv : fusion->allTvs()) {
if (tv->hasDeviceMesh()) {
for (auto d_id : tv->getDeviceMesh().vector()) {
max_index = std::max(max_index, d_id);
Expand All @@ -253,7 +253,7 @@ void unshard(TensorView* tv) {
}

void unshard(Fusion* fusion) {
for (auto tv : ir_utils::allTvs(fusion)) {
for (auto tv : fusion->allTvs()) {
unshard(tv);
}
}
Expand Down
4 changes: 2 additions & 2 deletions csrc/preseg_passes/mark_aliases_prepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ std::unordered_set<Expr*> exprsDependedByNonAliases(
const AliasAnalysisResult& analysis,
Fusion* fusion) {
std::vector<Val*> non_aliases;
for (TensorView* tv : ir_utils::allTvs(fusion)) {
for (TensorView* tv : fusion->allTvs()) {
if (analysis.getRoot(tv) == nullptr) {
non_aliases.push_back(tv);
}
Expand Down Expand Up @@ -129,7 +129,7 @@ void MarkAliasesPreparePass::runPass(Fusion* fusion) {
}

// Materialize the alias-enabling allocation domain.
for (TensorView* tv : ir_utils::allTvs(fusion)) {
for (TensorView* tv : fusion->allTvs()) {
if (analysis.getRoot(tv) == nullptr) {
continue;
}
Expand Down
2 changes: 1 addition & 1 deletion csrc/preseg_passes/propagate_shardings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ void PropagateShardingsPass::runPass(Fusion* fusion) {
// Validate that meshes are assigned to all TensorViews or none.
TensorView* tv_with_mesh = nullptr;
TensorView* tv_without_mesh = nullptr;
for (TensorView* tv : ir_utils::allTvs(fusion)) {
for (TensorView* tv : fusion->allTvs()) {
auto update_if_null = [](TensorView*& lhs, TensorView* rhs) {
if (lhs == nullptr) {
lhs = rhs;
Expand Down
4 changes: 2 additions & 2 deletions csrc/python_frontend/fusion_definition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ void FusionDefinition::setupSchedule(const at::ArrayRef<c10::IValue>& inputs) {
user_schedule_fusion,
args,
/*precomuted_values=*/nullptr,
ir_utils::allTvs(user_schedule_fusion));
user_schedule_fusion->allTvs());

// Manually setting the fusion guard as there is not a good way of using a
// guard in a local scope across the schedule function
Expand All @@ -243,7 +243,7 @@ void FusionDefinition::finalizeSchedule(
FUSER_PERF_SCOPE("FusionDefinition::finalizeSchedule");
// TODO: remove when multidevice executor integration is done natively
Fusion* fusion = user_sched_->schedule.get();
std::vector<TensorView*> tvs = ir_utils::allTvs(fusion);
std::vector<TensorView*> tvs = fusion->allTvs();
if (std::any_of(tvs.begin(), tvs.end(), [](Val* v) {
return v->isA<TensorView>() && v->as<TensorView>()->hasDeviceMesh();
})) {
Expand Down
2 changes: 1 addition & 1 deletion csrc/scheduler/normalization_inner_outer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ std::vector<TensorView*> getOuterBroadcastTvs(

// find the broadcast tensor whose broadcast mask is same to the reference
std::vector<TensorView*> outer_broadcast_tvs;
for (auto tv : ir_utils::allTvs(fusion)) {
for (auto tv : fusion->allTvs()) {
if (std::any_of(
tv->getLoopDomain().begin(),
tv->getLoopDomain().end(),
Expand Down
2 changes: 1 addition & 1 deletion csrc/scheduler/pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
// unrolling manually.
inlineAllAt(reference_tv, unswitch_pos, true);

auto all_tvs = ir_utils::allTvs(fusion);
auto all_tvs = fusion->allTvs();

// Inline at the inner most position. The CA position of all tensors except
// inputs, cached inputs and outputs will be updated.
Expand Down
2 changes: 1 addition & 1 deletion csrc/scheduler/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ SchedulerRuntimeInfo::SchedulerRuntimeInfo(
} else {
index_type_ = registry_utils::getIndexTypeOfKernel(
complete_fusion_,
all_tvs.empty() ? ir_utils::allTvs(complete_fusion_) : all_tvs,
all_tvs.empty() ? complete_fusion_->allTvs() : all_tvs,
args,
*expression_evaluator_);
}
Expand Down
2 changes: 1 addition & 1 deletion csrc/scheduler/registry_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ bool checkPatternEquivalence(
bool hasNonUniqueBcast(Fusion* fusion) {
ConcretizedBroadcastDomains concretize_info(fusion);

for (auto tv : ir_utils::allTvs(fusion)) {
for (auto tv : fusion->allTvs()) {
for (auto id : tv->getMaybeRootDomain()) {
if (concretize_info.maybeNonUniquelyConcretized(id)) {
return true;
Expand Down
Loading

0 comments on commit f669fcf

Please sign in to comment.