diff --git a/csrc/evaluator_common.cpp b/csrc/evaluator_common.cpp index 1042931914e..b17cfea9e8a 100644 --- a/csrc/evaluator_common.cpp +++ b/csrc/evaluator_common.cpp @@ -348,19 +348,21 @@ void PrecomputedValues::bindTensorMetaData( for (const auto dim : c10::irange(logical_domain.size())) { IterDomain* id = logical_domain[dim]; - auto dim_size = tensor.size(static_cast(dim)); - if (id->isDeviceDim()) { - dim_size = tv->getDeviceMesh().size(id->getParallelType()); - } - - if (id->hasExpandedExtent()) { - Val* extent = id->extent(); - Val* expanded_extent = id->expandedExtent(); - bindValue(extent->evaluatorIndex(), 1L); - bindValue(expanded_extent->evaluatorIndex(), dim_size); + const auto dim_size = tensor.size(static_cast(dim)); + if (id->isBroadcast()) { + // DIDs are ignored. + bindValue(id->extent()->evaluatorIndex(), 1L); + if (id->hasExpandedExtent()) { + bindValue(id->expandedExtent()->evaluatorIndex(), dim_size); + } } else { - Val* extent = id->extent(); - bindValue(extent->evaluatorIndex(), dim_size); + if (id->isDeviceDim()) { + bindValue( + id->extent()->evaluatorIndex(), + tv->getDeviceMesh().size(id->getParallelType())); + } else { + bindValue(id->extent()->evaluatorIndex(), dim_size); + } } } diff --git a/csrc/expr_evaluator.cpp b/csrc/expr_evaluator.cpp index c59842b2037..a360950354b 100644 --- a/csrc/expr_evaluator.cpp +++ b/csrc/expr_evaluator.cpp @@ -174,53 +174,61 @@ void ExpressionEvaluator::bind_( t.dim()); for (auto i : c10::irange(t.dim())) { auto id = logical_domain[i]; - if (id->hasExpandedExtent()) { - // Verify that t is also expanded - NVF_ERROR( - t.size(i) == 1 || t.stride(i) == 0, - "IterDomain ", - id->toString(), - " in ", - getInputPosString(tv), - "TensorView ", - tv->toString(), - " has expanded extent but input tensor has size ", - t.size(i), - " and stride ", - t.stride(i), - " in dimension ", - i); - bind_( - logical_domain[i]->expandedExtent(), t.size(i), evaluate_validate); - } else if (logical_domain[i]->isDeviceDim()) { - // Currently we have the restrictions: - // (1) Devices parallelized axis extent == DeviceMesh's extent - // (2) Device parallelized axis cannot be split or merged - // Therefore, the device parallelized extents will always be allocated - // with size 1, but the symbolic axis extent is binded with the extent - // of the DeviceMesh - NVF_CHECK( - 1 == t.size(i), - "TensorView ", - tv->toString(), - getInputPosString(tv), - " IterDomain ", - id->toString(), - "is sharded and must have size 1, but input tensor has size ", - t.size(i)); - NVF_CHECK( - tv->hasDeviceMesh(), - "TV ", - tv->toString(), - getInputPosString(tv), - " has an empty DeviceMesh with DID parallelization") - bind_( - logical_domain[i]->extent(), - static_cast( - tv->getDeviceMesh().size(logical_domain[i]->getParallelType())), - evaluate_validate); + if (id->isBroadcast()) { + // DIDs are ignored. + bind_(logical_domain[i]->extent(), 1, evaluate_validate); + if (id->hasExpandedExtent()) { + // Verify that t is also expanded + NVF_ERROR( + t.size(i) == 1 || t.stride(i) == 0, + "IterDomain ", + id->toString(), + " in ", + getInputPosString(tv), + "TensorView ", + tv->toString(), + " has expanded extent but input tensor has size ", + t.size(i), + " and stride ", + t.stride(i), + " in dimension ", + i); + bind_( + logical_domain[i]->expandedExtent(), + t.size(i), + evaluate_validate); + } } else { - bind_(logical_domain[i]->extent(), t.size(i), evaluate_validate); + if (logical_domain[i]->isDeviceDim()) { + // Currently we have the restrictions: + // (1) Devices parallelized axis extent == DeviceMesh's extent + // (2) Device parallelized axis cannot be split or merged + // Therefore, the device parallelized extents will always be allocated + // with size 1, but the symbolic axis extent is binded with the extent + // of the DeviceMesh + NVF_CHECK( + 1 == t.size(i), + "TensorView ", + tv->toString(), + getInputPosString(tv), + " IterDomain ", + id->toString(), + "is sharded and must have size 1, but input tensor has size ", + t.size(i)); + NVF_CHECK( + tv->hasDeviceMesh(), + "TV ", + tv->toString(), + getInputPosString(tv), + " has an empty DeviceMesh with DID parallelization") + bind_( + logical_domain[i]->extent(), + static_cast(tv->getDeviceMesh().size( + logical_domain[i]->getParallelType())), + evaluate_validate); + } else { + bind_(logical_domain[i]->extent(), t.size(i), evaluate_validate); + } } } } diff --git a/csrc/ops/utils.cpp b/csrc/ops/utils.cpp index 8b911d623c0..3b6bf6d561f 100644 --- a/csrc/ops/utils.cpp +++ b/csrc/ops/utils.cpp @@ -323,6 +323,13 @@ IterDomain* newOutputIterDomain( continue; } + NVF_ERROR( + id->getParallelType() == ParallelType::Serial || + isParallelTypeDeviceDim(id->getParallelType()), + id->getParallelType(), + " is not expected when building ops."); + parallel_type = promoteParallelType(parallel_type, id->getParallelType()); + if (id->isBroadcast()) { if (id->hasExpandedExtent()) { expanded_extent_val = @@ -331,13 +338,6 @@ IterDomain* newOutputIterDomain( continue; } - NVF_ERROR( - id->getParallelType() == ParallelType::Serial || - isParallelTypeDeviceDim(id->getParallelType()), - id->getParallelType(), - " is not expected when building ops."); - parallel_type = promoteParallelType(parallel_type, id->getParallelType()); - if (extent_is_from_symbolic && !id->isSymbolic()) { // We prefer to use extents from non-Symbolic inputs if there are any // because they might indicate a broadcast axis that is resolved in this diff --git a/csrc/preseg_passes/mark_aliases_prepare.cpp b/csrc/preseg_passes/mark_aliases_prepare.cpp index 2478105e33e..1a3cffc641e 100644 --- a/csrc/preseg_passes/mark_aliases_prepare.cpp +++ b/csrc/preseg_passes/mark_aliases_prepare.cpp @@ -18,6 +18,24 @@ namespace nvfuser::preseg_passes { namespace { +// Represents a use of `use_of` by `user`. This is to mark locations to segment +// so meta ops form a no-op region. When `user` is not null, we expect to +// segment between `use_of` and `user`, e.g., +// +// use_of -> [segment_set] -> copy of use_of -> [user] +// +// This happens due to bookending from outputs. +// +// When `user` is null, we expect to segment between `use_of` and all its +// users, e.g., +// +// use_of -> [segment_set] -> copy of use_of -> [user_0] +// | +// +> [user_1] +// | +// +> [user_2] +// +// This happens due to bookending from inputs. struct Use { TensorView* use_of; Expr* user; @@ -32,6 +50,14 @@ struct Use { } }; +std::ostream& operator<<(std::ostream& os, const Use& use) { + os << use.use_of; + if (use.user != nullptr) { + os << " used by " << use.user; + } + return os; +} + // A helper function that walks up from `out` until reaching a non-meta op or a // fusion input. Returns where it stops. Use findUseToSegment( @@ -66,19 +92,60 @@ std::unordered_set exprsDependedByNonAliases( return {depended_by_non_aliases.begin(), depended_by_non_aliases.end()}; } -// Inserts a `segment_set` after `use_of` and redirect aliasing users to -// use the `segment_set`. -void insertSegmentSetAfter( - std::vector::const_iterator first_user, - std::vector::const_iterator last_user) { +// Inserts a `segment_set` after `use_of` to separate meta and non-meta ops. +template +void insertSegmentSetAfter(InputIter first_user, InputIter last_user) { TensorView* use_of = first_user->use_of; - // There are a few corner cases where we don't need to add a - // `segment_set`. If `use_of` is only used by aliases, ... - if (static_cast(std::distance(first_user, last_user)) == - use_of->uses().size()) { + std::vector users; + users.reserve(use_of->uses().size()); + // `uses_to_segment` is sorted so `nullptr` if exists appears first. + if (first_user->user == nullptr) { + // This is an optimization to make fewer segments. In the + // following example, if bookending wants to segment (a) between `use_of` + // and all its users, (b) between `use_of` and `user_0`, and (c) between + // `use_of` and `user_1`. We can instead segment only between `use_of` and + // `user_2`, the complement set of [`first_user`, `last_user`). This is + // valid because the ops before `use_of`, those after `user_0`, and those + // after `user_1` are all meta ops that can be merged into one no-op + // segment. + // + // use_of | -> | [user 0] + // | + // +> | [user 1] + // | + // +> [user 2] + // + // ==> + // + // use_of -> [user 0] + // | + // +> [user 1] + // | + // +> | [user 2] + first_user++; + std::unordered_set to_remove; + std::for_each(first_user, last_user, [&](const Use& use) { + to_remove.insert(use.user); + }); + std::copy_if( + use_of->uses().begin(), + use_of->uses().end(), + std::back_inserter(users), + [&](Expr* user) { return to_remove.count(user) == 0; }); + } else { + std::transform( + first_user, last_user, std::back_inserter(users), [](const Use& use) { + return use.user; + }); + } + + // There are a few corner cases where we can avoid adding a + // `segment_set`. If a segment_set is to be added between `use_of` and all + // its users, ... + if (users.size() == use_of->uses().size()) { if (use_of->isFusionInput()) { - // Putting a `segment_set` between a fusion input and its users is + // Putting a `segment_set` between a fusion input and all its users is // unnecessary. return; } @@ -90,17 +157,15 @@ void insertSegmentSetAfter( } } - // If all aliasing users are `segment_set`, don't create another + // If all users to segment are `segment_set`, don't create another // `segment_set`. - if (std::all_of(first_user, last_user, [](const Use& use) { - return ir_utils::isSegmentSet(use.user); - })) { + if (std::all_of(users.begin(), users.end(), ir_utils::isSegmentSet)) { return; } // The general case. TensorView* copy = segment_set(use_of); - // Inherit the allocation domain from `use_of`. This is important to pass + // Inherit the allocation domain from `use_of`. This is needed for cases like // AliasTest.Bookend_SegmentSetPreservesAllocation. TensorDomain* replayed_domain = TransformReplay::replayCasP( @@ -110,14 +175,43 @@ void insertSegmentSetAfter( copy->setAllocationDomain( replayed_domain->allocation(), replayed_domain->contiguity()); } - std::for_each(first_user, last_user, [&](const Use& use) { - ir_utils::replaceValInExprInputs(use.user, use_of, copy); + // This is an optimization to make fewer segments. In the following example, + // we could literally add two `segment_set`s, one before `user_0` and the + // other before `user_1`. However, because these `segment_set`s are implied + // by bookending, the ops after `user_0` and those after `user_1` are all + // meta and can be merged into one no-op segment. + // + // use_of -> | [user_0] + // | + // +> | [user_1] + // | + // +> [user_2] + // + // => + // + // use_of -> [segment_set] -> copy -> [user_0] + // | | + // | +> [user 1] + // | + // +> [user_2] + std::for_each(users.begin(), users.end(), [&](Expr* user) { + ir_utils::replaceValInExprInputs(user, use_of, copy); }); if (use_of->isFusionOutput()) { use_of->fusion()->replaceOutput(use_of, copy); } } +bool isMetaOp(const AliasAnalysisResult& analysis, Expr* e) { + return std::all_of( + e->outputs().begin(), e->outputs().end(), [&analysis](Val* out) { + if (auto* out_tv = dynamic_cast(out)) { + return analysis.getRoot(out_tv) != nullptr; + } + return false; + }); +} + } // namespace void MarkAliasesPreparePass::runPass(Fusion* fusion) { @@ -156,8 +250,12 @@ void MarkAliasesPreparePass::runPass(Fusion* fusion) { } } - // The following emulates the bookend optimization. Only the output end is - // implemented at this moment. In general, the algorithm tries to walk up + // The following emulates the bookend optimization. This is done in two + // steps: the first step bookends the outputs and the second step does the + // inputs. TODO(wujingyue): extract this into a function. I'm adding the new + // logic in place just to make review easier. + // + // Step 1: for outputs, the algorithm tries to walk up // from each fusion output until reaching a non-alias, and put a // `segment_set` there so the meta ops that are skipped form a no-op segment. // @@ -174,28 +272,78 @@ void MarkAliasesPreparePass::runPass(Fusion* fusion) { // examples. This is the reason behind `depended_by_non_aliases`. const std::unordered_set& depended_by_non_aliases = exprsDependedByNonAliases(analysis, fusion); - std::vector uses_to_segment; - uses_to_segment.reserve(fusion->outputs().size()); + std::set uses_to_segment; for (auto* out : ir_utils::filterByType(fusion->outputs())) { Use use_to_segment = findUseToSegment(out, analysis, depended_by_non_aliases); if (use_to_segment.use_of != out) { - uses_to_segment.push_back(use_to_segment); + uses_to_segment.insert(use_to_segment); } } - // The remaining are optimizations to reduce the number of `segment_set`s - // inserted. + // Step 2: for inputs, the algorithm tries to walk down from each fusion + // input until reaching a non-meta or a fork. Stopping at the fork is to + // avoid feeding the same data via multiple inputs, e.g., // - // Group `uses_to_segment` by `use_of` and remove duplicates. - std::sort(uses_to_segment.begin(), uses_to_segment.end()); - uses_to_segment.erase( - std::unique(uses_to_segment.begin(), uses_to_segment.end()), - uses_to_segment.end()); + // in -> reshape_0 -> mul + // | ^ + // +--> reshape_1 ----+ + // + // If we separate `reshape_0` and `reshape_1` from `mul`, the pointwise + // kernel would take double the input. + std::queue frontier; + for (auto* tv : ir_utils::filterByType(fusion->inputs())) { + frontier.push(tv); + } + while (!frontier.empty()) { + TensorView* tv = frontier.front(); + frontier.pop(); + + auto should_enqueue_users = [&analysis](TensorView* tv) { + // Stop at a non-meta op. + if (!std::all_of( + tv->uses().begin(), tv->uses().end(), [&analysis](Expr* e) { + return isMetaOp(analysis, e); + })) { + return false; + } + + // Stop at a fork. + if (tv->uses().size() > 1 && + // The only exception is when the fork happens to be a split, which + // is a common pattern in RoPE. + !std::all_of( + tv->uses().begin(), + tv->uses().end(), + std::mem_fn(&Expr::isA))) { + return false; + } + + return true; + }; + if (should_enqueue_users(tv)) { + for (Expr* user : tv->uses()) { + // If the use of `tv` by `user` is going to be segmented due to + // bookending outputs, stop there. We could keep bookending but further + // segmenting a meta-op region is useless. + if (uses_to_segment.count(Use{tv, user})) { + continue; + } + for (auto* user_out : + ir_utils::filterByType(user->outputs())) { + frontier.push(user_out); + } + } + } else { + // Will insert a segment_set between `tv` and all its users. See + // Use::user for more details. + uses_to_segment.insert(Use{tv, nullptr}); + } + } if (isDebugDumpEnabled(DebugDumpOption::PreSegmenterLogging)) { for (const auto& use : uses_to_segment) { - debug() << "Will put a segment_set at " << use.user << std::endl; + debug() << "Will put a segment_set at " << use << std::endl; } } diff --git a/tests/cpp/test_alias.cpp b/tests/cpp/test_alias.cpp index 0a8a427a6b8..5c21aa10723 100644 --- a/tests/cpp/test_alias.cpp +++ b/tests/cpp/test_alias.cpp @@ -503,11 +503,12 @@ TEST_F(AliasTest, AliasOutputBeforeNonAliasOutput) { at::Tensor slice_out_tensor = out_tensors[0]; EXPECT_TRUE(slice_out_tensor.is_alias_of(in_tensor)); - const FusionExecutor& fe = onlyExecutorInMostRecentRuntime(fec); - EXPECT_FALSE(storesToOutput(fe, /*out_index=*/0)) - << "The generated CUDA kernel shouldn't store data to output 0:" - << std::endl - << fe.kernelString(); + FusionKernelRuntime* runtime = fec.getMostRecentKernelRuntime(); + EXPECT_THAT( + runtime->fusionSegments()->groups(), + UnorderedElementsAre( + HeuristicIs(ScheduleHeuristic::NoOp), + HeuristicIs(ScheduleHeuristic::PointWise))); } TEST_F(AliasTest, Set_NoAliasForIncompatibleLayout) { @@ -794,6 +795,33 @@ TEST_F(AliasTest, DoNotOverSegment_WithForks) { // EXPECT_TRUE(out_tensors[1].is_alias_of(out_tensors[0])); } +TEST_F(AliasTest, DoNotOverSegment_AllAliasWithForks) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* in = makeContigTensor(2); + TensorView* t = flatten(in); + TensorView* out0 = reshape(t, {in->axis(0)->extent(), in->axis(1)->extent()}); + TensorView* out1 = reshape(t, {in->axis(1)->extent(), in->axis(0)->extent()}); + + fusion->addInput(in); + fusion->addOutput(out0); + fusion->addOutput(out1); + + FusionExecutorCache fec(std::move(fusion)); + at::Tensor in_tensor = at::randn({2, 3}).cuda(); + std::vector out_tensors = fec.runFusionWithInputs({in_tensor}); + testValidate(fec.fusion(), out_tensors, {in_tensor}, __LINE__, __FILE__); + + FusionKernelRuntime* runtime = fec.getMostRecentKernelRuntime(); + EXPECT_THAT( + runtime->fusionSegments()->groups(), + ElementsAre(HeuristicIs(ScheduleHeuristic::NoOp))); + + EXPECT_TRUE(out_tensors[0].is_alias_of(in_tensor)); + EXPECT_TRUE(out_tensors[1].is_alias_of(in_tensor)); +} + TEST_F(AliasTest, Broadcast) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -1216,7 +1244,7 @@ TEST_F(AliasTest, Bookend_SegmentSetPreservesAllocation) { EXPECT_TRUE(permute_out_tensor.is_alias_of(in_tensor)); } -TEST_F(AliasTest, Bookend_InputsAndOutputs) { +TEST_F(AliasTest, Bookend_OneOutputAliasesTheOtherNot) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -1239,8 +1267,8 @@ TEST_F(AliasTest, Bookend_InputsAndOutputs) { FusionKernelRuntime* runtime = fec.getMostRecentKernelRuntime(); // MarkAliasesPrepare adds a `segment_set` between `in` and `permute`, which // leads to three segments: - // 1. segment_set`, a no-op segment, - // 2. permute`, a no-op segment, + // 1. `segment_set`, a no-op segment, + // 2. `permute`, a no-op segment, // 3. `mul` and `add`, a pointwise segment. EXPECT_THAT( runtime->fusionSegments()->groups(), @@ -1276,6 +1304,7 @@ TEST_F(AliasTest, Bookend_IntermediateTensors) { EXPECT_THAT( runtime->fusionSegments()->groups(), UnorderedElementsAre( + HeuristicIs(ScheduleHeuristic::NoOp), HeuristicIs(ScheduleHeuristic::NoOp), HeuristicIs(ScheduleHeuristic::PointWise))); for (SegmentedGroup* group : runtime->fusionSegments()->groups()) { @@ -1357,6 +1386,87 @@ TEST_F(AliasTest, Bookend_ReuseSegmentSet) { } } +// Repro for #2599. +TEST_F(AliasTest, Bookend_Rope) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + constexpr int64_t kNumHeads = 40; + constexpr int64_t kHiddenSize = 128; + + TensorView* qkv = TensorViewBuilder() + .ndims(3) + .dtype(DataType::Half) + .contiguity(true) + .shape({-1, -1, kNumHeads * kHiddenSize * 3}) + .build(); + TensorView* cos = TensorViewBuilder() + .ndims(2) + .dtype(DataType::Half) + .contiguity(true) + .shape({-1, kHiddenSize}) + .build(); + TensorView* sin = TensorViewBuilder() + .ndims(2) + .dtype(DataType::Half) + .contiguity(true) + .shape({-1, kHiddenSize}) + .build(); + fusion->addInput(qkv); + fusion->addInput(cos); + fusion->addInput(sin); + + qkv = reshape( + qkv, + {qkv->axis(0)->extent(), + qkv->axis(1)->extent(), + IrBuilder::create(kNumHeads), + IrBuilder::create(3), + IrBuilder::create(kHiddenSize)}); + qkv = permute(qkv, {0, 2, 3, 1, 4}); + + std::vector slices = chunk(qkv, /*chunks=*/3, /*dim=*/2); + auto* q = squeeze(slices[0], {2}); + auto* k = squeeze(slices[1], {2}); + auto* v = squeeze(slices[2], {2}); + + auto apply_rope = [cos, sin](TensorView* x) { + std::vector slices = chunk(x, /*chunks=*/2, /*dim=*/-1); + auto* real = slices[0]; + auto* imag = castOp(DataType::Half, neg(slices[1])); + TensorView* rotated = cat({imag, real}, -1); + TensorView* out = add(mul(x, cos), mul(rotated, sin)); + out = castOp(DataType::Half, out); + return out; + }; + q = apply_rope(q); + k = apply_rope(k); + + fusion->addOutput(q); + fusion->addOutput(k); + fusion->addOutput(v); + + constexpr int64_t kSeqLen = 4096; + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto qkv_tensor = + at::randn({2, kSeqLen, kNumHeads * kHiddenSize * 3}, options); + auto freq_tensor = at::randn({kSeqLen, kHiddenSize / 2}, options); + auto cos_tensor = freq_tensor.cos(); + auto sin_tensor = freq_tensor.sin(); + std::vector in_tensors = { + qkv_tensor, + at::cat({cos_tensor, cos_tensor}, -1), + at::cat({sin_tensor, sin_tensor}, -1)}; + + FusionExecutorCache fec(std::move(fusion)); + std::vector out_tensors = fec.runFusionWithInputs(in_tensors); + testValidate(fec.fusion(), out_tensors, in_tensors, __LINE__, __FILE__); + + EXPECT_THAT( + fec.getMostRecentKernelRuntime()->fusionSegments()->groups(), + Contains(HeuristicIs(ScheduleHeuristic::PointWise)).Times(4)); +} + TEST_F(AliasTest, QKVSplitBackprop) { // A subgraph of MoveSplitCatTest.Cancellable_Issue1768. constexpr int b = 16; diff --git a/tests/cpp/test_allocation_domain.cpp b/tests/cpp/test_allocation_domain.cpp index 4efc6e3adf8..db99af0959c 100644 --- a/tests/cpp/test_allocation_domain.cpp +++ b/tests/cpp/test_allocation_domain.cpp @@ -24,7 +24,8 @@ namespace nvfuser { using AllocationDomainTest = NVFuserTest; -using ::testing::ElementsAre; +using testing::Contains; +using testing::ElementsAre; // A global->shared->global copy kernel, shared memory allocated transposed to // avoid bank conflict. @@ -1251,16 +1252,14 @@ TEST_F(AllocationDomainTest, Issue1290_ContiguityWasMissing) { at::Tensor in_tensor = at::randn({2 * 4}).cuda().as_strided({2, 3}, {4, 1}); FusionExecutorCache fec(std::move(fusion)); - fec.runFusionWithInputs({in_tensor}); + auto out_tensors = fec.runFusionWithInputs({in_tensor}); + testValidate(fec.fusion(), out_tensors, {in_tensor}, __LINE__, __FILE__); // The initial issue was detected in the pointwise scheduler, so I added these - // checks to make sure it's a valid regression test. The transpose scheduler - // could accept this but decided not to because of a small problem size. - const std::vector& groups = - fec.getMostRecentKernelRuntime()->fusionSegments()->groups(); - ASSERT_EQ(groups.size(), 1); - SegmentedGroup* group = groups[0]; - EXPECT_EQ(group->heuristic(), ScheduleHeuristic::PointWise); + // checks to make sure it's a valid regression test. + EXPECT_THAT( + fec.getMostRecentKernelRuntime()->fusionSegments()->groups(), + Contains(HeuristicIs(ScheduleHeuristic::PointWise))); } TEST_F(AllocationDomainTest, Issue1290_ReplayCasPFailedDueToDifferentRanks) { diff --git a/tests/cpp/test_combined_inner_outer_reduction.cpp b/tests/cpp/test_combined_inner_outer_reduction.cpp index 55873d66bb1..0f2e82a2038 100644 --- a/tests/cpp/test_combined_inner_outer_reduction.cpp +++ b/tests/cpp/test_combined_inner_outer_reduction.cpp @@ -5,10 +5,21 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#include +#include +#include +#include +#include + #include +#include #include +#include +#include +#include +#include + +#include #include #include #include @@ -21,18 +32,10 @@ #include #include -#include -#include -#include -#include -#include -#include -#include -#include - namespace nvfuser { -using namespace at::indexing; +using testing::Contains; +using testing::Not; // tuple of data type, batch size (outer dim), hidden size (inner dim) using CombinedSchedulerParams = std::tuple; @@ -458,10 +461,14 @@ TEST_F(NVFuserTest, CombinedSchedulerSharedProducer_CUDA) { case 0: case 1: case 3: - EXPECT_TRUE(runtime->isSegmented()); + EXPECT_THAT( + runtime->fusionSegments()->groups(), + Contains(Not(HeuristicIs(ScheduleHeuristic::NoOp))).Times(2)); break; case 2: - EXPECT_FALSE(runtime->isSegmented()); + EXPECT_THAT( + runtime->fusionSegments()->groups(), + Contains(Not(HeuristicIs(ScheduleHeuristic::NoOp))).Times(1)); break; default: NVF_ERROR(false, "Invalid case id"); diff --git a/tests/cpp/test_gpu3.cpp b/tests/cpp/test_gpu3.cpp index 4cc1ac113f4..5a258a5c371 100644 --- a/tests/cpp/test_gpu3.cpp +++ b/tests/cpp/test_gpu3.cpp @@ -35,6 +35,8 @@ #include #include #include +#include +#include #include #include #include @@ -6253,19 +6255,24 @@ TEST_F(NVFuserTest, FusionAvoidRedundantWriteBroadcastedSoftmaxInput_CUDA) { auto cg_outputs = fec.runFusionWithInputs(inputs); // check thread_pred and write_stride - const auto& fe = fec.getMostRecentKernelRuntime()->executors().at(0); - auto kernel = fe.kernel(); - const auto& thread_pred_map = fe.threadPredMap(); - for (const auto expr : kernel->exprs()) { - auto tv = ir_utils::getTvOutput(expr); - if (tv && tv->name() == 15 && tv->getMemoryType() == MemoryType::Global) { - const auto& thread_pred = thread_pred_map.getPredicateInfo(tv); - bool predicted = thread_pred.redundant_types.get(ParallelType::BIDx) && - thread_pred.broadcast_ld_indices_map.count(ParallelType::BIDx); - NVF_CHECK( - predicted, - "Tv15 should be predicted by ParallelType::BIDx with a broadcast_ld_indices_map!"); - break; + for (const FusionExecutor& fe : + fec.getMostRecentKernelRuntime()->executors()) { + if (!fe.hasCompiledKernel()) { + continue; + } + auto kernel = fe.kernel(); + const auto& thread_pred_map = fe.threadPredMap(); + for (const auto expr : kernel->exprs()) { + auto tv = ir_utils::getTvOutput(expr); + if (tv && tv->name() == 15 && tv->getMemoryType() == MemoryType::Global) { + const auto& thread_pred = thread_pred_map.getPredicateInfo(tv); + bool predicted = thread_pred.redundant_types.get(ParallelType::BIDx) && + thread_pred.broadcast_ld_indices_map.count(ParallelType::BIDx); + NVF_CHECK( + predicted, + "Tv15 should be predicted by ParallelType::BIDx with a broadcast_ld_indices_map!"); + break; + } } } @@ -6308,20 +6315,27 @@ TEST_F(NVFuserTest, FusionAvoidRedundantWrite_CUDA) { auto cg_outputs = fec.runFusionWithInputs(inputs); // check thread_pred and write_stride - const auto& fe = fec.getMostRecentKernelRuntime()->executors().at(0); - auto kernel = fe.kernel(); - const auto& thread_pred_map = fe.threadPredMap(); - - for (const auto expr : kernel->exprs()) { - auto tv = ir_utils::getTvOutput(expr); - if (tv && tv->name() == 8 && tv->getMemoryType() == MemoryType::Global) { - const auto& thread_pred = thread_pred_map.getPredicateInfo(tv); - bool predicted = thread_pred.redundant_types.get(ParallelType::BIDx) && - thread_pred.broadcast_ld_indices_map.count(ParallelType::BIDx); - NVF_CHECK( - predicted, - "Tv8 should be predicted by ParallelType::BIDx with a broadcast_ld_indices_map!"); - break; + for (const FusionExecutor& fe : + fec.getMostRecentKernelRuntime()->executors()) { + if (!fe.hasCompiledKernel()) { + continue; + } + auto kernel = fe.kernel(); + const auto& thread_pred_map = fe.threadPredMap(); + + for (const auto expr : kernel->exprs()) { + auto tv = ir_utils::getTvOutput(expr); + if (tv && tv->name() == 8 && + tv->getMemoryType() == MemoryType::Global) { + const auto& thread_pred = thread_pred_map.getPredicateInfo(tv); + bool predicted = + thread_pred.redundant_types.get(ParallelType::BIDx) && + thread_pred.broadcast_ld_indices_map.count(ParallelType::BIDx); + NVF_CHECK( + predicted, + "Tv8 should be predicted by ParallelType::BIDx with a broadcast_ld_indices_map!"); + break; + } } } @@ -7822,6 +7836,12 @@ TEST_F(NVFuserTest, AvoidCachingSliceInput) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); + // Avoid `slice`s from being bookended. This test is to exercise kernel + // caching when a `SliceOp` is applied on a segment input. Bookending + // `slice`s would defeat that purpose. + preseg_passes::OptimizationPassGuard + optimization_guard(false); + // values to trigger the original bug. const int64_t eight = 8; const int64_t twenty = 20; diff --git a/tests/cpp/test_gpu_outer_reduction.cpp b/tests/cpp/test_gpu_outer_reduction.cpp index f6c120c5aba..d11d2df4b5a 100644 --- a/tests/cpp/test_gpu_outer_reduction.cpp +++ b/tests/cpp/test_gpu_outer_reduction.cpp @@ -5,10 +5,20 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#include #include +#include #include +#include +#include +#include +#include + +#include +#include +#include + +#include #include #include #include @@ -21,20 +31,14 @@ #include #include -#include -#include -#include - -#include -#include -#include -#include - namespace nvfuser { using OuterReductionTest = NVFuserTest; -using namespace at::indexing; +using testing::Contains; +using testing::IsSupersetOf; + +// using namespace at::indexing; // Shmoo testing of the optimized grouped grid welford TEST_F(OuterReductionTest, GroupedGridWelfordOuterOpt) { @@ -1522,20 +1526,19 @@ void grid_persistent_welford_outer_norm_like_scheduler( bool use_weights = false, DataType weights_dtype = DataType::Float) { const bool benchmark_mode = isBenchmarkMode(); - std::unique_ptr fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr.get(); - FusionGuard fg(&fusion); + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); std::vector bcast_pattern{true, true, true, false}; std::vector reduction_dims{2, 1, 0}; auto inp = makeContigTensor(4, dtype); - fusion.addInput(inp); + fusion->addInput(inp); TensorView* weights = nullptr; if (use_weights) { weights = makeContigTensor(1, weights_dtype); - fusion.addInput(weights); + fusion->addInput(weights); } auto inp_cast = cast(inp, DataType::Float); @@ -1548,7 +1551,7 @@ void grid_persistent_welford_outer_norm_like_scheduler( } out = cast(out, dtype); - fusion.addOutput(out); + fusion->addOutput(out); auto options = at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0); @@ -1564,25 +1567,15 @@ void grid_persistent_welford_outer_norm_like_scheduler( aten_inputs.push_back(t1); } - FusionExecutorCache executor_cache(std::move(fusion_ptr)); + FusionExecutorCache executor_cache(std::move(fusion)); auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); auto runtime = executor_cache.getMostRecentKernelRuntime(); - if (!shouldBePersistent(N, HW, dtype, false, use_weights, weights_dtype)) { - NVF_CHECK(runtime->isSegmented(), "Expected to be segmented"); - } else { - NVF_CHECK( - !runtime->isSegmented(), - "Unexpected number of segments: ", - runtime->fusionSegments()->groups().size()); - - const auto& scheduler_entry = - runtime->schedulerHeuristics()->heuristicsList().at(0); - NVF_CHECK( - scheduler_entry->heuristic() == ScheduleHeuristic::OuterPersistent, - "Unexpected heuristic was chosen: ", - scheduler_entry->heuristic()); + if (shouldBePersistent(N, HW, dtype, false, use_weights, weights_dtype)) { + EXPECT_THAT( + runtime->fusionSegments()->groups(), + Contains(HeuristicIs(ScheduleHeuristic::OuterPersistent)).Times(1)); if (benchmark_mode) { for (int i = 0; i < 10; ++i) { @@ -1590,17 +1583,16 @@ void grid_persistent_welford_outer_norm_like_scheduler( cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); } } + } else { + EXPECT_THAT( + runtime->fusionSegments()->groups(), + IsSupersetOf( + {HeuristicIs(ScheduleHeuristic::Reduction), + HeuristicIs(ScheduleHeuristic::PointWise)})); } - auto t0_cast = t0.to(at::kFloat); - auto t0_allreduce = - t0_cast.mean({0, 1, 2}).unsqueeze(0).unsqueeze(0).unsqueeze(0); - auto ref = t0_cast - t0_allreduce; - if (use_weights) { - ref = ref + t1.to(at::kFloat); - } - - testValidate(&fusion, cg_outputs, aten_inputs, {ref}, __LINE__, __FILE__, ""); + testValidate( + executor_cache.fusion(), cg_outputs, aten_inputs, __LINE__, __FILE__); } } // namespace @@ -1649,7 +1641,7 @@ TEST_F( TEST_F( OuterReductionTest, - GridPersistentWelfordOuterNormWithWeithtsLikeHalf256x7x512Scheduler) { + GridPersistentWelfordOuterNormWithWeightsLikeHalf256x7x512Scheduler) { grid_persistent_welford_outer_norm_like_scheduler( 256, 7, 512, DataType::Half, true, DataType::Float); } @@ -1677,9 +1669,8 @@ void grid_persistent_batchnorm_scheduler( int64_t C, DataType dtype) { const bool benchmark_mode = isBenchmarkMode(); - std::unique_ptr fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr.get(); - FusionGuard fg(fusion_ptr.get()); + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); const bool kTraining = true; const float kMomentum = 0.1; @@ -1692,11 +1683,11 @@ void grid_persistent_batchnorm_scheduler( auto running_mean = makeContigTensor(1, DataType::Float); auto running_var = makeContigTensor(1, DataType::Float); - fusion_ptr->addInput(input); - fusion_ptr->addInput(weight); - fusion_ptr->addInput(bias); - fusion_ptr->addInput(running_mean); - fusion_ptr->addInput(running_var); + fusion->addInput(input); + fusion->addInput(weight); + fusion->addInput(bias); + fusion->addInput(running_mean); + fusion->addInput(running_var); if (dtype == DataType::Half) { input = castOp(DataType::Float, input); @@ -1724,7 +1715,7 @@ void grid_persistent_batchnorm_scheduler( output = castOp(DataType::Half, output); } - fusion_ptr->addOutput(output); + fusion->addOutput(output); auto options_float = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); @@ -1743,25 +1734,15 @@ void grid_persistent_batchnorm_scheduler( std::vector aten_inputs( {at_input_nvfuser, at_weight, at_bias, at_running_mean, at_running_var}); - FusionExecutorCache executor_cache(std::move(fusion_ptr)); + FusionExecutorCache executor_cache(std::move(fusion)); auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); - auto runtime = executor_cache.getMostRecentKernelRuntime(); - - if (!shouldBePersistent(N, HW, dtype, false, true, DataType::Float)) { - NVF_CHECK(runtime->isSegmented(), "Expected to be segmented"); - } else { - NVF_CHECK( - !runtime->isSegmented(), - "Unexpected number of segments: ", - runtime->fusionSegments()->groups().size()); + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); - const auto& scheduler_entry = - runtime->schedulerHeuristics()->heuristicsList().at(0); - NVF_CHECK( - scheduler_entry->heuristic() == ScheduleHeuristic::OuterPersistent, - "Unexpected heuristic was chosen: ", - scheduler_entry->heuristic()); + if (shouldBePersistent(N, HW, dtype, false, true, DataType::Float)) { + EXPECT_THAT( + runtime->fusionSegments()->groups(), + Contains(HeuristicIs(ScheduleHeuristic::OuterPersistent)).Times(1)); if (benchmark_mode) { for (int i = 0; i < 10; ++i) { @@ -1769,23 +1750,16 @@ void grid_persistent_batchnorm_scheduler( cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); } } + } else { + EXPECT_THAT( + runtime->fusionSegments()->groups(), + IsSupersetOf( + {HeuristicIs(ScheduleHeuristic::Reduction), + HeuristicIs(ScheduleHeuristic::PointWise)})); } - auto at_output = at::batch_norm( - at_input, - at_weight, - at_bias, - at_running_mean, - at_running_var, - kTraining, - kMomentum, - kEps, - true); - - cg_outputs.at(0) = cg_outputs.at(0).permute({0, 3, 1, 2}); - testValidate( - &fusion, cg_outputs, aten_inputs, {at_output}, __LINE__, __FILE__, ""); + executor_cache.fusion(), cg_outputs, aten_inputs, __LINE__, __FILE__); } } // namespace diff --git a/tests/cpp/test_no_op.cpp b/tests/cpp/test_no_op.cpp index f0f163be54e..31a28a172b0 100644 --- a/tests/cpp/test_no_op.cpp +++ b/tests/cpp/test_no_op.cpp @@ -20,6 +20,7 @@ namespace nvfuser { using NoOpTest = NVFuserTest; +using testing::Each; using testing::IsEmpty; using testing::UnorderedElementsAre; @@ -228,9 +229,12 @@ TEST_F(NoOpTest, ExpandedReduction) { FusionKernelRuntime* runtime = fec.getMostRecentKernelRuntime(); EXPECT_THAT( runtime->fusionSegments()->groups(), - UnorderedElementsAre(HeuristicIs(ScheduleHeuristic::NoOp))); - const auto& executor = runtime->executors().front(); - EXPECT_THAT(executor.kernel()->summary().global_allocations, IsEmpty()); + Each(HeuristicIs(ScheduleHeuristic::NoOp))); + for (const auto& fe : runtime->executors()) { + if (fe.hasCompiledKernel()) { + EXPECT_THAT(fe.kernel()->summary().global_allocations, IsEmpty()); + } + } } } // namespace nvfuser diff --git a/tests/cpp/test_persistent_buffer.cpp b/tests/cpp/test_persistent_buffer.cpp index 1a17cb8902c..1a970d2ec7b 100644 --- a/tests/cpp/test_persistent_buffer.cpp +++ b/tests/cpp/test_persistent_buffer.cpp @@ -6,6 +6,7 @@ */ // clang-format on #include +#include #include #include @@ -15,10 +16,13 @@ #include #include #include + namespace nvfuser { using PersistentBufferTest = NVFuserTest; +using testing::Contains; + TEST_F(PersistentBufferTest, FusionPersistentBufferCalculation1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -1170,9 +1174,8 @@ TEST_F(NVFuserTest, AvoidProjectingToInputsIfRecomputeHasDropout) { // From T7, the backward search can find the corresponding reduction // input ID, which is {I2} in T2. TEST_F(PersistentBufferTest, PostReductionBroadcastCheck) { - std::unique_ptr fusion_ptr = std::make_unique(); - auto fusion = fusion_ptr.get(); - FusionGuard fg(fusion); + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); const int dim0 = 128; const int dim1 = 256; @@ -1195,22 +1198,20 @@ TEST_F(PersistentBufferTest, PostReductionBroadcastCheck) { .device(at::kCUDA, 0); auto t0 = at::randn({dim0, dim1}, options); auto t1 = at::randn({dim0, dim1}, options); - auto t2 = at::sum(t0, {1}).unsqueeze(1) + t0; - auto t4 = t2 + t1; - FusionExecutorCache fec(std::move(fusion_ptr)); + FusionExecutorCache fec(std::move(fusion)); auto cg_outputs = fec.runFusionWithInputs({t0, t1}); - NVF_CHECK( - !fec.getMostRecentKernelRuntime()->isSegmented(), - "unexpected segmentation!"); + testValidate(fec.fusion(), cg_outputs, {t0, t1}, __LINE__, __FILE__); - testValidate(fusion, cg_outputs, {t0, t1}, {t4}, __LINE__, __FILE__); + FusionKernelRuntime* runtime = fec.getMostRecentKernelRuntime(); + EXPECT_THAT( + runtime->fusionSegments()->groups(), + Contains(HeuristicIs(ScheduleHeuristic::InnerPersistent)).Times(1)); } // Cases with two broadcast IDs TEST_F(PersistentBufferTest, PostReductionBroadcastCheckMultiBcastDims) { - std::unique_ptr fusion_ptr = std::make_unique(); - auto fusion = fusion_ptr.get(); - FusionGuard fg(fusion); + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); const int dim0 = 16; const int dim1 = 32; @@ -1234,15 +1235,14 @@ TEST_F(PersistentBufferTest, PostReductionBroadcastCheckMultiBcastDims) { .device(at::kCUDA, 0); auto t0 = at::randn({dim0, dim1, dim2}, options); auto t1 = at::randn({dim0, dim1, dim2}, options); - auto t2 = at::sum(t0, {1, 2}).unsqueeze(-1).unsqueeze(-1) + t0; - auto t4 = t2 + t1; - FusionExecutorCache fec(std::move(fusion_ptr)); + FusionExecutorCache fec(std::move(fusion)); auto cg_outputs = fec.runFusionWithInputs({t0, t1}); - NVF_CHECK( - !fec.getMostRecentKernelRuntime()->isSegmented(), - "unexpected segmentation!"); + testValidate(fec.fusion(), cg_outputs, {t0, t1}, __LINE__, __FILE__); - testValidate(fusion, cg_outputs, {t0, t1}, {t4}, __LINE__, __FILE__); + FusionKernelRuntime* runtime = fec.getMostRecentKernelRuntime(); + EXPECT_THAT( + runtime->fusionSegments()->groups(), + Contains(HeuristicIs(ScheduleHeuristic::InnerPersistent)).Times(1)); } TEST_F(PersistentBufferTest, SmemPersistentNotSupportedIn3DReduction) { diff --git a/tests/cpp/test_predicate_elimination.cpp b/tests/cpp/test_predicate_elimination.cpp index 8b941f7e0e6..9ae7c0a7b44 100644 --- a/tests/cpp/test_predicate_elimination.cpp +++ b/tests/cpp/test_predicate_elimination.cpp @@ -8,12 +8,13 @@ #include #include -#include -#include - #include #include #include +#include +#include +#include +#include namespace nvfuser { @@ -326,6 +327,10 @@ TEST_F(PredicateEliminationTest, 8) { std::unique_ptr fusion_ptr = std::make_unique(); Fusion& fusion = *fusion_ptr.get(); FusionGuard fg(&fusion); + // Disable bookending to avoid segmentation. Validation code below assumes + // one segment. + preseg_passes::OptimizationPassGuard + optimization_guard(false); const int64_t channel_size = 16; const int64_t batch_size = 8; diff --git a/tests/cpp/test_remove_bcast_squeeze.cpp b/tests/cpp/test_remove_bcast_squeeze.cpp index fd4e45dafe7..c42cc87aa46 100644 --- a/tests/cpp/test_remove_bcast_squeeze.cpp +++ b/tests/cpp/test_remove_bcast_squeeze.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -69,6 +70,10 @@ TEST_F(RemoveBcastSqueezeTest, BcastSqueeze) { TEST_F(RemoveBcastSqueezeTest, BcastSqueezeMultipleUses) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); + // Prevent meta ops from being bookended. + preseg_passes::OptimizationPassGuard + optimization_guard(false); + DataType input_dtype = DataType::Float; const std::vector is_broadcast_dim{false, false, true}; auto tv0 = makeContigTensor(2, input_dtype); @@ -78,7 +83,7 @@ TEST_F(RemoveBcastSqueezeTest, BcastSqueezeMultipleUses) { auto tv1 = set(tv0); auto tv2 = broadcast(tv1, is_broadcast_dim); auto tv3 = squeeze(tv2, is_broadcast_dim); - auto tv4 = add(tv3, tv3); + auto tv4 = set(tv3); auto tv5 = add(tv2, tvb); fusion->addOutput(tv4); fusion->addOutput(tv5); @@ -93,11 +98,13 @@ TEST_F(RemoveBcastSqueezeTest, BcastSqueezeMultipleUses) { // run fusion auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::ones({3, 4}, options); - at::Tensor t1 = t0.unsqueeze(-1); + at::Tensor tb = t0.unsqueeze(-1); + std::vector in_tensors({t0, tb}); FusionExecutorCache executor_cache(std::move(fusion)); - std::vector outputs = - executor_cache.runFusionWithInputs({t0, t1}); - testValidate(executor_cache.fusion(), outputs, {t0, t1}, __LINE__, __FILE__); + std::vector out_tensors = + executor_cache.runFusionWithInputs(in_tensors); + testValidate( + executor_cache.fusion(), out_tensors, in_tensors, __LINE__, __FILE__); } TEST_F(RemoveBcastSqueezeTest, BcastSqueezeUnmatchedDim) { diff --git a/tests/cpp/test_resize.cpp b/tests/cpp/test_resize.cpp index b74348f15d9..3d64703f768 100644 --- a/tests/cpp/test_resize.cpp +++ b/tests/cpp/test_resize.cpp @@ -1457,6 +1457,10 @@ TEST_F(ResizeTest, SliceReduceScheduler1) { auto& fusion = *fusion_ptr; FusionGuard fg(fusion_ptr.get()); + // This is to prevent `slice` from being bookended. + preseg_passes::OptimizationPassGuard + optimization_guard(false); + auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); @@ -1498,6 +1502,10 @@ TEST_F(ResizeTest, SliceReduceScheduler2) { auto& fusion = *fusion_ptr; FusionGuard fg(fusion_ptr.get()); + // This is to prevent `slice`s from being bookended. + preseg_passes::OptimizationPassGuard + optimization_guard(false); + auto tv0 = makeContigTensor(2); fusion.addInput(tv0); @@ -1537,11 +1545,15 @@ TEST_F(ResizeTest, SliceReduceScheduler2) { } // Multiple slice+reduction. Same slices. Should be segmented at the moment. -TEST_F(ResizeTest, FusionSliceReduceScheduler3) { +TEST_F(ResizeTest, SliceReduceScheduler3) { auto fusion_ptr = std::make_unique(); auto& fusion = *fusion_ptr; FusionGuard fg(fusion_ptr.get()); + // This is to prevent `slice`s from being bookended. + preseg_passes::OptimizationPassGuard + optimization_guard(false); + auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); @@ -2055,8 +2067,12 @@ TEST_F(ResizeTest, ResizeReshapeAndSlice) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); - EnableOptionsGuard opt_guard; + EnableOptionsGuard enable_guard; EnableOptionsGuard::getCurOptions().set(EnableOption::MemoryPromotion); + // This is to prevent the fusion from being accepted by NoOp, which would + // defeat the purpose of testing PointWise. + preseg_passes::OptimizationPassGuard + optimization_guard(false); auto tv0 = makeSymbolicTensor(2); fusion->addInput(tv0); @@ -2066,10 +2082,7 @@ TEST_F(ResizeTest, ResizeReshapeAndSlice) { tv1, {{IrBuilder::create(0L), IrBuilder::create(2L)}, {IrBuilder::create(0L), IrBuilder::create(2L)}}); - // Without the `add`, the fusion will be accepted by NoOp, defeating the - // purpose of testing PointWise. - auto tv3 = add(tv2, tv2); - fusion->addOutput(tv3); + fusion->addOutput(tv2); std::vector shape({4, 8}); @@ -2095,6 +2108,10 @@ TEST_F(ResizeTest, ResizePermuteAndSlice) { EnableOptionsGuard opt_guard; EnableOptionsGuard::getCurOptions().set(EnableOption::MemoryPromotion); + // This is to prevent `tv3->definition()` from being bookended. + preseg_passes::OptimizationPassGuard + optimization_guard(false); + // Set the problem size so that it can trigger the transpose // scheduler. The scheduler selection is validated below. auto num_sms = @@ -2110,8 +2127,7 @@ TEST_F(ResizeTest, ResizePermuteAndSlice) { {{IrBuilder::create(1L), IrBuilder::create(shape.at(0) - 1)}, {IrBuilder::create(2L), IrBuilder::create(shape.at(1) - 2)}}); auto tv3 = transpose(tv2, 0, 1); - auto tv5 = add(tv3, tv3); - fusion->addOutput(tv5); + fusion->addOutput(tv3); auto tv4 = add(tv2, IrBuilder::create(1.0)); fusion->addOutput(tv4); diff --git a/tests/cpp/test_scatter_gather.cpp b/tests/cpp/test_scatter_gather.cpp index c8a39e88b01..7613ab654cf 100644 --- a/tests/cpp/test_scatter_gather.cpp +++ b/tests/cpp/test_scatter_gather.cpp @@ -17,13 +17,25 @@ #include #include #include +#include +#include #include #include #include namespace nvfuser { -using ScatterGatherTest = NVFuserTest; +class ScatterGatherTest : public NVFuserTest { + protected: + // For convenience, disable MarkAliasesPreparePass. Many tests in this file + // start and/or end with meta ops. MarkAliasesPreparePass + // would bookend them, making tests less interesting. + ScatterGatherTest() : optimization_guard_(false) {} + + private: + preseg_passes::OptimizationPassGuard + optimization_guard_; +}; namespace { auto randomVector(int64_t low, int64_t high, int rank) { @@ -1122,10 +1134,7 @@ TEST_F(ScatterGatherTest, TakeAlongAxisIntermediateTensorTranspose3) { auto tv3 = broadcast(tv1, {true, false, false}); auto tv4 = take_along_axis(tv2, tv3, 2); auto tv5 = transpose(tv4, 1, 2); - // Without the `add`, the transpose will be taken by NoOp, defeating the - // purpose of testing PointWise. - auto tv6 = add(tv5, tv5); - fusion.addOutput(tv6); + fusion.addOutput(tv5); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto options_i = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);