From 55b84997e3caaf08bfc89a4993d98c9580b1e16d Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 11 Dec 2024 09:22:01 -0800 Subject: [PATCH] WIP --- csrc/scheduler/resize.cpp | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/csrc/scheduler/resize.cpp b/csrc/scheduler/resize.cpp index 1fb190ab1d9..52565fb7e69 100644 --- a/csrc/scheduler/resize.cpp +++ b/csrc/scheduler/resize.cpp @@ -137,11 +137,11 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) { FUSER_PERF_SCOPE("ResizeScheduler::schedule"); FusionGuard fg(fusion); - +#if 0 fusion->printMath(); fusion->print(); std::cout << std::endl; - +#endif scheduler_utils::clearMemorySpace(fusion); scheduler_utils::cacheInputs(fusion, true); @@ -153,7 +153,7 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) { scheduler_tools::propagateResizeToInputs(expr); } - +#if 0 fusion->print(); std::cout << std::endl; @@ -173,14 +173,28 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) { std::cerr << expr->toString(4); } } - +#endif auto ref_tv = getReferenceTensor(fusion); // Just simple scheduling for now. // TODO: Do something smarter. Can just use the pointwise scheduler? - ref_tv->flatten(); - ref_tv->split(0, 128); - ref_tv->split(0, 1 << 14); + + // Make sure the DID ID located at the outermost position + int64_t reorder_pos = 0; + std::unordered_map old2new; + for (const auto i : c10::irange(ref_tv->getLoopDomain().size())) { + if (isParallelTypeDeviceDim(ref_tv->axis((int64_t)i)->getParallelType())) { + old2new.emplace((int64_t)i, reorder_pos); + ++reorder_pos; + } + } + ref_tv->reorder(old2new); + + // Schedule only the remaining IDs + const auto outermost_pos = (int64_t)old2new.size(); + ref_tv->flatten(outermost_pos); + ref_tv->split(outermost_pos, 128); + ref_tv->split(outermost_pos, 1 << 14); ref_tv->axis(-1)->parallelize(ParallelType::TIDx); ref_tv->axis(-2)->parallelize(ParallelType::BIDx);