Skip to content

Commit

Permalink
Preserve allocation domain when inserting segment_set.
Browse files Browse the repository at this point in the history
For #2599.
  • Loading branch information
wujingyue committed Aug 19, 2024
1 parent 46b583a commit a2e3f05
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
11 changes: 11 additions & 0 deletions csrc/preseg_passes/mark_aliases_prepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <ops/alias.h>
#include <options.h>
#include <preseg_passes/mark_aliases_prepare.h>
#include <transform_replay.h>

namespace nvfuser::preseg_passes {

Expand Down Expand Up @@ -99,6 +100,16 @@ void insertSegmentSetAfter(

// The general case.
TensorView* copy = segment_set(use_of);
// Inherit the allocation domain from `use_of`. This is important to pass
// AliasTest.Bookend_SegmentSetPreservesAllocation.
TensorDomain* replayed_domain =
TransformReplay::replayCasP(
copy, use_of, -1, TransformReplayOptions().replayAllocation())
.first;
if (replayed_domain->hasAllocation()) {
copy->setAllocationDomain(
replayed_domain->allocation(), replayed_domain->contiguity());
}
std::for_each(first_user, last_user, [&](const Use& use) {
ir_utils::replaceValInExprInputs(use.user, use_of, copy);
});
Expand Down
24 changes: 24 additions & 0 deletions tests/cpp/test_alias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1196,6 +1196,30 @@ TEST_F(AliasTest, InplaceUpdate) {
UnorderedElementsAre(HeuristicIs(ScheduleHeuristic::PointWise)));
}

TEST_F(AliasTest, Bookend_SegmentSetPreservesAllocation) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

TensorView* in = makeContigConcreteTensor({2, 3});
TensorView* permute_out = permute(in, {1, 0});
TensorView* compute_out = mul(in, in);
fusion->addInput(in);
fusion->addOutput(permute_out);
fusion->addOutput(compute_out);

in->setAllocationDomain({in->axis(1), in->axis(0)}, true);
permute_out->setAllocationDomain(
{permute_out->axis(0), permute_out->axis(1)}, true);

FusionExecutorCache fec(std::move(fusion));
at::Tensor in_tensor = at::randn({3, 2}).cuda().transpose(0, 1);
std::vector<at::Tensor> out_tensors = fec.runFusionWithInputs({in_tensor});
testValidate(fec.fusion(), out_tensors, {in_tensor}, __LINE__, __FILE__);

at::Tensor permute_out_tensor = out_tensors[0];
EXPECT_TRUE(permute_out_tensor.is_alias_of(in_tensor));
}

TEST_F(AliasTest, Bookend_InputsAndOutputs) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());
Expand Down

0 comments on commit a2e3f05

Please sign in to comment.