Skip to content

Commit

Permalink
only do grid split when needed (#2965)
Browse files Browse the repository at this point in the history
**Issue:** In inner outer persistent scheduler, the last step is doing
an outer reduction, the inner dim is parallelized by `vectorization`,
`bimdx`, and `gdimy`. Current main branch always do three splits using
`vectorization`, `bdimx`, and `gdimy`, however, the last split is not
needed if `vectorization * bdimx * gdimy >= inner dim`, for example:
```
T0 
logical domain : (iS264{gridDim.y}, iS265{i1})
 contiguity: t t
  Split: iS265{i1} by factor 4
  Split: iS997{( ceilDiv(i1, 4) )} by factor blockDim.x 
  Split: iS999{( ceilDiv(( ceilDiv(i1, 4) ), blockDim.x) )} by factor gridDim.y
```
The last split is redundant if `4 * blockDim.x * gridDim.y >= i1`
**Fix:**
Only split when `vectorization * bdimx * gdimy < inner dim`
**Influence:**
Removing this extra split saves one loop in the generated code.
Performance is increased in some cases but decreased in other cases, all
changes are within 10%. see [dashboard](http://nv/ekP).

---------

Co-authored-by: jjsjann123 <[email protected]>
  • Loading branch information
liqiangxl and jjsjann123 authored Oct 7, 2024
1 parent 615177d commit 2b9e9d6
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
19 changes: 14 additions & 5 deletions csrc/scheduler/normalization_inner_outer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,8 +610,14 @@ std::unique_ptr<ReductionParams> innerOuterPersistentHeuristic(
rparams->pad_outer_reduction_to_warp = true;
}
rparams->block_dim_iter_dom = ParallelType::TIDy;
rparams->combined_split_grid_inner_dim =
iop.vectorization_factor_outer * iop.bdimy * iop.gdimy <
inner_dim_numel;
} else {
rparams->block_dim_inner_reduction_extra = ParallelType::TIDy;
rparams->combined_split_grid_inner_dim =
iop.vectorization_factor_outer * iop.bdimx * iop.gdimy <
inner_dim_numel;
rparams->static_bdimx = true;
rparams->static_bdimy = true;
iop.bdimz = ceilDiv(
Expand Down Expand Up @@ -841,9 +847,10 @@ void scheduleReductionCombinedOuter(
axisID, NamedScalar::getParallelDim(ParallelType::TIDx));
outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::TIDx);
}

outer_reduction_tv->split(
axisID, NamedScalar::getParallelDim(ParallelType::BIDy));
if (rparams->combined_split_grid_inner_dim) {
outer_reduction_tv->split(
axisID, NamedScalar::getParallelDim(ParallelType::BIDy));
}
outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::BIDy);

} else {
Expand All @@ -864,8 +871,10 @@ void scheduleReductionCombinedOuter(
outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::TIDx);
}

outer_reduction_tv->split(
axisID, NamedScalar::getParallelDim(ParallelType::BIDy));
if (rparams->combined_split_grid_inner_dim) {
outer_reduction_tv->split(
axisID, NamedScalar::getParallelDim(ParallelType::BIDy));
}

outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::BIDy);
}
Expand Down
4 changes: 4 additions & 0 deletions csrc/scheduler/reduction_heuristic.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ class ReductionParams : public HeuristicParams {
bool tidx_for_outer_reduction = false;
// pad outer reduction to warp
bool pad_outer_reduction_to_warp = false;
// in outer reduction part of inner-outer persistent scheduler, may further
// split inner dim by grid
bool combined_split_grid_inner_dim = false;
// partial result of outer reduction is written to gmem then read back in a
// different parallel pattern set the vectorization factor of its read and
// write
Expand Down Expand Up @@ -191,6 +194,7 @@ class ReductionParams : public HeuristicParams {
other->tidx_for_outer_reduction == tidx_for_outer_reduction &&
other->pad_outer_reduction_to_warp == pad_outer_reduction_to_warp &&
other->vectorization_factor_outer == vectorization_factor_outer &&
other->combined_split_grid_inner_dim == combined_split_grid_inner_dim &&
other->vectorization_factor_tmp_gmem_write ==
vectorization_factor_tmp_gmem_write;

Expand Down

0 comments on commit 2b9e9d6

Please sign in to comment.