diff --git a/csrc/transform_iter.cpp b/csrc/transform_iter.cpp index 236b0aae987..a5c2c22b6e2 100644 --- a/csrc/transform_iter.cpp +++ b/csrc/transform_iter.cpp @@ -357,6 +357,15 @@ void ReplayTransformations::runReplay() { [](std::pair entry) { return entry.first; }); } +#define ERROR_ON_FAILURE(cond) \ + do { \ + if (error_on_failure_) { \ + NVF_ERROR( \ + (cond), \ + "Error during best effort replay, a transformation was called that conflicts with an root-to-logical call."); \ + } \ + } while (false) + BestEffortReplay::BestEffortReplay( const std::vector& replay_domain, const std::vector& target_domain, @@ -365,12 +374,14 @@ BestEffortReplay::BestEffortReplay( std::unordered_map target_forward_id_map, bool skip_replay_swizzle, bool skip_target_swizzle, - bool skip_resize) + bool skip_resize, + bool error_on_failure) : target2replay_id_map_(std::move(target2replay_map)), replay_forward_id_map_(std::move(replay_forward_id_map)), target_forward_id_map_(std::move(target_forward_id_map)), skip_replay_swizzle_(skip_replay_swizzle), - skip_target_swizzle_(skip_target_swizzle) { + skip_target_swizzle_(skip_target_swizzle), + error_on_failure_(error_on_failure) { for (auto entry : target2replay_id_map_) { loop_ids_[entry.second] = counter++; } @@ -517,7 +528,7 @@ BestEffortReplay::BestEffortReplay( // If some replay id inputs are part of rfactor, make sure all target // expression inputs map to a replay input - if (replay_has_logical_inp) { + if (error_on_failure_ && replay_has_logical_inp) { bool no_missing_exprs = std::none_of( replay_inps.begin(), replay_inps.end(), @@ -539,9 +550,8 @@ BestEffortReplay::BestEffortReplay( // If any inputs are missing, continue as this expr doesn't match. if (missing_replay_input) { - NVF_ERROR( - !replay_has_logical_inp || any_target_expr_contains_broadcast_id, - err_str); + ERROR_ON_FAILURE( + !replay_has_logical_inp || any_target_expr_contains_broadcast_id); continue; } @@ -568,7 +578,7 @@ BestEffortReplay::BestEffortReplay( // If expressions of mapped inputs don't match, then continue to next target // expr if (mismatched_replay_exprs || replay_expr == nullptr) { - NVF_ERROR(!replay_has_logical_inp, err_str); + ERROR_ON_FAILURE(!replay_has_logical_inp); continue; } @@ -581,27 +591,21 @@ BestEffortReplay::BestEffortReplay( // If there isn't an logical id in the replay's inputs and there's a // mismatched input, continue if (mismatched_inputs) { - NVF_ERROR(!replay_has_logical_inp, err_str); + ERROR_ON_FAILURE(!replay_has_logical_inp); continue; } // If there isn't an logical id in the replay's inputs and there's a // mismatch in replay_expr's and target_expr's outputs, continue if (target_expr->outputs().size() != replay_expr->outputs().size()) { - NVF_ERROR( - !replay_has_logical_inp, - err_str, - ". Target: ", - target_expr->toString(), - ", repaly: ", - replay_expr->toString()); + ERROR_ON_FAILURE(!replay_has_logical_inp); continue; } // If there isn't an logical id in the replay's inputs and there's a // mismatch in replay_expr's and target_expr's expression type, continue if (typeid(*replay_expr) != typeid(*target_expr)) { - NVF_ERROR(!replay_has_logical_inp, err_str); + ERROR_ON_FAILURE(!replay_has_logical_inp); continue; } @@ -613,7 +617,7 @@ BestEffortReplay::BestEffortReplay( auto t_split = target_expr->as(); if (!r_split->factor()->sameAs(t_split->factor()) || r_split->innerSplit() != t_split->innerSplit()) { - NVF_ERROR(!replay_has_logical_inp, err_str); + ERROR_ON_FAILURE(!replay_has_logical_inp); continue; } } @@ -625,7 +629,7 @@ BestEffortReplay::BestEffortReplay( auto r_swizzle_2d = replay_expr->as(); auto t_swizzle_2d = target_expr->as(); if (!(r_swizzle_2d->swizzleType() == t_swizzle_2d->swizzleType())) { - NVF_ERROR(!replay_has_logical_inp, err_str); + ERROR_ON_FAILURE(!replay_has_logical_inp); continue; } } @@ -635,7 +639,7 @@ BestEffortReplay::BestEffortReplay( auto t_resize = target_expr->as(); if (!r_resize->leftExpand()->sameAs(t_resize->leftExpand()) || !r_resize->rightExpand()->sameAs(t_resize->rightExpand())) { - NVF_ERROR(!replay_has_logical_inp, err_str); + ERROR_ON_FAILURE(!replay_has_logical_inp); continue; } } @@ -681,6 +685,8 @@ BestEffortReplay::BestEffortReplay( } } +#undef ERROR_ON_FAILURE + // Find the first position i where td1[i] is not the same as td2[i]. // "Same" means the DAG to generate td1[i] and td2[i] are the // equivelent. diff --git a/csrc/transform_iter.h b/csrc/transform_iter.h index b3253dc6a17..c5a26b7c089 100644 --- a/csrc/transform_iter.h +++ b/csrc/transform_iter.h @@ -382,6 +382,8 @@ class BestEffortReplay { bool skip_replay_swizzle_ = true; bool skip_target_swizzle_ = true; + bool error_on_failure_ = true; + bool inReplayForwardMap(IterDomain* id) const { return replay_forward_id_map_.find(id) != replay_forward_id_map_.end(); } @@ -438,7 +440,8 @@ class BestEffortReplay { std::unordered_map target_forward_id_map = {}, bool skip_replay_swizzle = true, bool skip_target_swizzle = true, - bool skip_resize = false); + bool skip_resize = false, + bool error_on_failure = true); // Return iter domain map from target_domain IDs to their "replayed" // replay_domain IDs. If not in map, was not replayed. diff --git a/tests/cpp/test_gpu3.cpp b/tests/cpp/test_gpu3.cpp index eecc11cf03f..521c1c6d4b7 100644 --- a/tests/cpp/test_gpu3.cpp +++ b/tests/cpp/test_gpu3.cpp @@ -8815,6 +8815,54 @@ TEST_F(NVFuserTest, ReplaceSymbolicSizes) { EXPECT_EQ(tv5->axis(0)->extent()->toInlineString(), "5"); } +// Make sure BestEffortReplay with error_on_failure=false does not +// complain about missing root-to-logical IterDomain ops +TEST_F(NVFuserTest, BestEffortReplayWithMismatchedRootToLogical) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({2, 4}); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = reshape(tv1, {2, 4}, {8}); + fusion.addOutput(tv2); + + // This split does not exist in tv2 + tv1->split(0, 1); + + // Due to the split of tv1, BestEffortReplay would not find any + // matching transformations. If error_on_failure is true, it would + // result in an error. + EXPECT_THAT( + [&]() { + BestEffortReplay replay( + tv2->getLoopDomain(), + tv1->getLoopDomain(), + PairwiseLogicalDomainMap(tv1, tv2).mapProducerToConsumer(), + /*replay_forward_id_map=*/{}, + /*target_forward_id_map=*/{}, + /*skip_replay_swizzle=*/false, + /*skip_target_swizzle=*/false, + /*skip_resize=*/false, + /*error_on_failure=*/true); + }, + ::testing::ThrowsMessage( + ::testing::HasSubstr("conflicts with an root-to-logical call"))); + + // Should not result in an error as error_on_failure is false + BestEffortReplay replay( + tv2->getLoopDomain(), + tv1->getLoopDomain(), + PairwiseLogicalDomainMap(tv1, tv2).mapProducerToConsumer(), + /*replay_forward_id_map=*/{}, + /*target_forward_id_map=*/{}, + /*skip_replay_swizzle=*/false, + /*skip_target_swizzle=*/false, + /*skip_resize=*/false, + /*error_on_failure=*/false); +} + // Test file size should be up to 10K LoC. Create a new file for more tests. } // namespace nvfuser