Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial resize scheduler #3556

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
11f5dce
Always enable IdModel-based indexing when resize is used
naoyam Dec 11, 2024
05ea88f
Don't run the tests without IdModel
naoyam Dec 11, 2024
0ad9fea
fix
naoyam Dec 11, 2024
4f14988
Allocation ordering fix
naoyam Dec 11, 2024
f3ce2d9
Merge remote-tracking branch 'origin/enable_id_model_for_resize' into…
naoyam Dec 11, 2024
0d35147
rotation + residual
naoyam Dec 9, 2024
7934e63
wip
naoyam Dec 10, 2024
9e71bc5
move DomainMap to its own file
naoyam Dec 10, 2024
57600bd
Use the reference finder of pointwise scheduler
naoyam Dec 10, 2024
25ebe94
use the scheduler in the resize test
naoyam Dec 11, 2024
55b8499
WIP
naoyam Dec 11, 2024
791c85b
Merge remote-tracking branch 'origin/main' into resize_scheduler_init…
naoyam Dec 11, 2024
839f23c
cleanup
naoyam Dec 11, 2024
7a10f02
WAR
naoyam Dec 11, 2024
c80dd91
Fix the failed alias test thanks to @wujingyue
naoyam Dec 11, 2024
9167cf0
cleanup
naoyam Dec 11, 2024
df63df2
cleanup
naoyam Dec 11, 2024
52acb42
cleanup
naoyam Dec 11, 2024
6363298
Merge branch 'main' into resize_scheduler_initial_version
naoyam Dec 13, 2024
ca09b93
Merge branch 'main' into resize_scheduler_initial_version
naoyam Dec 13, 2024
0d0a4d6
PR feedback
naoyam Dec 13, 2024
be3aee9
fix
naoyam Dec 13, 2024
4368e80
Rename DomainMap to PointwiseDomainMap
naoyam Dec 13, 2024
2a6f059
Merge remote-tracking branch 'origin/main' into resize_scheduler_init…
naoyam Dec 13, 2024
91e7d3e
Merge remote-tracking branch 'origin/main' into resize_scheduler_init…
naoyam Dec 13, 2024
96ac0fa
merge fix
naoyam Dec 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,10 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/scheduler/reduction_utils.cpp
${NVFUSER_SRCS_DIR}/scheduler/registry.cpp
${NVFUSER_SRCS_DIR}/scheduler/registry_utils.cpp
${NVFUSER_SRCS_DIR}/scheduler/resize.cpp
${NVFUSER_SRCS_DIR}/scheduler/runtime_info.cpp
${NVFUSER_SRCS_DIR}/scheduler/scheduler_types.cpp
${NVFUSER_SRCS_DIR}/scheduler/tools/domain_map.cpp
${NVFUSER_SRCS_DIR}/scheduler/tools/inlining.cpp
${NVFUSER_SRCS_DIR}/scheduler/tools/loop_domain_scheduler.cpp
${NVFUSER_SRCS_DIR}/scheduler/tools/maxinfo_propagator.cpp
Expand Down
2 changes: 1 addition & 1 deletion csrc/device_lower/pass/expr_sort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1088,7 +1088,7 @@ bool ExprSegmentationSorter::interIterUpdate() {
NVF_ERROR(
!fallback_mode_enabled_,
"Couldn't succcessfully sort out the fusion expressions. ",
"There are remaining connections of the heirarchical segmentation which should have been ",
"There are remaining connections of the hierarchical segmentation which should have been ",
"flattened to a single ordered group, or disjoint ordered groups.\n",
toString());
// We didn't finish, but we haven't tried the fallback, try again with that.
Expand Down
5 changes: 3 additions & 2 deletions csrc/scheduler/compile_time_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <fusion.h>
#include <scheduler/pointwise_utils.h>
#include <scheduler/scheduler_types.h>
#include <scheduler/tools/domain_map.h>
#include <scheduler/utils.h>
#include <scheduler/vectorize_helper.h>

Expand Down Expand Up @@ -54,7 +55,7 @@ enum class CompileTimeEntryType {
//! stores the domain map of a fusion.
class DomainMap {
public:
using DataType = pointwise_utils::DomainMap;
using DataType = scheduler_tools::DomainMap;
static const CompileTimeEntryType EntryType =
CompileTimeEntryType::DOMAIN_MAP;
};
Expand All @@ -63,7 +64,7 @@ class DomainMap {
//! stores the domain map of a fusion, used by transpose scheduler.
class TransposeDomainMap {
public:
using DataType = pointwise_utils::DomainMap;
using DataType = scheduler_tools::DomainMap;
static const CompileTimeEntryType EntryType =
CompileTimeEntryType::TRANSPOSE_DOMAIN_MAP;
};
Expand Down
53 changes: 8 additions & 45 deletions csrc/scheduler/pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,37 +29,6 @@ namespace {
// Unused at the moment, commenting for clang tidy
constexpr int64_t kThreadX = 128;

class DomainMap : public pointwise_utils::DomainMap {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is moved to pointwise_utils.h so that it can be also used from the resize scheduler

public:
using pointwise_utils::DomainMap::DomainMap;

// The pointwise scheduler heuristics requires a minimum number of axes.
// The output reference tensor should respect this requirement.
TensorView* findReferenceTensorView(int64_t minimum_num_axes = 0) const {
TensorView* result = nullptr;
int64_t max_dims = -1;
for (auto output_tv :
ir_utils::filterByType<TensorView>(fusion_->outputs())) {
if (isValidReference(output_tv) &&
hasMinimumSize(output_tv, minimum_num_axes) &&
!output_tv->isFusionInput()) {
int64_t n_dims = pointwise_utils::nRootDims(output_tv);
if (n_dims > max_dims) {
result = output_tv;
max_dims = n_dims;
}
}
}
return result;
}

private:
bool hasMinimumSize(TensorView* tv, int64_t num_axes) const {
NVF_ERROR(tv != nullptr);
return (num_axes == 0 || (int64_t)tv->getLogicalDomain().size() > num_axes);
}
};

} // namespace

std::unique_ptr<PointwiseParams> getPointwiseHeuristics(
Expand All @@ -79,14 +48,16 @@ std::unique_ptr<PointwiseParams> getPointwiseHeuristics(

auto domain_map_entry =
HeuristicDataCacheEntry<HeuristicCompileTime::DomainMap>(
data_cache,
[fusion]() { return std::make_unique<DomainMap>(fusion); });
const auto& domain_map = dynamic_cast<DomainMap&>(domain_map_entry.get());
data_cache, [fusion]() {
return std::make_unique<pointwise_utils::DomainMap>(fusion);
});
const auto& domain_map =
dynamic_cast<pointwise_utils::DomainMap&>(domain_map_entry.get());

auto largest_out_entry =
HeuristicDataCacheEntry<HeuristicCompileTime::ReferenceTensors>(
data_cache, [&domain_map]() {
std::vector<TensorView*> data{domain_map.findReferenceTensorView()};
std::vector<TensorView*> data{domain_map.findReferenceTensor()};
return std::make_unique<std::vector<TensorView*>>(std::move(data));
});
TensorView* largest_out = largest_out_entry.get()[0];
Expand Down Expand Up @@ -432,19 +403,11 @@ std::unique_ptr<PointwiseParams> getPointwiseHeuristics(
return params;
}

// Return reference tensor view.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just moved to pointwise_utils

TensorView* getReferenceTensorView(Fusion* fusion) {
FusionGuard fg(fusion);
DomainMap domain_map(fusion);
auto reference_tv = domain_map.findReferenceTensorView();
return reference_tv;
}

//! Utility for canSchedule interface to check if this fusion has
//! a fully broadcasted reference tensor, which is necessary for
//! the pointwise scheduler.
bool hasReferenceTensorView(Fusion* fusion) {
return getReferenceTensorView(fusion) != nullptr;
return pointwise_utils::getReferenceTensor(fusion) != nullptr;
}

bool PointWiseScheduler::canScheduleCompileTime(Fusion* fusion) {
Expand Down Expand Up @@ -541,7 +504,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) {
return;
}

TensorView* reference_tv = getReferenceTensorView(fusion);
TensorView* reference_tv = pointwise_utils::getReferenceTensor(fusion);

NVF_ERROR(
reference_tv != nullptr,
Expand Down
Loading
Loading