-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial resize scheduler #3556
base: main
Are you sure you want to change the base?
Initial resize scheduler #3556
Changes from 12 commits
11f5dce
05ea88f
0ad9fea
4f14988
f3ce2d9
0d35147
7934e63
9e71bc5
57600bd
25ebe94
55b8499
791c85b
839f23c
7a10f02
c80dd91
9167cf0
df63df2
52acb42
6363298
ca09b93
0d0a4d6
be3aee9
4368e80
2a6f059
91e7d3e
96ac0fa
7e9413a
40dd2c2
8056cfa
b9415e1
aebfd51
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,37 +29,6 @@ namespace { | |
// Unused at the moment, commenting for clang tidy | ||
constexpr int64_t kThreadX = 128; | ||
|
||
class DomainMap : public pointwise_utils::DomainMap { | ||
public: | ||
using pointwise_utils::DomainMap::DomainMap; | ||
|
||
// The pointwise scheduler heuristics requires a minimum number of axes. | ||
// The output reference tensor should respect this requirement. | ||
TensorView* findReferenceTensorView(int64_t minimum_num_axes = 0) const { | ||
TensorView* result = nullptr; | ||
int64_t max_dims = -1; | ||
for (auto output_tv : | ||
ir_utils::filterByType<TensorView>(fusion_->outputs())) { | ||
if (isValidReference(output_tv) && | ||
hasMinimumSize(output_tv, minimum_num_axes) && | ||
!output_tv->isFusionInput()) { | ||
int64_t n_dims = pointwise_utils::nRootDims(output_tv); | ||
if (n_dims > max_dims) { | ||
result = output_tv; | ||
max_dims = n_dims; | ||
} | ||
} | ||
} | ||
return result; | ||
} | ||
|
||
private: | ||
bool hasMinimumSize(TensorView* tv, int64_t num_axes) const { | ||
NVF_ERROR(tv != nullptr); | ||
return (num_axes == 0 || (int64_t)tv->getLogicalDomain().size() > num_axes); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
std::unique_ptr<PointwiseParams> getPointwiseHeuristics( | ||
|
@@ -79,9 +48,11 @@ std::unique_ptr<PointwiseParams> getPointwiseHeuristics( | |
|
||
auto domain_map_entry = | ||
HeuristicDataCacheEntry<HeuristicCompileTime::DomainMap>( | ||
data_cache, | ||
[fusion]() { return std::make_unique<DomainMap>(fusion); }); | ||
const auto& domain_map = dynamic_cast<DomainMap&>(domain_map_entry.get()); | ||
data_cache, [fusion]() { | ||
return std::make_unique<pointwise_utils::DomainMap>(fusion); | ||
}); | ||
const auto& domain_map = | ||
dynamic_cast<pointwise_utils::DomainMap&>(domain_map_entry.get()); | ||
|
||
auto largest_out_entry = | ||
HeuristicDataCacheEntry<HeuristicCompileTime::ReferenceTensors>( | ||
|
@@ -432,19 +403,11 @@ std::unique_ptr<PointwiseParams> getPointwiseHeuristics( | |
return params; | ||
} | ||
|
||
// Return reference tensor view. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just moved to pointwise_utils |
||
TensorView* getReferenceTensorView(Fusion* fusion) { | ||
FusionGuard fg(fusion); | ||
DomainMap domain_map(fusion); | ||
auto reference_tv = domain_map.findReferenceTensorView(); | ||
return reference_tv; | ||
} | ||
|
||
//! Utility for canSchedule interface to check if this fusion has | ||
//! a fully broadcasted reference tensor, which is necessary for | ||
//! the pointwise scheduler. | ||
bool hasReferenceTensorView(Fusion* fusion) { | ||
return getReferenceTensorView(fusion) != nullptr; | ||
return pointwise_utils::getReferenceTensor(fusion) != nullptr; | ||
} | ||
|
||
bool PointWiseScheduler::canScheduleCompileTime(Fusion* fusion) { | ||
|
@@ -541,7 +504,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { | |
return; | ||
} | ||
|
||
TensorView* reference_tv = getReferenceTensorView(fusion); | ||
TensorView* reference_tv = pointwise_utils::getReferenceTensor(fusion); | ||
|
||
NVF_ERROR( | ||
reference_tv != nullptr, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,56 +11,12 @@ | |
#include <exceptions.h> | ||
#include <ir/all_nodes.h> | ||
#include <ir/utils.h> | ||
#include <scheduler/tools/domain_map.h> | ||
#include <scheduler/utils.h> | ||
|
||
namespace nvfuser { | ||
namespace pointwise_utils { | ||
|
||
// DomainMap uses the ComputeAtMap to find a reference TensorView | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part is moved to |
||
// that maps to all IterDomains in the fusion. | ||
class DomainMap { | ||
public: | ||
DomainMap(Fusion* fusion); | ||
virtual ~DomainMap() = default; | ||
|
||
const ComputeAtMap& getComputeAtMap() const { | ||
return ca_map_; | ||
} | ||
|
||
// Determine if a TensorView is a valid reference tensor for this fusion. | ||
// The reference tensor must map to all the iterDomains in each input. | ||
bool isValidReference(TensorView* tv) const; | ||
|
||
protected: | ||
// Determine if all IterDomains are mapped between input and the given tvs | ||
bool areAllInputIdsMappedTo(TensorView* input_tv, TensorView* output_tv) | ||
const; | ||
|
||
virtual IterDomain* getMappedInputConcreteID( | ||
const std::unordered_set<IterDomain*>& in_concrete_ids, | ||
IterDomain* out_id) const; | ||
|
||
// Erase input concrete ID if it is mapped to output ID | ||
bool eraseIfMapped( | ||
std::unordered_set<IterDomain*>& in_concrete_ids, | ||
IterDomain* out_id) const; | ||
|
||
// Check if in_ids are mapped to ids through any root domain as | ||
// well as indirectly accessed domains with ops like torchGather | ||
void eraseifInputMappedThroughRootDomainAndIndexing( | ||
std::unordered_set<IterDomain*>& in_ids, | ||
const std::vector<IterDomain*>& ids) const; | ||
|
||
// Find any id in domain that maps with target id | ||
IterDomain* anyMapped( | ||
const std::vector<IterDomain*>& domain, | ||
IterDomain* target) const; | ||
|
||
Fusion* fusion_ = nullptr; | ||
ComputeAtMap ca_map_; | ||
std::vector<TensorView*> tvs_with_rfactor_; | ||
}; | ||
|
||
// Returns number of non-reduction/non-broadcas/non-device dims in logical | ||
// domain | ||
inline int64_t nRootDims(const TensorView* tv) { | ||
|
@@ -74,5 +30,44 @@ inline int64_t nRootDims(const TensorView* tv) { | |
return tv_n_dims; | ||
} | ||
|
||
class DomainMap : public scheduler_tools::DomainMap { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is moved from pointwise.cpp |
||
public: | ||
using scheduler_tools::DomainMap::DomainMap; | ||
|
||
// The pointwise scheduler heuristics requires a minimum number of axes. | ||
// The output reference tensor should respect this requirement. | ||
TensorView* findReferenceTensorView(int64_t minimum_num_axes = 0) const { | ||
TensorView* result = nullptr; | ||
int64_t max_dims = -1; | ||
for (auto output_tv : | ||
ir_utils::filterByType<TensorView>(fusion_->outputs())) { | ||
if (isValidReference(output_tv) && | ||
hasMinimumSize(output_tv, minimum_num_axes) && | ||
!output_tv->isFusionInput()) { | ||
int64_t n_dims = pointwise_utils::nRootDims(output_tv); | ||
if (n_dims > max_dims) { | ||
result = output_tv; | ||
max_dims = n_dims; | ||
} | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: I'd put long implementations like this to cpp for faster compilation. Otherwise, it would have to be compiled for each compilation unit including this header, which seems unnecessary. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, agreed, the long compilation time is painful... |
||
return result; | ||
} | ||
|
||
private: | ||
bool hasMinimumSize(TensorView* tv, int64_t num_axes) const { | ||
NVF_ERROR(tv != nullptr); | ||
return (num_axes == 0 || (int64_t)tv->getLogicalDomain().size() > num_axes); | ||
} | ||
}; | ||
|
||
// Return reference tensor view. | ||
inline TensorView* getReferenceTensor(Fusion* fusion) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved from pointwise.cpp. Also shortened the name a bit (was |
||
FusionGuard fg(fusion); | ||
DomainMap domain_map(fusion); | ||
auto reference_tv = domain_map.findReferenceTensorView(); | ||
return reference_tv; | ||
} | ||
|
||
} // namespace pointwise_utils | ||
} // namespace nvfuser |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part is moved to pointwise_utils.h so that it can be also used from the resize scheduler