Skip to content

Commit

Permalink
Allow silently ignore missing root-to-logical ops in BestEffortReplay (
Browse files Browse the repository at this point in the history
…#2901)

Extracted from and required for #2875
  • Loading branch information
naoyam authored Sep 4, 2024
1 parent 892b7ac commit 7be78f8
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 20 deletions.
44 changes: 25 additions & 19 deletions csrc/transform_iter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,15 @@ void ReplayTransformations::runReplay() {
[](std::pair<IterDomain*, size_t> entry) { return entry.first; });
}

#define ERROR_ON_FAILURE(cond) \
do { \
if (error_on_failure_) { \
NVF_ERROR( \
(cond), \
"Error during best effort replay, a transformation was called that conflicts with an root-to-logical call."); \
} \
} while (false)

BestEffortReplay::BestEffortReplay(
const std::vector<IterDomain*>& replay_domain,
const std::vector<IterDomain*>& target_domain,
Expand All @@ -365,12 +374,14 @@ BestEffortReplay::BestEffortReplay(
std::unordered_map<IterDomain*, IterDomain*> target_forward_id_map,
bool skip_replay_swizzle,
bool skip_target_swizzle,
bool skip_resize)
bool skip_resize,
bool error_on_failure)
: target2replay_id_map_(std::move(target2replay_map)),
replay_forward_id_map_(std::move(replay_forward_id_map)),
target_forward_id_map_(std::move(target_forward_id_map)),
skip_replay_swizzle_(skip_replay_swizzle),
skip_target_swizzle_(skip_target_swizzle) {
skip_target_swizzle_(skip_target_swizzle),
error_on_failure_(error_on_failure) {
for (auto entry : target2replay_id_map_) {
loop_ids_[entry.second] = counter++;
}
Expand Down Expand Up @@ -517,7 +528,7 @@ BestEffortReplay::BestEffortReplay(

// If some replay id inputs are part of rfactor, make sure all target
// expression inputs map to a replay input
if (replay_has_logical_inp) {
if (error_on_failure_ && replay_has_logical_inp) {
bool no_missing_exprs = std::none_of(
replay_inps.begin(),
replay_inps.end(),
Expand All @@ -539,9 +550,8 @@ BestEffortReplay::BestEffortReplay(

// If any inputs are missing, continue as this expr doesn't match.
if (missing_replay_input) {
NVF_ERROR(
!replay_has_logical_inp || any_target_expr_contains_broadcast_id,
err_str);
ERROR_ON_FAILURE(
!replay_has_logical_inp || any_target_expr_contains_broadcast_id);
continue;
}

Expand All @@ -568,7 +578,7 @@ BestEffortReplay::BestEffortReplay(
// If expressions of mapped inputs don't match, then continue to next target
// expr
if (mismatched_replay_exprs || replay_expr == nullptr) {
NVF_ERROR(!replay_has_logical_inp, err_str);
ERROR_ON_FAILURE(!replay_has_logical_inp);
continue;
}

Expand All @@ -581,27 +591,21 @@ BestEffortReplay::BestEffortReplay(
// If there isn't an logical id in the replay's inputs and there's a
// mismatched input, continue
if (mismatched_inputs) {
NVF_ERROR(!replay_has_logical_inp, err_str);
ERROR_ON_FAILURE(!replay_has_logical_inp);
continue;
}

// If there isn't an logical id in the replay's inputs and there's a
// mismatch in replay_expr's and target_expr's outputs, continue
if (target_expr->outputs().size() != replay_expr->outputs().size()) {
NVF_ERROR(
!replay_has_logical_inp,
err_str,
". Target: ",
target_expr->toString(),
", repaly: ",
replay_expr->toString());
ERROR_ON_FAILURE(!replay_has_logical_inp);
continue;
}

// If there isn't an logical id in the replay's inputs and there's a
// mismatch in replay_expr's and target_expr's expression type, continue
if (typeid(*replay_expr) != typeid(*target_expr)) {
NVF_ERROR(!replay_has_logical_inp, err_str);
ERROR_ON_FAILURE(!replay_has_logical_inp);
continue;
}

Expand All @@ -613,7 +617,7 @@ BestEffortReplay::BestEffortReplay(
auto t_split = target_expr->as<Split>();
if (!r_split->factor()->sameAs(t_split->factor()) ||
r_split->innerSplit() != t_split->innerSplit()) {
NVF_ERROR(!replay_has_logical_inp, err_str);
ERROR_ON_FAILURE(!replay_has_logical_inp);
continue;
}
}
Expand All @@ -625,7 +629,7 @@ BestEffortReplay::BestEffortReplay(
auto r_swizzle_2d = replay_expr->as<Swizzle2D>();
auto t_swizzle_2d = target_expr->as<Swizzle2D>();
if (!(r_swizzle_2d->swizzleType() == t_swizzle_2d->swizzleType())) {
NVF_ERROR(!replay_has_logical_inp, err_str);
ERROR_ON_FAILURE(!replay_has_logical_inp);
continue;
}
}
Expand All @@ -635,7 +639,7 @@ BestEffortReplay::BestEffortReplay(
auto t_resize = target_expr->as<Resize>();
if (!r_resize->leftExpand()->sameAs(t_resize->leftExpand()) ||
!r_resize->rightExpand()->sameAs(t_resize->rightExpand())) {
NVF_ERROR(!replay_has_logical_inp, err_str);
ERROR_ON_FAILURE(!replay_has_logical_inp);
continue;
}
}
Expand Down Expand Up @@ -681,6 +685,8 @@ BestEffortReplay::BestEffortReplay(
}
}

#undef ERROR_ON_FAILURE

// Find the first position i where td1[i] is not the same as td2[i].
// "Same" means the DAG to generate td1[i] and td2[i] are the
// equivelent.
Expand Down
5 changes: 4 additions & 1 deletion csrc/transform_iter.h
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,8 @@ class BestEffortReplay {
bool skip_replay_swizzle_ = true;
bool skip_target_swizzle_ = true;

bool error_on_failure_ = true;

bool inReplayForwardMap(IterDomain* id) const {
return replay_forward_id_map_.find(id) != replay_forward_id_map_.end();
}
Expand Down Expand Up @@ -438,7 +440,8 @@ class BestEffortReplay {
std::unordered_map<IterDomain*, IterDomain*> target_forward_id_map = {},
bool skip_replay_swizzle = true,
bool skip_target_swizzle = true,
bool skip_resize = false);
bool skip_resize = false,
bool error_on_failure = true);

// Return iter domain map from target_domain IDs to their "replayed"
// replay_domain IDs. If not in map, was not replayed.
Expand Down
48 changes: 48 additions & 0 deletions tests/cpp/test_gpu3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8815,6 +8815,54 @@ TEST_F(NVFuserTest, ReplaceSymbolicSizes) {
EXPECT_EQ(tv5->axis(0)->extent()->toInlineString(), "5");
}

// Make sure BestEffortReplay with error_on_failure=false does not
// complain about missing root-to-logical IterDomain ops
TEST_F(NVFuserTest, BestEffortReplayWithMismatchedRootToLogical) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeConcreteTensor({2, 4});
fusion.addInput(tv0);

auto tv1 = set(tv0);
auto tv2 = reshape(tv1, {2, 4}, {8});
fusion.addOutput(tv2);

// This split does not exist in tv2
tv1->split(0, 1);

// Due to the split of tv1, BestEffortReplay would not find any
// matching transformations. If error_on_failure is true, it would
// result in an error.
EXPECT_THAT(
[&]() {
BestEffortReplay replay(
tv2->getLoopDomain(),
tv1->getLoopDomain(),
PairwiseLogicalDomainMap(tv1, tv2).mapProducerToConsumer(),
/*replay_forward_id_map=*/{},
/*target_forward_id_map=*/{},
/*skip_replay_swizzle=*/false,
/*skip_target_swizzle=*/false,
/*skip_resize=*/false,
/*error_on_failure=*/true);
},
::testing::ThrowsMessage<nvfuser::nvfError>(
::testing::HasSubstr("conflicts with an root-to-logical call")));

// Should not result in an error as error_on_failure is false
BestEffortReplay replay(
tv2->getLoopDomain(),
tv1->getLoopDomain(),
PairwiseLogicalDomainMap(tv1, tv2).mapProducerToConsumer(),
/*replay_forward_id_map=*/{},
/*target_forward_id_map=*/{},
/*skip_replay_swizzle=*/false,
/*skip_target_swizzle=*/false,
/*skip_resize=*/false,
/*error_on_failure=*/false);
}

// Test file size should be up to 10K LoC. Create a new file for more tests.

} // namespace nvfuser

0 comments on commit 7be78f8

Please sign in to comment.