Skip to content

Commit

Permalink
Review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue committed Jul 27, 2024
1 parent 20f5b73 commit ec79fc5
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 67 deletions.
131 changes: 72 additions & 59 deletions csrc/preseg_passes/mark_aliases_prepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,22 @@ struct Use {
Use findUseToSegment(
TensorView* out,
const AliasAnalysisResult& analysis,
const std::unordered_set<Expr*>& used_by_non_aliases) {
const std::unordered_set<Expr*>& depended_by_non_aliases) {
Expr* user = nullptr;
while (true) {
Expr* def = out->definition();
if (analysis.getRoot(out) == nullptr || used_by_non_aliases.count(def)) {
if (analysis.getRoot(out) == nullptr ||
depended_by_non_aliases.count(def)) {
return {out, user};
}
out = def->input(0)->as<TensorView>();
user = def;
}
}

// Collects all expressions that are (transitively) used by non-alias
// TensorViews.
std::unordered_set<Expr*> exprsUsedByNonAliases(
// Collects all expressions that are depended (i.e. transitively used) by
// non-alias TensorViews.
std::unordered_set<Expr*> exprsDependedByNonAliases(
const AliasAnalysisResult& analysis,
Fusion* fusion) {
std::vector<Val*> non_aliases;
Expand All @@ -59,8 +60,51 @@ std::unordered_set<Expr*> exprsUsedByNonAliases(
non_aliases.push_back(tv);
}
}
std::vector<Expr*> used_by_non_aliases = StmtSort::getExprsTo(non_aliases);
return {used_by_non_aliases.begin(), used_by_non_aliases.end()};
std::vector<Expr*> depended_by_non_aliases =
StmtSort::getExprsTo(non_aliases);
return {depended_by_non_aliases.begin(), depended_by_non_aliases.end()};
}

// Inserts a `segment_set` after `use_of` and redirect aliasing users to
// use the `segment_set`.
void insertSegmentSetAfter(
std::vector<Use>::const_iterator first_user,
std::vector<Use>::const_iterator last_user) {
TensorView* use_of = first_user->use_of;

// There are a few corner cases where we don't need to add a
// `segment_set`. If `use_of` is only used by aliases, ...
if (static_cast<size_t>(std::distance(first_user, last_user)) ==
use_of->uses().size()) {
if (use_of->isFusionInput()) {
// Putting a `segment_set` between a fusion input and its users is
// unnecessary.
return;
}

// Rarely, if `use_of` is already defined by `segment_set`, don't
// create another `segment_set`.
if (ir_utils::isSegmentSet(use_of->definition())) {
return;
}
}

// If all aliasing users are `segment_set`, don't create another
// `segment_set`.
if (std::all_of(first_user, last_user, [](const Use& use) {
return ir_utils::isSegmentSet(use.user);
})) {
return;
}

// The general case.
TensorView* copy = segment_set(use_of);
std::for_each(first_user, last_user, [&](const Use& use) {
ir_utils::replaceValInExprInputs(use.user, use_of, copy);
});
if (use_of->isFusionOutput()) {
use_of->fusion()->replaceOutput(use_of, copy);
}
}

} // namespace
Expand Down Expand Up @@ -116,13 +160,14 @@ void MarkAliasesPreparePass::runPass(Fusion* fusion) {
//
// we want to avoid putting a `segment_set` before M1, a meta op, because
// that would lead to two kernels. See AliasTest.DoNotOverSegment_* for more
// examples.
const std::unordered_set<Expr*>& used_by_non_aliases =
exprsUsedByNonAliases(analysis, fusion);
// examples. This is the reason behind `depended_by_non_aliases`.
const std::unordered_set<Expr*>& depended_by_non_aliases =
exprsDependedByNonAliases(analysis, fusion);
std::vector<Use> uses_to_segment;
uses_to_segment.reserve(fusion->outputs().size());
for (auto* out : ir_utils::filterByType<TensorView>(fusion->outputs())) {
Use use_to_segment = findUseToSegment(out, analysis, used_by_non_aliases);
Use use_to_segment =
findUseToSegment(out, analysis, depended_by_non_aliases);
if (use_to_segment.use_of != out) {
uses_to_segment.push_back(use_to_segment);
}
Expand All @@ -143,54 +188,22 @@ void MarkAliasesPreparePass::runPass(Fusion* fusion) {
}
}

auto i = uses_to_segment.begin();
while (i != uses_to_segment.end()) {
TensorView* use_of = i->use_of;
auto j = i;
while (j != uses_to_segment.end() && j->use_of == use_of) {
j++;
}

auto insert_segment_set = [&]() {
// Put a `segment_set` after `use_of` and redirect aliasing users to
// use the `segment_set`.

// There are a few corner cases where we don't need to add a
// `segment_set`. If `use_of` is only used by aliases, ...
if (static_cast<size_t>(std::distance(i, j)) == use_of->uses().size()) {
if (use_of->isFusionInput()) {
// Putting a `segment_set` between a fusion input and its users is
// unnecessary.
return;
}

// Rarely, if `use_of` is already defined by `segment_set`, don't
// create another `segment_set`.
if (ir_utils::isSegmentSet(use_of->definition())) {
return;
}
}

// If all aliasing users are `segment_set`, don't create another
// `segment_set`.
if (std::all_of(i, j, [](const Use& use) {
return ir_utils::isSegmentSet(use.user);
})) {
return;
}

// The general case.
TensorView* copy = segment_set(use_of);
std::for_each(i, j, [&](const Use& use) {
ir_utils::replaceValInExprInputs(use.user, use_of, copy);
});
if (use_of->isFusionOutput()) {
fusion->replaceOutput(use_of, copy);
}
};
insert_segment_set();

i = j;
// Because `uses_to_segment` has been sorted by the TensorView being used, we
// use a double nested while loop to find and process all the users for each
// TensorView.
auto first_user = uses_to_segment.begin();
while (first_user != uses_to_segment.end()) {
TensorView* use_of = first_user->use_of;
auto last_user = first_user;
do {
last_user++;
} while (last_user != uses_to_segment.end() && last_user->use_of == use_of);
// At this point, <first_user,last_user> points the first user of `use_of`
// and one past the last user.

insertSegmentSetAfter(first_user, last_user);

first_user = last_user;
}

if (isDebugDumpEnabled(DebugDumpOption::PreSegmenterLogging)) {
Expand Down
3 changes: 1 addition & 2 deletions csrc/scheduler/mark_aliases.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ void markAliases(Fusion* fusion) {
vlog("Alias analysis result:\n", analysis.toString(/*indent_size=*/1));
}

for (TensorView* out :
ir_utils::filterByType<TensorView>(fusion->outputs())) {
for (auto* out : ir_utils::filterByType<TensorView>(fusion->outputs())) {
// AllocationType::ReuseBuffer requires the output to be updated in place
// so it can't be computed as an alias.
if (fusion->getOutputAlias(out).type == AllocationType::ReuseBuffer) {
Expand Down
2 changes: 2 additions & 0 deletions tests/cpp/test_resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2001,6 +2001,8 @@ TEST_F(ResizeTest, ResizeReshapeAndSlice) {
tv1,
{{IrBuilder::create<Val>(0L), IrBuilder::create<Val>(2L)},
{IrBuilder::create<Val>(0L), IrBuilder::create<Val>(2L)}});
// Without the `add`, the fusion will be accepted by NoOp, defeating the
// purpose of testing PointWise.
auto tv3 = add(tv2, tv2);
fusion->addOutput(tv3);

Expand Down
10 changes: 4 additions & 6 deletions tests/cpp/test_scatter_gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
#include <ir/builder.h>
#include <kernel_cache.h>
#include <ops/all_ops.h>
#include <preseg_passes/mark_aliases_prepare.h>
#include <preseg_passes/optimization_pass.h>
#include <scheduler/all_schedulers.h>
#include <tests/cpp/utils.h>
#include <tests/cpp/validator.h>
Expand Down Expand Up @@ -1102,9 +1100,6 @@ TEST_F(ScatterGatherTest, TakeAlongAxisIntermediateTensorTranspose2) {
// transpose the dimension produced by take_along_axis. Currently not
// supported by the transpose scheduler
TEST_F(ScatterGatherTest, TakeAlongAxisIntermediateTensorTranspose3) {
preseg_passes::OptimizationPassGuard<preseg_passes::MarkAliasesPreparePass>
optimization_guard(false);

auto fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr.get();
FusionGuard fg(&fusion);
Expand All @@ -1127,7 +1122,10 @@ TEST_F(ScatterGatherTest, TakeAlongAxisIntermediateTensorTranspose3) {
auto tv3 = broadcast(tv1, {true, false, false});
auto tv4 = take_along_axis(tv2, tv3, 2);
auto tv5 = transpose(tv4, 1, 2);
fusion.addOutput(tv5);
// Without the `add`, the transpose will be taken by NoOp, defeating the
// purpose of testing PointWise.
auto tv6 = add(tv5, tv5);
fusion.addOutput(tv6);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto options_i = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
Expand Down

0 comments on commit ec79fc5

Please sign in to comment.