From 2f4f8bcd55a8db3884251fa1af095cf707943e7d Mon Sep 17 00:00:00 2001 From: Alex Reinking Date: Tue, 5 Nov 2024 12:47:22 -0500 Subject: [PATCH 01/19] Add failing rfactor tests --- test/error/CMakeLists.txt | 2 ++ .../rfactor_after_var_and_rvar_fusion.cpp | 25 ++++++++++++++++++ test/error/rfactor_fused_var_and_rvar.cpp | 26 +++++++++++++++++++ 3 files changed, 53 insertions(+) create mode 100644 test/error/rfactor_after_var_and_rvar_fusion.cpp create mode 100644 test/error/rfactor_fused_var_and_rvar.cpp diff --git a/test/error/CMakeLists.txt b/test/error/CMakeLists.txt index 0478c3b11087..5272b2717de7 100644 --- a/test/error/CMakeLists.txt +++ b/test/error/CMakeLists.txt @@ -94,6 +94,8 @@ tests(GROUPS error require_fail.cpp reuse_var_in_schedule.cpp reused_args.cpp + rfactor_after_var_and_rvar_fusion.cpp + rfactor_fused_var_and_rvar.cpp rfactor_inner_dim_non_commutative.cpp round_up_and_blend_race.cpp run_with_large_stack_throws.cpp diff --git a/test/error/rfactor_after_var_and_rvar_fusion.cpp b/test/error/rfactor_after_var_and_rvar_fusion.cpp new file mode 100644 index 000000000000..8b94fbdd1b16 --- /dev/null +++ b/test/error/rfactor_after_var_and_rvar_fusion.cpp @@ -0,0 +1,25 @@ +#include "Halide.h" + +using namespace Halide; + +int main(int argc, char **argv) { + Func f{"f"}; + RDom r({{0, 5}, {0, 5}, {0, 5}}, "r"); + Var x{"x"}, y{"y"}; + f(x, y) = 0; + f(x, y) += r.x + r.y + r.z; + + RVar rxy{"rxy"}, yrz{"yrz"}; + Var z{"z"}; + + // Error: In schedule for f.update(0), can't perform rfactor() after fusing y and r$z + f.update() + .fuse(r.x, r.y, rxy) + .fuse(r.z, y, yrz) + .rfactor(rxy, z); + + f.print_loop_nest(); + + printf("Success!\n"); + return 0; +} \ No newline at end of file diff --git a/test/error/rfactor_fused_var_and_rvar.cpp b/test/error/rfactor_fused_var_and_rvar.cpp new file mode 100644 index 000000000000..a167ca543c47 --- /dev/null +++ b/test/error/rfactor_fused_var_and_rvar.cpp @@ -0,0 +1,26 @@ +#include "Halide.h" + +using namespace Halide; + +int main(int argc, char **argv) { + Func f{"f"}; + RDom r({{0, 5}, {0, 5}, {0, 5}}, "r"); + Var x{"x"}, y{"y"}; + f(x, y) = 0; + f(x, y) += r.x + r.y + r.z; + + RVar rxy{"rxy"}, yrz{"yrz"}, yr{"yr"}; + Var z{"z"}; + + // Error: In schedule for f.update(0), can't perform rfactor() after fusing r$z and y + f.update() + .fuse(r.x, r.y, rxy) + .fuse(y, r.z, yrz) + .fuse(rxy, yrz, yr) + .rfactor(yr, z); + + f.print_loop_nest(); + + printf("Success!\n"); + return 0; +} \ No newline at end of file From 47cdbccefe49c898317269506d43045775946701 Mon Sep 17 00:00:00 2001 From: Alex Reinking Date: Thu, 21 Nov 2024 16:20:34 -0500 Subject: [PATCH 02/19] Rewrite rfactor() --- src/Func.cpp | 796 +++++++++++++++++++++-------------------------- src/Func.h | 2 +- src/Substitute.h | 10 + 3 files changed, 369 insertions(+), 439 deletions(-) diff --git a/src/Func.cpp b/src/Func.cpp index c243e6950f3f..9b03c64aa5ae 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #ifdef _MSC_VER @@ -609,522 +610,441 @@ Expr substitute_self_reference(Expr val, const string &func, const Function &sub return val; } -// Substitute the occurrence of 'name' in 'exprs' with 'value'. -void substitute_var_in_exprs(const string &name, const Expr &value, vector &exprs) { - for (auto &expr : exprs) { - expr = substitute(name, value, expr); - } +} // anonymous namespace + +Func Stage::rfactor(const RVar &r, const Var &v) { + definition.schedule().touched() = true; + return rfactor({{r, v}}); } -void apply_split_result(const vector> &bounds_let_stmts, - const vector &splits_result, - vector &predicates, vector &args, - vector &values) { +// Helpers for rfactor implementation +namespace { - for (const auto &res : splits_result) { - switch (res.type) { - case ApplySplitResult::Substitution: - case ApplySplitResult::LetStmt: - // Apply substitutions to the list of predicates, args, and values. - // Make sure we substitute in all the let stmts as well since we are - // not going to add them to the exprs. - substitute_var_in_exprs(res.name, res.value, predicates); - substitute_var_in_exprs(res.name, res.value, args); - substitute_var_in_exprs(res.name, res.value, values); - break; - default: - internal_assert(res.type == ApplySplitResult::Predicate); - predicates.push_back(res.value); - break; - } +struct DimHash { + std::size_t operator()(const Dim &s) const noexcept { + return std::hash{}(s.var); } - - // Make sure we substitute in all the let stmts from 'bounds_let_stmts' - // since we are not going to add them to the exprs. - for (const auto &let : bounds_let_stmts) { - substitute_var_in_exprs(let.first, let.second, predicates); - substitute_var_in_exprs(let.first, let.second, args); - substitute_var_in_exprs(let.first, let.second, values); +}; +struct DimEq { + bool operator()(const Dim &lhs, const Dim &rhs) const noexcept { + return lhs.var == rhs.var; } -} +}; +using DimSet = std::unordered_set; -/** Apply split directives on the reduction variables. Remove the old RVar from - * the list and add the split result (inner and outer RVars) to the list. Add - * new predicates corresponding to the TailStrategy to the RDom predicate list. */ -bool apply_split(const Split &s, vector &rvars, - vector &predicates, vector &args, - vector &values, map &dim_extent_alignment) { - internal_assert(s.split_type == Split::SplitVar); - const auto it = std::find_if(rvars.begin(), rvars.end(), - [&s](const ReductionVariable &rv) { return (s.old_var == rv.var); }); +template +struct has_name : std::false_type {}; - Expr old_max, old_min, old_extent; +template +struct has_name> : std::true_type {}; - if (it != rvars.end()) { - debug(4) << " Splitting " << it->var << " into " << s.outer << " and " << s.inner << "\n"; +template +std::optional match_by_dim_name(const std::string &dim, const std::vector &items) { + const auto has_v = std::find_if(items.begin(), items.end(), [&dim](auto &x) { + if constexpr (has_name::value) { + return var_name_match(dim, x.name()); + } else { + return var_name_match(dim, x.var); + } + }); + return has_v == items.end() ? std::nullopt : std::make_optional(*has_v); +} - old_max = simplify(it->min + it->extent - 1); - old_min = it->min; - old_extent = it->extent; +template +std::optional find_by_var_name(const std::vector &items, const std::string &name) { + const auto has_v = std::find_if(items.begin(), items.end(), [&name](auto &x) { + if constexpr (has_name::value) { + return var_name_match(x.name(), name); + } else { + return var_name_match(x.var, name); + } + }); + return has_v == items.end() ? std::nullopt : std::make_optional(*has_v); +} - it->var = s.inner; - it->min = 0; - it->extent = s.factor; +template +std::optional find_by_var_name(const std::vector &items, const VarOrRVar &v) { + return find_by_var_name(items, v.name()); +} - rvars.insert(it + 1, {s.outer, 0, simplify((old_extent - 1 + s.factor) / s.factor)}); +std::optional rfactor_validate_args(const vector> &preserved, const AssociativeOp &prover_result, const std::vector &dims) { + if (!prover_result.associative()) { + return "can't perform rfactor() because we can't prove associativity of the operator"; + } - vector splits_result = apply_split(s, "", dim_extent_alignment); - vector> bounds_let_stmts = compute_loop_bounds_after_split(s, ""); - apply_split_result(bounds_let_stmts, splits_result, predicates, args, values); + DimSet is_rfactored; + for (const auto &[rv, v] : preserved) { + // Check that the RVar are in the dims list + const auto &rv_dim = find_by_var_name(dims, rv); + if (!(rv_dim && rv_dim->is_rvar())) { + std::stringstream s; + s << "can't perform rfactor() on " << rv.name() << " since it is not in the reduction domain"; + return s.str(); + } + is_rfactored.insert(*rv_dim); - return true; + // Check that the new pure Vars we used to rename the RVar aren't already in the dims list + const auto &v_dim = find_by_var_name(dims, v); + if (v_dim) { + std::stringstream s; + s << "can't rename the rvars " << rv.name() << " into " << v.name() + << ", since it is already used in this Func's schedule elsewhere."; + return s.str(); + } } - return false; -} -/** Apply fuse directives on the reduction variables. Remove the - * fused RVars from the list and add the fused RVar to the list. */ -bool apply_fuse(const Split &s, vector &rvars, - vector &predicates, vector &args, - vector &values, map &dim_extent_alignment) { - internal_assert(s.split_type == Split::FuseVars); - const auto &iter_outer = std::find_if(rvars.begin(), rvars.end(), - [&s](const ReductionVariable &rv) { return (s.outer == rv.var); }); - const auto &iter_inner = std::find_if(rvars.begin(), rvars.end(), - [&s](const ReductionVariable &rv) { return (s.inner == rv.var); }); + // If the operator is associative but non-commutative, rfactor() on inner + // dimensions (excluding the outer dimensions) is not valid. + if (!prover_result.commutative()) { + std::optional last_rvar; + for (const auto &d : reverse_view(dims)) { + if (is_rfactored.count(d) && last_rvar && !is_rfactored.count(*last_rvar)) { + std::stringstream s; + s << "can't rfactor an inner dimension " << d.var + << " without rfactoring the outer dimensions, since the " + << "operator is non-commutative."; + return s.str(); + } + if (d.is_rvar()) { + last_rvar = d; + } + } + } - Expr inner_min, inner_extent, outer_min, outer_extent; - if ((iter_outer != rvars.end()) && (iter_inner != rvars.end())) { - debug(4) << " Fusing " << s.outer << " and " << s.inner << " into " << s.old_var << "\n"; + return std::nullopt; +} - inner_min = iter_inner->min; - inner_extent = iter_inner->extent; - outer_min = iter_outer->min; - outer_extent = iter_outer->extent; +template +std::vector &operator+=(std::vector &base, const std::vector &other) { + base.insert(base.end(), other.begin(), other.end()); + return base; +} - Expr extent = iter_outer->extent * iter_inner->extent; - iter_outer->var = s.old_var; - iter_outer->min = 0; - iter_outer->extent = extent; - rvars.erase(iter_inner); +template +std::vector &operator+=(std::vector &base, const std::vector &other) { + std::copy(other.begin(), other.end(), std::back_inserter(base)); + return base; +} - vector splits_result = apply_split(s, "", dim_extent_alignment); - vector> bounds_let_stmts = compute_loop_bounds_after_split(s, ""); - apply_split_result(bounds_let_stmts, splits_result, predicates, args, values); +template +std::vector operator+(const std::vector &base, const std::vector &other) { + std::vector merged = base; + merged += other; + return merged; +} - return true; - } - return false; +template +std::vector operator+(const std::vector &base, const std::vector &other) { + std::vector merged = base; + merged += other; + return merged; } -/** Apply purify directives on the reduction variables and predicates. Purify - * replace a RVar with a Var, thus, the RVar needs to be removed from the list. - * Any reference to the RVar in the predicates will be replaced with reference - * to a Var. */ -bool apply_purify(const Split &s, vector &rvars, - vector &predicates, vector &args, - vector &values, map &dim_extent_alignment) { - internal_assert(s.split_type == Split::PurifyRVar); - const auto &iter = std::find_if(rvars.begin(), rvars.end(), - [&s](const ReductionVariable &rv) { return (s.old_var == rv.var); }); - if (iter != rvars.end()) { - debug(4) << " Purify RVar " << iter->var << " into Var " << s.outer - << ", deleting it from the rvars list\n"; - rvars.erase(iter); +template +std::vector copy_convert(const std::vector &vec) { + return std::vector{} + vec; +} - vector splits_result = apply_split(s, "", dim_extent_alignment); - vector> bounds_let_stmts = compute_loop_bounds_after_split(s, ""); - apply_split_result(bounds_let_stmts, splits_result, predicates, args, values); +using SubstitutionMap = std::map; - return true; +// This is a helper function for building up a substitution map that +// corresponds to pushing down a nest of lets. The lets should be fed +// to this function from innermost to outermost. This is equivalent to +// building a let-nest as a one-hole context and then simplifying. +void add_let(SubstitutionMap &proj, const std::string &name, const Expr &value) { + for (auto &[_, e] : proj) { + e = substitute(name, value, e); } - return false; + proj.emplace(name, value); } -/** Apply rename directives on the reduction variables. */ -bool apply_rename(const Split &s, vector &rvars, - vector &predicates, vector &args, - vector &values, map &dim_extent_alignment) { - internal_assert(s.split_type == Split::RenameVar); - const auto &iter = std::find_if(rvars.begin(), rvars.end(), - [&s](const ReductionVariable &rv) { return (s.old_var == rv.var); }); - if (iter != rvars.end()) { - debug(4) << " Renaming " << iter->var << " into " << s.outer << "\n"; - iter->var = s.outer; - - vector splits_result = apply_split(s, "", dim_extent_alignment); - vector> bounds_let_stmts = compute_loop_bounds_after_split(s, ""); - apply_split_result(bounds_let_stmts, splits_result, predicates, args, values); - - return true; +void rebind(SubstitutionMap &proj, const std::string &name, const Expr &value) { + for (auto &[_, e] : proj) { + e = substitute(name, value, e); + } + if (!proj.count(name)) { + proj.emplace(name, value); } - return false; } -/** Apply scheduling directives (e.g. split, fuse, etc.) on the reduction - * variables. */ -bool apply_split_directive(const Split &s, vector &rvars, - vector &predicates, vector &args, - vector &values) { - map dim_extent_alignment; - for (const ReductionVariable &rv : rvars) { - dim_extent_alignment[rv.var] = rv.extent; - } +std::pair project_rdom(const std::vector &dims, const Definition &def) { + const auto &rdom = def.schedule().rvars(); + const auto &splits = def.schedule().splits(); + const auto &predicate = def.predicate(); - vector> rvar_bounds; - for (const ReductionVariable &rv : rvars) { - rvar_bounds.emplace_back(rv.var + ".loop_min", rv.min); - rvar_bounds.emplace_back(rv.var + ".loop_max", simplify(rv.min + rv.extent - 1)); - rvar_bounds.emplace_back(rv.var + ".loop_extent", rv.extent); + // Compute new bounds for the RDom by walking backwards through the splits. + SubstitutionMap bounds_projection{}; + for (const Split &split : reverse_view(splits)) { + for (const auto &[name, value] : compute_loop_bounds_after_split(split, "")) { + add_let(bounds_projection, name, value); + } + } + for (const ReductionVariable &rv : rdom) { + add_let(bounds_projection, rv.var + ".loop_min", rv.min); + add_let(bounds_projection, rv.var + ".loop_max", rv.min + rv.extent - 1); + add_let(bounds_projection, rv.var + ".loop_extent", rv.extent); } - bool found = false; - switch (s.split_type) { - case Split::SplitVar: - found = apply_split(s, rvars, predicates, args, values, dim_extent_alignment); - break; - case Split::FuseVars: - found = apply_fuse(s, rvars, predicates, args, values, dim_extent_alignment); - break; - case Split::PurifyRVar: - found = apply_purify(s, rvars, predicates, args, values, dim_extent_alignment); - break; - case Split::RenameVar: - found = apply_rename(s, rvars, predicates, args, values, dim_extent_alignment); - break; + // Build the new RDom + std::vector new_rvars; + for (const Dim &dim : dims) { + Expr new_min = substitute(bounds_projection, Variable::make(Int(32), dim.var + ".loop_min")); + Expr new_extent = substitute(bounds_projection, Variable::make(Int(32), dim.var + ".loop_extent")); + new_rvars.push_back(ReductionVariable{dim.var, simplify(new_min), simplify(new_extent)}); } + ReductionDomain new_rdom{new_rvars}; + new_rdom.where(predicate); - if (found) { - for (const auto &let : rvar_bounds) { - substitute_var_in_exprs(let.first, let.second, predicates); - substitute_var_in_exprs(let.first, let.second, args); - substitute_var_in_exprs(let.first, let.second, values); + // Compute the mapping of old dimensions to the projected RDom values. + SubstitutionMap dim_projection{}; + SubstitutionMap dim_extent_alignment{}; + for (const ReductionVariable &rv : rdom) { + add_let(dim_projection, rv.var, Variable::make(Int(32), rv.var)); + dim_extent_alignment[rv.var] = rv.extent; + } + for (const Split &split : splits) { + for (const auto &result : apply_split(split, "", dim_extent_alignment)) { + switch (result.type) { + case ApplySplitResult::Substitution: + case ApplySplitResult::SubstitutionInCalls: + case ApplySplitResult::SubstitutionInProvides: + case ApplySplitResult::LetStmt: + add_let(dim_projection, result.name, result.value); + break; + case ApplySplitResult::PredicateCalls: + case ApplySplitResult::PredicateProvides: + case ApplySplitResult::Predicate: + new_rdom.where(substitute(bounds_projection, result.value)); + break; + case ApplySplitResult::BlendProvides: + // TODO: what to do here? BlendProvides is not included in the above checks. + break; + } } } - return found; -} - -} // anonymous namespace + for (auto &[_, e] : dim_projection) { + e = substitute(bounds_projection, e); + } + for (const ReductionVariable &rv : new_rdom.domain()) { + rebind(dim_projection, rv.var, Variable::make(Int(32), rv.var, new_rdom)); + } -Func Stage::rfactor(const RVar &r, const Var &v) { - definition.schedule().touched() = true; - return rfactor({{r, v}}); + return std::make_pair(new_rdom, dim_projection); } -Func Stage::rfactor(vector> preserved) { +} // namespace + +Func Stage::rfactor(const vector> &preserved) { user_assert(!definition.is_init()) << "rfactor() must be called on an update definition\n"; definition.schedule().touched() = true; - const string &func_name = function.name(); - vector &args = definition.args(); - vector &values = definition.values(); - - // Figure out which pure vars were used in this update definition. - std::set pure_vars_used; - internal_assert(args.size() == dim_vars.size()); - for (size_t i = 0; i < args.size(); i++) { - if (const Internal::Variable *var = args[i].as()) { - if (var->name == dim_vars[i].name()) { - pure_vars_used.insert(var->name); - } - } - } - // Check whether the operator is associative and determine the operator and // its identity for each value in the definition if it is a Tuple - const auto &prover_result = prove_associativity(func_name, args, values); + const auto &prover_result = prove_associativity(function.name(), definition.args(), definition.values()); - user_assert(prover_result.associative()) - << "Failed to call rfactor() on " << name() - << " since it can't prove associativity of the operator\n"; - internal_assert(prover_result.size() == values.size()); + const auto &error = rfactor_validate_args(preserved, prover_result, definition.schedule().dims()); + user_assert(!error) << "In schedule for " << name() << ": " << *error << "\n" + << dump_argument_list(); - vector &splits = definition.schedule().splits(); - vector &dims = definition.schedule().dims(); - vector &rvars = definition.schedule().rvars(); - vector predicates = definition.split_predicate(); - - Scope scope; // Contains list of RVars lifted to the intermediate Func - vector rvars_removed; + // sort preserved by the dimension ordering + vector preserved_rvars; + vector preserved_vars; + vector preserved_rdims; + vector intermediate_rdims; + { + std::unordered_map dim_ordering; + for (int i = 0; i < definition.schedule().dims().size(); i++) { + dim_ordering.emplace(definition.schedule().dims()[i].var, i); + } - vector is_rfactored(dims.size(), false); - for (const pair &i : preserved) { - const RVar &rv = i.first; - const Var &v = i.second; - { - // Check that the RVar are in the dims list - const auto &iter = std::find_if(dims.begin(), dims.end(), - [&rv](const Dim &dim) { return var_name_match(dim.var, rv.name()); }); - user_assert((iter != dims.end()) && (*iter).is_rvar()) - << "In schedule for " << name() - << ", can't perform rfactor() on " << rv.name() - << " since it is not in the reduction domain\n" - << dump_argument_list(); - is_rfactored[iter - dims.begin()] = true; + std::vector> preserved_with_dims; + for (const auto &[rv, v] : preserved) { + const std::optional rdim = find_by_var_name(definition.schedule().dims(), rv); + internal_assert(rdim); + preserved_with_dims.emplace_back(rv, v, *rdim); } - { - // Check that the new pure Vars we used to rename the RVar aren't already in the dims list - const auto &iter = std::find_if(dims.begin(), dims.end(), - [&v](const Dim &dim) { return var_name_match(dim.var, v.name()); }); - user_assert(iter == dims.end()) - << "In schedule for " << name() - << ", can't rename the rvars " << rv.name() << " into " << v.name() - << ", since it is already used in this Func's schedule elsewhere.\n" - << dump_argument_list(); + + std::sort(preserved_with_dims.begin(), preserved_with_dims.end(), [&](const auto &lhs, const auto &rhs) { + return dim_ordering.at(std::get<2>(lhs).var) < dim_ordering.at(std::get<2>(rhs).var); + }); + + for (const auto &[rv, v, dim] : preserved_with_dims) { + preserved_rvars.push_back(rv); + preserved_vars.push_back(v); + preserved_rdims.push_back(dim); } - } - // If the operator is associative but non-commutative, rfactor() on inner - // dimensions (excluding the outer dimensions) is not valid. - if (!prover_result.commutative()) { - int last_rvar = -1; - for (int i = dims.size() - 1; i >= 0; --i) { - if ((last_rvar != -1) && is_rfactored[i]) { - user_assert(is_rfactored[last_rvar]) - << "In schedule for " << name() - << ", can't rfactor an inner dimension " << dims[i].var - << " without rfactoring the outer dimensions, since the " - << "operator is non-commutative.\n" - << dump_argument_list(); - } - if (dims[i].is_rvar()) { - last_rvar = i; + for (const Dim &dim : definition.schedule().dims()) { + if (dim.is_rvar() && !match_by_dim_name(dim.var, preserved_rvars)) { + intermediate_rdims.push_back(dim); } } } - // We need to apply the split directives on the reduction vars, so that we can - // correctly lift the RVars not in 'rvars_kept' and distribute the RVars to the - // intermediate and merge Funcs. - { - vector temp; - for (const Split &s : splits) { - // If it's already applied, we should remove it from the split list. - if (!apply_split_directive(s, rvars, predicates, args, values)) { - temp.push_back(s); - } - } - splits = temp; - } - - // Reduction domain of the intermediate update definition - vector intm_rvars; - for (const auto &rv : rvars) { - const auto &iter = std::find_if(preserved.begin(), preserved.end(), - [&rv](const pair &pair) { return var_name_match(rv.var, pair.first.name()); }); - if (iter == preserved.end()) { - intm_rvars.push_back(rv); - scope.push(rv.var, rv.var); - } - } - RDom intm_rdom(intm_rvars); - - // Sort the Rvars kept and their Vars replacement based on the RVars of - // the reduction domain AFTER applying the split directives, so that we - // can have a consistent args order for the update definition of the - // intermediate and new merge Funcs. - std::sort(preserved.begin(), preserved.end(), - [&](const pair &lhs, const pair &rhs) { - const auto &iter_lhs = std::find_if(rvars.begin(), rvars.end(), - [&lhs](const ReductionVariable &rv) { return var_name_match(rv.var, lhs.first.name()); }); - const auto &iter_rhs = std::find_if(rvars.begin(), rvars.end(), - [&rhs](const ReductionVariable &rv) { return var_name_match(rv.var, rhs.first.name()); }); - return iter_lhs < iter_rhs; - }); - // The list of RVars to keep in the new update definition - vector rvars_kept(preserved.size()); - // List of pure Vars to replace the RVars in the intermediate's update definition - vector vars_rename(preserved.size()); - for (size_t i = 0; i < preserved.size(); ++i) { - const auto &val = preserved[i]; - rvars_kept[i] = val.first; - vars_rename[i] = val.second; - } - - // List of RVars for the new reduction domain. Any RVars not in 'rvars_kept' - // are removed from the RDom + // Intermediate func + Func intm(function.name() + "_intm"); + + // Intermediate pure definition { - vector temp; - for (const auto &rv : rvars) { - const auto &iter = std::find_if(rvars_kept.begin(), rvars_kept.end(), - [&rv](const RVar &rvar) { return var_name_match(rv.var, rvar.name()); }); - if (iter != rvars_kept.end()) { - temp.push_back(rv); - } else { - rvars_removed.push_back(rv.var); - } - } - rvars.swap(temp); + intm(dim_vars + preserved_vars) = Tuple(prover_result.pattern.identities); } - RDom f_rdom(rvars); - - // Init definition of the intermediate Func - // Compute args of the init definition of the intermediate Func. - // Replace the RVars, which are in 'rvars_kept', with the specified new pure - // Vars. Also, add the pure Vars of the original init definition as part of - // the args. - // For example, if we have the following Func f: - // f(x, y) = 10 - // f(r.x, r.y) += h(r.x, r.y) - // Calling f.update(0).rfactor({{r.y, u}}) will generate the following - // intermediate Func: - // f_intm(x, y, u) = 0 - // f_intm(r.x, u, u) += h(r.x, u) + // Intermediate update definition + { + auto [intermediate_rdom, intermediate_map] = project_rdom(intermediate_rdims, definition); + for (int i = 0; i < preserved.size(); i++) { + rebind(intermediate_map, preserved_rdims[i].var, preserved_vars[i]); + } + for (const auto &var : dim_vars) { + intermediate_map.erase(var.name()); + } - vector init_args; - init_args.insert(init_args.end(), dim_vars.begin(), dim_vars.end()); - init_args.insert(init_args.end(), vars_rename.begin(), vars_rename.end()); + intermediate_rdom.set_predicate(simplify(substitute(intermediate_map, intermediate_rdom.predicate()))); - vector init_vals(values.size()); - for (size_t i = 0; i < init_vals.size(); ++i) { - init_vals[i] = prover_result.pattern.identities[i]; - } + vector args = definition.args() + preserved_vars; + vector values; + for (const auto &val : definition.values()) { + values.push_back(substitute_self_reference(val, function.name(), intm.function(), preserved_vars)); + } + args = substitute(intermediate_map, args); + values = substitute(intermediate_map, values); + intm.function().define_update(args, values, intermediate_rdom); - Func intm(func_name + "_intm"); - intm(init_args) = Tuple(init_vals); + // Intermediate schedule + intm.function().update(0).schedule() = definition.schedule().get_copy(); - // Args of the update definition of the intermediate Func - vector update_args(args.size() + vars_rename.size()); + auto &intm_dims = intm.function().update(0).schedule().dims(); - // We need to substitute the reference to the old RDom's RVars with - // the new RDom's RVars. Also, substitute the reference to RVars which - // are in 'rvars_kept' with their corresponding new pure Vars - map substitution_map; - for (size_t i = 0; i < intm_rvars.size(); ++i) { - substitution_map[intm_rvars[i].var] = intm_rdom[i]; - } - for (size_t i = 0; i < vars_rename.size(); i++) { - update_args[i + args.size()] = vars_rename[i]; - RVar rvar_kept = rvars_kept[i]; - // Find the full name of rvar_kept in rvars - const auto &iter = std::find_if(rvars.begin(), rvars.end(), - [&rvar_kept](const ReductionVariable &rv) { return var_name_match(rv.var, rvar_kept.name()); }); - substitution_map[iter->var] = vars_rename[i]; - } - for (size_t i = 0; i < args.size(); i++) { - Expr arg = substitute(substitution_map, args[i]); - update_args[i] = arg; - } + // Replace rvar dims IN the preserved list with their Vars in the INTERMEDIATE Func + for (auto &dim : intm_dims) { + const auto it = std::find_if(preserved_rvars.begin(), preserved_rvars.end(), [&](const auto &rv) { + return var_name_match(dim.var, rv.name()); + }); + if (it != preserved_rvars.end()) { + const auto offset = it - preserved_rvars.begin(); + const auto &var = preserved_vars[offset]; + const auto &pure_dim = find_by_var_name(intm.function().definition().schedule().dims(), var); + internal_assert(pure_dim); + dim = *pure_dim; + } + } - // Compute the predicates for the intermediate Func and the new update definition - for (const Expr &pred : predicates) { - Expr subs_pred = substitute(substitution_map, pred); - intm_rdom.where(subs_pred); - if (!expr_uses_vars(pred, scope)) { - // Only keep the predicate that does not depend on the lifted RVars - // (either explicitly or implicitly). For example, if 'rx' is split - // into 'rxo' and 'rxi' and 'rxo' is part of the lifted RVars, we'll - // ignore every predicate that depends on 'rx' - f_rdom.where(pred); + // Add factored pure dims to the INTERMEDIATE func just before outermost + DimSet dims; + dims.insert(intm_dims.begin(), intm_dims.end()); + for (const Var &dim_v : preserved_vars) { + const std::optional &dim = find_by_var_name(intm.function().definition().schedule().dims(), dim_v); + internal_assert(dim) << "Failed to find " << dim_v.name() << " in list of pure dims"; + if (!dims.count(*dim)) { + intm_dims.insert(intm_dims.end() - 1, *dim); + } } - } - definition.predicate() = f_rdom.domain().predicate(); - // The update values the intermediate Func should compute - vector update_vals(values.size()); - for (size_t i = 0; i < update_vals.size(); i++) { - Expr val = substitute(substitution_map, values[i]); - // Need to update the self-reference in the update definition to point - // to the new intermediate Func - val = substitute_self_reference(val, func_name, intm.function(), vars_rename); - update_vals[i] = val; + intm.function().update(0).schedule().rvars() = intermediate_rdom.domain(); } - // There may not actually be a reference to the RDom in the args or values, - // so we use Function::define_update, which lets pass pass an explicit RDom. - intm.function().define_update(update_args, update_vals, intm_rdom.domain()); - // Determine the dims and schedule of the update definition of the - // intermediate Func. We copy over the schedule from the original - // update definition (e.g. split, parallelize, vectorize, etc.) - intm.function().update(0).schedule().dims() = dims; - intm.function().update(0).schedule().splits() = splits; + // Preserved update definition + { + auto [preserved_rdom, _] = project_rdom(preserved_rdims, definition); - // Copy over the storage order of the original pure dims - vector &intm_storage_dims = intm.function().schedule().storage_dims(); - internal_assert(intm_storage_dims.size() == - function.schedule().storage_dims().size() + vars_rename.size()); - for (size_t i = 0; i < function.schedule().storage_dims().size(); ++i) { - intm_storage_dims[i] = function.schedule().storage_dims()[i]; - } + // Replace the current definition with calls to the intermediate func. + vector dim_exprs = copy_convert(dim_vars); + vector f_load_args = dim_exprs; + for (const ReductionVariable &rv : preserved_rdom.domain()) { + f_load_args.push_back(Variable::make(Int(32), rv.var, preserved_rdom)); + } - for (size_t i = 0; i < rvars_kept.size(); ++i) { - // Apply the purify directive that replaces the RVar in rvars_kept - // with a pure Var - intm.update(0).purify(rvars_kept[i], vars_rename[i]); - } + SubstitutionMap replacements; + for (size_t i = 0; i < preserved.size(); i++) { + replacements.emplace(preserved_rdims[i].var, preserved_vars[i]); + } + for (size_t i = 0; i < definition.values().size(); ++i) { + if (!prover_result.ys[i].var.empty()) { + Expr r = (definition.values().size() == 1) ? Expr(intm(f_load_args)) : Expr(intm(f_load_args)[i]); + replacements.emplace(prover_result.ys[i].var, r); + } - // Determine the dims of the new update definition - - // The new update definition needs all the pure vars of the Func, but the - // one we're rfactoring may not have used them all. Add any missing ones to - // the dims list. - - // Add pure Vars from the original init definition to the dims list - // if they are not already in the list - for (const Var &v : dim_vars) { - if (!pure_vars_used.count(v.name())) { - Dim d = {v.name(), ForType::Serial, DeviceAPI::None, DimType::PureVar, Partition::Auto}; - // Insert it just before Var::outermost - dims.insert(dims.end() - 1, d); - } - } + if (!prover_result.xs[i].var.empty()) { + Expr prev_val = Call::make(intm.types()[i], function.name(), + dim_exprs, Call::CallType::Halide, + FunctionPtr(), i); + replacements.emplace(prover_result.xs[i].var, prev_val); + } else { + user_warning << "Update definition of " << name() << " at index " << i + << " doesn't depend on the previous value. This isn't a" + << " reduction operation\n"; + } + } - // Then, we need to remove lifted RVars from the dims list - for (const string &rv : rvars_removed) { - remove(rv); - } + std::vector reducing_dims; + { + DimSet preserved_rdim_set; + preserved_rdim_set.insert(preserved_rdims.begin(), preserved_rdims.end()); - // Define the new update definition which refers to the intermediate Func. - // Using the same example as above, the new update definition is: - // f(x, y) += f_intm(x, y, r.y) + // Remove rvar dims NOT IN the preserved list from the REDUCING Func + for (const auto &dim : definition.schedule().dims()) { + if (!dim.is_rvar() || preserved_rdim_set.count(dim)) { + reducing_dims.push_back(dim); + } + } - // Args for store in the new update definition - vector f_store_args(dim_vars.size()); - for (size_t i = 0; i < f_store_args.size(); ++i) { - f_store_args[i] = dim_vars[i]; - } - - // Call's args to the intermediate Func in the new update definition - vector f_load_args; - f_load_args.insert(f_load_args.end(), dim_vars.begin(), dim_vars.end()); - for (int i = 0; i < f_rdom.dimensions(); ++i) { - f_load_args.push_back(f_rdom[i]); - } - internal_assert(f_load_args.size() == init_args.size()); + // Add missing pure vars to the REDUCING func just before outermost + for (int i = 0; i < dim_vars.size(); i++) { + if (!expr_uses_var(definition.args()[i], dim_vars[i].name())) { + Dim d = {dim_vars[i].name(), ForType::Serial, DeviceAPI::None, DimType::PureVar, Partition::Auto}; + reducing_dims.insert(reducing_dims.end() - 1, d); + } + } + } - // Update value of the new update definition. It loads values from - // the intermediate Func. - vector f_values(values.size()); + definition.args() = dim_exprs; + definition.predicate() = const_true(); // TODO: replace with strongest postcondition of the intermediate predicate with the eliminated rvars havoc'd + definition.schedule().dims() = std::move(reducing_dims); + definition.schedule().rvars() = preserved_rdom.domain(); + definition.values() = substitute(replacements, prover_result.pattern.ops); + } - // There might be cross-dependencies between tuple elements, so we need - // to collect all substitutions first. - map replacements; - for (size_t i = 0; i < f_values.size(); ++i) { - if (!prover_result.ys[i].var.empty()) { - Expr r = (values.size() == 1) ? Expr(intm(f_load_args)) : Expr(intm(f_load_args)[i]); - replacements.emplace(prover_result.ys[i].var, r); + // Clean up the splits lists + for (Stage st : {*this, intm.update(0)}) { + Scope<> dims; + for (const Var &v : st.dim_vars) { + dims.push(v.name()); } - - if (!prover_result.xs[i].var.empty()) { - Expr prev_val = Call::make(intm.types()[i], func_name, - f_store_args, Call::CallType::Halide, - FunctionPtr(), i); - replacements.emplace(prover_result.xs[i].var, prev_val); - } else { - user_warning << "Update definition of " << name() << " at index " << i - << " doesn't depend on the previous value. This isn't a" - << " reduction operation\n"; + for (const ReductionVariable &rv : st.definition.schedule().rvars()) { + dims.push(rv.var); } + std::vector new_splits; + for (const Split &split : st.definition.schedule().splits()) { + switch (split.split_type) { + case Split::SplitVar: + if (dims.contains(split.old_var)) { + dims.pop(split.old_var); + dims.push(split.outer); + dims.push(split.inner); + new_splits.push_back(split); + } + break; + case Split::FuseVars: + if (dims.contains(split.outer) && dims.contains(split.inner)) { + dims.pop(split.outer); + dims.pop(split.inner); + dims.push(split.old_var); + new_splits.push_back(split); + } + break; + case Split::PurifyRVar: + case Split::RenameVar: + if (dims.contains(split.old_var)) { + dims.pop(split.old_var); + dims.push(split.outer); + new_splits.push_back(split); + } + break; + } + } + st.definition.schedule().splits().swap(new_splits); } - for (size_t i = 0; i < f_values.size(); ++i) { - f_values[i] = substitute(replacements, prover_result.pattern.ops[i]); - } - - // Update the definition - args.swap(f_store_args); - values.swap(f_values); return intm; } @@ -1187,7 +1107,7 @@ void Stage::split(const string &old, const string &outer, const string &inner, c bool round_up_ok = !exact; if (round_up_ok && !definition.is_init()) { // If it's the outermost split in this dimension, RoundUp - // is OK. Otherwise we need GuardWithIf to avoid + // is OK. Otherwise, we need GuardWithIf to avoid // recomputing values in the case where the inner split // factor does not divide the outer split factor. std::set inner_vars; @@ -1224,7 +1144,7 @@ void Stage::split(const string &old, const string &outer, const string &inner, c bool predicate_loads_ok = !exact; if (predicate_loads_ok && tail == TailStrategy::PredicateLoads) { // If it's the outermost split in this dimension, PredicateLoads - // is OK. Otherwise we can't prove it's safe. + // is OK. Otherwise, we can't prove it's safe. std::set inner_vars; for (const Split &s : definition.schedule().splits()) { switch (s.split_type) { diff --git a/src/Func.h b/src/Func.h index 32d8f1e58c69..a0f374eb874f 100644 --- a/src/Func.h +++ b/src/Func.h @@ -184,7 +184,7 @@ class Stage { * */ // @{ - Func rfactor(std::vector> preserved); + Func rfactor(const std::vector> &preserved); Func rfactor(const RVar &r, const Var &v); // @} diff --git a/src/Substitute.h b/src/Substitute.h index 22bdf640b7a8..55e7ac1cf0fa 100644 --- a/src/Substitute.h +++ b/src/Substitute.h @@ -37,6 +37,16 @@ Expr substitute(const Expr &find, const Expr &replacement, const Expr &expr); Stmt substitute(const Expr &find, const Expr &replacement, const Stmt &stmt); // @} +/** Substitute a container of Exprs or Stmts out of place */ +template +T substitute(const std::map &replacements, const T &container) { + T output; + std::transform(container.begin(), container.end(), std::back_inserter(output), [&](const auto &expr_or_stmt) { + return substitute(replacements, expr_or_stmt); + }); + return output; +} + /** Substitutions where the IR may be a general graph (and not just a * DAG). */ // @{ From 530bd3f46c2d132cb579463542eed842a8c5adbe Mon Sep 17 00:00:00 2001 From: Alex Reinking Date: Sun, 24 Nov 2024 09:52:29 -0500 Subject: [PATCH 03/19] Remove unused operators --- src/Func.cpp | 24 ++---------------------- 1 file changed, 2 insertions(+), 22 deletions(-) diff --git a/src/Func.cpp b/src/Func.cpp index 9b03c64aa5ae..6f2a10d36309 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -684,8 +684,7 @@ std::optional rfactor_validate_args(const vector> & is_rfactored.insert(*rv_dim); // Check that the new pure Vars we used to rename the RVar aren't already in the dims list - const auto &v_dim = find_by_var_name(dims, v); - if (v_dim) { + if (find_by_var_name(dims, v)) { std::stringstream s; s << "can't rename the rvars " << rv.name() << " into " << v.name() << ", since it is already used in this Func's schedule elsewhere."; @@ -714,29 +713,10 @@ std::optional rfactor_validate_args(const vector> & return std::nullopt; } -template -std::vector &operator+=(std::vector &base, const std::vector &other) { - base.insert(base.end(), other.begin(), other.end()); - return base; -} - -template -std::vector &operator+=(std::vector &base, const std::vector &other) { - std::copy(other.begin(), other.end(), std::back_inserter(base)); - return base; -} - -template -std::vector operator+(const std::vector &base, const std::vector &other) { - std::vector merged = base; - merged += other; - return merged; -} - template std::vector operator+(const std::vector &base, const std::vector &other) { std::vector merged = base; - merged += other; + merged.insert(merged.end(), other.begin(), other.end()); return merged; } From 6483c12188842263ecec1fe0c3d5a7445905df94 Mon Sep 17 00:00:00 2001 From: Alex Reinking Date: Sun, 24 Nov 2024 10:22:21 -0500 Subject: [PATCH 04/19] Clean up dim/var matching helpers --- src/Func.cpp | 49 ++++++++++++++----------------------------------- 1 file changed, 14 insertions(+), 35 deletions(-) diff --git a/src/Func.cpp b/src/Func.cpp index 6f2a10d36309..74cd1328d835 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -632,41 +632,21 @@ struct DimEq { }; using DimSet = std::unordered_set; -template -struct has_name : std::false_type {}; - -template -struct has_name> : std::true_type {}; - -template -std::optional match_by_dim_name(const std::string &dim, const std::vector &items) { - const auto has_v = std::find_if(items.begin(), items.end(), [&dim](auto &x) { - if constexpr (has_name::value) { - return var_name_match(dim, x.name()); - } else { - return var_name_match(dim, x.var); - } +std::optional find_rvar(const std::vector &items, const Dim &dim) { + const auto has_v = std::find_if(items.begin(), items.end(), [&](auto &x) { + return var_name_match(dim.var, x.name()); }); return has_v == items.end() ? std::nullopt : std::make_optional(*has_v); } -template -std::optional find_by_var_name(const std::vector &items, const std::string &name) { - const auto has_v = std::find_if(items.begin(), items.end(), [&name](auto &x) { - if constexpr (has_name::value) { - return var_name_match(x.name(), name); - } else { - return var_name_match(x.var, name); - } +std::optional find_dim(const std::vector &items, const VarOrRVar &v) { + const std::string &name = v.name(); + const auto has_v = std::find_if(items.begin(), items.end(), [&](auto &x) { + return var_name_match(x.var, name); }); return has_v == items.end() ? std::nullopt : std::make_optional(*has_v); } -template -std::optional find_by_var_name(const std::vector &items, const VarOrRVar &v) { - return find_by_var_name(items, v.name()); -} - std::optional rfactor_validate_args(const vector> &preserved, const AssociativeOp &prover_result, const std::vector &dims) { if (!prover_result.associative()) { return "can't perform rfactor() because we can't prove associativity of the operator"; @@ -675,7 +655,7 @@ std::optional rfactor_validate_args(const vector> & DimSet is_rfactored; for (const auto &[rv, v] : preserved) { // Check that the RVar are in the dims list - const auto &rv_dim = find_by_var_name(dims, rv); + const auto &rv_dim = find_dim(dims, rv); if (!(rv_dim && rv_dim->is_rvar())) { std::stringstream s; s << "can't perform rfactor() on " << rv.name() << " since it is not in the reduction domain"; @@ -684,7 +664,7 @@ std::optional rfactor_validate_args(const vector> & is_rfactored.insert(*rv_dim); // Check that the new pure Vars we used to rename the RVar aren't already in the dims list - if (find_by_var_name(dims, v)) { + if (find_dim(dims, v)) { std::stringstream s; s << "can't rename the rvars " << rv.name() << " into " << v.name() << ", since it is already used in this Func's schedule elsewhere."; @@ -840,7 +820,7 @@ Func Stage::rfactor(const vector> &preserved) { std::vector> preserved_with_dims; for (const auto &[rv, v] : preserved) { - const std::optional rdim = find_by_var_name(definition.schedule().dims(), rv); + const std::optional rdim = find_dim(definition.schedule().dims(), rv); internal_assert(rdim); preserved_with_dims.emplace_back(rv, v, *rdim); } @@ -856,7 +836,7 @@ Func Stage::rfactor(const vector> &preserved) { } for (const Dim &dim : definition.schedule().dims()) { - if (dim.is_rvar() && !match_by_dim_name(dim.var, preserved_rvars)) { + if (dim.is_rvar() && !find_rvar(preserved_rvars, dim)) { intermediate_rdims.push_back(dim); } } @@ -882,12 +862,11 @@ Func Stage::rfactor(const vector> &preserved) { intermediate_rdom.set_predicate(simplify(substitute(intermediate_map, intermediate_rdom.predicate()))); - vector args = definition.args() + preserved_vars; + vector args = substitute(intermediate_map, definition.args() + preserved_vars); vector values; for (const auto &val : definition.values()) { values.push_back(substitute_self_reference(val, function.name(), intm.function(), preserved_vars)); } - args = substitute(intermediate_map, args); values = substitute(intermediate_map, values); intm.function().define_update(args, values, intermediate_rdom); @@ -904,7 +883,7 @@ Func Stage::rfactor(const vector> &preserved) { if (it != preserved_rvars.end()) { const auto offset = it - preserved_rvars.begin(); const auto &var = preserved_vars[offset]; - const auto &pure_dim = find_by_var_name(intm.function().definition().schedule().dims(), var); + const auto &pure_dim = find_dim(intm.function().definition().schedule().dims(), var); internal_assert(pure_dim); dim = *pure_dim; } @@ -914,7 +893,7 @@ Func Stage::rfactor(const vector> &preserved) { DimSet dims; dims.insert(intm_dims.begin(), intm_dims.end()); for (const Var &dim_v : preserved_vars) { - const std::optional &dim = find_by_var_name(intm.function().definition().schedule().dims(), dim_v); + const std::optional &dim = find_dim(intm.function().definition().schedule().dims(), dim_v); internal_assert(dim) << "Failed to find " << dim_v.name() << " in list of pure dims"; if (!dims.count(*dim)) { intm_dims.insert(intm_dims.end() - 1, *dim); From 7ffdb66075196bab8e05fce07962a49ed4358c86 Mon Sep 17 00:00:00 2001 From: Alex Reinking Date: Sun, 24 Nov 2024 10:25:07 -0500 Subject: [PATCH 05/19] Fix definition in PyStage.cpp --- python_bindings/src/halide/halide_/PyStage.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python_bindings/src/halide/halide_/PyStage.cpp b/python_bindings/src/halide/halide_/PyStage.cpp index b412a6f2b39e..fac47fa3cf1f 100644 --- a/python_bindings/src/halide/halide_/PyStage.cpp +++ b/python_bindings/src/halide/halide_/PyStage.cpp @@ -14,7 +14,7 @@ void define_stage(py::module &m) { .def("dump_argument_list", &Stage::dump_argument_list) .def("name", &Stage::name) - .def("rfactor", (Func(Stage::*)(std::vector>)) & Stage::rfactor, + .def("rfactor", (Func(Stage::*)(const std::vector> &)) & Stage::rfactor, py::arg("preserved")) .def("rfactor", (Func(Stage::*)(const RVar &, const Var &)) & Stage::rfactor, py::arg("r"), py::arg("v")) From 3fb172f2ee26b4b7ce7846d471f4864901b5dd3c Mon Sep 17 00:00:00 2001 From: Alex Reinking Date: Sun, 24 Nov 2024 10:59:33 -0500 Subject: [PATCH 06/19] Use and_condition_over_domain to predicate the reducing definition in rfactor --- src/Func.cpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/Func.cpp b/src/Func.cpp index 74cd1328d835..31657327bd5e 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -905,7 +905,8 @@ Func Stage::rfactor(const vector> &preserved) { // Preserved update definition { - auto [preserved_rdom, _] = project_rdom(preserved_rdims, definition); + auto [preserved_rdom, preserved_map] = project_rdom(preserved_rdims, definition); + preserved_rdom.set_predicate(simplify(substitute(preserved_map, preserved_rdom.predicate()))); // Replace the current definition with calls to the intermediate func. vector dim_exprs = copy_convert(dim_vars); @@ -957,8 +958,13 @@ Func Stage::rfactor(const vector> &preserved) { } } + Scope intm_rdom; + for (const auto &[var, min, extent] : intm.update(0).definition.schedule().rvars()) { + intm_rdom.push(var, Interval{min, min + extent - 1}); + } + definition.args() = dim_exprs; - definition.predicate() = const_true(); // TODO: replace with strongest postcondition of the intermediate predicate with the eliminated rvars havoc'd + definition.predicate() = !and_condition_over_domain(simplify(!preserved_rdom.predicate()), intm_rdom); definition.schedule().dims() = std::move(reducing_dims); definition.schedule().rvars() = preserved_rdom.domain(); definition.values() = substitute(replacements, prover_result.pattern.ops); From 27922009d78b01edbe2771e35d03e8e8d43b5315 Mon Sep 17 00:00:00 2001 From: Alex Reinking Date: Sun, 24 Nov 2024 11:21:39 -0500 Subject: [PATCH 07/19] Disallow rfactor() on funcs with RVar+Var fused schedules --- src/Func.cpp | 45 +++++++++++++++++++++++++++++++++--- test/common/expect_abort.cpp | 3 +++ 2 files changed, 45 insertions(+), 3 deletions(-) diff --git a/src/Func.cpp b/src/Func.cpp index 31657327bd5e..958a762b2120 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -647,14 +647,16 @@ std::optional find_dim(const std::vector &items, const VarOrRVar &v) { return has_v == items.end() ? std::nullopt : std::make_optional(*has_v); } -std::optional rfactor_validate_args(const vector> &preserved, const AssociativeOp &prover_result, const std::vector &dims) { +std::optional rfactor_validate_args(const vector> &preserved, const AssociativeOp &prover_result, const Definition &definition) { + const std::vector &dims = definition.schedule().dims(); + if (!prover_result.associative()) { return "can't perform rfactor() because we can't prove associativity of the operator"; } DimSet is_rfactored; for (const auto &[rv, v] : preserved) { - // Check that the RVar are in the dims list + // Check that the RVars are in the dims list const auto &rv_dim = find_dim(dims, rv); if (!(rv_dim && rv_dim->is_rvar())) { std::stringstream s; @@ -690,6 +692,43 @@ std::optional rfactor_validate_args(const vector> & } } + // Check that no Vars were fused into RVars + Scope<> rdims; + for (const ReductionVariable &rv : definition.schedule().rvars()) { + rdims.push(rv.var); + } + for (const Split &split : definition.schedule().splits()) { + switch (split.split_type) { + case Split::SplitVar: + if (rdims.contains(split.old_var)) { + rdims.pop(split.old_var); + rdims.push(split.outer); + rdims.push(split.inner); + } + break; + case Split::FuseVars: + if (rdims.contains(split.outer) || rdims.contains(split.inner)) { + if (!(rdims.contains(split.outer) && rdims.contains(split.inner))) { + std::stringstream s; + s << "Cannot rfactor a Func that has fused a Var into an RVar (" + << split.outer << ", " << split.inner << ")"; + return s.str(); + } + rdims.pop(split.outer); + rdims.pop(split.inner); + rdims.push(split.old_var); + } + break; + case Split::PurifyRVar: + case Split::RenameVar: + if (rdims.contains(split.old_var)) { + rdims.pop(split.old_var); + rdims.push(split.outer); + } + break; + } + } + return std::nullopt; } @@ -803,7 +842,7 @@ Func Stage::rfactor(const vector> &preserved) { // its identity for each value in the definition if it is a Tuple const auto &prover_result = prove_associativity(function.name(), definition.args(), definition.values()); - const auto &error = rfactor_validate_args(preserved, prover_result, definition.schedule().dims()); + const auto &error = rfactor_validate_args(preserved, prover_result, definition); user_assert(!error) << "In schedule for " << name() << ": " << *error << "\n" << dump_argument_list(); diff --git a/test/common/expect_abort.cpp b/test/common/expect_abort.cpp index cb09a7242921..fec89b0913b7 100644 --- a/test/common/expect_abort.cpp +++ b/test/common/expect_abort.cpp @@ -19,6 +19,9 @@ auto handler = ([]() { << std::flush; suppress_abort = false; std::abort(); // We should never EXPECT an internal error + } catch (const Halide::Error &e) { + std::cerr << e.what() << "\n" + << std::flush; } catch (const std::exception &e) { std::cerr << e.what() << "\n" << std::flush; From 9ccacd230597f4e1aa5da2f57fa9de7131192c14 Mon Sep 17 00:00:00 2001 From: Alex Reinking Date: Sun, 24 Nov 2024 11:23:58 -0500 Subject: [PATCH 08/19] Use size_t in place of int --- src/Func.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Func.cpp b/src/Func.cpp index 958a762b2120..9a39f82ba30c 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -853,7 +853,7 @@ Func Stage::rfactor(const vector> &preserved) { vector intermediate_rdims; { std::unordered_map dim_ordering; - for (int i = 0; i < definition.schedule().dims().size(); i++) { + for (size_t i = 0; i < definition.schedule().dims().size(); i++) { dim_ordering.emplace(definition.schedule().dims()[i].var, i); } @@ -892,7 +892,7 @@ Func Stage::rfactor(const vector> &preserved) { // Intermediate update definition { auto [intermediate_rdom, intermediate_map] = project_rdom(intermediate_rdims, definition); - for (int i = 0; i < preserved.size(); i++) { + for (size_t i = 0; i < preserved.size(); i++) { rebind(intermediate_map, preserved_rdims[i].var, preserved_vars[i]); } for (const auto &var : dim_vars) { @@ -989,7 +989,7 @@ Func Stage::rfactor(const vector> &preserved) { } // Add missing pure vars to the REDUCING func just before outermost - for (int i = 0; i < dim_vars.size(); i++) { + for (size_t i = 0; i < dim_vars.size(); i++) { if (!expr_uses_var(definition.args()[i], dim_vars[i].name())) { Dim d = {dim_vars[i].name(), ForType::Serial, DeviceAPI::None, DimType::PureVar, Partition::Auto}; reducing_dims.insert(reducing_dims.end() - 1, d); From e154fc17584e789ec6f0845157b13fb7e244b853 Mon Sep 17 00:00:00 2001 From: Alex Reinking Date: Mon, 25 Nov 2024 08:46:41 -0500 Subject: [PATCH 09/19] Clean up uses of split_predicate() --- src/BoundsInference.cpp | 24 ++++++++---------------- src/Derivative.cpp | 2 +- src/ScheduleFunctions.cpp | 3 +-- 3 files changed, 10 insertions(+), 19 deletions(-) diff --git a/src/BoundsInference.cpp b/src/BoundsInference.cpp index 724adb993afd..aba76f7798ed 100644 --- a/src/BoundsInference.cpp +++ b/src/BoundsInference.cpp @@ -14,6 +14,7 @@ #include #include +#include namespace Halide { namespace Internal { @@ -297,7 +298,6 @@ class BoundsInference : public IRMutator { } // Default case (no specialization) - vector predicates = def.split_predicate(); for (const ReductionVariable &rv : def.schedule().rvars()) { rvars.insert(rv); } @@ -308,23 +308,15 @@ class BoundsInference : public IRMutator { } vecs[1] = def.values(); + vector predicates = def.split_predicate(); for (size_t i = 0; i < result.size(); ++i) { for (const Expr &val : vecs[i]) { - if (!predicates.empty()) { - Expr cond_val = Call::make(val.type(), - Internal::Call::if_then_else, - {likely(predicates[0]), val}, - Internal::Call::PureIntrinsic); - for (size_t i = 1; i < predicates.size(); ++i) { - cond_val = Call::make(cond_val.type(), - Internal::Call::if_then_else, - {likely(predicates[i]), cond_val}, - Internal::Call::PureIntrinsic); - } - result[i].emplace_back(const_true(), cond_val); - } else { - result[i].emplace_back(const_true(), val); - } + Expr cond_val = std::accumulate( + predicates.begin(), predicates.end(), val, + [](const auto &acc, const auto &pred) { + return Call::make(acc.type(), Call::if_then_else, {likely(pred), acc}, Call::PureIntrinsic); + }); + result[i].emplace_back(const_true(), cond_val); } } diff --git a/src/Derivative.cpp b/src/Derivative.cpp index 2520d27e290f..06451732d80c 100644 --- a/src/Derivative.cpp +++ b/src/Derivative.cpp @@ -1534,7 +1534,7 @@ void ReverseAccumulationVisitor::propagate_halide_function_call( // f(r.x) = ... && r is associative // => f(x) = ... if (var != nullptr && var->reduction_domain.defined() && - var->reduction_domain.split_predicate().empty()) { + is_const_one(var->reduction_domain.predicate())) { ReductionDomain rdom = var->reduction_domain; int rvar_id = -1; for (int rid = 0; rid < (int)rdom.domain().size(); rid++) { diff --git a/src/ScheduleFunctions.cpp b/src/ScheduleFunctions.cpp index c7a257dd085e..74f0fa738e42 100644 --- a/src/ScheduleFunctions.cpp +++ b/src/ScheduleFunctions.cpp @@ -178,7 +178,6 @@ Stmt build_loop_nest( const auto &dims = func.args(); const auto &func_s = func.schedule(); const auto &stage_s = def.schedule(); - const auto &predicates = def.split_predicate(); // We'll build it from inside out, starting from the body, // then wrapping it in for loops. @@ -306,7 +305,7 @@ Stmt build_loop_nest( } // Put all the reduction domain predicates into the containers vector. - for (Expr pred : predicates) { + for (Expr pred : def.split_predicate()) { pred = qualify(prefix, pred); // Add a likely qualifier if there isn't already one if (Call::as_intrinsic(pred, {Call::likely, Call::likely_if_innermost})) { From ed3c4059c862d417a6bd65dc34da33f3c5728c70 Mon Sep 17 00:00:00 2001 From: Alex Reinking Date: Mon, 25 Nov 2024 08:54:20 -0500 Subject: [PATCH 10/19] Clean out some excess std:: qualifications --- src/Func.cpp | 47 +++++++++++++++++++++++++---------------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/src/Func.cpp b/src/Func.cpp index 9a39f82ba30c..719a128c8f8f 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -36,8 +36,11 @@ namespace Halide { using std::map; using std::ofstream; +using std::optional; using std::pair; using std::string; +using std::tuple; +using std::unordered_map; using std::vector; using namespace Internal; @@ -632,23 +635,23 @@ struct DimEq { }; using DimSet = std::unordered_set; -std::optional find_rvar(const std::vector &items, const Dim &dim) { +optional find_rvar(const vector &items, const Dim &dim) { const auto has_v = std::find_if(items.begin(), items.end(), [&](auto &x) { return var_name_match(dim.var, x.name()); }); return has_v == items.end() ? std::nullopt : std::make_optional(*has_v); } -std::optional find_dim(const std::vector &items, const VarOrRVar &v) { - const std::string &name = v.name(); +optional find_dim(const vector &items, const VarOrRVar &v) { + const string &name = v.name(); const auto has_v = std::find_if(items.begin(), items.end(), [&](auto &x) { return var_name_match(x.var, name); }); return has_v == items.end() ? std::nullopt : std::make_optional(*has_v); } -std::optional rfactor_validate_args(const vector> &preserved, const AssociativeOp &prover_result, const Definition &definition) { - const std::vector &dims = definition.schedule().dims(); +optional rfactor_validate_args(const vector> &preserved, const AssociativeOp &prover_result, const Definition &definition) { + const vector &dims = definition.schedule().dims(); if (!prover_result.associative()) { return "can't perform rfactor() because we can't prove associativity of the operator"; @@ -677,7 +680,7 @@ std::optional rfactor_validate_args(const vector> & // If the operator is associative but non-commutative, rfactor() on inner // dimensions (excluding the outer dimensions) is not valid. if (!prover_result.commutative()) { - std::optional last_rvar; + optional last_rvar; for (const auto &d : reverse_view(dims)) { if (is_rfactored.count(d) && last_rvar && !is_rfactored.count(*last_rvar)) { std::stringstream s; @@ -733,31 +736,31 @@ std::optional rfactor_validate_args(const vector> & } template -std::vector operator+(const std::vector &base, const std::vector &other) { - std::vector merged = base; +vector operator+(const vector &base, const vector &other) { + vector merged = base; merged.insert(merged.end(), other.begin(), other.end()); return merged; } template -std::vector copy_convert(const std::vector &vec) { - return std::vector{} + vec; +vector copy_convert(const vector &vec) { + return vector{} + vec; } -using SubstitutionMap = std::map; +using SubstitutionMap = std::map; // This is a helper function for building up a substitution map that // corresponds to pushing down a nest of lets. The lets should be fed // to this function from innermost to outermost. This is equivalent to // building a let-nest as a one-hole context and then simplifying. -void add_let(SubstitutionMap &proj, const std::string &name, const Expr &value) { +void add_let(SubstitutionMap &proj, const string &name, const Expr &value) { for (auto &[_, e] : proj) { e = substitute(name, value, e); } proj.emplace(name, value); } -void rebind(SubstitutionMap &proj, const std::string &name, const Expr &value) { +void rebind(SubstitutionMap &proj, const string &name, const Expr &value) { for (auto &[_, e] : proj) { e = substitute(name, value, e); } @@ -766,7 +769,7 @@ void rebind(SubstitutionMap &proj, const std::string &name, const Expr &value) { } } -std::pair project_rdom(const std::vector &dims, const Definition &def) { +pair project_rdom(const vector &dims, const Definition &def) { const auto &rdom = def.schedule().rvars(); const auto &splits = def.schedule().splits(); const auto &predicate = def.predicate(); @@ -785,7 +788,7 @@ std::pair project_rdom(const std::vector } // Build the new RDom - std::vector new_rvars; + vector new_rvars; for (const Dim &dim : dims) { Expr new_min = substitute(bounds_projection, Variable::make(Int(32), dim.var + ".loop_min")); Expr new_extent = substitute(bounds_projection, Variable::make(Int(32), dim.var + ".loop_extent")); @@ -828,7 +831,7 @@ std::pair project_rdom(const std::vector rebind(dim_projection, rv.var, Variable::make(Int(32), rv.var, new_rdom)); } - return std::make_pair(new_rdom, dim_projection); + return {new_rdom, dim_projection}; } } // namespace @@ -852,14 +855,14 @@ Func Stage::rfactor(const vector> &preserved) { vector preserved_rdims; vector intermediate_rdims; { - std::unordered_map dim_ordering; + unordered_map dim_ordering; for (size_t i = 0; i < definition.schedule().dims().size(); i++) { dim_ordering.emplace(definition.schedule().dims()[i].var, i); } - std::vector> preserved_with_dims; + vector> preserved_with_dims; for (const auto &[rv, v] : preserved) { - const std::optional rdim = find_dim(definition.schedule().dims(), rv); + const optional rdim = find_dim(definition.schedule().dims(), rv); internal_assert(rdim); preserved_with_dims.emplace_back(rv, v, *rdim); } @@ -932,7 +935,7 @@ Func Stage::rfactor(const vector> &preserved) { DimSet dims; dims.insert(intm_dims.begin(), intm_dims.end()); for (const Var &dim_v : preserved_vars) { - const std::optional &dim = find_dim(intm.function().definition().schedule().dims(), dim_v); + const optional &dim = find_dim(intm.function().definition().schedule().dims(), dim_v); internal_assert(dim) << "Failed to find " << dim_v.name() << " in list of pure dims"; if (!dims.count(*dim)) { intm_dims.insert(intm_dims.end() - 1, *dim); @@ -976,7 +979,7 @@ Func Stage::rfactor(const vector> &preserved) { } } - std::vector reducing_dims; + vector reducing_dims; { DimSet preserved_rdim_set; preserved_rdim_set.insert(preserved_rdims.begin(), preserved_rdims.end()); @@ -1018,7 +1021,7 @@ Func Stage::rfactor(const vector> &preserved) { for (const ReductionVariable &rv : st.definition.schedule().rvars()) { dims.push(rv.var); } - std::vector new_splits; + vector new_splits; for (const Split &split : st.definition.schedule().splits()) { switch (split.split_type) { case Split::SplitVar: From 22add4880bba2b6c5f2d89f78968a8e316e50d73 Mon Sep 17 00:00:00 2001 From: Alex Reinking Date: Mon, 25 Nov 2024 10:36:47 -0500 Subject: [PATCH 11/19] Use dim_match instead of var_name_match --- src/Func.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/Func.cpp b/src/Func.cpp index 719a128c8f8f..6ae66636b7b5 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -637,15 +637,14 @@ using DimSet = std::unordered_set; optional find_rvar(const vector &items, const Dim &dim) { const auto has_v = std::find_if(items.begin(), items.end(), [&](auto &x) { - return var_name_match(dim.var, x.name()); + return dim_match(dim, x); }); return has_v == items.end() ? std::nullopt : std::make_optional(*has_v); } optional find_dim(const vector &items, const VarOrRVar &v) { - const string &name = v.name(); const auto has_v = std::find_if(items.begin(), items.end(), [&](auto &x) { - return var_name_match(x.var, name); + return dim_match(x, v); }); return has_v == items.end() ? std::nullopt : std::make_optional(*has_v); } @@ -920,7 +919,7 @@ Func Stage::rfactor(const vector> &preserved) { // Replace rvar dims IN the preserved list with their Vars in the INTERMEDIATE Func for (auto &dim : intm_dims) { const auto it = std::find_if(preserved_rvars.begin(), preserved_rvars.end(), [&](const auto &rv) { - return var_name_match(dim.var, rv.name()); + return dim_match(dim, rv); }); if (it != preserved_rvars.end()) { const auto offset = it - preserved_rvars.begin(); From 1355d3e8d9cdab4fba10f31b562cdb6f6ffcf2f6 Mon Sep 17 00:00:00 2001 From: Alex Reinking Date: Mon, 25 Nov 2024 10:48:03 -0500 Subject: [PATCH 12/19] Use unordered_set directly instead of DimSet --- src/Func.cpp | 35 ++++++++++++++--------------------- 1 file changed, 14 insertions(+), 21 deletions(-) diff --git a/src/Func.cpp b/src/Func.cpp index 6ae66636b7b5..a4a9363494dd 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -41,6 +41,7 @@ using std::pair; using std::string; using std::tuple; using std::unordered_map; +using std::unordered_set; using std::vector; using namespace Internal; @@ -623,18 +624,6 @@ Func Stage::rfactor(const RVar &r, const Var &v) { // Helpers for rfactor implementation namespace { -struct DimHash { - std::size_t operator()(const Dim &s) const noexcept { - return std::hash{}(s.var); - } -}; -struct DimEq { - bool operator()(const Dim &lhs, const Dim &rhs) const noexcept { - return lhs.var == rhs.var; - } -}; -using DimSet = std::unordered_set; - optional find_rvar(const vector &items, const Dim &dim) { const auto has_v = std::find_if(items.begin(), items.end(), [&](auto &x) { return dim_match(dim, x); @@ -656,7 +645,7 @@ optional rfactor_validate_args(const vector> &preserved, return "can't perform rfactor() because we can't prove associativity of the operator"; } - DimSet is_rfactored; + unordered_set is_rfactored; for (const auto &[rv, v] : preserved) { // Check that the RVars are in the dims list const auto &rv_dim = find_dim(dims, rv); @@ -665,7 +654,7 @@ optional rfactor_validate_args(const vector> &preserved, s << "can't perform rfactor() on " << rv.name() << " since it is not in the reduction domain"; return s.str(); } - is_rfactored.insert(*rv_dim); + is_rfactored.insert(rv_dim->var); // Check that the new pure Vars we used to rename the RVar aren't already in the dims list if (find_dim(dims, v)) { @@ -681,7 +670,7 @@ optional rfactor_validate_args(const vector> &preserved, if (!prover_result.commutative()) { optional last_rvar; for (const auto &d : reverse_view(dims)) { - if (is_rfactored.count(d) && last_rvar && !is_rfactored.count(*last_rvar)) { + if (is_rfactored.count(d.var) && last_rvar && !is_rfactored.count(last_rvar->var)) { std::stringstream s; s << "can't rfactor an inner dimension " << d.var << " without rfactoring the outer dimensions, since the " @@ -931,12 +920,14 @@ Func Stage::rfactor(const vector> &preserved) { } // Add factored pure dims to the INTERMEDIATE func just before outermost - DimSet dims; - dims.insert(intm_dims.begin(), intm_dims.end()); + unordered_set dims; + for (const auto &dim : intm_dims) { + dims.insert(dim.var); + } for (const Var &dim_v : preserved_vars) { const optional &dim = find_dim(intm.function().definition().schedule().dims(), dim_v); internal_assert(dim) << "Failed to find " << dim_v.name() << " in list of pure dims"; - if (!dims.count(*dim)) { + if (!dims.count(dim->var)) { intm_dims.insert(intm_dims.end() - 1, *dim); } } @@ -980,12 +971,14 @@ Func Stage::rfactor(const vector> &preserved) { vector reducing_dims; { - DimSet preserved_rdim_set; - preserved_rdim_set.insert(preserved_rdims.begin(), preserved_rdims.end()); + unordered_set preserved_rdim_set; + for (const auto &dim : preserved_rdims) { + preserved_rdim_set.insert(dim.var); + } // Remove rvar dims NOT IN the preserved list from the REDUCING Func for (const auto &dim : definition.schedule().dims()) { - if (!dim.is_rvar() || preserved_rdim_set.count(dim)) { + if (!dim.is_rvar() || preserved_rdim_set.count(dim.var)) { reducing_dims.push_back(dim); } } From a201afea818589783a25eb17eddde31e0b19b71d Mon Sep 17 00:00:00 2001 From: Alex Reinking Date: Mon, 25 Nov 2024 10:57:24 -0500 Subject: [PATCH 13/19] Compute preserved rdims set earlier to drop find_rvar --- src/Func.cpp | 23 +++++++---------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/src/Func.cpp b/src/Func.cpp index a4a9363494dd..a17bc3ee4057 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -624,13 +624,6 @@ Func Stage::rfactor(const RVar &r, const Var &v) { // Helpers for rfactor implementation namespace { -optional find_rvar(const vector &items, const Dim &dim) { - const auto has_v = std::find_if(items.begin(), items.end(), [&](auto &x) { - return dim_match(dim, x); - }); - return has_v == items.end() ? std::nullopt : std::make_optional(*has_v); -} - optional find_dim(const vector &items, const VarOrRVar &v) { const auto has_v = std::find_if(items.begin(), items.end(), [&](auto &x) { return dim_match(x, v); @@ -841,6 +834,7 @@ Func Stage::rfactor(const vector> &preserved) { vector preserved_rvars; vector preserved_vars; vector preserved_rdims; + unordered_set preserved_rdims_set; vector intermediate_rdims; { unordered_map dim_ordering; @@ -848,14 +842,15 @@ Func Stage::rfactor(const vector> &preserved) { dim_ordering.emplace(definition.schedule().dims()[i].var, i); } - vector> preserved_with_dims; + using PreservedData = tuple; + vector preserved_with_dims; for (const auto &[rv, v] : preserved) { const optional rdim = find_dim(definition.schedule().dims(), rv); internal_assert(rdim); preserved_with_dims.emplace_back(rv, v, *rdim); } - std::sort(preserved_with_dims.begin(), preserved_with_dims.end(), [&](const auto &lhs, const auto &rhs) { + std::sort(preserved_with_dims.begin(), preserved_with_dims.end(), [&](const PreservedData &lhs, const PreservedData &rhs) { return dim_ordering.at(std::get<2>(lhs).var) < dim_ordering.at(std::get<2>(rhs).var); }); @@ -863,10 +858,11 @@ Func Stage::rfactor(const vector> &preserved) { preserved_rvars.push_back(rv); preserved_vars.push_back(v); preserved_rdims.push_back(dim); + preserved_rdims_set.insert(dim.var); } for (const Dim &dim : definition.schedule().dims()) { - if (dim.is_rvar() && !find_rvar(preserved_rvars, dim)) { + if (dim.is_rvar() && !preserved_rdims_set.count(dim.var)) { intermediate_rdims.push_back(dim); } } @@ -971,14 +967,9 @@ Func Stage::rfactor(const vector> &preserved) { vector reducing_dims; { - unordered_set preserved_rdim_set; - for (const auto &dim : preserved_rdims) { - preserved_rdim_set.insert(dim.var); - } - // Remove rvar dims NOT IN the preserved list from the REDUCING Func for (const auto &dim : definition.schedule().dims()) { - if (!dim.is_rvar() || preserved_rdim_set.count(dim.var)) { + if (!dim.is_rvar() || preserved_rdims_set.count(dim.var)) { reducing_dims.push_back(dim); } } From 7a8c68679f9128c35c762266e91e682301e40c7c Mon Sep 17 00:00:00 2001 From: Alex Reinking Date: Mon, 25 Nov 2024 11:35:50 -0500 Subject: [PATCH 14/19] Hoist projection code into common block --- src/Func.cpp | 53 ++++++++++++++++++++++++++-------------------------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/src/Func.cpp b/src/Func.cpp index a17bc3ee4057..9e21ff4137ba 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -868,6 +868,29 @@ Func Stage::rfactor(const vector> &preserved) { } } + // Project the RDom into each side + ReductionDomain intermediate_rdom, preserved_rdom; + SubstitutionMap intermediate_map, preserved_map; + { + // Intermediate + std::tie(intermediate_rdom, intermediate_map) = project_rdom(intermediate_rdims, definition); + for (size_t i = 0; i < preserved.size(); i++) { + rebind(intermediate_map, preserved_rdims[i].var, preserved_vars[i]); + } + for (const auto &var : dim_vars) { + intermediate_map.erase(var.name()); + } + intermediate_rdom.set_predicate(simplify(substitute(intermediate_map, intermediate_rdom.predicate()))); + + // Preserved + std::tie(preserved_rdom, preserved_map) = project_rdom(preserved_rdims, definition); + Scope intm_rdom; + for (const auto &[var, min, extent] : intermediate_rdom.domain()) { + intm_rdom.push(var, Interval{min, min + extent - 1}); + } + preserved_rdom.set_predicate(!and_condition_over_domain(simplify(substitute(preserved_map, !preserved_rdom.predicate())), intm_rdom)); + } + // Intermediate func Func intm(function.name() + "_intm"); @@ -878,16 +901,6 @@ Func Stage::rfactor(const vector> &preserved) { // Intermediate update definition { - auto [intermediate_rdom, intermediate_map] = project_rdom(intermediate_rdims, definition); - for (size_t i = 0; i < preserved.size(); i++) { - rebind(intermediate_map, preserved_rdims[i].var, preserved_vars[i]); - } - for (const auto &var : dim_vars) { - intermediate_map.erase(var.name()); - } - - intermediate_rdom.set_predicate(simplify(substitute(intermediate_map, intermediate_rdom.predicate()))); - vector args = substitute(intermediate_map, definition.args() + preserved_vars); vector values; for (const auto &val : definition.values()) { @@ -933,9 +946,6 @@ Func Stage::rfactor(const vector> &preserved) { // Preserved update definition { - auto [preserved_rdom, preserved_map] = project_rdom(preserved_rdims, definition); - preserved_rdom.set_predicate(simplify(substitute(preserved_map, preserved_rdom.predicate()))); - // Replace the current definition with calls to the intermediate func. vector dim_exprs = copy_convert(dim_vars); vector f_load_args = dim_exprs; @@ -943,21 +953,17 @@ Func Stage::rfactor(const vector> &preserved) { f_load_args.push_back(Variable::make(Int(32), rv.var, preserved_rdom)); } - SubstitutionMap replacements; - for (size_t i = 0; i < preserved.size(); i++) { - replacements.emplace(preserved_rdims[i].var, preserved_vars[i]); - } for (size_t i = 0; i < definition.values().size(); ++i) { if (!prover_result.ys[i].var.empty()) { Expr r = (definition.values().size() == 1) ? Expr(intm(f_load_args)) : Expr(intm(f_load_args)[i]); - replacements.emplace(prover_result.ys[i].var, r); + add_let(preserved_map, prover_result.ys[i].var, r); } if (!prover_result.xs[i].var.empty()) { Expr prev_val = Call::make(intm.types()[i], function.name(), dim_exprs, Call::CallType::Halide, FunctionPtr(), i); - replacements.emplace(prover_result.xs[i].var, prev_val); + add_let(preserved_map, prover_result.xs[i].var, prev_val); } else { user_warning << "Update definition of " << name() << " at index " << i << " doesn't depend on the previous value. This isn't a" @@ -983,16 +989,11 @@ Func Stage::rfactor(const vector> &preserved) { } } - Scope intm_rdom; - for (const auto &[var, min, extent] : intm.update(0).definition.schedule().rvars()) { - intm_rdom.push(var, Interval{min, min + extent - 1}); - } - definition.args() = dim_exprs; - definition.predicate() = !and_condition_over_domain(simplify(!preserved_rdom.predicate()), intm_rdom); + definition.predicate() = preserved_rdom.predicate(); definition.schedule().dims() = std::move(reducing_dims); definition.schedule().rvars() = preserved_rdom.domain(); - definition.values() = substitute(replacements, prover_result.pattern.ops); + definition.values() = substitute(preserved_map, prover_result.pattern.ops); } // Clean up the splits lists From 7819855678b08bc51278a9ab7ce798e8ae31c4e7 Mon Sep 17 00:00:00 2001 From: Alex Reinking Date: Mon, 25 Nov 2024 12:13:56 -0500 Subject: [PATCH 15/19] Use structured bindings --- src/Func.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Func.cpp b/src/Func.cpp index 9e21ff4137ba..a36f80a6cc45 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -762,10 +762,10 @@ pair project_rdom(const vector &dims, con add_let(bounds_projection, name, value); } } - for (const ReductionVariable &rv : rdom) { - add_let(bounds_projection, rv.var + ".loop_min", rv.min); - add_let(bounds_projection, rv.var + ".loop_max", rv.min + rv.extent - 1); - add_let(bounds_projection, rv.var + ".loop_extent", rv.extent); + for (const auto &[var, min, extent] : rdom) { + add_let(bounds_projection, var + ".loop_min", min); + add_let(bounds_projection, var + ".loop_max", min + extent - 1); + add_let(bounds_projection, var + ".loop_extent", extent); } // Build the new RDom From 9b7a4a22609e268148ef03abe6f65a808cc34f2b Mon Sep 17 00:00:00 2001 From: Alex Reinking Date: Mon, 25 Nov 2024 12:37:36 -0500 Subject: [PATCH 16/19] Drop rebind() as add_let() was equivalent --- src/Func.cpp | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/src/Func.cpp b/src/Func.cpp index a36f80a6cc45..dc77fb89be46 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -741,15 +741,6 @@ void add_let(SubstitutionMap &proj, const string &name, const Expr &value) { proj.emplace(name, value); } -void rebind(SubstitutionMap &proj, const string &name, const Expr &value) { - for (auto &[_, e] : proj) { - e = substitute(name, value, e); - } - if (!proj.count(name)) { - proj.emplace(name, value); - } -} - pair project_rdom(const vector &dims, const Definition &def) { const auto &rdom = def.schedule().rvars(); const auto &splits = def.schedule().splits(); @@ -788,9 +779,6 @@ pair project_rdom(const vector &dims, con for (const Split &split : splits) { for (const auto &result : apply_split(split, "", dim_extent_alignment)) { switch (result.type) { - case ApplySplitResult::Substitution: - case ApplySplitResult::SubstitutionInCalls: - case ApplySplitResult::SubstitutionInProvides: case ApplySplitResult::LetStmt: add_let(dim_projection, result.name, result.value); break; @@ -799,6 +787,9 @@ pair project_rdom(const vector &dims, con case ApplySplitResult::Predicate: new_rdom.where(substitute(bounds_projection, result.value)); break; + case ApplySplitResult::Substitution: + case ApplySplitResult::SubstitutionInCalls: + case ApplySplitResult::SubstitutionInProvides: case ApplySplitResult::BlendProvides: // TODO: what to do here? BlendProvides is not included in the above checks. break; @@ -809,7 +800,7 @@ pair project_rdom(const vector &dims, con e = substitute(bounds_projection, e); } for (const ReductionVariable &rv : new_rdom.domain()) { - rebind(dim_projection, rv.var, Variable::make(Int(32), rv.var, new_rdom)); + add_let(dim_projection, rv.var, Variable::make(Int(32), rv.var, new_rdom)); } return {new_rdom, dim_projection}; @@ -875,7 +866,7 @@ Func Stage::rfactor(const vector> &preserved) { // Intermediate std::tie(intermediate_rdom, intermediate_map) = project_rdom(intermediate_rdims, definition); for (size_t i = 0; i < preserved.size(); i++) { - rebind(intermediate_map, preserved_rdims[i].var, preserved_vars[i]); + add_let(intermediate_map, preserved_rdims[i].var, preserved_vars[i]); } for (const auto &var : dim_vars) { intermediate_map.erase(var.name()); From ce9af71e2c77411f1fd48e018c5e9694ca45ab2a Mon Sep 17 00:00:00 2001 From: Alex Reinking Date: Mon, 25 Nov 2024 12:55:29 -0500 Subject: [PATCH 17/19] More cleaning --- src/Func.cpp | 10 +++++----- test/error/rfactor_after_var_and_rvar_fusion.cpp | 2 +- test/error/rfactor_fused_var_and_rvar.cpp | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/Func.cpp b/src/Func.cpp index dc77fb89be46..88ad5e78986e 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -772,9 +772,9 @@ pair project_rdom(const vector &dims, con // Compute the mapping of old dimensions to the projected RDom values. SubstitutionMap dim_projection{}; SubstitutionMap dim_extent_alignment{}; - for (const ReductionVariable &rv : rdom) { - add_let(dim_projection, rv.var, Variable::make(Int(32), rv.var)); - dim_extent_alignment[rv.var] = rv.extent; + for (const auto &[var, _, extent] : rdom) { + add_let(dim_projection, var, Variable::make(Int(32), var)); + dim_extent_alignment[var] = extent; } for (const Split &split : splits) { for (const auto &result : apply_split(split, "", dim_extent_alignment)) { @@ -791,7 +791,7 @@ pair project_rdom(const vector &dims, con case ApplySplitResult::SubstitutionInCalls: case ApplySplitResult::SubstitutionInProvides: case ApplySplitResult::BlendProvides: - // TODO: what to do here? BlendProvides is not included in the above checks. + // The lets returned by ApplySplit are sufficient break; } } @@ -799,7 +799,7 @@ pair project_rdom(const vector &dims, con for (auto &[_, e] : dim_projection) { e = substitute(bounds_projection, e); } - for (const ReductionVariable &rv : new_rdom.domain()) { + for (const auto &rv : new_rdom.domain()) { add_let(dim_projection, rv.var, Variable::make(Int(32), rv.var, new_rdom)); } diff --git a/test/error/rfactor_after_var_and_rvar_fusion.cpp b/test/error/rfactor_after_var_and_rvar_fusion.cpp index 8b94fbdd1b16..acda4e4bb6fb 100644 --- a/test/error/rfactor_after_var_and_rvar_fusion.cpp +++ b/test/error/rfactor_after_var_and_rvar_fusion.cpp @@ -22,4 +22,4 @@ int main(int argc, char **argv) { printf("Success!\n"); return 0; -} \ No newline at end of file +} diff --git a/test/error/rfactor_fused_var_and_rvar.cpp b/test/error/rfactor_fused_var_and_rvar.cpp index a167ca543c47..64a79c269690 100644 --- a/test/error/rfactor_fused_var_and_rvar.cpp +++ b/test/error/rfactor_fused_var_and_rvar.cpp @@ -23,4 +23,4 @@ int main(int argc, char **argv) { printf("Success!\n"); return 0; -} \ No newline at end of file +} From 576654dbc055fddf69315ee66a97ac6a8a7dba3a Mon Sep 17 00:00:00 2001 From: Alex Reinking Date: Mon, 25 Nov 2024 16:56:27 -0500 Subject: [PATCH 18/19] Remove not-helpful-enough helpers --- src/Func.cpp | 51 +++++++++++++++++++++++++-------------------------- 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/src/Func.cpp b/src/Func.cpp index 88ad5e78986e..a8c00746f410 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -607,11 +607,14 @@ class SubstituteSelfReference : public IRMutator { /** Substitute all self-reference calls to 'func' with 'substitute' which * args (LHS) is the old args (LHS) plus 'new_args' in that order. * Expect this method to be called on the value (RHS) of an update definition. */ -Expr substitute_self_reference(Expr val, const string &func, const Function &substitute, - const vector &new_args) { +vector substitute_self_reference(const vector &values, const string &func, + const Function &substitute, const vector &new_args) { SubstituteSelfReference subs(func, substitute, new_args); - val = subs.mutate(val); - return val; + vector result; + for (const auto &val : values) { + result.push_back(subs.mutate(val)); + } + return result; } } // anonymous namespace @@ -716,18 +719,6 @@ optional rfactor_validate_args(const vector> &preserved, return std::nullopt; } -template -vector operator+(const vector &base, const vector &other) { - vector merged = base; - merged.insert(merged.end(), other.begin(), other.end()); - return merged; -} - -template -vector copy_convert(const vector &vec) { - return vector{} + vec; -} - using SubstitutionMap = std::map; // This is a helper function for building up a substitution map that @@ -821,6 +812,12 @@ Func Stage::rfactor(const vector> &preserved) { user_assert(!error) << "In schedule for " << name() << ": " << *error << "\n" << dump_argument_list(); + const vector dim_vars_exprs = [&] { + vector result; + result.insert(result.end(), dim_vars.begin(), dim_vars.end()); + return result; + }(); + // sort preserved by the dimension ordering vector preserved_rvars; vector preserved_vars; @@ -887,16 +884,19 @@ Func Stage::rfactor(const vector> &preserved) { // Intermediate pure definition { - intm(dim_vars + preserved_vars) = Tuple(prover_result.pattern.identities); + vector args = dim_vars_exprs; + args.insert(args.end(), preserved_vars.begin(), preserved_vars.end()); + intm(args) = Tuple(prover_result.pattern.identities); } // Intermediate update definition { - vector args = substitute(intermediate_map, definition.args() + preserved_vars); - vector values; - for (const auto &val : definition.values()) { - values.push_back(substitute_self_reference(val, function.name(), intm.function(), preserved_vars)); - } + vector args = definition.args(); + args.insert(args.end(), preserved_vars.begin(), preserved_vars.end()); + args = substitute(intermediate_map, args); + + vector values = definition.values(); + values = substitute_self_reference(values, function.name(), intm.function(), preserved_vars); values = substitute(intermediate_map, values); intm.function().define_update(args, values, intermediate_rdom); @@ -938,8 +938,7 @@ Func Stage::rfactor(const vector> &preserved) { // Preserved update definition { // Replace the current definition with calls to the intermediate func. - vector dim_exprs = copy_convert(dim_vars); - vector f_load_args = dim_exprs; + vector f_load_args = dim_vars_exprs; for (const ReductionVariable &rv : preserved_rdom.domain()) { f_load_args.push_back(Variable::make(Int(32), rv.var, preserved_rdom)); } @@ -952,7 +951,7 @@ Func Stage::rfactor(const vector> &preserved) { if (!prover_result.xs[i].var.empty()) { Expr prev_val = Call::make(intm.types()[i], function.name(), - dim_exprs, Call::CallType::Halide, + dim_vars_exprs, Call::CallType::Halide, FunctionPtr(), i); add_let(preserved_map, prover_result.xs[i].var, prev_val); } else { @@ -980,7 +979,7 @@ Func Stage::rfactor(const vector> &preserved) { } } - definition.args() = dim_exprs; + definition.args() = dim_vars_exprs; definition.predicate() = preserved_rdom.predicate(); definition.schedule().dims() = std::move(reducing_dims); definition.schedule().rvars() = preserved_rdom.domain(); From fd04843efea14a68f8b249a841610efa97161feb Mon Sep 17 00:00:00 2001 From: Alex Reinking Date: Mon, 25 Nov 2024 17:29:22 -0500 Subject: [PATCH 19/19] Reorganize update definitions to mirror each other --- src/Func.cpp | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/Func.cpp b/src/Func.cpp index a8c00746f410..52fbd6225579 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -830,15 +830,14 @@ Func Stage::rfactor(const vector> &preserved) { dim_ordering.emplace(definition.schedule().dims()[i].var, i); } - using PreservedData = tuple; - vector preserved_with_dims; + vector> preserved_with_dims; for (const auto &[rv, v] : preserved) { const optional rdim = find_dim(definition.schedule().dims(), rv); internal_assert(rdim); preserved_with_dims.emplace_back(rv, v, *rdim); } - std::sort(preserved_with_dims.begin(), preserved_with_dims.end(), [&](const PreservedData &lhs, const PreservedData &rhs) { + std::sort(preserved_with_dims.begin(), preserved_with_dims.end(), [&](const auto &lhs, const auto &rhs) { return dim_ordering.at(std::get<2>(lhs).var) < dim_ordering.at(std::get<2>(rhs).var); }); @@ -901,9 +900,7 @@ Func Stage::rfactor(const vector> &preserved) { intm.function().define_update(args, values, intermediate_rdom); // Intermediate schedule - intm.function().update(0).schedule() = definition.schedule().get_copy(); - - auto &intm_dims = intm.function().update(0).schedule().dims(); + vector intm_dims = definition.schedule().dims(); // Replace rvar dims IN the preserved list with their Vars in the INTERMEDIATE Func for (auto &dim : intm_dims) { @@ -924,14 +921,16 @@ Func Stage::rfactor(const vector> &preserved) { for (const auto &dim : intm_dims) { dims.insert(dim.var); } - for (const Var &dim_v : preserved_vars) { - const optional &dim = find_dim(intm.function().definition().schedule().dims(), dim_v); - internal_assert(dim) << "Failed to find " << dim_v.name() << " in list of pure dims"; + for (const Var &var : preserved_vars) { + const optional &dim = find_dim(intm.function().definition().schedule().dims(), var); + internal_assert(dim) << "Failed to find " << var.name() << " in list of pure dims"; if (!dims.count(dim->var)) { intm_dims.insert(intm_dims.end() - 1, *dim); } } + intm.function().update(0).schedule() = definition.schedule().get_copy(); + intm.function().update(0).schedule().dims() = std::move(intm_dims); intm.function().update(0).schedule().rvars() = intermediate_rdom.domain(); } @@ -980,10 +979,10 @@ Func Stage::rfactor(const vector> &preserved) { } definition.args() = dim_vars_exprs; + definition.values() = substitute(preserved_map, prover_result.pattern.ops); definition.predicate() = preserved_rdom.predicate(); definition.schedule().dims() = std::move(reducing_dims); definition.schedule().rvars() = preserved_rdom.domain(); - definition.values() = substitute(preserved_map, prover_result.pattern.ops); } // Clean up the splits lists