Skip to content

Commit

Permalink
move DomainMap to its own file
Browse files Browse the repository at this point in the history
  • Loading branch information
naoyam committed Dec 10, 2024
1 parent 0800b6d commit ea00ddf
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 408 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -226,14 +226,14 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/scheduler/normalization_outer.cpp
${NVFUSER_SRCS_DIR}/scheduler/normalization_utils.cpp
${NVFUSER_SRCS_DIR}/scheduler/pointwise.cpp
${NVFUSER_SRCS_DIR}/scheduler/pointwise_utils.cpp
${NVFUSER_SRCS_DIR}/scheduler/reduction.cpp
${NVFUSER_SRCS_DIR}/scheduler/reduction_utils.cpp
${NVFUSER_SRCS_DIR}/scheduler/registry.cpp
${NVFUSER_SRCS_DIR}/scheduler/registry_utils.cpp
${NVFUSER_SRCS_DIR}/scheduler/resize.cpp
${NVFUSER_SRCS_DIR}/scheduler/runtime_info.cpp
${NVFUSER_SRCS_DIR}/scheduler/scheduler_types.cpp
${NVFUSER_SRCS_DIR}/scheduler/tools/domain_map.cpp
${NVFUSER_SRCS_DIR}/scheduler/tools/inlining.cpp
${NVFUSER_SRCS_DIR}/scheduler/tools/loop_domain_scheduler.cpp
${NVFUSER_SRCS_DIR}/scheduler/tools/maxinfo_propagator.cpp
Expand Down
5 changes: 3 additions & 2 deletions csrc/scheduler/compile_time_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <fusion.h>
#include <scheduler/pointwise_utils.h>
#include <scheduler/scheduler_types.h>
#include <scheduler/tools/domain_map.h>
#include <scheduler/utils.h>
#include <scheduler/vectorize_helper.h>

