Skip to content

Commit

Permalink
Set default max_num_threadblock_graphs to 1 for attention workloads (#93
Browse files Browse the repository at this point in the history
)

* fix & upd range

* adjust heuristics for grid dim enumeration

* better search statistics display

* add verbose flag

* set default max_num_threadblock_graphs to 1 for attentions

---------

Co-authored-by: Mengdi Wu <[email protected]>
  • Loading branch information
wmdi and wmdi authored Oct 1, 2024
1 parent 11e6140 commit 8edf81c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
15 changes: 8 additions & 7 deletions src/search/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@ GeneratorConfig GeneratorConfig::get_default_config() {
2 /* max_num_threadblock_graph_outputs */,
8 /* search_thread */,
{
// type::KN_MATMUL_OP,
// type::KN_EXP_OP,
// type::KN_SILU_OP,
// type::KN_ADD_OP,
// type::KN_MUL_OP,
// type::KN_DIV_OP,
type::KN_MATMUL_OP,
type::KN_EXP_OP,
type::KN_SILU_OP,
type::KN_ADD_OP,
type::KN_MUL_OP,
type::KN_DIV_OP,
type::KN_REDUCTION_2_OP,
type::KN_CUSTOMIZED_OP,
} /* knop_to_explore */,
{
Expand Down Expand Up @@ -54,7 +55,7 @@ GeneratorConfig GeneratorConfig::get_default_config() {

void GeneratorConfig::enable_attention_specific_optimization() {
_enable_attention_specific_optimization = true;
max_num_threadblock_graphs = 2;
max_num_kernel_graph_op = 7;
}

void GeneratorConfig::enable_concat_matmul_transformation() {
Expand Down
2 changes: 1 addition & 1 deletion src/search/dim_strategy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ std::vector<dim3>

auto generate_1d_grids = [&](std::vector<int> const &dims) {
std::vector<dim3> cands;
for (size_t x = 32; x <= 64; x *= 2) {
for (size_t x = 8; x <= 128; x *= 2) {
for (int dim : dims) {
if (dim % x == 0) {
cands.push_back({dim / x, 1, 1});
Expand Down

0 comments on commit 8edf81c

Please sign in to comment.