Skip to content

Commit

Permalink
Resolve conflicts by recomputation
Browse files Browse the repository at this point in the history
  • Loading branch information
naoyam committed Dec 20, 2024
1 parent ac5a1bc commit 5d1d07e
Show file tree
Hide file tree
Showing 4 changed files with 243 additions and 66 deletions.
61 changes: 46 additions & 15 deletions csrc/scheduler/resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,6 @@ bool ResizeScheduler::canScheduleCompileTime(Fusion* fusion) {

auto resize_based_tensor_ops = ir_utils::getOpsOfType<SliceOp, PadOp>(fusion);

if (auto non_exclusive_resizes = scheduler_tools::getNonExclusiveResizeInfo(
resize_based_tensor_ops, id_model.idGraph(IdMappingMode::EXACT));
!non_exclusive_resizes.empty()) {
std::stringstream msg;
msg << "Propagation of resizes would affect fusion outputs.";
for (const auto& [tv, resize_ids] : non_exclusive_resizes) {
msg << " Resize input tv: " << tv->toString()
<< ", resize input ID groups: " << nvfuser::toString(resize_ids);
}
scheduler_debug_utils::canScheduleRejectReason(schedulerType(), msg.str());
return false;
}

// Slicing of or to a broadcast ID is not allowed yet.
for (auto tensor_op : resize_based_tensor_ops) {
TensorView* out_tv = tensor_op->output(0)->as<TensorView>();
Expand Down Expand Up @@ -133,6 +120,30 @@ bool ResizeScheduler::canScheduleCompileTime(Fusion* fusion) {
return false;
}

for (auto out_tv : ir_utils::filterByType<TensorView>(fusion->outputs())) {
if (out_tv == ref_tv) {
continue;
}
auto exprs = ValGraphBFS::getExprGroupsBetween(
broadcast_graph,
broadcast_graph.toGroups(ref_tv->getLogicalDomain()),
broadcast_graph.toGroups(out_tv->getLogicalDomain()),
/*require_all_to_visited=*/false)
.first;
for (const auto& [expr_g, dir] : exprs) {
if (expr_g->front()->isA<Resize>()) {
std::stringstream msg;
msg << "Resize between reference and output not allowed.";
msg << " Reference: " << ref_tv->toString()
<< ". Output: " << out_tv->toString()
<< ". Resize: " << expr_g->front()->toString();
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(), msg.str());
return false;
}
}
}

// Disable the scheduler if there's a squeeze op. The loop option
// may also need to be enabled in that case, but that option is not
// turned on automatically yet.
Expand Down Expand Up @@ -163,6 +174,21 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) {
scheduler_utils::cacheInputs(fusion, true);
scheduler_utils::cacheAndForkOutputs(fusion, true);

auto resize_based_tensor_ops = ir_utils::getOpsOfType<SliceOp, PadOp>(fusion);

IdModel id_model(fusion, /*build_graphs=*/false);
const auto& exact_graph = id_model.buildExactGraph();

// Replicate resize inputs if necessary to avoid conflicting propagations
for (const auto& [out_tv, exlusivity_info] :
scheduler_tools::getNonExclusiveResizeInfo(
resize_based_tensor_ops, exact_graph)) {
auto resize_based_op = out_tv->definition();
auto inp_tv = resize_based_op->input(0)->as<TensorView>();
auto inp_tv_copy = RecomputeTv::recompute(inp_tv);
ir_utils::replaceValInExprInputs(resize_based_op, inp_tv, inp_tv_copy);
}

for (auto expr : fusion->exprs()) {
if (!expr->isOneOf<SliceOp, PadOp>()) {
continue;
Expand All @@ -186,9 +212,14 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) {
ref_tv->axis(-1)->parallelize(ParallelType::TIDx);
ref_tv->axis(-2)->parallelize(ParallelType::BIDx);

// Propagate the reference to the other tensors
// Propagate the reference to the other tensors. Note that the
// update flag is enabled so to workaround the resize propagation
// issue. This may not work if there's a tensor that is reshaped
// from the reference tensor, but that should not be the case as the
// reference is picked by the same routine used for the pointwise
// scheduler.
scheduler_tools::scheduleLoopDomainsLike(
fusion->allTvs(), ref_tv->getLoopDomain());
fusion->allTvs(), ref_tv->getLoopDomain(), true);

inlineMost();

Expand Down
21 changes: 17 additions & 4 deletions csrc/scheduler/tools/resize_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,14 @@ void propagateResizeToInputs(Expr* resize_tensor_op) {
}
}

std::unordered_map<TensorView*, ValGroups> getNonExclusiveResizeInfo(
std::unordered_map<TensorView*, ResizeExclusivityInfo> getNonExclusiveResizeInfo(
const std::vector<Expr*>& ordered_resize_tensor_ops,
const ValGraph& exact_graph) {
const ValGraph& exact_graph,
bool ignore_fusion_inputs) {
NVF_ERROR(!ordered_resize_tensor_ops.empty());
Fusion* fusion = ordered_resize_tensor_ops[0]->fusion();

std::unordered_map<TensorView*, ValGroups> non_exclusive_resizes;
std::unordered_map<TensorView*, ResizeExclusivityInfo> non_exclusive_resizes;

std::unordered_set<Val*> inputs{
fusion->inputs().begin(), fusion->inputs().end()};
Expand All @@ -95,6 +96,8 @@ std::unordered_map<TensorView*, ValGroups> getNonExclusiveResizeInfo(
auto inp_tv = dynamic_cast<TensorView*>(resize_tensor_op->inputs().at(0));
auto out_tv = dynamic_cast<TensorView*>(resize_tensor_op->outputs().at(0));

ResizeExclusivityInfo info;

ValGroups resize_inp_ids = get_root_to_logical_resizes(out_tv);
NVF_ERROR(!resize_inp_ids.empty());

Expand All @@ -107,6 +110,10 @@ std::unordered_map<TensorView*, ValGroups> getNonExclusiveResizeInfo(
// visible changes through the tensor, the resize is considered
// non-exclusive.
for (auto dep_tv : ir_utils::filterByType<TensorView>(dep_vals)) {
if (ignore_fusion_inputs && dep_tv->isFusionInput()) {
continue;
}

bool maybe_non_exclusive = false;

if (dep_tv->isFusionOutput()) {
Expand Down Expand Up @@ -159,10 +166,16 @@ std::unordered_map<TensorView*, ValGroups> getNonExclusiveResizeInfo(
}

// This resize input ID is not exclusively used
non_exclusive_resizes[inp_tv].pushBack(resize_inp_id);
// non_exclusive_resizes[inp_tv].first.pushBack(resize_inp_id);
info.shared_tvs.push_back(dep_tv);
info.resized_ids.pushBack(resize_inp_id);
}
}

if (!info.shared_tvs.empty()) {
NVF_ERROR(non_exclusive_resizes.emplace(out_tv, info).second);
}

// Analysis of exclusiveness until in_tv is done. Following
// resize-based tensor ops do not need to check the same section
// of the fusion and can start from out_tv.
Expand Down
19 changes: 17 additions & 2 deletions csrc/scheduler/tools/resize_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,24 @@ void propagateResizeToInputs(Expr* resize_op);
// The function returns a map from tensors that are input to
// non-exclusive ops to their resize input ID groups. This map will be
// used to resolve the non-exclusiveness by replication.
std::unordered_map<TensorView*, ValGroups> getNonExclusiveResizeInfo(
struct ResizeExclusivityInfo {
std::vector<TensorView*> shared_tvs;
// std::unordered_map<TensorView*, ValGroups> resized_ids;
ValGroups resized_ids;

bool operator==(const ResizeExclusivityInfo& other) const {
return shared_tvs == other.shared_tvs && resized_ids == other.resized_ids;
}

bool operator!=(const ResizeExclusivityInfo& other) const {
return !(*this == other);
}
};

std::unordered_map<TensorView*, ResizeExclusivityInfo> getNonExclusiveResizeInfo(
const std::vector<Expr*>& ordered_resize_tensor_ops,
const ValGraph& exact_graph);
const ValGraph& exact_graph,
bool ignore_fusion_inputs = false);

} // namespace scheduler_tools
} // namespace nvfuser
Loading

0 comments on commit 5d1d07e

Please sign in to comment.