Expand Down Expand Up @@ -54,7 +55,7 @@ enum class CompileTimeEntryType {
//! stores the domain map of a fusion.
class DomainMap {
public:
using DataType = pointwise_utils::DomainMap;
using DataType = scheduler_tools::DomainMap;
static const CompileTimeEntryType EntryType =
CompileTimeEntryType::DOMAIN_MAP;
};
Expand All @@ -63,7 +64,7 @@ class DomainMap {
//! stores the domain map of a fusion, used by transpose scheduler.
class TransposeDomainMap {
public:
using DataType = pointwise_utils::DomainMap;
using DataType = scheduler_tools::DomainMap;
static const CompileTimeEntryType EntryType =
CompileTimeEntryType::TRANSPOSE_DOMAIN_MAP;
};
Expand Down
41 changes: 6 additions & 35 deletions csrc/scheduler/pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,37 +29,6 @@ namespace {
// Unused at the moment, commenting for clang tidy
constexpr int64_t kThreadX = 128;

class DomainMap : public pointwise_utils::DomainMap {
public:
using pointwise_utils::DomainMap::DomainMap;

// The pointwise scheduler heuristics requires a minimum number of axes.
// The output reference tensor should respect this requirement.
TensorView* findReferenceTensorView(int64_t minimum_num_axes = 0) const {
TensorView* result = nullptr;
int64_t max_dims = -1;
for (auto output_tv :
ir_utils::filterByType<TensorView>(fusion_->outputs())) {
if (isValidReference(output_tv) &&
hasMinimumSize(output_tv, minimum_num_axes) &&
!output_tv->isFusionInput()) {
int64_t n_dims = pointwise_utils::nRootDims(output_tv);
if (n_dims > max_dims) {
result = output_tv;
max_dims = n_dims;
}
}
}
return result;
}

private:
bool hasMinimumSize(TensorView* tv, int64_t num_axes) const {
NVF_ERROR(tv != nullptr);
return (num_axes == 0 || (int64_t)tv->getLogicalDomain().size() > num_axes);
}
};

} // namespace

std::unique_ptr<PointwiseParams> getPointwiseHeuristics(
Expand All @@ -79,9 +48,11 @@ std::unique_ptr<PointwiseParams> getPointwiseHeuristics(

auto domain_map_entry =
HeuristicDataCacheEntry<HeuristicCompileTime::DomainMap>(
data_cache,
[fusion]() { return std::make_unique<DomainMap>(fusion); });
const auto& domain_map = dynamic_cast<DomainMap&>(domain_map_entry.get());
data_cache, [fusion]() {
return std::make_unique<pointwise_utils::DomainMap>(fusion);
});
const auto& domain_map =
dynamic_cast<pointwise_utils::DomainMap&>(domain_map_entry.get());

auto largest_out_entry =
HeuristicDataCacheEntry<HeuristicCompileTime::ReferenceTensors>(
Expand Down Expand Up @@ -435,7 +406,7 @@ std::unique_ptr<PointwiseParams> getPointwiseHeuristics(
// Return reference tensor view.
TensorView* getReferenceTensorView(Fusion* fusion) {
FusionGuard fg(fusion);
DomainMap domain_map(fusion);
pointwise_utils::DomainMap domain_map(fusion);
auto reference_tv = domain_map.findReferenceTensorView();
return reference_tv;
}
Expand Down
77 changes: 32 additions & 45 deletions csrc/scheduler/pointwise_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,56 +11,12 @@
#include <exceptions.h>
#include <ir/all_nodes.h>
#include <ir/utils.h>
#include <scheduler/tools/domain_map.h>
#include <scheduler/utils.h>

namespace nvfuser {
namespace pointwise_utils {

// DomainMap uses the ComputeAtMap to find a reference TensorView
// that maps to all IterDomains in the fusion.
class DomainMap {
public:
DomainMap(Fusion* fusion);
virtual ~DomainMap() = default;

const ComputeAtMap& getComputeAtMap() const {
return ca_map_;
}

// Determine if a TensorView is a valid reference tensor for this fusion.
// The reference tensor must map to all the iterDomains in each input.
bool isValidReference(TensorView* tv) const;

protected:
// Determine if all IterDomains are mapped between input and the given tvs
bool areAllInputIdsMappedTo(TensorView* input_tv, TensorView* output_tv)
const;

virtual IterDomain* getMappedInputConcreteID(
const std::unordered_set<IterDomain*>& in_concrete_ids,
IterDomain* out_id) const;

// Erase input concrete ID if it is mapped to output ID
bool eraseIfMapped(
std::unordered_set<IterDomain*>& in_concrete_ids,
IterDomain* out_id) const;

// Check if in_ids are mapped to ids through any root domain as
// well as indirectly accessed domains with ops like torchGather
void eraseifInputMappedThroughRootDomainAndIndexing(
std::unordered_set<IterDomain*>& in_ids,
const std::vector<IterDomain*>& ids) const;

// Find any id in domain that maps with target id
IterDomain* anyMapped(
const std::vector<IterDomain*>& domain,
IterDomain* target) const;

Fusion* fusion_ = nullptr;
ComputeAtMap ca_map_;
std::vector<TensorView*> tvs_with_rfactor_;
};

// Returns number of non-reduction/non-broadcas/non-device dims in logical
// domain
inline int64_t nRootDims(const TensorView* tv) {
Expand All @@ -74,5 +30,36 @@ inline int64_t nRootDims(const TensorView* tv) {
return tv_n_dims;
}

class DomainMap : public scheduler_tools::DomainMap {
public:
using scheduler_tools::DomainMap::DomainMap;

// The pointwise scheduler heuristics requires a minimum number of axes.
// The output reference tensor should respect this requirement.
TensorView* findReferenceTensorView(int64_t minimum_num_axes = 0) const {
TensorView* result = nullptr;
int64_t max_dims = -1;
for (auto output_tv :
ir_utils::filterByType<TensorView>(fusion_->outputs())) {
if (isValidReference(output_tv) &&
hasMinimumSize(output_tv, minimum_num_axes) &&
!output_tv->isFusionInput()) {
int64_t n_dims = pointwise_utils::nRootDims(output_tv);
if (n_dims > max_dims) {
result = output_tv;
max_dims = n_dims;
}
}
}
return result;
}

private:
bool hasMinimumSize(TensorView* tv, int64_t num_axes) const {
NVF_ERROR(tv != nullptr);
return (num_axes == 0 || (int64_t)tv->getLogicalDomain().size() > num_axes);
}
};

} // namespace pointwise_utils
} // namespace nvfuser
Loading

0 comments on commit ea00ddf

Please sign in to comment.