Skip to content

Commit

Permalink
Add a repro.
Browse files Browse the repository at this point in the history
For #2599.
  • Loading branch information
wujingyue committed Aug 20, 2024
1 parent 90623fe commit 37af073
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 3 deletions.
25 changes: 25 additions & 0 deletions csrc/ops/alias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -840,4 +840,29 @@ TensorView* slice(
return slice(inp, slices);
}

std::vector<TensorView*> split(
TensorView* in,
int64_t dim,
const int64_t num_slices) {
const auto in_logical = TensorDomain::noReductions(in->getLogicalDomain());
const auto num_dims = static_cast<int64_t>(in_logical.size());
dim = wrapDim(dim, num_dims);
Val* dim_size = in_logical[dim]->extent();
Val* slice_size = SimplifyingIrBuilder::divExpr(
dim_size, IrBuilder::create<Val>(num_slices));

std::vector<TensorView*> slices;
slices.reserve(num_slices);
std::vector<Slice> ranges(num_dims);
for (auto i : c10::irange(num_slices)) {
ranges[dim].start = ranges[dim].stop;
ranges[dim].stop =
(i == num_slices - 1 ? nullptr
: SimplifyingIrBuilder::mulExpr(
slice_size, IrBuilder::create<Val>(i + 1)));
slices.push_back(slice(in, ranges));
}
return slices;
}

} // namespace nvfuser
4 changes: 4 additions & 0 deletions csrc/ops/alias.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,8 @@ NVF_API TensorView* slice(
const std::vector<int64_t>& starts,
const std::vector<int64_t>& stops);

// Splits `in`'s dimension `dim` into `num_slices` slices. All but the last
// slice will be of size `floor(dim_size/num_slices)`.
std::vector<TensorView*> split(TensorView* in, int64_t dim, int64_t num_slices);

} // namespace nvfuser
88 changes: 85 additions & 3 deletions tests/cpp/test_alias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1220,7 +1220,7 @@ TEST_F(AliasTest, Bookend_SegmentSetPreservesAllocation) {
EXPECT_TRUE(permute_out_tensor.is_alias_of(in_tensor));
}

TEST_F(AliasTest, Bookend_InputsAndOutputs) {
TEST_F(AliasTest, Bookend_OneOutputAliasesTheOtherNot) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

Expand All @@ -1243,8 +1243,8 @@ TEST_F(AliasTest, Bookend_InputsAndOutputs) {
FusionKernelRuntime* runtime = fec.getMostRecentKernelRuntime();
// MarkAliasesPrepare adds a `segment_set` between `in` and `permute`, which
// leads to three segments:
// 1. segment_set`, a no-op segment,
// 2. permute`, a no-op segment,
// 1. `segment_set`, a no-op segment,
// 2. `permute`, a no-op segment,
// 3. `mul` and `add`, a pointwise segment.
EXPECT_THAT(
runtime->fusionSegments()->groups(),
Expand Down Expand Up @@ -1280,6 +1280,7 @@ TEST_F(AliasTest, Bookend_IntermediateTensors) {
EXPECT_THAT(
runtime->fusionSegments()->groups(),
UnorderedElementsAre(
HeuristicIs(ScheduleHeuristic::NoOp),
HeuristicIs(ScheduleHeuristic::NoOp),
HeuristicIs(ScheduleHeuristic::PointWise)));
for (SegmentedGroup* group : runtime->fusionSegments()->groups()) {
Expand Down Expand Up @@ -1361,6 +1362,87 @@ TEST_F(AliasTest, Bookend_ReuseSegmentSet) {
}
}

// Repro for #2599.
TEST_F(AliasTest, Bookend_Rope) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

constexpr int64_t kNumHeads = 40;
constexpr int64_t kHiddenSize = 128;

TensorView* qkv = TensorViewBuilder()
.ndims(3)
.dtype(DataType::Half)
.contiguity(true)
.shape({-1, -1, kNumHeads * kHiddenSize * 3})
.build();
TensorView* cos = TensorViewBuilder()
.ndims(2)
.dtype(DataType::Half)
.contiguity(true)
.shape({-1, kHiddenSize})
.build();
TensorView* sin = TensorViewBuilder()
.ndims(2)
.dtype(DataType::Half)
.contiguity(true)
.shape({-1, kHiddenSize})
.build();
fusion->addInput(qkv);
fusion->addInput(cos);
fusion->addInput(sin);

qkv = reshape(
qkv,
{qkv->axis(0)->extent(),
qkv->axis(1)->extent(),
IrBuilder::create<Val>(kNumHeads),
IrBuilder::create<Val>(3),
IrBuilder::create<Val>(kHiddenSize)});
qkv = permute(qkv, {0, 2, 3, 1, 4});

std::vector<TensorView*> slices = split(qkv, /*dim=*/2, /*num_slices=*/3);
auto* q = squeeze(slices[0], {2});
auto* k = squeeze(slices[1], {2});
auto* v = squeeze(slices[2], {2});

auto apply_rope = [cos, sin](TensorView* x) {
std::vector<TensorView*> slices = split(x, /*dim=*/-1, /*num_slices=*/2);
auto* real = slices[0];
auto* imag = castOp(DataType::Half, neg(slices[1]));
TensorView* rotated = cat({imag, real}, -1);
TensorView* out = add(mul(x, cos), mul(rotated, sin));
out = castOp(DataType::Half, out);
return out;
};
q = apply_rope(q);
k = apply_rope(k);

fusion->addOutput(q);
fusion->addOutput(k);
fusion->addOutput(v);

constexpr int64_t kSeqLen = 4096;
auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
auto qkv_tensor =
at::randn({2, kSeqLen, kNumHeads * kHiddenSize * 3}, options);
auto freq_tensor = at::randn({kSeqLen, kHiddenSize / 2}, options);
auto cos_tensor = freq_tensor.cos();
auto sin_tensor = freq_tensor.sin();
std::vector<c10::IValue> in_tensors = {
qkv_tensor,
at::cat({cos_tensor, cos_tensor}, -1),
at::cat({sin_tensor, sin_tensor}, -1)};

FusionExecutorCache fec(std::move(fusion));
std::vector<at::Tensor> out_tensors = fec.runFusionWithInputs(in_tensors);
testValidate(fec.fusion(), out_tensors, in_tensors, __LINE__, __FILE__);

EXPECT_THAT(
fec.getMostRecentKernelRuntime()->fusionSegments()->groups(),
Contains(HeuristicIs(ScheduleHeuristic::PointWise)).Times(4));
}

TEST_F(AliasTest, QKVSplitBackprop) {
// A subgraph of MoveSplitCatTest.Cancellable_Issue1768.
constexpr int b = 16;
Expand Down

0 comments on commit 37af073

Please sign in to comment.