-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Resolve conflicts by recomputation #3625
Changes from 8 commits
1a370e7
ac5a1bc
8b8c708
d364442
7380a40
eb9fffa
9631958
76dbab9
75338a4
221a323
e48a2f6
f9a2d37
d66a67d
b4f1391
c0e5b0a
b5cdfcf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -73,19 +73,6 @@ bool ResizeScheduler::canScheduleCompileTime(Fusion* fusion) { | |
|
||
auto resize_based_tensor_ops = ir_utils::getOpsOfType<SliceOp, PadOp>(fusion); | ||
|
||
if (auto non_exclusive_resizes = scheduler_tools::getNonExclusiveResizeInfo( | ||
resize_based_tensor_ops, id_model.idGraph(IdMappingMode::EXACT)); | ||
!non_exclusive_resizes.empty()) { | ||
std::stringstream msg; | ||
msg << "Propagation of resizes would affect fusion outputs."; | ||
for (const auto& [tv, resize_ids] : non_exclusive_resizes) { | ||
msg << " Resize input tv: " << tv->toString() | ||
<< ", resize input ID groups: " << nvfuser::toString(resize_ids); | ||
} | ||
scheduler_debug_utils::canScheduleRejectReason(schedulerType(), msg.str()); | ||
return false; | ||
} | ||
|
||
// Slicing of or to a broadcast ID is not allowed yet. | ||
for (auto tensor_op : resize_based_tensor_ops) { | ||
TensorView* out_tv = tensor_op->output(0)->as<TensorView>(); | ||
|
@@ -133,6 +120,30 @@ bool ResizeScheduler::canScheduleCompileTime(Fusion* fusion) { | |
return false; | ||
} | ||
|
||
for (auto out_tv : ir_utils::filterByType<TensorView>(fusion->outputs())) { | ||
if (out_tv == ref_tv) { | ||
continue; | ||
} | ||
auto exprs = ValGraphBFS::getExprGroupsBetween( | ||
broadcast_graph, | ||
broadcast_graph.toGroups(ref_tv->getLogicalDomain()), | ||
broadcast_graph.toGroups(out_tv->getLogicalDomain()), | ||
/*require_all_to_visited=*/false) | ||
.first; | ||
for (const auto& [expr_g, dir] : exprs) { | ||
if (expr_g->front()->isA<Resize>()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is my understanding correct that we will segment if we have one resized output and one not resized?
In that case we will still have a resize between the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it is definitely possible, but it's a non-trivial problem to pick a reference tensor. If the dimension of |
||
std::stringstream msg; | ||
msg << "Resize between reference and output not allowed."; | ||
msg << " Reference: " << ref_tv->toString() | ||
<< ". Output: " << out_tv->toString() | ||
<< ". Resize: " << expr_g->front()->toString(); | ||
scheduler_debug_utils::canScheduleRejectReason( | ||
schedulerType(), msg.str()); | ||
return false; | ||
} | ||
} | ||
} | ||
|
||
// Disable the scheduler if there's a squeeze op. The loop option | ||
// may also need to be enabled in that case, but that option is not | ||
// turned on automatically yet. | ||
|
@@ -163,6 +174,27 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) { | |
scheduler_utils::cacheInputs(fusion, true); | ||
scheduler_utils::cacheAndForkOutputs(fusion, true); | ||
|
||
auto resize_based_tensor_ops = ir_utils::getOpsOfType<SliceOp, PadOp>(fusion); | ||
|
||
IdModel id_model(fusion, /*build_graphs=*/false); | ||
const auto& exact_graph = id_model.buildExactGraph(); | ||
|
||
// Replicate resize inputs if necessary to avoid conflicting propagations | ||
for (const auto& [out_tv, exlusivity_info] : | ||
scheduler_tools::getNonExclusiveResizeInfo( | ||
resize_based_tensor_ops, exact_graph)) { | ||
auto resize_based_op = out_tv->definition(); | ||
auto inp_tv = resize_based_op->input(0)->as<TensorView>(); | ||
// Since cacheInput may skip caching if an input is used by | ||
// slice/pad, inp_tv may be a fusion input, in which case it is | ||
// not necessary to recompute the tensor. | ||
if (inp_tv->isFusionInput()) { | ||
continue; | ||
} | ||
auto inp_tv_copy = RecomputeTv::recompute(inp_tv); | ||
ir_utils::replaceValInExprInputs(resize_based_op, inp_tv, inp_tv_copy); | ||
} | ||
|
||
for (auto expr : fusion->exprs()) { | ||
if (!expr->isOneOf<SliceOp, PadOp>()) { | ||
continue; | ||
|
@@ -186,9 +218,14 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) { | |
ref_tv->axis(-1)->parallelize(ParallelType::TIDx); | ||
ref_tv->axis(-2)->parallelize(ParallelType::BIDx); | ||
|
||
// Propagate the reference to the other tensors | ||
// Propagate the reference to the other tensors. Note that the | ||
// update flag is enabled so to workaround the resize propagation | ||
// issue. This may not work if there's a tensor that is reshaped | ||
// from the reference tensor, but that should not be the case as the | ||
// reference is picked by the same routine used for the pointwise | ||
// scheduler. | ||
scheduler_tools::scheduleLoopDomainsLike( | ||
fusion->allTvs(), ref_tv->getLoopDomain()); | ||
fusion->allTvs(), ref_tv->getLoopDomain(), true); | ||
naoyam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
inlineMost(); | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check is needed since the non-exclusivity check is dropped. It was redundant before.