Skip to content

Commit

Permalink
Add a repro.
Browse files Browse the repository at this point in the history
For #2599.
  • Loading branch information
wujingyue committed Aug 29, 2024
1 parent 58dfdc1 commit 26f1a73
Showing 1 changed file with 85 additions and 3 deletions.
88 changes: 85 additions & 3 deletions tests/cpp/test_alias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1216,7 +1216,7 @@ TEST_F(AliasTest, Bookend_SegmentSetPreservesAllocation) {
EXPECT_TRUE(permute_out_tensor.is_alias_of(in_tensor));
}

TEST_F(AliasTest, Bookend_InputsAndOutputs) {
TEST_F(AliasTest, Bookend_OneOutputAliasesTheOtherNot) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

Expand All @@ -1239,8 +1239,8 @@ TEST_F(AliasTest, Bookend_InputsAndOutputs) {
FusionKernelRuntime* runtime = fec.getMostRecentKernelRuntime();
// MarkAliasesPrepare adds a `segment_set` between `in` and `permute`, which
// leads to three segments:
// 1. segment_set`, a no-op segment,
// 2. permute`, a no-op segment,
// 1. `segment_set`, a no-op segment,
// 2. `permute`, a no-op segment,
// 3. `mul` and `add`, a pointwise segment.
EXPECT_THAT(
runtime->fusionSegments()->groups(),
Expand Down Expand Up @@ -1276,6 +1276,7 @@ TEST_F(AliasTest, Bookend_IntermediateTensors) {
EXPECT_THAT(
runtime->fusionSegments()->groups(),
UnorderedElementsAre(
HeuristicIs(ScheduleHeuristic::NoOp),
HeuristicIs(ScheduleHeuristic::NoOp),
HeuristicIs(ScheduleHeuristic::PointWise)));
for (SegmentedGroup* group : runtime->fusionSegments()->groups()) {
Expand Down Expand Up @@ -1357,6 +1358,87 @@ TEST_F(AliasTest, Bookend_ReuseSegmentSet) {
}
}

// Repro for #2599.
TEST_F(AliasTest, Bookend_Rope) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

constexpr int64_t kNumHeads = 40;
constexpr int64_t kHiddenSize = 128;

TensorView* qkv = TensorViewBuilder()
.ndims(3)
.dtype(DataType::Half)
.contiguity(true)
.shape({-1, -1, kNumHeads * kHiddenSize * 3})
.build();
TensorView* cos = TensorViewBuilder()
.ndims(2)
.dtype(DataType::Half)
.contiguity(true)
.shape({-1, kHiddenSize})
.build();
TensorView* sin = TensorViewBuilder()
.ndims(2)
.dtype(DataType::Half)
.contiguity(true)
.shape({-1, kHiddenSize})
.build();
fusion->addInput(qkv);
fusion->addInput(cos);
fusion->addInput(sin);

qkv = reshape(
qkv,
{qkv->axis(0)->extent(),
qkv->axis(1)->extent(),
IrBuilder::create<Val>(kNumHeads),
IrBuilder::create<Val>(3),
IrBuilder::create<Val>(kHiddenSize)});
qkv = permute(qkv, {0, 2, 3, 1, 4});

std::vector<TensorView*> slices = chunk(qkv, /*chunks=*/3, /*dim=*/2);
auto* q = squeeze(slices[0], {2});
auto* k = squeeze(slices[1], {2});
auto* v = squeeze(slices[2], {2});

auto apply_rope = [cos, sin](TensorView* x) {
std::vector<TensorView*> slices = chunk(x, /*chunks=*/2, /*dim=*/-1);
auto* real = slices[0];
auto* imag = castOp(DataType::Half, neg(slices[1]));
TensorView* rotated = cat({imag, real}, -1);
TensorView* out = add(mul(x, cos), mul(rotated, sin));
out = castOp(DataType::Half, out);
return out;
};
q = apply_rope(q);
k = apply_rope(k);

fusion->addOutput(q);
fusion->addOutput(k);
fusion->addOutput(v);

constexpr int64_t kSeqLen = 4096;
auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
auto qkv_tensor =
at::randn({2, kSeqLen, kNumHeads * kHiddenSize * 3}, options);
auto freq_tensor = at::randn({kSeqLen, kHiddenSize / 2}, options);
auto cos_tensor = freq_tensor.cos();
auto sin_tensor = freq_tensor.sin();
std::vector<c10::IValue> in_tensors = {
qkv_tensor,
at::cat({cos_tensor, cos_tensor}, -1),
at::cat({sin_tensor, sin_tensor}, -1)};

FusionExecutorCache fec(std::move(fusion));
std::vector<at::Tensor> out_tensors = fec.runFusionWithInputs(in_tensors);
testValidate(fec.fusion(), out_tensors, in_tensors, __LINE__, __FILE__);

EXPECT_THAT(
fec.getMostRecentKernelRuntime()->fusionSegments()->groups(),
Contains(HeuristicIs(ScheduleHeuristic::PointWise)).Times(4));
}

TEST_F(AliasTest, QKVSplitBackprop) {
// A subgraph of MoveSplitCatTest.Cancellable_Issue1768.
constexpr int b = 16;
Expand Down

0 comments on commit 26f1a73

Please sign in to comment.