Skip to content

Commit

Permalink
fix pattern match
Browse files Browse the repository at this point in the history
  • Loading branch information
naoyam committed Dec 15, 2024
1 parent 7e9413a commit 40dd2c2
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions csrc/scheduler/resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,29 @@ bool ResizeScheduler::canScheduleCompileTime(Fusion* fusion) {
return false;
}

auto all_dep_vals = DependencyCheck::getAllValsBetween(
{fusion->inputs().begin(), fusion->inputs().end()},
{resize_based_tensor_ops.at(0)->output(0)});
for (auto tv : ir_utils::filterByType<TensorView>(all_dep_vals)) {
if (tv->isFusionOutput()) {
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(),
"Dependency to fusion output not allowed: ",
tv->toString());
return false;
}
for (auto consumer_of_tv : ir_utils::consumerTvsOf(tv)) {
if (std::find(all_dep_vals.begin(), all_dep_vals.end(), consumer_of_tv) ==
all_dep_vals.end()) {
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(),
"Resize inputs must be exclusively consumed by resize: ",
consumer_of_tv->toString());
return false;
}
}
}

// Slicing of or to a broadcast ID is not allowed yet.
for (auto tensor_op : resize_based_tensor_ops) {
TensorView* out_tv = tensor_op->output(0)->as<TensorView>();
Expand Down

0 comments on commit 40dd2c2

Please sign in to comment.