Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial resize scheduler #3556

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
11f5dce
Always enable IdModel-based indexing when resize is used
naoyam Dec 11, 2024
05ea88f
Don't run the tests without IdModel
naoyam Dec 11, 2024
0ad9fea
fix
naoyam Dec 11, 2024
4f14988
Allocation ordering fix
naoyam Dec 11, 2024
f3ce2d9
Merge remote-tracking branch 'origin/enable_id_model_for_resize' into…
naoyam Dec 11, 2024
0d35147
rotation + residual
naoyam Dec 9, 2024
7934e63
wip
naoyam Dec 10, 2024
9e71bc5
move DomainMap to its own file
naoyam Dec 10, 2024
57600bd
Use the reference finder of pointwise scheduler
naoyam Dec 10, 2024
25ebe94
use the scheduler in the resize test
naoyam Dec 11, 2024
55b8499
WIP
naoyam Dec 11, 2024
791c85b
Merge remote-tracking branch 'origin/main' into resize_scheduler_init…
naoyam Dec 11, 2024
839f23c
cleanup
naoyam Dec 11, 2024
7a10f02
WAR
naoyam Dec 11, 2024
c80dd91
Fix the failed alias test thanks to @wujingyue
naoyam Dec 11, 2024
9167cf0
cleanup
naoyam Dec 11, 2024
df63df2
cleanup
naoyam Dec 11, 2024
52acb42
cleanup
naoyam Dec 11, 2024
6363298
Merge branch 'main' into resize_scheduler_initial_version
naoyam Dec 13, 2024
ca09b93
Merge branch 'main' into resize_scheduler_initial_version
naoyam Dec 13, 2024
0d0a4d6
PR feedback
naoyam Dec 13, 2024
be3aee9
fix
naoyam Dec 13, 2024
4368e80
Rename DomainMap to PointwiseDomainMap
naoyam Dec 13, 2024
2a6f059
Merge remote-tracking branch 'origin/main' into resize_scheduler_init…
naoyam Dec 13, 2024
91e7d3e
Merge remote-tracking branch 'origin/main' into resize_scheduler_init…
naoyam Dec 13, 2024
96ac0fa
merge fix
naoyam Dec 13, 2024
7e9413a
python frontend fix
naoyam Dec 15, 2024
40dd2c2
fix pattern match
naoyam Dec 15, 2024
8056cfa
fix
naoyam Dec 15, 2024
b9415e1
test fix
naoyam Dec 15, 2024
aebfd51
Disable segmentation
naoyam Dec 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,14 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/scheduler/normalization_outer.cpp
${NVFUSER_SRCS_DIR}/scheduler/normalization_utils.cpp
${NVFUSER_SRCS_DIR}/scheduler/pointwise.cpp
${NVFUSER_SRCS_DIR}/scheduler/pointwise_utils.cpp
${NVFUSER_SRCS_DIR}/scheduler/reduction.cpp
${NVFUSER_SRCS_DIR}/scheduler/reduction_utils.cpp
${NVFUSER_SRCS_DIR}/scheduler/registry.cpp
${NVFUSER_SRCS_DIR}/scheduler/registry_utils.cpp
${NVFUSER_SRCS_DIR}/scheduler/resize.cpp
${NVFUSER_SRCS_DIR}/scheduler/runtime_info.cpp
${NVFUSER_SRCS_DIR}/scheduler/scheduler_types.cpp
${NVFUSER_SRCS_DIR}/scheduler/tools/domain_map.cpp
${NVFUSER_SRCS_DIR}/scheduler/tools/inlining.cpp
${NVFUSER_SRCS_DIR}/scheduler/tools/loop_domain_scheduler.cpp
${NVFUSER_SRCS_DIR}/scheduler/tools/maxinfo_propagator.cpp
Expand Down
5 changes: 3 additions & 2 deletions csrc/scheduler/compile_time_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <fusion.h>
#include <scheduler/pointwise_utils.h>
#include <scheduler/scheduler_types.h>
#include <scheduler/tools/domain_map.h>
#include <scheduler/utils.h>
#include <scheduler/vectorize_helper.h>

