diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index c98543a179a..1c49713eaab 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -3586,6 +3586,194 @@ bool CombineReductions::shouldRun( return false; } +// This preprocessing attempts to find groups of exprs consist of an +// up-cast, followed by some ops and ended by a downcast. It is highly +// likely that such sequences of ops should never be segmented +// out. This is particularly commonly seen in fusions given by Thunder +// as it inserts fine-grained downcasting and upcasting ops. Without +// this preprocessing, a fusion may be segmented right after an +// up-cast op, for example, and in fact it happened quite frequently +// in some of the RoPE cases. This preprocessing does not completely +// avoid such segmentation boundaries, but it should become less +// likely. See also https://github.com/NVIDIA/Fuser/pull/3699. +class MergeUpAndDownCast { + public: + static void run(SegmentCandidateFinder* segment_candidate_finder) { + MergeUpAndDownCast group_cast(segment_candidate_finder); + } + + private: + MergeUpAndDownCast(SegmentCandidateFinder* segment_candidate_finder) + : segment_candidate_finder_(segment_candidate_finder) { + merge(); + } + + void merge() { + bool merged = true; + while (merged) { + merged = false; + std::unordered_set considered_groups; + + for (SegmentedGroup* group : segment_candidate_finder_->groups()) { + // If the group is an up-cast group, see if there's a + // candidate group starting with the group. + if (!isUpCast(group) || considered_groups.count(group)) { + continue; + } + + auto groups_to_merge = getCandidateCastGroup(group); + if (groups_to_merge.size() < 2) { + continue; + } + + for (auto group : groups_to_merge) { + considered_groups.insert(group); + } + + // Try merging the detected group + if (mergeCastGroup(groups_to_merge)) { + merged = true; + break; + } + } + } + } + + // Try to detect a set of groups that could be merged as a cast + // group. The analysis starts with an initial group that solely + // consists of an up-cast expression. From the initial group, it + // traverses its neighbor groups. If the group is an down-cast group, + // it only traverses through the consumer edges. If it's an up-cast + // group, it only traverses through the producer edges. + // + // Additionaly, this traversal has several safeguards to keep the + // DAG property intact: + // + // - For a given group, it does not visit its consumers if it has + // multiple consumers, even if the group is not a down-cast + // group. + // - Similarly, it does not visit a producer if the producer has + // multiple cosumers. + // + // The basic form of this set of groups should look like an up-cast + // group, followed by some op groups and ended by a down-cast + // group. However, it is not always the case because of the above + // safeguards. For example, the following groups would be detected + // as a cast group. + // + // t1 = bf16ToFp32(t0) + // t2 = neg(t1) + // t3 = sin(t2) + // t4 = cos(t2) + // t5 = fp32ToBf16(t3) + // t6 = fp32ToBf16(t4) + // + // In this case, t1 and t2 would be detected as a candidate group, + // but t3 and t4 would not be included. While we could certainly + // extend the analysis, it would need to make sure the DAG property + // is not violated. + std::vector getCandidateCastGroup( + SegmentedGroup* initial_group) { + std::vector groups_to_merge; + std::unordered_set groups_to_merge_set; + + std::deque to_visit; + to_visit.push_back(initial_group); + + while (!to_visit.empty()) { + SegmentedGroup* group = to_visit.front(); + to_visit.pop_front(); + + if (groups_to_merge_set.count(group)) { + continue; + } + + // For simplicity, all groups are assumed to be the initial + // single-expr groups. Skip if not + + groups_to_merge.push_back(group); + groups_to_merge_set.insert(group); + + // Consumer traversal. Stop if this group is a down cast + // group. Also stop if there are multiple consumer edges to + // simplify keeping the DAG property. + if (!isDownCast(group) && group->consumer_edges.size() == 1) { + auto consumer_edge = group->consumer_edges.at(0); + SegmentedGroup* consumer_group = consumer_edge->to; + if (!groups_to_merge_set.count(consumer_group)) { + to_visit.push_back(consumer_group); + } + } + + if (!isUpCast(group)) { + for (const auto producer_edge : group->producer_edges) { + SegmentedGroup* producer_group = producer_edge->from; + // Don't add producers that have more than multiple consumers + if (producer_group->consumer_edges.size() > 1) { + continue; + } + if (!groups_to_merge_set.count(producer_group)) { + to_visit.push_back(producer_group); + } + } + } + } + + return groups_to_merge; + } + + // Try merging a candidate cast group. Return true if merged. + bool mergeCastGroup(const std::vector& groups) { + auto sched_type = tryMerge( + segment_candidate_finder_->segmented_fusion_.get(), + segment_candidate_finder_->runtimeInfo(), + groups); + + if (sched_type == SchedulerType::None) { + return false; + } + + segment_candidate_finder_->mergeAllGivenGroups(groups); + + return true; + } + + bool isUpCast(SegmentedGroup* group) const { + if (auto precision_bits = getProducerConsumerPrecision(group); + precision_bits.has_value()) { + return precision_bits->first < precision_bits->second; + } else { + return false; + } + } + + bool isDownCast(SegmentedGroup* group) const { + if (auto precision_bits = getProducerConsumerPrecision(group); + precision_bits.has_value()) { + return precision_bits->first > precision_bits->second; + } else { + return false; + } + } + + std::optional> getProducerConsumerPrecision( + SegmentedGroup* group) const { + if (group->exprs().size() != 1) { + return std::nullopt; + } + + auto uop = dynamic_cast(group->exprs().front()); + if (uop == nullptr || uop->getUnaryOpType() != UnaryOpType::Cast) { + return std::nullopt; + } + + return ir_utils::getPrecisionOfProducerConsumerTensors(uop); + } + + private: + SegmentCandidateFinder* segment_candidate_finder_ = nullptr; +}; + namespace { //! Returns true if group1 and group2 are an immediate producer-consumer pair. @@ -3945,6 +4133,9 @@ void SegmentCandidateFinder::findSegments() { removeScalarEdges(); // Run pre-merge heuristics + MergeUpAndDownCast::run(this); + segmented_fusion_->validateIfDebug(true); + if (options_.run_combine_reductions && CombineReductions::shouldRun(this)) { CombineReductions::run(this); } diff --git a/csrc/fusion_segmenter.h b/csrc/fusion_segmenter.h index 59f8ff2d574..c70aab19e49 100644 --- a/csrc/fusion_segmenter.h +++ b/csrc/fusion_segmenter.h @@ -488,6 +488,7 @@ class GroupDependencyAnalysis; // Manual node merging passes class CombineReductions; +class MergeUpAndDownCast; //! Options to configure/debug candidate finder struct SegmentCandidateFinderOptions { @@ -691,6 +692,7 @@ class SegmentCandidateFinder { //! eventually should have a dedicated interface //! instead of keeping adding friends friend class CombineReductions; + friend class MergeUpAndDownCast; //! options to configure and debug the segment process SegmentCandidateFinderOptions options_; diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index 4de5d7c8097..107ab898453 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -1524,4 +1524,32 @@ std::vector strideOrderToAllocation( return allocation_domain; } +std::optional> getPrecisionOfProducerConsumerTensors( + UnaryOp* uop) { + NVF_CHECK( + uop != nullptr && uop->getUnaryOpType() == UnaryOpType::Cast, + "Invalid expr: ", + uop->toString()); + + auto inp_tv = ir_utils::getTvInput(uop); + auto out_tv = ir_utils::getTvOutput(uop); + if (inp_tv == nullptr || out_tv == nullptr) { + return std::nullopt; + } + + auto inp_dtype = inp_tv->dtype().type; + auto out_dtype = out_tv->dtype().type; + auto inp_prim_type = std::get_if(&inp_dtype); + auto out_prim_type = std::get_if(&out_dtype); + + if (inp_prim_type == nullptr || out_prim_type == nullptr || + *inp_prim_type == PrimDataType::Index || + *out_prim_type == PrimDataType::Index) { + return std::nullopt; + } + + return std::make_pair( + primDataTypeSize(*inp_prim_type), primDataTypeSize(*out_prim_type)); +} + } // namespace nvfuser::ir_utils diff --git a/csrc/ir/utils.h b/csrc/ir/utils.h index 4ac93824037..37b53a8df36 100644 --- a/csrc/ir/utils.h +++ b/csrc/ir/utils.h @@ -803,4 +803,9 @@ std::vector strideOrderToAllocation( const std::vector& logical_domain, const std::vector& stride_order); +// Returns the number of bytes of data types of the producer and +// consumer tensors of a cast unary op +std::optional> getPrecisionOfProducerConsumerTensors( + UnaryOp* cast_op); + } // namespace nvfuser::ir_utils diff --git a/tests/cpp/test_gpu3.cpp b/tests/cpp/test_gpu3.cpp index d6e715f1fe8..c7eba2b1fed 100644 --- a/tests/cpp/test_gpu3.cpp +++ b/tests/cpp/test_gpu3.cpp @@ -9349,6 +9349,41 @@ TEST_F(NVFuserTest, RepeatBroadcastAndNonBroadcast) { testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); } +TEST_F(NVFuserTest, CastPrecision) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2, DataType::Half); + fusion.addInput(tv0); + + auto tv1 = castOp(DataType::Float, tv0); + auto tv2 = castOp(DataType::BFloat16, tv1); + fusion.addOutput(tv2); + + auto tv3 = makeSymbolicTensor(2, DataType::Index); + fusion.addInput(tv3); + + auto tv4 = castOp(DataType::Int, tv3); + fusion.addOutput(tv4); + + auto tv1_precision = ir_utils::getPrecisionOfProducerConsumerTensors( + tv1->definition()->as()); + ASSERT_TRUE(tv1_precision.has_value()); + EXPECT_EQ(tv1_precision->first, 2); + EXPECT_EQ(tv1_precision->second, 4); + + auto tv2_precision = ir_utils::getPrecisionOfProducerConsumerTensors( + tv2->definition()->as()); + ASSERT_TRUE(tv2_precision.has_value()); + EXPECT_EQ(tv2_precision->first, 4); + EXPECT_EQ(tv2_precision->second, 2); + + // Precision of type Index is not possible to determine until lowering + auto tv4_precision = ir_utils::getPrecisionOfProducerConsumerTensors( + tv4->definition()->as()); + ASSERT_FALSE(tv4_precision.has_value()); +} + // Test file size should be up to 10K LoC. Create a new file for more tests. } // namespace nvfuser