Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into resize_scheduler_init…
Browse files Browse the repository at this point in the history
…ial_version
  • Loading branch information
naoyam committed Dec 13, 2024
2 parents 4368e80 + 558bf2e commit 2a6f059
Show file tree
Hide file tree
Showing 9 changed files with 284 additions and 12 deletions.
1 change: 0 additions & 1 deletion csrc/device_lower/pass/replace_size.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,6 @@ void replaceSymbolicSizes(Fusion* fusion) {
}
}

// Run mutation on the fusion with the tensor_dim_map
ir_utils::replaceValue(fusion, extent_simplification_map);
}

Expand Down
3 changes: 3 additions & 0 deletions csrc/ir/internal_base_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,9 @@ class TensorDomain : public Val {
// Similar to allIDs but returns all ID expressions.
std::vector<Expr*> allExprs() const;

// Combine allIDs and allExprs
std::vector<Statement*> allStatements() const;

const std::vector<IterDomain*>& maybeAllocation() const {
return hasAllocation() ? allocation_domain_ : logical();
};
Expand Down
28 changes: 28 additions & 0 deletions csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3803,6 +3803,34 @@ std::vector<Expr*> TensorDomain::allExprs() const {
return exprs.vector();
}

std::vector<Statement*> TensorDomain::allStatements() const {
auto all_ids = allIDs();
std::unordered_set<Val*> all_id_set{all_ids.begin(), all_ids.end()};

VectorOfUniqueEntries<Statement*> stmts;
for (auto id : all_ids) {
// Visit definition if available and all inputs are already visited
auto def = id->definition();
if (def != nullptr) {
if (std::all_of(
def->inputs().begin(), def->inputs().end(), [&](Val* inp) {
return all_id_set.find(inp) != all_id_set.end();
})) {
stmts.pushBack(def);
} else {
NVF_ERROR(std::none_of(
def->inputs().begin(), def->inputs().end(), [&](Val* inp) {
return all_id_set.find(inp) != all_id_set.end();
}));
}
}

stmts.pushBack(id);
}

return stmts.vector();
}

Split::Split(
IrBuilderPasskey passkey,
IterDomain* outer,
Expand Down
20 changes: 19 additions & 1 deletion csrc/ir/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ class ValReplacementMutator : private OptOutMutator {
// typically not used by anything else. If we don't grab that count, then it
// would be a tensorview that doesn't get updated extents. Therefore, first
// grab all leaves towards outputs and grab stmts from there.
auto stmts = StmtSort::getStmtsTo(allLeafOuts(fusion), true, true);
auto stmts = StmtSort::getAllStmtsTo(allLeafOuts(fusion), true, true);

// Some fusions, such as standalone rand_like, can have disconnected DAG, so
// we need some mechanism to make sure our replacement set is as complete as
Expand Down Expand Up @@ -501,6 +501,24 @@ class ValReplacementMutator : private OptOutMutator {
std::unordered_set<Val*> outputs;
std::vector<Val*> ordered_outputs;
for (auto expr : exprs) {
// Iter domains and their exprs are taken care by traversing
// from TensorDomain with TensorDomain::allStatements, so they
// don't need to be included here
if (std::any_of(
expr->outputs().begin(), expr->outputs().end(), [](Val* output) {
return output->isA<IterDomain>();
})) {
NVF_ERROR(std::all_of(
expr->outputs().begin(), expr->outputs().end(), [](Val* output) {
return output->isA<IterDomain>();
}));
NVF_ERROR(std::all_of(
expr->inputs().begin(), expr->inputs().end(), [](Val* input) {
return input->isA<IterDomain>();
}));
continue;
}

inputs.insert(expr->inputs().begin(), expr->inputs().end());
outputs.insert(expr->outputs().begin(), expr->outputs().end());
ordered_outputs.insert(
Expand Down
92 changes: 92 additions & 0 deletions csrc/iter_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,98 @@ std::vector<Statement*> StmtSort::getStmtsTo(
return es.stmts;
}

std::vector<Statement*> StmtSort::getAllStmts(
Fusion* fusion,
bool traverse_members,
bool traverse_attributes,
bool traverse_siblings) {
return getAllStmtsTo(
fusion->getTerminatingOutputs(),
traverse_members,
traverse_attributes,
traverse_siblings);
}

std::vector<Statement*> StmtSort::getAllStmtsTo(
const std::vector<Val*>& to,
bool traverse_members,
bool traverse_attributes,
bool traverse_siblings) {
// If members are not traversed, this can just be handled by getStmts
if (!traverse_members) {
return getStmtsTo(
to, traverse_members, traverse_attributes, traverse_siblings);
}

// to is assumed to include only scalar or TensorView
NVF_ERROR(std::all_of(to.begin(), to.end(), [](Val* to_val) {
return to_val->vtype() == ValType::TensorView ||
to_val->vtype() == ValType::Others;
}));

// First, grab all statements without traversing tensor members
auto stmts = getStmtsTo(to, false, traverse_attributes, traverse_siblings);

VectorOfUniqueEntries<Statement*> all_stmts;

// For TensorView, further traversing into its members. Note that
// traverse_members is always true here
for (auto stmt : stmts) {
auto tv = dynamic_cast<TensorView*>(stmt);
if (tv == nullptr) {
all_stmts.pushBack(stmt);
continue;
}

// Instead of using MemberStatements, grab the iter domains and
// their exprs with TensorDomain::allStatements(), which
// internally uses IRBFS.
auto all_id_stmts = tv->domain()->allStatements();

// For iter domains, traverse further their members and then visit
// themselves. For ID exprs, traverse attributes then the expr
// themselves.
for (auto id_stmt : all_id_stmts) {
if (auto id = dynamic_cast<IterDomain*>(id_stmt)) {
auto id_members = MemberStatements::next(id);
// Note that traverse_members is always true at this point
for (auto id_member : id_members) {
for (auto stmt_dep : StmtSort::getStmtsTo(
{id_member->as<Val>()},
/*traverse_members=*/true,
traverse_attributes,
traverse_siblings)) {
all_stmts.pushBack(stmt_dep);
}
}
all_stmts.pushBack(id);
} else {
auto expr = dynamic_cast<Expr*>(id_stmt);
NVF_ERROR(expr != nullptr);
if (traverse_attributes) {
for (auto attr : expr->attributes()) {
for (auto stmt_dep : StmtSort::getStmtsTo(
{attr->as<Val>()},
/*traverse_members=*/true,
traverse_attributes,
traverse_siblings)) {
all_stmts.pushBack(stmt_dep);
}
}
}
all_stmts.pushBack(expr);
}
}

// All depednent vals and exprs for this TensorDomain are in
// all_stmts. Append TensorDomain and then TensorView
all_stmts.pushBack(tv->domain());
all_stmts.pushBack(tv);
}

return all_stmts.vector();
}

std::vector<Statement*> StmtSort::getStmtsBetween(
const std::vector<Val*>& from,
const std::vector<Val*>& to,
Expand Down
22 changes: 22 additions & 0 deletions csrc/iter_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,28 @@ class StmtSort : public IterVisitor {
bool traverse_attributes = false,
bool traverse_siblings = false);

// Returns all ordered Statements of a given fusion. Unlike
// getStmts, for TensorDomain, all of its iter domains and exprs are
// grabbed and returned in a topological order.
NVF_API static std::vector<Statement*> getAllStmts(
Fusion* fusion,
bool traverse_members = false,
bool traverse_attributes = false,
bool traverse_siblings = false);

// Returns ordered Statements required to produce 'to', including
// 'to'. Unlike getStmtsTo, for TensorDomain, all of its iter domains and
// exprs are grabbed and returned in a topological order, regardless of
// `traverse_members`.
//
// The to vals are assumed to be either TensorView or scalar
// Val. This assumption could be removed if desired.
NVF_API static std::vector<Statement*> getAllStmtsTo(
const std::vector<Val*>& to,
bool traverse_members = false,
bool traverse_attributes = false,
bool traverse_siblings = false);

// Returns ordered Statements required to produce from, including from.
// Stops traversal once hiting any Statements in to. Includes Statements in
// to.
Expand Down
53 changes: 43 additions & 10 deletions doc/dev/python_scheduling/autotune_inner_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,10 @@ class FUSION(Enum):
class InnerReductionConfiguration:
# The vectorization factor for inner reduction domain.
vectorize_factor: int = 1
# The unroll factor for the inner reduction domain.
reduction_unroll_factor: int = 1
# The unroll factor for the outer iteration domain.
unroll_factor: int = 1
iteration_unroll_factor: int = 1
# The grid size for the outer iteration domain.
# If grdim > 1, then godim corresponds with y axis of the grid.
# Otherwise, it is the x axis of the grid.
Expand Down Expand Up @@ -121,11 +123,16 @@ def convert_to_inner_reduction_params(self, scheduler_config, reduction_params):
reduction_params.vectorize_inner_reduction = (
scheduler_config.vectorize_factor > 1
)
reduction_params.unroll_factor_top_of_vectorization = (
scheduler_config.reduction_unroll_factor
)

if scheduler_config.bdimy > 1:
reduction_params.block_dim_iter_dom = ParallelType.block_y

reduction_params.unroll_factor_iter_dom = scheduler_config.unroll_factor
reduction_params.unroll_factor_iter_dom = (
scheduler_config.iteration_unroll_factor
)

gdimx = -1
gdimy = -1
Expand Down Expand Up @@ -161,16 +168,27 @@ def convert_to_inner_reduction_params(self, scheduler_config, reduction_params):
def generate_scheduler_configurations(self, input_shape):
threads_per_cta_options = [128, 256, 512, 1024]
vectorization_factor_options = [1, 2, 4, 8]
unroll_factor_options = list(range(1, 11))
reduction_unroll_factor_options = list(range(1, 6))
iteration_unroll_factor_options = list(range(1, 6))
warp_size = 32

num_iterations, num_reductions = input_shape

for threads_per_cta, vectorize_factor, unroll_factor in itertools.product(
threads_per_cta_options, vectorization_factor_options, unroll_factor_options
for (
threads_per_cta,
vectorize_factor,
reduction_unroll_factor,
iteration_unroll_factor,
) in itertools.product(
threads_per_cta_options,
vectorization_factor_options,
reduction_unroll_factor_options,
iteration_unroll_factor_options,
):
scheduler_config = self.InnerReductionConfiguration(
vectorize_factor=vectorize_factor, unroll_factor=unroll_factor
vectorize_factor=vectorize_factor,
reduction_unroll_factor=reduction_unroll_factor,
iteration_unroll_factor=iteration_unroll_factor,
)
scheduler_config.bdimx = min(
threads_per_cta,
Expand All @@ -184,20 +202,35 @@ def generate_scheduler_configurations(self, input_shape):
max(1, floor_div(threads_per_cta, scheduler_config.bdimx)),
)
scheduler_config.godim = ceil_div(
num_iterations, scheduler_config.bdimy * scheduler_config.unroll_factor
num_iterations, scheduler_config.bdimy * iteration_unroll_factor
)

# number of reduction elements not handled by a CTA
remaining_reduction = ceil_div(
num_reductions,
(scheduler_config.bdimx * scheduler_config.vectorize_factor),
ceil_div(
ceil_div(num_reductions, vectorize_factor), scheduler_config.bdimx
),
reduction_unroll_factor,
)

if unroll_factor == 1 and remaining_reduction > 1:
if iteration_unroll_factor == 1 and remaining_reduction > 1:
# all remaining reduction goes to grdim
scheduler_config.grdim = remaining_reduction
yield scheduler_config

# When iteration dim is small, there may be unused SMs. We need
# to shift work from block reduction to grid reduction to
# increase SM usage.
godim = scheduler_config.godim
grdim = 1
while (
godim * grdim * 2 <= self.gpu_properties.multi_processor_count
and (remaining_reduction / grdim) >= 2
):
grdim *= 2
scheduler_config.grdim = grdim
yield scheduler_config

# grid stride across reduction iterDomain is 1
scheduler_config.grdim = 1
yield scheduler_config
Expand Down
Loading

0 comments on commit 2a6f059

Please sign in to comment.