diff --git a/tests/cpp/test_alias.cpp b/tests/cpp/test_alias.cpp index 0a8a427a6b8..b7e605ded76 100644 --- a/tests/cpp/test_alias.cpp +++ b/tests/cpp/test_alias.cpp @@ -1216,7 +1216,7 @@ TEST_F(AliasTest, Bookend_SegmentSetPreservesAllocation) { EXPECT_TRUE(permute_out_tensor.is_alias_of(in_tensor)); } -TEST_F(AliasTest, Bookend_InputsAndOutputs) { +TEST_F(AliasTest, Bookend_OneOutputAliasesTheOtherNot) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -1239,8 +1239,8 @@ TEST_F(AliasTest, Bookend_InputsAndOutputs) { FusionKernelRuntime* runtime = fec.getMostRecentKernelRuntime(); // MarkAliasesPrepare adds a `segment_set` between `in` and `permute`, which // leads to three segments: - // 1. segment_set`, a no-op segment, - // 2. permute`, a no-op segment, + // 1. `segment_set`, a no-op segment, + // 2. `permute`, a no-op segment, // 3. `mul` and `add`, a pointwise segment. EXPECT_THAT( runtime->fusionSegments()->groups(), @@ -1276,6 +1276,7 @@ TEST_F(AliasTest, Bookend_IntermediateTensors) { EXPECT_THAT( runtime->fusionSegments()->groups(), UnorderedElementsAre( + HeuristicIs(ScheduleHeuristic::NoOp), HeuristicIs(ScheduleHeuristic::NoOp), HeuristicIs(ScheduleHeuristic::PointWise))); for (SegmentedGroup* group : runtime->fusionSegments()->groups()) { @@ -1357,6 +1358,87 @@ TEST_F(AliasTest, Bookend_ReuseSegmentSet) { } } +// Repro for #2599. +TEST_F(AliasTest, Bookend_Rope) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + constexpr int64_t kNumHeads = 40; + constexpr int64_t kHiddenSize = 128; + + TensorView* qkv = TensorViewBuilder() + .ndims(3) + .dtype(DataType::Half) + .contiguity(true) + .shape({-1, -1, kNumHeads * kHiddenSize * 3}) + .build(); + TensorView* cos = TensorViewBuilder() + .ndims(2) + .dtype(DataType::Half) + .contiguity(true) + .shape({-1, kHiddenSize}) + .build(); + TensorView* sin = TensorViewBuilder() + .ndims(2) + .dtype(DataType::Half) + .contiguity(true) + .shape({-1, kHiddenSize}) + .build(); + fusion->addInput(qkv); + fusion->addInput(cos); + fusion->addInput(sin); + + qkv = reshape( + qkv, + {qkv->axis(0)->extent(), + qkv->axis(1)->extent(), + IrBuilder::create(kNumHeads), + IrBuilder::create(3), + IrBuilder::create(kHiddenSize)}); + qkv = permute(qkv, {0, 2, 3, 1, 4}); + + std::vector slices = chunk(qkv, /*chunks=*/3, /*dim=*/2); + auto* q = squeeze(slices[0], {2}); + auto* k = squeeze(slices[1], {2}); + auto* v = squeeze(slices[2], {2}); + + auto apply_rope = [cos, sin](TensorView* x) { + std::vector slices = chunk(x, /*chunks=*/2, /*dim=*/-1); + auto* real = slices[0]; + auto* imag = castOp(DataType::Half, neg(slices[1])); + TensorView* rotated = cat({imag, real}, -1); + TensorView* out = add(mul(x, cos), mul(rotated, sin)); + out = castOp(DataType::Half, out); + return out; + }; + q = apply_rope(q); + k = apply_rope(k); + + fusion->addOutput(q); + fusion->addOutput(k); + fusion->addOutput(v); + + constexpr int64_t kSeqLen = 4096; + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto qkv_tensor = + at::randn({2, kSeqLen, kNumHeads * kHiddenSize * 3}, options); + auto freq_tensor = at::randn({kSeqLen, kHiddenSize / 2}, options); + auto cos_tensor = freq_tensor.cos(); + auto sin_tensor = freq_tensor.sin(); + std::vector in_tensors = { + qkv_tensor, + at::cat({cos_tensor, cos_tensor}, -1), + at::cat({sin_tensor, sin_tensor}, -1)}; + + FusionExecutorCache fec(std::move(fusion)); + std::vector out_tensors = fec.runFusionWithInputs(in_tensors); + testValidate(fec.fusion(), out_tensors, in_tensors, __LINE__, __FILE__); + + EXPECT_THAT( + fec.getMostRecentKernelRuntime()->fusionSegments()->groups(), + Contains(HeuristicIs(ScheduleHeuristic::PointWise)).Times(4)); +} + TEST_F(AliasTest, QKVSplitBackprop) { // A subgraph of MoveSplitCatTest.Cancellable_Issue1768. constexpr int b = 16;