From e787562e25f4b094835a9bc059cdd3de78eecab6 Mon Sep 17 00:00:00 2001 From: Mengdi Wu <48128384+wmdi@users.noreply.github.com> Date: Mon, 30 Sep 2024 10:14:01 -0400 Subject: [PATCH] Update heuristics for grid dim enumeration (#91) * fix & upd range * adjust heuristics for grid dim enumeration --------- Co-authored-by: Mengdi Wu --- include/mirage/type.h | 2 ++ src/search/config.cc | 13 ++++----- src/search/dim_strategy.cc | 55 +++++++++++++++++++------------------- src/search/irange.cc | 11 ++++++++ src/search/op_utils.cc | 7 +++++ src/threadblock/graph.cc | 8 ++++++ 6 files changed, 63 insertions(+), 33 deletions(-) diff --git a/include/mirage/type.h b/include/mirage/type.h index 95170a4..f96790b 100644 --- a/include/mirage/type.h +++ b/include/mirage/type.h @@ -68,6 +68,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(KNOperatorType, {KN_REDUCTION_0_OP, "kn_reduction_0_op"}, {KN_REDUCTION_1_OP, "kn_reduction_1_op"}, {KN_REDUCTION_2_OP, "kn_reduction_2_op"}, + {KN_RMS_NORM_OP, "kn_rms_norm_op"}, {KN_EXP_OP, "kn_exp_op"}, {KN_SQUARE_OP, "kn_square_op"}, {KN_SQRT_OP, "kn_sqrt_op"}, @@ -144,6 +145,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM( {TB_REDUCTION_0_TO_DIMX_OP, "tb_reduction_0_to_dimx_op"}, {TB_REDUCTION_1_TO_DIMX_OP, "tb_reduction_1_to_dimx_op"}, {TB_REDUCTION_2_TO_DIMX_OP, "tb_reduction_2_to_dimx_op"}, + {TB_RMS_NORM_OP, "tb_rms_norm_op"}, {TB_CONCAT_0_OP, "tb_concat_0_op"}, {TB_CONCAT_1_OP, "tb_concat_1_op"}, {TB_CONCAT_2_OP, "tb_concat_2_op"}, diff --git a/src/search/config.cc b/src/search/config.cc index 8ff6c15..6663b79 100644 --- a/src/search/config.cc +++ b/src/search/config.cc @@ -12,12 +12,12 @@ GeneratorConfig GeneratorConfig::get_default_config() { 2 /* max_num_threadblock_graph_outputs */, 8 /* search_thread */, { - type::KN_MATMUL_OP, - type::KN_EXP_OP, - type::KN_SILU_OP, - type::KN_ADD_OP, - type::KN_MUL_OP, - type::KN_DIV_OP, + // type::KN_MATMUL_OP, + // type::KN_EXP_OP, + // type::KN_SILU_OP, + // type::KN_ADD_OP, + // type::KN_MUL_OP, + // type::KN_DIV_OP, type::KN_CUSTOMIZED_OP, } /* knop_to_explore */, { @@ -27,6 +27,7 @@ GeneratorConfig GeneratorConfig::get_default_config() { type::TB_ADD_OP, type::TB_MUL_OP, type::TB_DIV_OP, + type::TB_RMS_NORM_OP, type::TB_FORLOOP_ACCUM_NO_RED_OP, type::TB_FORLOOP_ACCUM_RED_LD_SUM_OP, type::TB_FORLOOP_ACCUM_RED_LD_MEAN_OP, diff --git a/src/search/dim_strategy.cc b/src/search/dim_strategy.cc index b6498ca..6ada395 100644 --- a/src/search/dim_strategy.cc +++ b/src/search/dim_strategy.cc @@ -30,30 +30,25 @@ std::vector DimStrategy::get_tbop_cand() { std::vector DimStrategy::get_grid_dim_cand(std::vector const &tensors) { - auto tip_filter = [](std::vector const &tips, int x) { - for (int tip : tips) { - if (tip % x == 0) { - return true; - } - } - return false; - }; - - auto generate_1d_grids = [&](std::vector const &tips) { + auto generate_1d_grids = [&](std::vector const &dims) { std::vector cands; - for (size_t x = 4; x <= 128; x *= 2) { - if (tip_filter(tips, x)) { - cands.push_back({x, 1, 1}); + for (size_t x = 32; x <= 64; x *= 2) { + for (int dim : dims) { + if (dim % x == 0) { + cands.push_back({dim / x, 1, 1}); + } } } return cands; }; - auto generate_2d_grids = [&](int x, std::vector const &tips) { + auto generate_2d_grids = [&](int x, std::vector const &dims) { std::vector cands; - for (size_t y = 1; y <= 16; y *= 2) { - if (tip_filter(tips, y)) { - cands.push_back({x, y, 1}); + for (size_t y = 32; y <= 64; y *= 2) { + for (int dim : dims) { + if (dim % y == 0) { + cands.push_back({x, dim / y, 1}); + } } } return cands; @@ -78,31 +73,35 @@ std::vector return -1; }; - auto get_tips = [&] { - std::unordered_set tips; + auto get_dims = [&] { + std::unordered_set dims; for (DTensor const &tensor : tensors) { for (int i = 0; i < tensor.num_dims; ++i) { - tips.insert(tensor.dim[i]); + dims.insert(tensor.dim[i]); } } - return std::vector(tips.begin(), tips.end()); + return std::vector(dims.begin(), dims.end()); }; std::vector cands = config.grid_dim_to_explore; - cands = vector_concat(cands, generate_1d_grids(get_tips())); + cands = vector_concat(cands, generate_1d_grids(get_dims())); if (config._enable_attention_specific_optimization) { int batch = get_batch(); if (batch != -1) { - cands = vector_concat(cands, generate_2d_grids(batch, get_tips())); + cands = vector_concat(cands, generate_2d_grids(batch, get_dims())); + } + if (tensors.size() > 2) { + cands.push_back({batch, 16, 4}); } } - - cands = deduplicate(cands); cands = filter(cands, [](dim3 const &dim) { - return dim.x * dim.y * dim.z <= config::MAX_NUM_THREADBLOCKS_PER_KERNEL; + int num_threadblocks = dim.x * dim.y * dim.z; + return 32 <= num_threadblocks && num_threadblocks <= config::MAX_NUM_THREADBLOCKS_PER_KERNEL; }); + cands = deduplicate(cands); + if (config.randomized_branches) { std::random_shuffle(cands.begin(), cands.end()); } @@ -220,7 +219,6 @@ std::vector> tensors, grid_dim, config.imap_to_explore, {}, results); } else { std::vector imap_to_explore = { - {0, -1, -1}, {0, -1, 1}, {0, 1, -1}, {0, 2, -1}, @@ -329,6 +327,9 @@ std::vector DimStrategy::get_forloop_range_cand( std::vector results; for (int x : config.frange_to_explore) { + if (config._enable_attention_specific_optimization && x > 8) { + continue; + } bool feasible = true; for (size_t i = 0; i < input_tensors.size(); ++i) { if (forloop_dim[i] == -1) { diff --git a/src/search/irange.cc b/src/search/irange.cc index 92677dd..f368ded 100644 --- a/src/search/irange.cc +++ b/src/search/irange.cc @@ -629,6 +629,11 @@ ITBRange forward_propagate(ITBRange const &tbrange, .truncate(op.output_tensors[0])); break; } + case type::TB_RMS_NORM_OP: { + ret = ITBRange(tbrange.range_set.extend_dim(op.output_tensors[0].num_dims - 1) + .truncate(op.output_tensors[0])); + break; + } case type::TB_FORLOOP_ACCUM_NO_RED_OP: { ret = tbrange.extend_forloop_dim(); break; @@ -676,6 +681,12 @@ ITBRange backward_propagate(ITBRange const &tbrange, ret = EXP_AS_IDENTITY ? tbrange : ITBRange(); break; } + case type::TBOperatorType::TB_RMS_NORM_OP: { + ret = ITBRange( + tbrange.range_set.extend_dim(op.input_tensors[opd_idx].num_dims - 1) + .truncate(op.input_tensors[opd_idx])); + break; + } case type::TBOperatorType::TB_ADD_OP: case type::TBOperatorType::TB_MUL_OP: { ret = tbrange; diff --git a/src/search/op_utils.cc b/src/search/op_utils.cc index ad734f5..44ca0c4 100644 --- a/src/search/op_utils.cc +++ b/src/search/op_utils.cc @@ -18,6 +18,7 @@ bool is_unary(type::TBOperatorType op) { std::unordered_set true_values{ type::TBOperatorType::TB_EXP_OP, type::TBOperatorType::TB_SILU_OP, + type::TBOperatorType::TB_RMS_NORM_OP, type::TBOperatorType::TB_REDUCTION_0_OP, type::TBOperatorType::TB_REDUCTION_1_OP, type::TBOperatorType::TB_REDUCTION_2_OP, @@ -116,6 +117,10 @@ std::shared_ptr return std::make_shared(opd); case type::TBOperatorType::TB_SILU_OP: return std::make_shared(opd); + case type::TBOperatorType::TB_RMS_NORM_OP: { + return std::make_shared
( + opd, std::make_shared(tensor.dim[tensor.num_dims - 1], opd)); + } case type::TBOperatorType::TB_REDUCTION_0_OP: return std::make_shared(tensor.dim[0], opd); case type::TBOperatorType::TB_REDUCTION_1_OP: @@ -317,6 +322,8 @@ TBOperator *create_op(threadblock::Graph &g, case type::TBOperatorType::TB_EXP_OP: case type::TBOperatorType::TB_SILU_OP: return g.create_elementunary_op(input, type); + case type::TBOperatorType::TB_RMS_NORM_OP: + return g.create_rms_norm_op(input); case type::TBOperatorType::TB_REDUCTION_0_OP: case type::TBOperatorType::TB_REDUCTION_1_OP: case type::TBOperatorType::TB_REDUCTION_2_OP: { diff --git a/src/threadblock/graph.cc b/src/threadblock/graph.cc index 11f5d40..d108dad 100644 --- a/src/threadblock/graph.cc +++ b/src/threadblock/graph.cc @@ -750,6 +750,14 @@ void from_json(json const &j, Graph &graph) { op.at("output_tensors")[0].at("guid").get(); break; } + case type::TBOperatorType::TB_RMS_NORM_OP: { + STensor const &output = graph.rms_norm( + get_tensor_from_guid( + op.at("input_tensors")[0].at("guid").get())); + guid_mapping[output.guid] = + op.at("output_tensors")[0].at("guid").get(); + break; + } case type::TBOperatorType::TB_ADD_OP: case type::TBOperatorType::TB_MUL_OP: case type::TBOperatorType::TB_DIV_OP: {