Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge up-cast, ops, down-cast sequences as minimal units of segments #3699

Merged
merged 8 commits into from
Jan 16, 2025

Conversation

naoyam
Copy link
Collaborator

@naoyam naoyam commented Jan 13, 2025

This PR attempts to improve fusion segmentation by using a common pattern we see in fusions given by Thunder. Specifically, it is quite common to see a pattern like below:

Background

Inputs: T0_bf16, T1_bf16

T1_fp32 = bfloat16ToFloat(T0_bf16);
T2_fp32 = bfloat16ToFloat(T1_bf16);
T3_fp32 = T1_fp32 + T2_fp32
T4_bf16 = floatToBfloat16(T3_fp32);

Output: T4_bf16

Here, the half-precision inputs are upcast to float for the arithmetic op and then the result is cast back to half. This pattern of upcasting to float, doing some arithmetic ops, and then casting back to half is quite common. The above example is just a trivial fusion, but larger more complex fusions can have many sequences of ops like this.

Problem

In some of the RoPE cases, I have observed inefficient segmentation results where a large fusion is segmented after an upcast op like the one shown above. Suppose the above pattern appears in a large fusion and there's no scheduler that can accept the whole fusion without segmentation. The fusion segmenter may segment the fusion like blow:

// Segment 1
...
T1_fp32 = bfloat16ToFloat(T0_bf16);
T2_fp32 = bfloat16ToFloat(T1_bf16);

// Segment 2
T3_fp32 = T1_fp32 + T2_fp32
T4_bf16 = floatToBfloat16(T3_fp32);
...

Ideally, since this is just a trivial arithmetic op, it should be considered an atomic op and should never be segmented out, i.e., it should be either:

// Segment 1
...
T1_fp32 = bfloat16ToFloat(T0_bf16);
T2_fp32 = bfloat16ToFloat(T1_bf16);
T3_fp32 = T1_fp32 + T2_fp32
T4_bf16 = floatToBfloat16(T3_fp32);
...

// Segment 2
...

or

// Segment 1
...

// Segment 2
T1_fp32 = bfloat16ToFloat(T0_bf16);
T2_fp32 = bfloat16ToFloat(T1_bf16);
T3_fp32 = T1_fp32 + T2_fp32
T4_bf16 = floatToBfloat16(T3_fp32);
...

This is not as good as the first option as there are two intermediate tensors between the segments, but generally, it is likely we should be able to get better segmentations by considering these patterns as non-decomposable atomic operations.

Proposal

This PR tries to exploit the above observation by adding another pre-merge step in the fusion segmenter. Specifically, MergeUpAndDownCast::run analyzes a given fusion and tries to detect several simple patterns like the above and see if they can be merged to a segment. The result of this initial merging is then used by the existing merge process.

The pattern that the current analysis consider is basically an up-cast op followed by some other ops until a down-cast op. When merging these ops, we need to be careful so that the DAG property should not be violated. In order to guarantee the property, the current analysis stops when an expr has a consumer tensor that is used by multiple ops. So, for example, given a fusion like:

Inputs: T0_bf16

T1_fp32 = bfloat16ToFloat(T0_bf16);
T2_fp32 = sin(T1_fp32);
T3_bf16 = floatToBfloat16(T2_fp32);
T4_fp32 = cos(T1_fp32);
T5_bf16 = floatToBfloat16(T4_fp32);

Output: T3_bf16, T5_bf16

For this fusion, we are not trying to do anything as T1_fp32 is used by both of the sin and cos ops. This is a limitation but not urgent at this moment for RoPE, so I decided to make it as simple as possible.

When an expr has multiple producers, we traverse all of them and as long as those producers are only used by the expr, we keep looking for an upcast operation. If a producer has multiple uses, guaranteeing the DAG property becomes non-trivial, which I'm not addressing in this PR.

Results

Here's per-segment times of the Mistral forward RoPE measured using ncu running on an H100:

    Duration                         us         9.41
    Duration                         us        15.49
    Duration                         us         7.55
    Duration                         us         3.17
    Duration                         us         3.33
    Duration                         us         5.57
    Duration                         us        62.08
    Duration                         us        25.79

