Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge up-cast, ops, down-cast sequences as minimal units of segments #3699

Merged
merged 8 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 207 additions & 0 deletions csrc/fusion_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3586,6 +3586,210 @@ bool CombineReductions::shouldRun(
return false;
}

// This preprocessing attempts to find groups of exprs consist of an
// up-cast, followed by some ops and ended by a downcast. It is highly
// likely that such sequences of ops should never be segmented
// out. This is particularly commonly seen in fusions given by Thunder
// as it inserts fine-grained downcasting and upcasting ops. Without
// this preprocessing, a fusion may be segmented right after an
// up-cast op, for example, and in fact it happened quitely frequently
naoyam marked this conversation as resolved.
Show resolved Hide resolved
// in some of the RoPE cases. This preprocessing does not completely
// avoid such segmentation boundaries, but it should become less
// likely. See also https://github.com/NVIDIA/Fuser/pull/3699.
class MergeUpAndDownCast {
public:
static void run(SegmentCandidateFinder* segment_candidate_finder) {
MergeUpAndDownCast group_cast(segment_candidate_finder);
}

private:
MergeUpAndDownCast(SegmentCandidateFinder* segment_candidate_finder)
: segment_candidate_finder_(segment_candidate_finder) {
merge();
}

void merge() {
bool merged = true;
while (merged) {
merged = false;
std::unordered_set<SegmentedGroup*> considered_groups;

for (SegmentedGroup* group : segment_candidate_finder_->groups()) {
// If the group is an up-cast group, see if there's a
// candidate group starting with the group.
if (!isUpCast(group) || considered_groups.count(group)) {
continue;
}

auto groups_to_merge = getCandidateCastGroup(group);
if (groups_to_merge.size() < 2) {
continue;
}

for (auto group : groups_to_merge) {
considered_groups.insert(group);
}

// Try merging the detected group
if (mergeCastGroup(groups_to_merge)) {
merged = true;
break;
}
}
}
}

// Try to detect a set of groups that could be merged as a cast
// group. The analysis starts with an initial group that solely
// consists of an up-cast expression. From the initial group, it
// traverses its neighbor groups. If the group is an down-cast group,
// it only traverses through the consumer edges. If it's an up-cast
// group, it only traverses through the producer edges.
//
// Additionaly, this traversal has several safeguards to keep the
// DAG property intact:
//
// - For a given group, it does not visit its consumers if it has
// multiple consumers, even if the group is not a down-cast
// group.
// - Similarly, it does not visit a producer if the producer has
// multiple cosumers.
//
// The basic form of this set of groups should look like an up-cast
// group, followed by some op groups and ended by a down-cast
// group. However, it is not always the case because of the above
// safeguards. For example, the following groups would be detected
// as a cast group.
//
// t1 = bf16ToFp32(t0)
// t2 = neg(t1)
// t3 = sin(t2)
// t4 = cos(t2)
// t5 = fp32ToBf16(t3)
// t6 = fp32ToBf16(t4)
//
// In this case, t1 and t2 would be detected as a candidate group,
// but t3 and t4 would not be included. While we could certainly
// extend the analysis, it would need to make sure the DAG property
// is not violated.
std::vector<SegmentedGroup*> getCandidateCastGroup(
SegmentedGroup* initial_group) {
std::vector<SegmentedGroup*> groups_to_merge;
std::unordered_set<SegmentedGroup*> groups_to_merge_set;

std::deque<SegmentedGroup*> to_visit;
to_visit.push_back(initial_group);

while (!to_visit.empty()) {
SegmentedGroup* group = to_visit.front();
to_visit.pop_front();

if (groups_to_merge_set.count(group)) {
continue;
}

// For simplicity, all groups are assumed to be the initial
// single-expr groups. Skip if not

groups_to_merge.push_back(group);
groups_to_merge_set.insert(group);

// Consumer traversal. Stop if this group is a down cast
// group. Also stop if there are multiple consumer edges to
// simplify keeping the DAG property.
if (!isDownCast(group) && group->consumer_edges.size() == 1) {
auto consumer_edge = group->consumer_edges.at(0);
SegmentedGroup* consumer_group = consumer_edge->to;
if (!groups_to_merge_set.count(consumer_group)) {
to_visit.push_back(consumer_group);
}
}

if (!isUpCast(group)) {
for (const auto producer_edge : group->producer_edges) {
SegmentedGroup* producer_group = producer_edge->from;
// Don't add producers that have more than multiple consumers
if (producer_group->consumer_edges.size() > 1) {
continue;
}
if (!groups_to_merge_set.count(producer_group)) {
to_visit.push_back(producer_group);
}
}
}
}

return groups_to_merge;
}

// Try merging a candidate cast group. Return a merged group if merged.
SegmentedGroup* mergeCastGroup(const std::vector<SegmentedGroup*>& groups) {
naoyam marked this conversation as resolved.
Show resolved Hide resolved
auto sched_type = tryMerge(
segment_candidate_finder_->segmented_fusion_.get(),
segment_candidate_finder_->runtimeInfo(),
groups);

if (sched_type == SchedulerType::None) {
return nullptr;
}

auto joined_group = segment_candidate_finder_->mergeAllGivenGroups(groups);

return joined_group;
}

bool isUpCast(SegmentedGroup* group) const {
if (auto precision_bits = getProducerConsumerPrecision(group);
precision_bits.has_value()) {
return precision_bits->first < precision_bits->second;
} else {
return false;
}
}

bool isDownCast(SegmentedGroup* group) const {
if (auto precision_bits = getProducerConsumerPrecision(group);
precision_bits.has_value()) {
return precision_bits->first > precision_bits->second;
} else {
return false;
}
}

std::optional<std::pair<int64_t, int64_t>> getProducerConsumerPrecision(
SegmentedGroup* group) const {
if (group->exprs().size() != 1) {
return std::nullopt;
}

auto uop = dynamic_cast<UnaryOp*>(group->exprs().front());
if (uop == nullptr || uop->getUnaryOpType() != UnaryOpType::Cast) {
return std::nullopt;
}

auto inp_tv = ir_utils::getTvInput(uop);
auto out_tv = ir_utils::getTvOutput(uop);
if (inp_tv == nullptr || out_tv == nullptr) {
return std::nullopt;
}

auto inp_dtype = inp_tv->dtype().type;
auto out_dtype = out_tv->dtype().type;
auto inp_prim_type = std::get_if<PrimDataType>(&inp_dtype);
auto out_prim_type = std::get_if<PrimDataType>(&out_dtype);

if (inp_prim_type == nullptr || out_prim_type == nullptr) {
return std::nullopt;
}

return std::make_pair(
primDataTypeSize(*inp_prim_type), primDataTypeSize(*out_prim_type));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you needed a bail out here when either inp_prim_type or out_prim_type are PrimDataType::Index, which primDataTypeSize would throw an exception. I think welford would have that.

I only know that when I hit it in the presegmentation passes. :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel we could use a test for that. We don't have to do it in this PR though.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Updated.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test

}

private:
SegmentCandidateFinder* segment_candidate_finder_ = nullptr;
};

namespace {

//! Returns true if group1 and group2 are an immediate producer-consumer pair.
Expand Down Expand Up @@ -3945,6 +4149,9 @@ void SegmentCandidateFinder::findSegments() {
removeScalarEdges();

// Run pre-merge heuristics
MergeUpAndDownCast::run(this);
segmented_fusion_->validateIfDebug(true);

if (options_.run_combine_reductions && CombineReductions::shouldRun(this)) {
CombineReductions::run(this);
}
Expand Down
2 changes: 2 additions & 0 deletions csrc/fusion_segmenter.h
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ class GroupDependencyAnalysis;

// Manual node merging passes
class CombineReductions;
class MergeUpAndDownCast;

//! Options to configure/debug candidate finder
struct SegmentCandidateFinderOptions {
Expand Down Expand Up @@ -691,6 +692,7 @@ class SegmentCandidateFinder {
//! eventually should have a dedicated interface
//! instead of keeping adding friends
friend class CombineReductions;
friend class MergeUpAndDownCast;

//! options to configure and debug the segment process
SegmentCandidateFinderOptions options_;
Expand Down
Loading