Skip to content

Commit

Permalink
s/DependencyCheck::getAllExprsBetween/StmtSort::getExprsBetween (#1413)
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue authored and jacobhinkle committed Dec 6, 2023
1 parent d50da38 commit f5e4326
Show file tree
Hide file tree
Showing 41 changed files with 119 additions and 222 deletions.
1 change: 0 additions & 1 deletion csrc/compute_at_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,6 @@ void IterDomainGraph::build(Fusion* fusion) {
// Grab all the rfactor ids.
for (auto consumer_tv : all_consumer_tvs) {
auto exprs = StmtSort::getExprsTo(
fusion,
{consumer_tv->getMaybeRFactorDomain().begin(),
consumer_tv->getMaybeRFactorDomain().end()});
for (auto expr : exprs) {
Expand Down
12 changes: 3 additions & 9 deletions csrc/contiguity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ OrderedIdInformation::OrderedIdInformation(
// consistently_ordered_ids_, id_to_alloc_ids_, and
// exclusively_consumes_allocs_ for all the IDs
auto exprs = StmtSort::getExprsBetween(
ids[0]->fusion(),
{alloc_domain.begin(), alloc_domain.end()},
{ids.begin(), ids.end()});
{alloc_domain.begin(), alloc_domain.end()}, {ids.begin(), ids.end()});

for (auto expr : exprs) {
OptInDispatch::dispatch(expr);
Expand Down Expand Up @@ -386,9 +384,7 @@ NonDivisibleSplitDependencies::NonDivisibleSplitDependencies(
return;
}
auto transforms = StmtSort::getExprsBetween(
ids[0]->fusion(),
{alloc_domain.begin(), alloc_domain.end()},
{ids.begin(), ids.end()});
{alloc_domain.begin(), alloc_domain.end()}, {ids.begin(), ids.end()});
for (auto transform : transforms) {
auto inp_ids = ir_utils::filterByType<IterDomain>(transform->inputs());
for (auto inp_id : inp_ids) {
Expand Down Expand Up @@ -545,9 +541,7 @@ void ContigIDs::build(const std::vector<IterDomain*>& ids) {

if (!contig_ids_.empty()) {
auto exprs = StmtSort::getExprsBetween(
ids.at(0)->fusion(),
{alloc_domain_.begin(), alloc_domain_.end()},
{ids.begin(), ids.end()});
{alloc_domain_.begin(), alloc_domain_.end()}, {ids.begin(), ids.end()});
for (auto expr : exprs) {
if (auto resize = dynamic_cast<Resize*>(expr)) {
resize_deps_.insert(resize->out());
Expand Down
2 changes: 1 addition & 1 deletion csrc/device_lower/analysis/divisible_split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ std::unordered_set<Split*> getAllDivisibleSplits(
// Take the view transformations and add all the splits. Those splits are
// the only divisible splits.
auto view_exprs =
StmtSort::getExprsTo(fusion, {rfactor_dom.begin(), rfactor_dom.end()});
StmtSort::getExprsTo({rfactor_dom.begin(), rfactor_dom.end()});
auto split_exprs = ir_utils::filterByType<Split>(view_exprs);
all_divisible_splits.insert(split_exprs.begin(), split_exprs.end());
}
Expand Down
4 changes: 2 additions & 2 deletions csrc/device_lower/analysis/predicate_elimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ class PredicateChcker : public IterVisitor {
// provided.
bool predicateNonDivisibleRootDomains(Expr* expr) const {
for (auto output : ir_utils::filterByType<TensorView>(expr->outputs())) {
const auto all_exprs = DependencyCheck::getAllExprsBetween(
const auto all_exprs = StmtSort::getExprsBetween(
{output->getMaybeRFactorDomain().begin(),
output->getMaybeRFactorDomain().end()},
{output->getLeafDomain().begin(), output->getLeafDomain().end()});
Expand Down Expand Up @@ -863,7 +863,7 @@ class PredicateChcker : public IterVisitor {
} // namespace

PredicateElimination::PredicateElimination(Fusion* fusion) {
traverseTo(fusion, fusion->outputs());
traverseTo(fusion->outputs());
}

bool PredicateElimination::needsPredicate(Expr* expr) const {
Expand Down
2 changes: 1 addition & 1 deletion csrc/device_lower/analysis/shift.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ void HaloInfo::setHaloWidth(IterDomain* id, int halo_width) {

// Propagate extent information from root axes to descendants
void HaloInfo::build(TensorDomain* td) {
auto exprs = DependencyCheck::getAllExprsBetween(
auto exprs = StmtSort::getExprsBetween(
{td->maybeRFactor().begin(), td->maybeRFactor().end()},
{td->leaf().begin(), td->leaf().end()});

Expand Down
4 changes: 1 addition & 3 deletions csrc/device_lower/analysis/sync_information.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ struct ProducerConsumerIndexingInfoCache {
const auto& consumer_leaf_ids_shared_with_producer =
getConsumerLeafIDsSharedWithProducer();
consumer_root_ids_shared_with_producer_ = InputsOf::outputs(
producer_tv_->fusion(),
{consumer_leaf_ids_shared_with_producer.begin(),
consumer_leaf_ids_shared_with_producer.end()});
}
Expand Down Expand Up @@ -261,10 +260,9 @@ bool useSameIndex(
// consumer_id. The goal of the analysis below is to find out if all
// of the root IDs are indexed in the same way between the producer
// and consumer tensors.
auto consumer_root_ids = InputsOf::output(consumer_id->fusion(), consumer_id);
auto consumer_root_ids = InputsOf::output(consumer_id);

auto producer_root_vals = StmtSort::getStmtsBetween(
producer_id->fusion(),
{producer_tv->getMaybeRFactorDomain().begin(),
producer_tv->getMaybeRFactorDomain().end()},
{producer_id});
Expand Down
15 changes: 8 additions & 7 deletions csrc/device_lower/analysis/thread_predicate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,9 @@ void ThreadPredicateMap::updateBitSet(const Expr* expr) {

// Run through inputs and update bitsets
for (const auto* inp : expr->inputs()) {
if (!ir_utils::isTV(inp))
if (!ir_utils::isTV(inp)) {
continue;
}

auto tv_inp = inp->as<TensorView>();

Expand Down Expand Up @@ -365,7 +366,7 @@ class RedundantUseAnalysis : BackwardVisitor {
public:
RedundantUseAnalysis(Fusion* fusion, const ThreadPredicateMap& pred_map)
: fusion_(fusion), pred_map_(pred_map) {
traverseTo(fusion, fusion->terminatingMathVals());
traverseTo(fusion->terminatingMathVals());
}

//! Returns a bit map signifying the parallel dimensions
Expand Down Expand Up @@ -619,14 +620,14 @@ class ConcretizedBroadcastRedundantWriteRemover {

// Find all the root domains that are merged to the leaf domain.
// e.g. Root: [I1,B2,B3] -> Leaf: [I1*B2*B3]
std::vector<IterDomain*> getRootDomainsMergedToLeaf(IterDomain* ld) {
std::vector<IterDomain*> getRootDomainsMergedToLeaf(IterDomain* id) {
std::vector<IterDomain*> merged_root_domains;
std::vector<int> index_root_domain;
std::vector<IterDomain*> intermediate_domains = root_domain_;
auto all_exp = DependencyCheck::getAllExprsBetween(
{root_domain_.begin(), root_domain_.end()}, {ld});
for (auto expr : all_exp) {
if (auto merge = dynamic_cast<Merge*>(expr)) {
auto all_exp = StmtSort::getExprsBetween(
{root_domain_.begin(), root_domain_.end()}, {id});
for (Expr* expr : all_exp) {
if (auto* merge = dynamic_cast<Merge*>(expr)) {
auto outer_iter =
std::find(root_domain_.begin(), root_domain_.end(), merge->outer());
auto inner_iter =
Expand Down
3 changes: 1 addition & 2 deletions csrc/device_lower/pass/alias_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,7 @@ bool isSerialBroadcastResolution(
// traverse across view boundaries as we do in indexing. This
// should not result in false aliasing but may miss safe aliasing
// opportunities.
auto serial_loop_roots =
InputsOf::outputs(FusionGuard::getCurFusion(), serial_loop_concrete_ids);
auto serial_loop_roots = InputsOf::outputs(serial_loop_concrete_ids);

// Collect exact concrete id's in producer's root domain
std::unordered_set<IterDomain*> producer_exact_concrete_root_ids;
Expand Down
2 changes: 1 addition & 1 deletion csrc/device_lower/pass/allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ class AllocationInserter : public kir::ExprMutator {
[](IterDomain* dom) { return dom->as<Val>(); });

// Get all exprs involved in generating the allocation IDs
auto exprs = StmtSort::getExprsTo(tv->fusion(), start_vals);
auto exprs = StmtSort::getExprsTo(start_vals);

// Get the halo extent if found
auto getExtent = [this](IterDomain* id) {
Expand Down
10 changes: 5 additions & 5 deletions csrc/device_lower/pass/expr_sort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,14 +406,16 @@ std::string ExprGroup::toString() const {
os << " ca_ids {";
for (size_t i = 0; i < payload()->ca_domains.size(); i++) {
os << payload()->ca_domains[i];
if (i + 1 != payload()->ca_domains.size())
if (i + 1 != payload()->ca_domains.size()) {
os << ", ";
}
}
os << "} pa_ids {";
for (size_t i = 0; i < payload()->pa_domains.size(); i++) {
os << payload()->pa_domains[i];
if (i + 1 != payload()->pa_domains.size())
if (i + 1 != payload()->pa_domains.size()) {
os << ", ";
}
}
os << "}";
os << "\nExprs {\n";
Expand Down Expand Up @@ -1507,9 +1509,7 @@ void ExprSegmentationSorter::sort() {
// Not putting the exprs between allKnownVals() and fusion inputs here
// because they are computed using the expr evaluator.
auto all_exprs = StmtSort::getExprsBetween(
fusion_,
GpuLower::current()->allKnownVals(),
fusion_->getTerminatingOutputs());
GpuLower::current()->allKnownVals(), fusion_->getTerminatingOutputs());

// Figure out all the values used as inputs to the expressions we're sorting
// (to find terminating expressions). There could be branches of expressions
Expand Down
3 changes: 1 addition & 2 deletions csrc/device_lower/pass/warp_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ class EliminateDeadBroadcastAndAllocate {
// Also find any TVs used in index expressions.
// These expressions will likely not be in the Expr tree we are
// provided, so we need to traverse to find them.
auto all_index_roots =
InputsOf::outputs(FusionGuard::getCurFusion(), {ti->index()});
auto all_index_roots = InputsOf::outputs({ti->index()});
auto index_root_tis =
ir_utils::filterByType<kir::TensorIndex>(all_index_roots);
for (auto rootti : index_root_tis) {
Expand Down
2 changes: 1 addition & 1 deletion csrc/device_lower/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ std::vector<Expr*> replaceInputsInExpr(
std::vector<Expr*> getAllSwizzlesBetween(
std::vector<IterDomain*> from,
std::vector<IterDomain*> to) {
auto all_expr = DependencyCheck::getAllExprsBetween(
auto all_expr = StmtSort::getExprsBetween(
{from.begin(), from.end()}, {to.begin(), to.end()});

std::vector<Expr*> all_swizzles;
Expand Down
5 changes: 2 additions & 3 deletions csrc/device_lower/validation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ class VectorizeValidator : public OptInDispatch {
IterDomain* v_id,
TensorView* tv,
std::string name) {
auto replay_exprs = DependencyCheck::getAllExprsBetween(
auto replay_exprs = StmtSort::getExprsBetween(
{tv->getMaybeAllocationDomain().begin(),
tv->getMaybeAllocationDomain().end()},
{v_id});
Expand Down Expand Up @@ -836,7 +836,7 @@ void validatePartialSplit(Fusion* fusion) {

for (auto tv : ir_utils::allTvs(fusion)) {
auto exprs = StmtSort::getExprsTo(
tv->fusion(), {tv->getLeafDomain().begin(), tv->getLeafDomain().end()});
{tv->getLeafDomain().begin(), tv->getLeafDomain().end()});
for (auto split : ir_utils::filterByType<Split>(exprs)) {
// When the start and stop offsets are not zero, make sure the
// range defined by the split includes the required range to
Expand Down Expand Up @@ -1276,7 +1276,6 @@ void validateResize(Fusion* fusion) {
for (auto tv : ir_utils::filterByType<TensorView>(fusion_vals)) {
// Make sure resize is only used as part of rfactor transformations
auto rf_to_leaf_exprs = StmtSort::getExprsBetween(
fusion,
{tv->getMaybeRFactorDomain().begin(),
tv->getMaybeRFactorDomain().end()},
{tv->getLeafDomain().begin(), tv->getLeafDomain().end()});
Expand Down
5 changes: 2 additions & 3 deletions csrc/dynamic_transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor {
!fusion->isA<kir::Kernel>(),
"Invalid container. Kernel container not allowed.\n");

traverseTo(fusion, fusion->getTerminatingOutputs(), false, false);
traverseTo(fusion->getTerminatingOutputs(), false, false);

finalizeDynamicVals();

Expand Down Expand Up @@ -147,7 +147,7 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor {
//! Process vector of leaf dynamic values by finding inputs and recording the
//! result into info_
void finalizeDynamicVals() {
const auto inputs = InputsOf::outputs(info_.fusion(), leaf_dynamic_vals_);
const auto inputs = InputsOf::outputs(leaf_dynamic_vals_);
info_.root_dynamic_vals_.insert(inputs.begin(), inputs.end());

// initial_info_ provides a set of Vals that are used for concretization.
Expand Down Expand Up @@ -621,7 +621,6 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) {
// Note that it is assumed that theres's no further expression
// beyond the rfactor domain as asserted above
auto all_id_exprs = StmtSort::getExprsBetween(
tv->fusion(),
{tv->getRootDomain().begin(), tv->getRootDomain().end()},
{tv->getMaybeRFactorDomain().begin(),
tv->getMaybeRFactorDomain().end()});
Expand Down
22 changes: 7 additions & 15 deletions csrc/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ void FusionExecutor::compileFusion(
}
output_extents.emplace_back(extent);
}
auto dependencies = InputsOf::outputs(fusion, output_extents);
auto dependencies = InputsOf::outputs(output_extents);
if (std::any_of(dependencies.begin(), dependencies.end(), [](Val* val) {
return val->isFusionInput();
})) {
Expand Down Expand Up @@ -607,7 +607,6 @@ std::pair<std::vector<int64_t>, std::vector<int64_t>> inferShapeOfOutput(

class ForwardTraverseFromAllocToRFactor {
at::Tensor tensor_;
TensorView* tv_;
ExpressionEvaluator& ee_;
std::list<IterDomain*>& frontier_;

Expand Down Expand Up @@ -725,18 +724,15 @@ class ForwardTraverseFromAllocToRFactor {
public:
ForwardTraverseFromAllocToRFactor(
at::Tensor tensor,
TensorView* tv,
ExpressionEvaluator& ee,
std::list<IterDomain*>& frontier)
: tensor_(std::move(tensor)), tv_(tv), ee_(ee), frontier_(frontier) {}
: tensor_(std::move(tensor)), ee_(ee), frontier_(frontier) {}

at::Tensor run(
const std::vector<IterDomain*>& rfactor,
const std::vector<IterDomain*>& alloc) {
auto forward_exprs = StmtSort::getExprsBetween(
tv_->fusion(),
{alloc.begin(), alloc.end()},
{rfactor.begin(), rfactor.end()});
{alloc.begin(), alloc.end()}, {rfactor.begin(), rfactor.end()});
for (auto expr : forward_exprs) {
handle(expr);
}
Expand All @@ -748,7 +744,6 @@ class ForwardTraverseFromAllocToRFactor {
// transformations.
class BackwardTraverseFromAllocToRFactor {
at::Tensor tensor_;
TensorView* tv_;
ExpressionEvaluator& ee_;
std::list<IterDomain*>& frontier_;

Expand Down Expand Up @@ -853,18 +848,15 @@ class BackwardTraverseFromAllocToRFactor {
public:
BackwardTraverseFromAllocToRFactor(
at::Tensor tensor,
TensorView* tv,
ExpressionEvaluator& ee,
std::list<IterDomain*>& frontier)
: tensor_(std::move(tensor)), tv_(tv), ee_(ee), frontier_(frontier) {}
: tensor_(std::move(tensor)), ee_(ee), frontier_(frontier) {}

at::Tensor run(
const std::vector<IterDomain*>& rfactor,
const std::vector<IterDomain*>& alloc) {
auto backward_exprs = StmtSort::getExprsBetween(
tv_->fusion(),
{rfactor.begin(), rfactor.end()},
{alloc.begin(), alloc.end()});
{rfactor.begin(), rfactor.end()}, {alloc.begin(), alloc.end()});
std::reverse(backward_exprs.begin(), backward_exprs.end());
for (auto expr : backward_exprs) {
handle(expr);
Expand Down Expand Up @@ -894,9 +886,9 @@ at::Tensor transformOutputFromAllocationToRFactor(
// forward and a backward traverse.
std::list<IterDomain*> frontier(alloc.begin(), alloc.end());
NVF_ERROR(tensor.dim() == (int64_t)frontier.size());
tensor = ForwardTraverseFromAllocToRFactor(tensor, tv, ee, frontier)
tensor = ForwardTraverseFromAllocToRFactor(tensor, ee, frontier)
.run(rfactor, alloc);
tensor = BackwardTraverseFromAllocToRFactor(tensor, tv, ee, frontier)
tensor = BackwardTraverseFromAllocToRFactor(tensor, ee, frontier)
.run(rfactor, alloc);
NVF_ERROR(frontier.size() == rfactor.size());
// Now that all affine transformations are handled, and frontiers should
Expand Down
6 changes: 3 additions & 3 deletions csrc/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ bool Fusion::isNoOp() {
}

std::vector<Val*> Fusion::inputsOf(Val* val) {
return InputsOf::output(this, val);
return InputsOf::output(val);
}

void Fusion::validateInputs() {
Expand Down Expand Up @@ -528,7 +528,7 @@ void Fusion::printMath(bool from_outputs_only) {
leaf_vals.push_back(val);
}
}
exprs_for_print = StmtSort::getExprsTo(this, leaf_vals);
exprs_for_print = StmtSort::getExprsTo(leaf_vals);
}

debug() << "\n%kernel_math {\n";
Expand Down Expand Up @@ -649,7 +649,7 @@ std::vector<Val*> Fusion::usedMathVals() {
// there can be vals that are created inside a fusion without using
// anything from inputs. See, for example, tv0 in the
// FusionOuterSplit test.
const auto inputs = InputsOf::outputs(this, outputs());
const auto inputs = InputsOf::outputs(outputs());
auto used_math_vals = DependencyCheck::getAllValsBetween(
{inputs.begin(), inputs.end()}, outputs());
// When an expre has multiple outputs and only some of them are
Expand Down
11 changes: 3 additions & 8 deletions csrc/fusion_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2207,12 +2207,7 @@ std::optional<std::unique_ptr<SchedulerEntry>> SegmentedGroup::
}

void SegmentedGroup::resetExprList() {
auto input_group_vec = getAllInputs(this);
std::unordered_set<Val*> input_group_set(
input_group_vec.begin(), input_group_vec.end());
auto expr_set =
DependencyCheck::getAllExprsBetween(input_group_set, getAllOutputs(this));
exprs_ = std::vector<Expr*>(expr_set.begin(), expr_set.end());
exprs_ = StmtSort::getExprsBetween(getAllInputs(this), getAllOutputs(this));
}

// Custom merge node passes:
Expand Down Expand Up @@ -3703,7 +3698,7 @@ void SegmentCandidateFinder::resolveInputsInGroup(SegmentedGroup* group) {
group->input_vals = IterVisitor::getInputsTo(group->inputs());

// Grab all expressions needed to produce to_visit
auto input_exprs = StmtSort::getExprsTo(completeFusion(), to_visit);
auto input_exprs = StmtSort::getExprsTo(to_visit);

// Insert those expressions at the beginning of the group
group->exprs_.insert(
Expand Down Expand Up @@ -3963,7 +3958,7 @@ class ForceHalfAnnotation : public IterVisitor {
val->getDataType().value() == DataType::BFloat16);
});

annotation.traverseTo(fusion, fp16_outputs);
annotation.traverseTo(fp16_outputs);
return annotation.force_fp16_tv_set_;
}

Expand Down
Loading

0 comments on commit f5e4326

Please sign in to comment.