Here's with this PR:

    Duration                         us        15.33
    Duration                         us         2.85
    Duration                         us         3.04
    Duration                         us         5.41
    Duration                         us         5.34
    Duration                         us        25.79
    Duration                         us        14.30

The total times are roughly 130 us vs 80 us.

@naoyam
Copy link
Collaborator Author

naoyam commented Jan 13, 2025

!test

@naoyam
Copy link
Collaborator Author

naoyam commented Jan 13, 2025

!test --diff

@jjsjann123
Copy link
Collaborator

!test --pybench-full

@naoyam
Copy link
Collaborator Author

naoyam commented Jan 16, 2025

!test

Copy link

github-actions bot commented Jan 16, 2025

PR Reviewer Guide 🔍

(Review updated until commit 3b96de7)

Here are some key observations to aid the review process:

⏱️ Estimated effort to review: 4 🔵🔵🔵🔵⚪
🧪 PR contains tests
⚡ Recommended focus areas for review

Logic Change

The new MergeUpAndDownCast class and its usage in SegmentCandidateFinder may introduce changes to the fusion segmentation logic, potentially affecting the performance or correctness of the system.

// This preprocessing attempts to find groups of exprs consist of an
// up-cast, followed by some ops and ended by a downcast. It is highly
// likely that such sequences of ops should never be segmented
// out. This is particularly commonly seen in fusions given by Thunder
// as it inserts fine-grained downcasting and upcasting ops. Without
// this preprocessing, a fusion may be segmented right after an
// up-cast op, for example, and in fact it happened quite frequently
// in some of the RoPE cases. This preprocessing does not completely
// avoid such segmentation boundaries, but it should become less
// likely. See also https://github.com/NVIDIA/Fuser/pull/3699.
class MergeUpAndDownCast {
 public:
  static void run(SegmentCandidateFinder* segment_candidate_finder) {
    MergeUpAndDownCast group_cast(segment_candidate_finder);
  }

 private:
  MergeUpAndDownCast(SegmentCandidateFinder* segment_candidate_finder)
      : segment_candidate_finder_(segment_candidate_finder) {
    merge();
  }

  void merge() {
    bool merged = true;
    while (merged) {
      merged = false;
      std::unordered_set<SegmentedGroup*> considered_groups;

      for (SegmentedGroup* group : segment_candidate_finder_->groups()) {
        // If the group is an up-cast group, see if there's a
        // candidate group starting with the group.
        if (!isUpCast(group) || considered_groups.count(group)) {
          continue;
        }

        auto groups_to_merge = getCandidateCastGroup(group);
        if (groups_to_merge.size() < 2) {
          continue;
        }

        for (auto group : groups_to_merge) {
          considered_groups.insert(group);
        }

        // Try merging the detected group
        if (mergeCastGroup(groups_to_merge)) {
          merged = true;
          break;
        }
      }
    }
  }

