Skip to content

Commit

Permalink
Rename RootPredicateInfo to just PredicateInfo (NVIDIA#2604)
Browse files Browse the repository at this point in the history
Used to predicate root domains but not true anymore. No logic change.
  • Loading branch information
naoyam authored Jul 16, 2024
1 parent 8342c6e commit 77e0317
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 16 deletions.
12 changes: 6 additions & 6 deletions csrc/index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2422,7 +2422,7 @@ std::unordered_map<IterDomain*, Val*> updateInitialLoopIndexMap(
} // namespace

// Returns predicates and the concrete (by loop map) root domains they cover
std::vector<RootPredicateInfo> Index::getReferenceRootPredicates(
std::vector<PredicateInfo> Index::getReferenceRootPredicates(
TensorView* consumer_tv,
const std::vector<ForLoop*>& loops,
const std::unordered_set<ForLoop*>& rotated_loops,
Expand Down Expand Up @@ -2474,7 +2474,7 @@ std::vector<RootPredicateInfo> Index::getReferenceRootPredicates(
non_divisible_splits.begin(),
non_divisible_splits.end());

std::vector<RootPredicateInfo> pred_info_vec;
std::vector<PredicateInfo> pred_info_vec;

for (const auto& contig_id_entry : contig_id_infos) {
auto contig_id = contig_id_entry.id;
Expand All @@ -2499,7 +2499,7 @@ std::vector<RootPredicateInfo> Index::getReferenceRootPredicates(
continue;
}

RootPredicateInfo info;
PredicateInfo info;

// The final predicates will look like:
// (index + start_offset) >= 0 && (index + stop_offset) < extent.
Expand Down Expand Up @@ -2552,16 +2552,16 @@ std::vector<RootPredicateInfo> Index::getReferenceRootPredicates(
info.stop_predicate_ = stop_pred;

for (auto consumer_id : contig_id_entry.covered_ids) {
info.root_ids_.insert(consumer_id);
info.predicated_domains_.insert(consumer_id);
}
pred_info_vec.emplace_back(info);
}

return pred_info_vec;
}

RootPredicateInfo RootPredicateInfo::getFalseInfo() {
RootPredicateInfo info;
PredicateInfo PredicateInfo::getFalseInfo() {
PredicateInfo info;
info.start_predicate_ = GpuLower::current()->kernel()->falseVal();
info.stop_predicate_ = GpuLower::current()->kernel()->falseVal();

Expand Down
18 changes: 10 additions & 8 deletions csrc/index_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,10 @@ class IndexSwizzle : public IndexCompute {
std::unordered_set<IterDomain*> swizzled_ids_;
};

//! Predicate information of a root or contiguous merged domain
class RootPredicateInfo {
//! Information about a predicate. By default, it corresponds to a
//! single logical domain but may cover multiple logial domains due to
//! contigous indexing.
class PredicateInfo {
friend class Index;

public:
Expand All @@ -382,13 +384,13 @@ class RootPredicateInfo {
return stop_offset_;
}

const auto& rootIds() const {
return root_ids_;
const auto& predicatedDomains() const {
return predicated_domains_;
}

//! Return a false RootPredicateInfo, i.e., both start and stop
//! predicates are false.
static RootPredicateInfo getFalseInfo();
static PredicateInfo getFalseInfo();

private:
// prdicate for lower end
Expand All @@ -399,8 +401,8 @@ class RootPredicateInfo {
Val* start_offset_ = nullptr;
// Offset of the stop predicate
Val* stop_offset_ = nullptr;
// Track which roots have been handled by the generated predicates
std::unordered_set<IterDomain*> root_ids_;
// Track which domains are covered by the generated predicates
std::unordered_set<IterDomain*> predicated_domains_;
};

// Simple interface for IndexCompute
Expand Down Expand Up @@ -557,7 +559,7 @@ class Index {
//! predicate, this is not a bool value as if we have an unswitch loop
//! with a vectorized loop inside, we only want to base the "unswitch"
//! like predicate on the vectorized loop.
static std::vector<RootPredicateInfo> getReferenceRootPredicates(
static std::vector<PredicateInfo> getReferenceRootPredicates(
TensorView* consumer_tv,
const std::vector<ForLoop*>& loops,
const std::unordered_set<ForLoop*>& rotated_loops,
Expand Down
4 changes: 2 additions & 2 deletions csrc/predicate_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ Val* PredicateCompute::getInlinePredicate(
bool non_zero_start_found = false;
for (const auto& pred_info : pred_info_vec) {
if (pred_type == PredicateType::ReductionWrite) {
const auto& consumer_ids = pred_info.rootIds();
const auto& consumer_ids = pred_info.predicatedDomains();
bool pred_for_reduction_axis = false;
for (auto consumer_id : consumer_ids) {
if (consumer_id->isReduction()) {
Expand Down Expand Up @@ -489,7 +489,7 @@ void UnswitchPredicate::predicateOn(Expr* tv_expr) {
NVF_ERROR(pred_info.startPredicate() != nullptr);
NVF_ERROR(pred_info.stopPredicate() != nullptr);

const auto& root_ids = pred_info.rootIds();
const auto& root_ids = pred_info.predicatedDomains();

bool add_pred = false;

Expand Down

0 comments on commit 77e0317

Please sign in to comment.