Skip to content

Commit

Permalink
Update heuristics for grid dim enumeration (#91)
Browse files Browse the repository at this point in the history
* fix & upd range

* adjust heuristics for grid dim enumeration

---------

Co-authored-by: Mengdi Wu <[email protected]>
  • Loading branch information
wmdi and wmdi authored Sep 30, 2024
1 parent 3d53724 commit e787562
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 33 deletions.
2 changes: 2 additions & 0 deletions include/mirage/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(KNOperatorType,
{KN_REDUCTION_0_OP, "kn_reduction_0_op"},
{KN_REDUCTION_1_OP, "kn_reduction_1_op"},
{KN_REDUCTION_2_OP, "kn_reduction_2_op"},
{KN_RMS_NORM_OP, "kn_rms_norm_op"},
{KN_EXP_OP, "kn_exp_op"},
{KN_SQUARE_OP, "kn_square_op"},
{KN_SQRT_OP, "kn_sqrt_op"},
Expand Down Expand Up @@ -144,6 +145,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(
{TB_REDUCTION_0_TO_DIMX_OP, "tb_reduction_0_to_dimx_op"},
{TB_REDUCTION_1_TO_DIMX_OP, "tb_reduction_1_to_dimx_op"},
{TB_REDUCTION_2_TO_DIMX_OP, "tb_reduction_2_to_dimx_op"},
{TB_RMS_NORM_OP, "tb_rms_norm_op"},
{TB_CONCAT_0_OP, "tb_concat_0_op"},
{TB_CONCAT_1_OP, "tb_concat_1_op"},
{TB_CONCAT_2_OP, "tb_concat_2_op"},
Expand Down
13 changes: 7 additions & 6 deletions src/search/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ GeneratorConfig GeneratorConfig::get_default_config() {
2 /* max_num_threadblock_graph_outputs */,
8 /* search_thread */,
{
type::KN_MATMUL_OP,
type::KN_EXP_OP,
type::KN_SILU_OP,
type::KN_ADD_OP,
type::KN_MUL_OP,
type::KN_DIV_OP,
// type::KN_MATMUL_OP,
// type::KN_EXP_OP,
// type::KN_SILU_OP,
// type::KN_ADD_OP,
// type::KN_MUL_OP,
// type::KN_DIV_OP,
type::KN_CUSTOMIZED_OP,
} /* knop_to_explore */,
{
Expand All @@ -27,6 +27,7 @@ GeneratorConfig GeneratorConfig::get_default_config() {
type::TB_ADD_OP,
type::TB_MUL_OP,
type::TB_DIV_OP,
type::TB_RMS_NORM_OP,
type::TB_FORLOOP_ACCUM_NO_RED_OP,
type::TB_FORLOOP_ACCUM_RED_LD_SUM_OP,
type::TB_FORLOOP_ACCUM_RED_LD_MEAN_OP,
Expand Down
55 changes: 28 additions & 27 deletions src/search/dim_strategy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,30 +30,25 @@ std::vector<type::TBOperatorType> DimStrategy::get_tbop_cand() {
std::vector<dim3>
DimStrategy::get_grid_dim_cand(std::vector<DTensor> const &tensors) {

auto tip_filter = [](std::vector<int> const &tips, int x) {
for (int tip : tips) {
if (tip % x == 0) {
return true;
}
}
return false;
};

auto generate_1d_grids = [&](std::vector<int> const &tips) {
auto generate_1d_grids = [&](std::vector<int> const &dims) {
std::vector<dim3> cands;
for (size_t x = 4; x <= 128; x *= 2) {
if (tip_filter(tips, x)) {
cands.push_back({x, 1, 1});
for (size_t x = 32; x <= 64; x *= 2) {
for (int dim : dims) {
if (dim % x == 0) {
cands.push_back({dim / x, 1, 1});
}
}
}
return cands;
};

auto generate_2d_grids = [&](int x, std::vector<int> const &tips) {
auto generate_2d_grids = [&](int x, std::vector<int> const &dims) {
std::vector<dim3> cands;
for (size_t y = 1; y <= 16; y *= 2) {
if (tip_filter(tips, y)) {
cands.push_back({x, y, 1});
for (size_t y = 32; y <= 64; y *= 2) {
for (int dim : dims) {
if (dim % y == 0) {
cands.push_back({x, dim / y, 1});
}
}
}
return cands;
Expand All @@ -78,31 +73,35 @@ std::vector<dim3>
return -1;
};

auto get_tips = [&] {
std::unordered_set<int> tips;
auto get_dims = [&] {
std::unordered_set<int> dims;
for (DTensor const &tensor : tensors) {
for (int i = 0; i < tensor.num_dims; ++i) {
tips.insert(tensor.dim[i]);
dims.insert(tensor.dim[i]);
}
}
return std::vector<int>(tips.begin(), tips.end());
return std::vector<int>(dims.begin(), dims.end());
};

std::vector<dim3> cands = config.grid_dim_to_explore;

cands = vector_concat(cands, generate_1d_grids(get_tips()));
cands = vector_concat(cands, generate_1d_grids(get_dims()));
if (config._enable_attention_specific_optimization) {
int batch = get_batch();
if (batch != -1) {
cands = vector_concat(cands, generate_2d_grids(batch, get_tips()));
cands = vector_concat(cands, generate_2d_grids(batch, get_dims()));
}
if (tensors.size() > 2) {
cands.push_back({batch, 16, 4});
}
}

cands = deduplicate(cands);
cands = filter(cands, [](dim3 const &dim) {
return dim.x * dim.y * dim.z <= config::MAX_NUM_THREADBLOCKS_PER_KERNEL;
int num_threadblocks = dim.x * dim.y * dim.z;
return 32 <= num_threadblocks && num_threadblocks <= config::MAX_NUM_THREADBLOCKS_PER_KERNEL;
});

cands = deduplicate(cands);

if (config.randomized_branches) {
std::random_shuffle(cands.begin(), cands.end());
}
Expand Down Expand Up @@ -220,7 +219,6 @@ std::vector<std::vector<int3>>
tensors, grid_dim, config.imap_to_explore, {}, results);
} else {
std::vector<int3> imap_to_explore = {
{0, -1, -1},
{0, -1, 1},
{0, 1, -1},
{0, 2, -1},
Expand Down Expand Up @@ -329,6 +327,9 @@ std::vector<int> DimStrategy::get_forloop_range_cand(
std::vector<int> results;

for (int x : config.frange_to_explore) {
if (config._enable_attention_specific_optimization && x > 8) {
continue;
}
bool feasible = true;
for (size_t i = 0; i < input_tensors.size(); ++i) {
if (forloop_dim[i] == -1) {
Expand Down
11 changes: 11 additions & 0 deletions src/search/irange.cc
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,11 @@ ITBRange forward_propagate(ITBRange const &tbrange,
.truncate(op.output_tensors[0]));
break;
}
case type::TB_RMS_NORM_OP: {
ret = ITBRange(tbrange.range_set.extend_dim(op.output_tensors[0].num_dims - 1)
.truncate(op.output_tensors[0]));
break;
}
case type::TB_FORLOOP_ACCUM_NO_RED_OP: {
ret = tbrange.extend_forloop_dim();
break;
Expand Down Expand Up @@ -676,6 +681,12 @@ ITBRange backward_propagate(ITBRange const &tbrange,
ret = EXP_AS_IDENTITY ? tbrange : ITBRange();
break;
}
case type::TBOperatorType::TB_RMS_NORM_OP: {
ret = ITBRange(
tbrange.range_set.extend_dim(op.input_tensors[opd_idx].num_dims - 1)
.truncate(op.input_tensors[opd_idx]));
break;
}
case type::TBOperatorType::TB_ADD_OP:
case type::TBOperatorType::TB_MUL_OP: {
ret = tbrange;
Expand Down
7 changes: 7 additions & 0 deletions src/search/op_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ bool is_unary(type::TBOperatorType op) {
std::unordered_set<type::TBOperatorType> true_values{
type::TBOperatorType::TB_EXP_OP,
type::TBOperatorType::TB_SILU_OP,
type::TBOperatorType::TB_RMS_NORM_OP,
type::TBOperatorType::TB_REDUCTION_0_OP,
type::TBOperatorType::TB_REDUCTION_1_OP,
type::TBOperatorType::TB_REDUCTION_2_OP,
Expand Down Expand Up @@ -116,6 +117,10 @@ std::shared_ptr<AlgebraicPattern>
return std::make_shared<Exp>(opd);
case type::TBOperatorType::TB_SILU_OP:
return std::make_shared<Silu>(opd);
case type::TBOperatorType::TB_RMS_NORM_OP: {
return std::make_shared<Div>(
opd, std::make_shared<RMS>(tensor.dim[tensor.num_dims - 1], opd));
}
case type::TBOperatorType::TB_REDUCTION_0_OP:
return std::make_shared<Red>(tensor.dim[0], opd);
case type::TBOperatorType::TB_REDUCTION_1_OP:
Expand Down Expand Up @@ -317,6 +322,8 @@ TBOperator *create_op(threadblock::Graph &g,
case type::TBOperatorType::TB_EXP_OP:
case type::TBOperatorType::TB_SILU_OP:
return g.create_elementunary_op(input, type);
case type::TBOperatorType::TB_RMS_NORM_OP:
return g.create_rms_norm_op(input);
case type::TBOperatorType::TB_REDUCTION_0_OP:
case type::TBOperatorType::TB_REDUCTION_1_OP:
case type::TBOperatorType::TB_REDUCTION_2_OP: {
Expand Down
8 changes: 8 additions & 0 deletions src/threadblock/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,14 @@ void from_json(json const &j, Graph &graph) {
op.at("output_tensors")[0].at("guid").get<int>();
break;
}
case type::TBOperatorType::TB_RMS_NORM_OP: {
STensor const &output = graph.rms_norm(
get_tensor_from_guid(
op.at("input_tensors")[0].at("guid").get<int>()));
guid_mapping[output.guid] =
op.at("output_tensors")[0].at("guid").get<int>();
break;
}
case type::TBOperatorType::TB_ADD_OP:
case type::TBOperatorType::TB_MUL_OP:
case type::TBOperatorType::TB_DIV_OP: {
Expand Down

0 comments on commit e787562

Please sign in to comment.