  // Try to detect a set of groups that could be merged as a cast
  // group. The analysis starts with an initial group that solely
  // consists of an up-cast expression. From the initial group, it
  // traverses its neighbor groups. If the group is an down-cast group,
  // it only traverses through the consumer edges. If it's an up-cast
  // group, it only traverses through the producer edges.
  //
  // Additionaly, this traversal has several safeguards to keep the
  // DAG property intact:
  //
  // - For a given group, it does not visit its consumers if it has
  //   multiple consumers, even if the group is not a down-cast
  //   group.
  // - Similarly, it does not visit a producer if the producer has
  //   multiple cosumers.
  //
  // The basic form of this set of groups should look like an up-cast
  // group, followed by some op groups and ended by a down-cast
  // group. However, it is not always the case because of the above
  // safeguards. For example, the following groups would be detected
  // as a cast group.
  //
  // t1 = bf16ToFp32(t0)
  // t2 = neg(t1)
  // t3 = sin(t2)
  // t4 = cos(t2)
  // t5 = fp32ToBf16(t3)
  // t6 = fp32ToBf16(t4)
  //
  // In this case, t1 and t2 would be detected as a candidate group,
  // but t3 and t4 would not be included. While we could certainly
  // extend the analysis, it would need to make sure the DAG property
  // is not violated.
  std::vector<SegmentedGroup*> getCandidateCastGroup(
      SegmentedGroup* initial_group) {
    std::vector<SegmentedGroup*> groups_to_merge;
    std::unordered_set<SegmentedGroup*> groups_to_merge_set;

    std::deque<SegmentedGroup*> to_visit;
    to_visit.push_back(initial_group);

    while (!to_visit.empty()) {
      SegmentedGroup* group = to_visit.front();
      to_visit.pop_front();

      if (groups_to_merge_set.count(group)) {
        continue;
      }

      // For simplicity, all groups are assumed to be the initial
      // single-expr groups. Skip if not

      groups_to_merge.push_back(group);
      groups_to_merge_set.insert(group);

      // Consumer traversal. Stop if this group is a down cast
      // group. Also stop if there are multiple consumer edges to
      // simplify keeping the DAG property.
      if (!isDownCast(group) && group->consumer_edges.size() == 1) {
        auto consumer_edge = group->consumer_edges.at(0);
        SegmentedGroup* consumer_group = consumer_edge->to;
        if (!groups_to_merge_set.count(consumer_group)) {
          to_visit.push_back(consumer_group);
        }
      }

      if (!isUpCast(group)) {
        for (const auto producer_edge : group->producer_edges) {
          SegmentedGroup* producer_group = producer_edge->from;
          // Don't add producers that have more than multiple consumers
          if (producer_group->consumer_edges.size() > 1) {
            continue;
          }
          if (!groups_to_merge_set.count(producer_group)) {
            to_visit.push_back(producer_group);
          }
        }
      }
    }

    return groups_to_merge;
  }

  // Try merging a candidate cast group. Return true if merged.
  bool mergeCastGroup(const std::vector<SegmentedGroup*>& groups) {
    auto sched_type = tryMerge(
        segment_candidate_finder_->segmented_fusion_.get(),
        segment_candidate_finder_->runtimeInfo(),
        groups);

    if (sched_type == SchedulerType::None) {
      return false;
    }

    segment_candidate_finder_->mergeAllGivenGroups(groups);

    return true;
  }

  bool isUpCast(SegmentedGroup* group) const {
    if (auto precision_bits = getProducerConsumerPrecision(group);
        precision_bits.has_value()) {
      return precision_bits->first < precision_bits->second;
    } else {
      return false;
    }
  }

  bool isDownCast(SegmentedGroup* group) const {
    if (auto precision_bits = getProducerConsumerPrecision(group);
        precision_bits.has_value()) {
      return precision_bits->first > precision_bits->second;
    } else {
      return false;
    }
  }

  std::optional<std::pair<int64_t, int64_t>> getProducerConsumerPrecision(
      SegmentedGroup* group) const {
    if (group->exprs().size() != 1) {
      return std::nullopt;
    }

    auto uop = dynamic_cast<UnaryOp*>(group->exprs().front());
    if (uop == nullptr || uop->getUnaryOpType() != UnaryOpType::Cast) {
      return std::nullopt;
    }

    return ir_utils::getPrecisionOfProducerConsumerTensors(uop);
  }

 private:
  SegmentCandidateFinder* segment_candidate_finder_ = nullptr;
};