Expand Down Expand Up @@ -54,7 +55,7 @@ enum class CompileTimeEntryType {
//! stores the domain map of a fusion.
class DomainMap {
public:
using DataType = pointwise_utils::DomainMap;
using DataType = scheduler_tools::DomainMap;
static const CompileTimeEntryType EntryType =
CompileTimeEntryType::DOMAIN_MAP;
};
Expand All @@ -63,7 +64,7 @@ class DomainMap {
//! stores the domain map of a fusion, used by transpose scheduler.
class TransposeDomainMap {
public:
using DataType = pointwise_utils::DomainMap;
using DataType = scheduler_tools::DomainMap;
static const CompileTimeEntryType EntryType =
CompileTimeEntryType::TRANSPOSE_DOMAIN_MAP;
};
Expand Down
51 changes: 7 additions & 44 deletions csrc/scheduler/pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,37 +29,6 @@ namespace {
// Unused at the moment, commenting for clang tidy
constexpr int64_t kThreadX = 128;

class DomainMap : public pointwise_utils::DomainMap {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is moved to pointwise_utils.h so that it can be also used from the resize scheduler

public:
using pointwise_utils::DomainMap::DomainMap;

// The pointwise scheduler heuristics requires a minimum number of axes.
// The output reference tensor should respect this requirement.
TensorView* findReferenceTensorView(int64_t minimum_num_axes = 0) const {
TensorView* result = nullptr;
int64_t max_dims = -1;
for (auto output_tv :
ir_utils::filterByType<TensorView>(fusion_->outputs())) {
if (isValidReference(output_tv) &&
hasMinimumSize(output_tv, minimum_num_axes) &&
!output_tv->isFusionInput()) {
int64_t n_dims = pointwise_utils::nRootDims(output_tv);
if (n_dims > max_dims) {
result = output_tv;
max_dims = n_dims;
}
}
}
return result;
}

private:
bool hasMinimumSize(TensorView* tv, int64_t num_axes) const {
NVF_ERROR(tv != nullptr);
return (num_axes == 0 || (int64_t)tv->getLogicalDomain().size() > num_axes);
}
};

} // namespace

std::unique_ptr<PointwiseParams> getPointwiseHeuristics(
Expand All @@ -79,9 +48,11 @@ std::unique_ptr<PointwiseParams> getPointwiseHeuristics(

auto domain_map_entry =
HeuristicDataCacheEntry<HeuristicCompileTime::DomainMap>(
data_cache,
[fusion]() { return std::make_unique<DomainMap>(fusion); });
const auto& domain_map = dynamic_cast<DomainMap&>(domain_map_entry.get());
data_cache, [fusion]() {
return std::make_unique<pointwise_utils::DomainMap>(fusion);
});
const auto& domain_map =
dynamic_cast<pointwise_utils::DomainMap&>(domain_map_entry.get());

auto largest_out_entry =
HeuristicDataCacheEntry<HeuristicCompileTime::ReferenceTensors>(
Expand Down Expand Up @@ -432,19 +403,11 @@ std::unique_ptr<PointwiseParams> getPointwiseHeuristics(
return params;
}

// Return reference tensor view.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just moved to pointwise_utils

TensorView* getReferenceTensorView(Fusion* fusion) {
FusionGuard fg(fusion);
DomainMap domain_map(fusion);
auto reference_tv = domain_map.findReferenceTensorView();
return reference_tv;
}

//! Utility for canSchedule interface to check if this fusion has
//! a fully broadcasted reference tensor, which is necessary for
//! the pointwise scheduler.
bool hasReferenceTensorView(Fusion* fusion) {
return getReferenceTensorView(fusion) != nullptr;
return pointwise_utils::getReferenceTensor(fusion) != nullptr;
}

bool PointWiseScheduler::canScheduleCompileTime(Fusion* fusion) {
Expand Down Expand Up @@ -541,7 +504,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) {
return;
}

TensorView* reference_tv = getReferenceTensorView(fusion);
TensorView* reference_tv = pointwise_utils::getReferenceTensor(fusion);

NVF_ERROR(
reference_tv != nullptr,
Expand Down
85 changes: 40 additions & 45 deletions csrc/scheduler/pointwise_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,56 +11,12 @@
#include <exceptions.h>
#include <ir/all_nodes.h>
#include <ir/utils.h>
#include <scheduler/tools/domain_map.h>
#include <scheduler/utils.h>

namespace nvfuser {
namespace pointwise_utils {

// DomainMap uses the ComputeAtMap to find a reference TensorView
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is moved to scheduler/tools/domain_map.h

// that maps to all IterDomains in the fusion.
class DomainMap {
public:
DomainMap(Fusion* fusion);
virtual ~DomainMap() = default;

const ComputeAtMap& getComputeAtMap() const {
return ca_map_;
}

// Determine if a TensorView is a valid reference tensor for this fusion.
// The reference tensor must map to all the iterDomains in each input.
bool isValidReference(TensorView* tv) const;

protected:
// Determine if all IterDomains are mapped between input and the given tvs
bool areAllInputIdsMappedTo(TensorView* input_tv, TensorView* output_tv)
const;

virtual IterDomain* getMappedInputConcreteID(
const std::unordered_set<IterDomain*>& in_concrete_ids,
IterDomain* out_id) const;

// Erase input concrete ID if it is mapped to output ID
bool eraseIfMapped(
std::unordered_set<IterDomain*>& in_concrete_ids,
IterDomain* out_id) const;

// Check if in_ids are mapped to ids through any root domain as
// well as indirectly accessed domains with ops like torchGather
void eraseifInputMappedThroughRootDomainAndIndexing(
std::unordered_set<IterDomain*>& in_ids,
const std::vector<IterDomain*>& ids) const;

// Find any id in domain that maps with target id
IterDomain* anyMapped(
const std::vector<IterDomain*>& domain,
IterDomain* target) const;

Fusion* fusion_ = nullptr;
ComputeAtMap ca_map_;
std::vector<TensorView*> tvs_with_rfactor_;
};

// Returns number of non-reduction/non-broadcas/non-device dims in logical
// domain
inline int64_t nRootDims(const TensorView* tv) {
Expand All @@ -74,5 +30,44 @@ inline int64_t nRootDims(const TensorView* tv) {
return tv_n_dims;
}

class DomainMap : public scheduler_tools::DomainMap {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is moved from pointwise.cpp

public:
using scheduler_tools::DomainMap::DomainMap;

// The pointwise scheduler heuristics requires a minimum number of axes.
// The output reference tensor should respect this requirement.
TensorView* findReferenceTensorView(int64_t minimum_num_axes = 0) const {
TensorView* result = nullptr;
int64_t max_dims = -1;
for (auto output_tv :
ir_utils::filterByType<TensorView>(fusion_->outputs())) {
if (isValidReference(output_tv) &&
hasMinimumSize(output_tv, minimum_num_axes) &&
!output_tv->isFusionInput()) {
int64_t n_dims = pointwise_utils::nRootDims(output_tv);
if (n_dims > max_dims) {
result = output_tv;
max_dims = n_dims;
}
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I'd put long implementations like this to cpp for faster compilation. Otherwise, it would have to be compiled for each compilation unit including this header, which seems unnecessary.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, agreed, the long compilation time is painful...

return result;
}

private:
bool hasMinimumSize(TensorView* tv, int64_t num_axes) const {
NVF_ERROR(tv != nullptr);
return (num_axes == 0 || (int64_t)tv->getLogicalDomain().size() > num_axes);
}
};

// Return reference tensor view.
inline TensorView* getReferenceTensor(Fusion* fusion) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved from pointwise.cpp. Also shortened the name a bit (was getReferenceTensorView)

FusionGuard fg(fusion);
DomainMap domain_map(fusion);
auto reference_tv = domain_map.findReferenceTensorView();
return reference_tv;
}

} // namespace pointwise_utils
} // namespace nvfuser
3 changes: 3 additions & 0 deletions csrc/scheduler/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <scheduler/matmul_utils.h>
#include <scheduler/registry.h>
#include <scheduler/registry_utils.h>
#include <scheduler/resize.h>
#include <scheduler/runtime_info.h>
#include <scheduler/utils.h>

Expand Down Expand Up @@ -90,6 +91,8 @@ std::unique_ptr<SchedulerEntry> SchedulerEntry::makeSchedulerInstance(
return std::make_unique<MatmulScheduler>();
case SchedulerType::ExprEval:
return std::make_unique<ExprEvalScheduler>();
case SchedulerType::Resize:
return std::make_unique<ResizeScheduler>();
default:
NVF_THROW("unreachable");
}
Expand Down
Loading
Loading