-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add knobs control inner dim unroll and outer dim unroll in pointwise scheduler #3275
Changes from all commits
12568b6
09ba0b6
c86a0b2
f5b349f
23efc80
67450af
35af092
089dd85
e2221b5
00571b4
377b7fc
12ad2e6
94680b6
7e04577
c5b0365
1307ba8
4ea639b
6c22a3b
b23cb41
7e0cbff
50ba432
cb60e11
8616bc9
416bc31
63a38f1
5293c7e
0710686
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -369,8 +369,21 @@ std::unique_ptr<PointwiseParams> getPointwiseHeuristics( | |
// not used. should allow to use both unroll and vectorization together in | ||
// heuristics tuning. | ||
if (params->vectorization_factor == 1) { | ||
params->unroll_factor = scheduler_utils::safeDiv( | ||
auto total_unroll = scheduler_utils::safeDiv( | ||
max_vect_unroll_factor, params->vectorization_factor); | ||
// for 1D scheduler, unroll the inner dimension | ||
// since there is no outer dimension. | ||
if (break_point == 0) { | ||
params->unroll_factor_inner = total_unroll; | ||
params->unroll_factor_outer = 1L; | ||
} else { | ||
// for 2D scheduler, unroll the outer dimension | ||
// to prioritize resue across different rows, will | ||
// be revised in heuristics tuning, e.g. unroll different | ||
// dims based on the broadcast dimension. | ||
params->unroll_factor_inner = 1L; | ||
params->unroll_factor_outer = total_unroll; | ||
} | ||
} | ||
|
||
NVF_ERROR(right_elem_count > 0 || break_point == 0); | ||
|
@@ -394,7 +407,10 @@ std::unique_ptr<PointwiseParams> getPointwiseHeuristics( | |
<< "num_elems: " << n_elems << "\n" | ||
<< "elem_counts: " << elem_counts << "\n" | ||
<< "max_input_dtype_size: " << max_input_dtype_size << "\n" | ||
<< "unroll_factor: " << params->unroll_factor << std::endl | ||
<< "unroll_factor_inner: " << params->unroll_factor_inner | ||
<< std::endl | ||
<< "unroll_factor_outer: " << params->unroll_factor_outer | ||
<< std::endl | ||
<< "vectorize_factor: " << params->vectorization_factor << std::endl | ||
<< "\n" | ||
<< "logical_reorder_map: "; | ||
|
@@ -677,11 +693,17 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { | |
reference_tv->reorder({{lhs_i, 0}, {-1, 1}}); | ||
|
||
// vectorization without unroll | ||
if (pparams->unroll_factor == 1 && pparams->vectorization_factor > 1) { | ||
if (pparams->unroll_factor_outer == 1 && | ||
pparams->unroll_factor_inner == 1 && | ||
pparams->vectorization_factor > 1) { | ||
reference_tv->split(1, pparams->vectorization_factor); | ||
reference_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDx)); | ||
reference_tv->split(0, 1); | ||
// [outer, Unswitch | i-remainder, TIDx, Vectorization] | ||
// Here and in the following comments: | ||
// prefix [i] represent inner dimension | ||
// prefix [o] represent inner dimension | ||
// [|] separates the outer and inner dimensions | ||
reference_tv->axis(1)->parallelize(ParallelType::Unswitch); | ||
reference_tv->axis(3)->parallelize(ParallelType::TIDx); | ||
// Vectorization are propagated separately | ||
|
@@ -700,41 +722,58 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { | |
reference_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDx)); | ||
// [outer | i-remainder, TIDx, Vect] | ||
|
||
reference_tv->split(0, pparams->unroll_factor); | ||
// [o-remainder, Unroll| i-remainder, TIDx, Vect] | ||
if (pparams->unroll_factor_inner > 1) { | ||
reference_tv->split(1, pparams->unroll_factor_inner); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we are splitting on dimension 1? which is the TIDx here right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is for 2D scheduler, start with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So this is a behavior change then. If we look at the above commented code change, we are doing
Which means the old behavior (outer unroll) is being updated to a default inner unroll instead? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. Should assign unroll to inner dim only when the scheduler is 1D, for 2D should assign to outer dim.
|
||
} | ||
// [outer| i-remainder, i-Unroll, TIDx, Vect] | ||
|
||
if (pparams->unroll_factor_outer > 1) { | ||
reference_tv->split(0, pparams->unroll_factor_outer); | ||
} | ||
// [o-remainder, o-Unroll, | i-remainder, i-Unroll, TIDx, Vect] | ||
|
||
reference_tv->split(0, 1); | ||
// [o-remainder, Unswitch, Unroll | i-remainder, TIDx, Vect] | ||
// [o-remainder, Unswitch, o-Unroll | i-remainder, i-Unroll, TIDx, Vect] | ||
|
||
reference_tv->reorder({{3, 1}}); | ||
// [o-remainder, i-remainder, Unswitch, Unroll, TIDx, Vect] | ||
int i_remainder_pos = pparams->unroll_factor_outer > 1 ? 3 : 2; | ||
reference_tv->reorder({{i_remainder_pos, 1}}); | ||
// [o-remainder, i-remainder, Unswitch, o-Unroll, i-Unroll, TIDx, Vect] | ||
|
||
reference_tv->axis(2)->parallelize(ParallelType::Unswitch); | ||
// Here we do not set axis(3)->parallelize(Unroll) because we do not want | ||
// it to be propagated. We manually unroll by splitting the inline | ||
// propagation process into two steps: | ||
// step 1: inline at the unswitch position for cached inputs and outputs | ||
// step 2: inline at the inner most dim for the rest of the graph | ||
reference_tv->axis(4)->parallelize(ParallelType::TIDx); | ||
int tidx_pos = 3; | ||
if (pparams->unroll_factor_inner > 1) { | ||
tidx_pos++; | ||
} | ||
if (pparams->unroll_factor_outer > 1) { | ||
tidx_pos++; | ||
} | ||
reference_tv->axis(tidx_pos)->parallelize(ParallelType::TIDx); | ||
if (pparams->vectorization_factor > 1) { | ||
vectorize_id = reference_tv->axis(5); | ||
// can't use {-1}, there may be deviceId | ||
vectorize_id = reference_tv->axis(tidx_pos + 1); | ||
} | ||
// [o-remainder, i-remainder, Unswitch, Unroll, TIDx, Vect] | ||
// [o-remainder, i-remainder, Unswitch, o-Unroll, i-Unroll, TIDx, Vect] | ||
} | ||
|
||
// Move out of the way to furthest left point | ||
reference_tv->reorder({{1, 0}}); | ||
// [i-remainder, o-remainder, Unswitch, Unroll, TIDx, Vect] | ||
// [i-remainder, o-remainder, Unswitch, o-Unroll, i-Unroll, TIDx, Vect] | ||
if (pparams->split_block) { | ||
reference_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDy)); | ||
// [i-remainder, o-remainder, TIDy, Unswitch, Unroll, TIDx, Vect] | ||
// [i-remainder, o-remainder, TIDy, Unswitch, o-Unroll, i-Unroll, TIDx, | ||
// Vect] | ||
if (pparams->flip_grid_binding) { | ||
// [BIDy | BIDx, TIDy | Unswitch, Unroll, TIDx, Vect] | ||
// [BIDy | BIDx, TIDy | Unswitch, o-Unroll, i-Unroll, TIDx, Vect] | ||
reference_tv->axis(1)->parallelize(ParallelType::BIDx); | ||
reference_tv->axis(2)->parallelize(ParallelType::TIDy); | ||
if (pparams->split_grid_y_dim) { | ||
// [i-remainder, BIDy{65535} | BIDx, TIDy | Unswitch, Unroll, TIDx, | ||
// Vect] | ||
// [i-remainder, BIDy{65535} | BIDx, TIDy | Unswitch, o-Unroll, | ||
// i-Unroll, TIDx, Vect] | ||
reference_tv->split(0, 65535); | ||
reference_tv->axis(1)->parallelize(ParallelType::BIDy); | ||
unswitch_pos = 5; | ||
|
@@ -743,12 +782,12 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { | |
unswitch_pos = 4; | ||
} | ||
} else { | ||
// [BIDx | BIDy TIDy | Unswitch, Unroll, TIDx, Vect] | ||
// [BIDx | BIDy TIDy | Unswitch, o-Unroll, i-Unroll, TIDx, Vect] | ||
reference_tv->axis(0)->parallelize(ParallelType::BIDx); | ||
reference_tv->axis(2)->parallelize(ParallelType::TIDy); | ||
if (pparams->split_grid_y_dim) { | ||
// [BIDx | i-remainder, BIDy{65535}, TIDy | Unswitch, Unroll, TIDx, | ||
// Vect] | ||
// [BIDx | i-remainder, BIDy{65535}, TIDy | Unswitch, o-Unroll, | ||
// i-Unroll, TIDx, Vect] | ||
reference_tv->split(1, 65535); | ||
reference_tv->axis(2)->parallelize(ParallelType::BIDy); | ||
unswitch_pos = 5; | ||
|
@@ -796,7 +835,8 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { | |
// unmerged...] | ||
reference_tv->reorder({{-1, 0}}); | ||
|
||
if (pparams->unroll_factor == 1 && pparams->vectorization_factor > 1) { | ||
if (pparams->unroll_factor_inner == 1 && | ||
pparams->vectorization_factor > 1) { | ||
// Vectorize | ||
reference_tv->split(0, pparams->vectorization_factor); | ||
// Unswitch | ||
|
@@ -822,7 +862,9 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { | |
// Threads | ||
reference_tv->split(0, kThreadX); | ||
// Unroll | ||
reference_tv->split(0, pparams->unroll_factor); | ||
if (pparams->unroll_factor_inner > 1) { | ||
reference_tv->split(0, pparams->unroll_factor_inner); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. qq: we are not using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this |
||
} | ||
// Unswitch | ||
reference_tv->split(0, 1); | ||
|
||
|
@@ -834,9 +876,10 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { | |
// propagation process into two steps: | ||
// step 1: inline at the unswitch position for cached inputs and outputs | ||
// step 2: inline at the inner most dim for the rest of the graph | ||
reference_tv->axis(3)->parallelize(ParallelType::TIDx); | ||
int tidx_pos = pparams->unroll_factor_inner > 1 ? 3 : 2; | ||
reference_tv->axis(tidx_pos)->parallelize(ParallelType::TIDx); | ||
if (pparams->vectorization_factor > 1) { | ||
vectorize_id = reference_tv->axis(4); | ||
vectorize_id = reference_tv->axis(tidx_pos + 1); | ||
} | ||
} | ||
unswitch_pos = 2; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍🏼