diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index 840ec90e69a..e1bd220d047 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -403,19 +403,11 @@ std::unique_ptr getPointwiseHeuristics( return params; } -// Return reference tensor view. -TensorView* getReferenceTensorView(Fusion* fusion) { - FusionGuard fg(fusion); - pointwise_utils::DomainMap domain_map(fusion); - auto reference_tv = domain_map.findReferenceTensorView(); - return reference_tv; -} - //! Utility for canSchedule interface to check if this fusion has //! a fully broadcasted reference tensor, which is necessary for //! the pointwise scheduler. bool hasReferenceTensorView(Fusion* fusion) { - return getReferenceTensorView(fusion) != nullptr; + return pointwise_utils::getReferenceTensor(fusion) != nullptr; } bool PointWiseScheduler::canScheduleCompileTime(Fusion* fusion) { @@ -512,7 +504,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { return; } - TensorView* reference_tv = getReferenceTensorView(fusion); + TensorView* reference_tv = pointwise_utils::getReferenceTensor(fusion); NVF_ERROR( reference_tv != nullptr, diff --git a/csrc/scheduler/pointwise_utils.h b/csrc/scheduler/pointwise_utils.h index 5812b9de19c..4b3cf6d60fc 100644 --- a/csrc/scheduler/pointwise_utils.h +++ b/csrc/scheduler/pointwise_utils.h @@ -61,5 +61,13 @@ class DomainMap : public scheduler_tools::DomainMap { } }; +// Return reference tensor view. +inline TensorView* getReferenceTensor(Fusion* fusion) { + FusionGuard fg(fusion); + DomainMap domain_map(fusion); + auto reference_tv = domain_map.findReferenceTensorView(); + return reference_tv; +} + } // namespace pointwise_utils } // namespace nvfuser diff --git a/csrc/scheduler/resize.cpp b/csrc/scheduler/resize.cpp index 08f2aa04afd..0c55e6c8b8f 100644 --- a/csrc/scheduler/resize.cpp +++ b/csrc/scheduler/resize.cpp @@ -93,14 +93,6 @@ std::unique_ptr ResizeScheduler::computeHeuristics( return params; } -namespace { - -TensorView* getReferenceTensor(Fusion* fusion) { - return nullptr; -} - -} // namespace - void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) { FUSER_PERF_SCOPE("ResizeScheduler::schedule"); @@ -126,7 +118,8 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) { scheduler_tools::propagateResizeToInputs(expr); } - auto ref_tv = getReferenceTensor(fusion); + // Just use the pointwise version for now + auto ref_tv = pointwise_utils::getReferenceTensor(fusion); std::cerr << "Reference: " << ref_tv->toString() << "\n";