Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
naoyam committed Dec 10, 2024
1 parent d0d9143 commit 8b8679a
Showing 1 changed file with 8 additions and 32 deletions.
40 changes: 8 additions & 32 deletions csrc/scheduler/resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
namespace nvfuser {

bool ResizeScheduler::canScheduleCompileTime(Fusion* fusion) {
std::cerr << "ResizeScheduler::canScheduleCompileTime\n";

if (!ir_utils::hasOpsOfType<SliceOp, PadOp>(fusion)) {
scheduler_debug_utils::canScheduleRejectReason(
schedulerType(), "No resize op to schedule");
Expand All @@ -51,6 +49,10 @@ bool ResizeScheduler::canScheduleCompileTime(Fusion* fusion) {
return false;
}

// For now, the resize scheduler is only allowed for a limited set
// of fusion patterns. The restrictions are planned to be
// incrementally relaxed.

// For now, only a single resize op is allowed to exist.
auto resize_based_tensor_ops = ir_utils::getOpsOfType<SliceOp, PadOp>(fusion);
if (resize_based_tensor_ops.size() != 1) {
Expand Down Expand Up @@ -98,61 +100,35 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) {

FusionGuard fg(fusion);

std::cerr << "ResizeScheduler::schedule\n";

scheduler_utils::clearMemorySpace(fusion);

scheduler_utils::cacheInputs(fusion, true);

const auto exprs = fusion->exprs();
for (auto expr : exprs) {
for (auto expr : fusion->exprs()) {
if (!expr->isOneOf<SliceOp, PadOp>()) {
continue;
}

std::cerr << "Propagating resize tensor op: " << expr->toString();
scheduler_tools::propagateResizeToInputs(expr);
}

// Just use the pointwise version for now
auto ref_tv = pointwise_utils::getReferenceTensor(fusion);

std::cerr << "Reference: " << ref_tv->toString() << "\n";

// Just simple scheduling for now.
// TODO: Do something smarter. Can just use the pointwise scheduler?
ref_tv->flatten();
ref_tv->split(0, 128);
ref_tv->split(0, 1 << 14);
ref_tv->axis(-1)->parallelize(ParallelType::TIDx);
ref_tv->axis(-2)->parallelize(ParallelType::BIDx);

std::cerr << "Scheduled reference:\n";

// Propagate the reference to the other tensors
scheduler_tools::scheduleLoopDomainsLike(
fusion->allTvs(), ref_tv->getLoopDomain());

{
std::cerr << "All done\n";
for (auto tv : fusion->allTvs()) {
std::cerr << "Final scheduled T" << tv->name() << "\n";
if (tv->hasRoot()) {
std::cerr << "\tRoot: " << toDelimitedString(tv->getRootDomain())
<< "\n";
}
std::cerr << "\tLogical: " << toDelimitedString(tv->getLogicalDomain())
<< "\n";
std::cerr << "\tLoop: " << toDelimitedString(tv->getLoopDomain()) << "\n";
std::cerr << "\tAdditional ids: "
<< toDelimitedString(tv->domain()->additionalIDs()) << "\n";
for (auto expr : tv->domain()->allExprs()) {
std::cerr << expr->toString(4);
}
}
}

inlineMost();

fusion->printMath();

return;
}

Expand Down

0 comments on commit 8b8679a

Please sign in to comment.