namespace {
New Function

The new getPrecisionOfProducerConsumerTensors function may have implications on the precision and correctness of the system, particularly in handling cast operations.

std::optional<std::pair<int64_t, int64_t>> getPrecisionOfProducerConsumerTensors(
    UnaryOp* uop) {
  NVF_CHECK(
      uop != nullptr && uop->getUnaryOpType() == UnaryOpType::Cast,
      "Invalid expr: ",
      uop->toString());

  auto inp_tv = ir_utils::getTvInput(uop);
  auto out_tv = ir_utils::getTvOutput(uop);
  if (inp_tv == nullptr || out_tv == nullptr) {
    return std::nullopt;
  }

  auto inp_dtype = inp_tv->dtype().type;
  auto out_dtype = out_tv->dtype().type;
  auto inp_prim_type = std::get_if<PrimDataType>(&inp_dtype);
  auto out_prim_type = std::get_if<PrimDataType>(&out_dtype);

  if (inp_prim_type == nullptr || out_prim_type == nullptr ||
      *inp_prim_type == PrimDataType::Index ||
      *out_prim_type == PrimDataType::Index) {
    return std::nullopt;
  }

  return std::make_pair(
      primDataTypeSize(*inp_prim_type), primDataTypeSize(*out_prim_type));
}

@naoyam naoyam marked this pull request as ready for review January 16, 2025 18:29
@naoyam
Copy link
Collaborator Author

naoyam commented Jan 16, 2025

@jjsjann123 confirmed there's no performance regression with the existing benchmarks. There's some small number of cases that are indeed improved by this PR.

@jjsjann123 Please review the PR. I just updated with some minor cleanups. No change with the logic. I thought what kinds of tests we could have, but I don't have a good idea. As long as all the existing tests are working functionally with no perf regression, it seems it's good enough to me.

@naoyam naoyam changed the title [WIP] Merge up-cast, ops, down-cast sequences as minimal units of segments Merge up-cast, ops, down-cast sequences as minimal units of segments Jan 16, 2025
@jjsjann123
Copy link
Collaborator

@xwang233 regarding the review thing.

The implementation change in SegmentCandidateFinder::findSegments is mistakenly identified as Function Signature Change.

I'm not sure if this is just a category given by the llm, so don't know if there's any actionable item on this one.

@xwang233
Copy link
Collaborator

@xwang233 regarding the review thing.

The implementation change in SegmentCandidateFinder::findSegments is mistakenly identified as Function Signature Change.

I'm not sure if this is just a category given by the llm, so don't know if there's any actionable item on this one.

It's not a preset category. Perhaps you can help rewrite and find the best prompt given to llm here

PR_REVIEWER__EXTRA_INSTRUCTIONS: |
Focus on potential logic change, especially on changes to function signatures.

csrc/fusion_segmenter.cpp Outdated Show resolved Hide resolved
}

return std::make_pair(
primDataTypeSize(*inp_prim_type), primDataTypeSize(*out_prim_type));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you needed a bail out here when either inp_prim_type or out_prim_type are PrimDataType::Index, which primDataTypeSize would throw an exception. I think welford would have that.

I only know that when I hit it in the presegmentation passes. :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel we could use a test for that. We don't have to do it in this PR though.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Updated.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test

csrc/fusion_segmenter.cpp Outdated Show resolved Hide resolved
@jjsjann123
Copy link
Collaborator

@xwang233 regarding the review thing.
The implementation change in SegmentCandidateFinder::findSegments is mistakenly identified as Function Signature Change.
I'm not sure if this is just a category given by the llm, so don't know if there's any actionable item on this one.

It's not a preset category. Perhaps you can help rewrite and find the best prompt given to llm here

PR_REVIEWER__EXTRA_INSTRUCTIONS: |
Focus on potential logic change, especially on changes to function signatures.

That looks like a pretty reasonable prompt. So it is just the model not properly identify what's a signature change in the diff. 😠

Copy link
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚢

csrc/fusion_segmenter.cpp Outdated Show resolved Hide resolved
csrc/ir/utils.cpp Outdated Show resolved Hide resolved
csrc/ir/utils.h Outdated Show resolved Hide resolved
tests/cpp/test_gpu3.cpp Outdated Show resolved Hide resolved
tests/cpp/test_gpu3.cpp Outdated Show resolved Hide resolved
tests/cpp/test_gpu3.cpp Outdated Show resolved Hide resolved
@naoyam
Copy link
Collaborator Author

naoyam commented Jan 16, 2025

!test

@naoyam naoyam merged commit 1274da4 into main Jan 16, 2025
40 of 41 checks passed
@naoyam naoyam deleted the merge_up_and_down_cast branch January 16, 2025 23:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants