Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
naoyam committed Dec 11, 2024
1 parent 25ebe94 commit 55b8499
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions csrc/scheduler/resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,11 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) {
FUSER_PERF_SCOPE("ResizeScheduler::schedule");

FusionGuard fg(fusion);

#if 0
fusion->printMath();
fusion->print();
std::cout << std::endl;

#endif
scheduler_utils::clearMemorySpace(fusion);

scheduler_utils::cacheInputs(fusion, true);
Expand All @@ -153,7 +153,7 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) {

scheduler_tools::propagateResizeToInputs(expr);
}

#if 0
fusion->print();
std::cout << std::endl;

Expand All @@ -173,14 +173,28 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) {
std::cerr << expr->toString(4);
}
}

#endif
auto ref_tv = getReferenceTensor(fusion);

// Just simple scheduling for now.
// TODO: Do something smarter. Can just use the pointwise scheduler?
ref_tv->flatten();
ref_tv->split(0, 128);
ref_tv->split(0, 1 << 14);

// Make sure the DID ID located at the outermost position
int64_t reorder_pos = 0;
std::unordered_map<int64_t, int64_t> old2new;
for (const auto i : c10::irange(ref_tv->getLoopDomain().size())) {
if (isParallelTypeDeviceDim(ref_tv->axis((int64_t)i)->getParallelType())) {
old2new.emplace((int64_t)i, reorder_pos);
++reorder_pos;
}
}
ref_tv->reorder(old2new);

// Schedule only the remaining IDs
const auto outermost_pos = (int64_t)old2new.size();
ref_tv->flatten(outermost_pos);
ref_tv->split(outermost_pos, 128);
ref_tv->split(outermost_pos, 1 << 14);
ref_tv->axis(-1)->parallelize(ParallelType::TIDx);
ref_tv->axis(-2)->parallelize(ParallelType::BIDx);

Expand Down

0 comments on commit 55b8499

Please sign in to comment.