From 8b8679a7d3963b6da8272bbff8f56c166aba7519 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 10 Dec 2024 14:14:49 -0800 Subject: [PATCH] cleanup --- csrc/scheduler/resize.cpp | 40 ++++++++------------------------------- 1 file changed, 8 insertions(+), 32 deletions(-) diff --git a/csrc/scheduler/resize.cpp b/csrc/scheduler/resize.cpp index 8b23c22a4f9..3813d4fc33f 100644 --- a/csrc/scheduler/resize.cpp +++ b/csrc/scheduler/resize.cpp @@ -24,8 +24,6 @@ namespace nvfuser { bool ResizeScheduler::canScheduleCompileTime(Fusion* fusion) { - std::cerr << "ResizeScheduler::canScheduleCompileTime\n"; - if (!ir_utils::hasOpsOfType(fusion)) { scheduler_debug_utils::canScheduleRejectReason( schedulerType(), "No resize op to schedule"); @@ -51,6 +49,10 @@ bool ResizeScheduler::canScheduleCompileTime(Fusion* fusion) { return false; } + // For now, the resize scheduler is only allowed for a limited set + // of fusion patterns. The restrictions are planned to be + // incrementally relaxed. + // For now, only a single resize op is allowed to exist. auto resize_based_tensor_ops = ir_utils::getOpsOfType(fusion); if (resize_based_tensor_ops.size() != 1) { @@ -98,61 +100,35 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) { FusionGuard fg(fusion); - std::cerr << "ResizeScheduler::schedule\n"; - scheduler_utils::clearMemorySpace(fusion); scheduler_utils::cacheInputs(fusion, true); - const auto exprs = fusion->exprs(); - for (auto expr : exprs) { + for (auto expr : fusion->exprs()) { if (!expr->isOneOf()) { continue; } - std::cerr << "Propagating resize tensor op: " << expr->toString(); scheduler_tools::propagateResizeToInputs(expr); } // Just use the pointwise version for now auto ref_tv = pointwise_utils::getReferenceTensor(fusion); - std::cerr << "Reference: " << ref_tv->toString() << "\n"; - + // Just simple scheduling for now. + // TODO: Do something smarter. Can just use the pointwise scheduler? ref_tv->flatten(); ref_tv->split(0, 128); ref_tv->split(0, 1 << 14); ref_tv->axis(-1)->parallelize(ParallelType::TIDx); ref_tv->axis(-2)->parallelize(ParallelType::BIDx); - std::cerr << "Scheduled reference:\n"; - + // Propagate the reference to the other tensors scheduler_tools::scheduleLoopDomainsLike( fusion->allTvs(), ref_tv->getLoopDomain()); - { - std::cerr << "All done\n"; - for (auto tv : fusion->allTvs()) { - std::cerr << "Final scheduled T" << tv->name() << "\n"; - if (tv->hasRoot()) { - std::cerr << "\tRoot: " << toDelimitedString(tv->getRootDomain()) - << "\n"; - } - std::cerr << "\tLogical: " << toDelimitedString(tv->getLogicalDomain()) - << "\n"; - std::cerr << "\tLoop: " << toDelimitedString(tv->getLoopDomain()) << "\n"; - std::cerr << "\tAdditional ids: " - << toDelimitedString(tv->domain()->additionalIDs()) << "\n"; - for (auto expr : tv->domain()->allExprs()) { - std::cerr << expr->toString(4); - } - } - } - inlineMost(); - fusion->printMath(); - return; }