Skip to content

Commit

Permalink
refactor 2d_inner_reduction_heuristics
Browse files Browse the repository at this point in the history
  • Loading branch information
liqiangxl committed Nov 3, 2024
1 parent bdeb737 commit 06f8a2a
Showing 1 changed file with 117 additions and 148 deletions.
265 changes: 117 additions & 148 deletions csrc/scheduler/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,31 +63,30 @@ void reduceProductTo(int64_t& z, int64_t& y, int64_t& x, const int64_t max) {
}
}

std::unique_ptr<ReductionParams> inner2dReductionHeuristic(
const int64_t total_reduction_numel,
const int64_t total_iteration_numel,
const int64_t n_tensor_inputs,
const int64_t max_input_dtype_size,
const size_t vectorize_factor) {
// Set some targets for parallelization
const int64_t n_elems = total_reduction_numel * total_iteration_numel;
std::tuple<int64_t, int64_t, int64_t> getThreadsPerBlockPerSmAndSmCount() {
auto dev_prop = at::cuda::getCurrentDeviceProperties();
return {
dev_prop->maxThreadsPerBlock,
dev_prop->maxThreadsPerMultiProcessor,
dev_prop->multiProcessorCount};
}

// WARNING: At some point we may want to generate heuristics for another
// device that is not the current device.
const int64_t max_threads_per_sm =
(int64_t)dev_prop->maxThreadsPerMultiProcessor;

const int64_t device_multiprocessor_count =
(int64_t)dev_prop->multiProcessorCount;

auto const max_unroll = ceilDiv(
int64_t getMaxUnroll(
const int64_t max_input_dtype_size,
const int64_t n_tensor_inputs) {
return ceilDiv(
// Available unrolling based on size of data type
(int64_t)16 / (int64_t)max_input_dtype_size,
(int64_t)16 / max_input_dtype_size,
// Reduce unrolling if we have many inputs, start reduction at 4 inputs
scheduler_utils::lastPow2(
std::max((int64_t)n_tensor_inputs >> 2, (int64_t)1)));
scheduler_utils::lastPow2(std::max(n_tensor_inputs >> 2, (int64_t)1)));
}

int64_t getL1L2WarpSize(
const int64_t total_reduction_numel,
const int64_t total_iteration_numel,
const int64_t n_tensor_inputs,
const int64_t max_input_dtype_size) {
const int64_t n_elems = total_reduction_numel * total_iteration_numel;
// Conservative value, could be set to larger based on arch if necessary.
constexpr int64_t l1_cache = (int64_t)32 * 1024;
// Could change per generation, but for l1 we want to consider active threads,
Expand Down Expand Up @@ -116,11 +115,18 @@ std::unique_ptr<ReductionParams> inner2dReductionHeuristic(
(n_tensor_inputs * max_input_dtype_size * active_threads),
(int64_t)1)),
(int64_t)16);
return std::min(warp_size_based_on_l1, warp_size_based_on_l2);
}

// Take the smaller
const int64_t min_warp_size =
std::min(warp_size_based_on_l1, warp_size_based_on_l2);

std::tuple<int64_t, int64_t, int64_t> getTargetThreadsBlocksIterations(
const int64_t total_reduction_numel,
const int64_t total_iteration_numel,
const int64_t n_tensor_inputs,
const int64_t max_input_dtype_size,
const int64_t sm_count,
const int64_t max_threads_per_sm,
const int64_t max_unroll,
const int64_t min_warp_size) {
// Initialization
int64_t target_blocks = 1;
int64_t target_unroll = 1;
Expand All @@ -141,13 +147,14 @@ std::unique_ptr<ReductionParams> inner2dReductionHeuristic(
min_warp_size, ceilDiv(total_reduction_numel, min_target_iterations));

// If we have one warp per block, check if that's enough to saturate the SMs
const int64_t n_elems = total_reduction_numel * total_iteration_numel;
target_blocks = ceilDiv(n_elems, min_warp_size);

// If we have more than a wave of blocks, put parallelism into unrolling and
// target iterations
if (target_blocks > device_multiprocessor_count) {
auto available_unroll = std::max(
n_elems / (min_warp_size * device_multiprocessor_count), (int64_t)1);
if (target_blocks > sm_count) {
auto available_unroll =
std::max(n_elems / (min_warp_size * sm_count), (int64_t)1);

// Spread across unrolling and iterations, want a balance of the two so flip
// back and forth to alternate adding to them.
Expand All @@ -167,8 +174,7 @@ std::unique_ptr<ReductionParams> inner2dReductionHeuristic(

available_unroll = std::max(
n_elems /
(min_warp_size * device_multiprocessor_count * target_unroll *
target_iterations),
(min_warp_size * sm_count * target_unroll * target_iterations),
(int64_t)1);

flip = !flip;
Expand All @@ -180,7 +186,7 @@ std::unique_ptr<ReductionParams> inner2dReductionHeuristic(
}

// Cap target blocks to 4 waves
target_blocks = std::min(target_blocks, device_multiprocessor_count * 4);
target_blocks = std::min(target_blocks, sm_count * 4);

if (target_blocks * target_unroll * target_iterations < n_elems) {
// targetting 4 waves, so try to use a quarter of available threads
Expand All @@ -194,129 +200,120 @@ std::unique_ptr<ReductionParams> inner2dReductionHeuristic(
target_threads_in_block +=
min_warp_size - target_threads_in_block % min_warp_size;
}
return {target_threads_in_block, target_blocks, target_iterations};
}

// To get to target threads:
// Prioritize
// (1) x dim in reduction
// (2) unrolling in reduction
// (3) y in output
// To get target blocks:
// Prioritize
// (1) x dim in multiple outputs
// (2) y dim in multiple reductions

// Cross grid inner reduction, number of blocks to cross-grid on
int64_t gridim = 1;

// Blocks for outputs
int64_t godim = 1;

// Threads for reduction
int64_t bdimx = 1;
// Threads for outputs
int64_t bdimy = 1;

// Unroll amount
int64_t inner_reduction_unroll_factor = 1;
int64_t iter_unroll_factor = 1;

inner_reduction_unroll_factor =
vectorize_factor > 1 ? (int64_t)vectorize_factor : 1;

// Grab what we can out of reduction domain, but don't go over a warp size yet
bdimx = std::min(
std::max(
ceilDiv(total_reduction_numel, inner_reduction_unroll_factor),
(int64_t)min_warp_size),
target_threads_in_block);

std::unique_ptr<ReductionParams> inner2dReductionHeuristic(
const int64_t total_reduction_numel,
const int64_t total_iteration_numel,
const int64_t n_tensor_inputs,
const int64_t max_input_dtype_size,
const size_t max_vectorize_factor) {
// Set some targets for parallelization
auto [threads_per_block, threads_per_sm, sm_count] =
getThreadsPerBlockPerSmAndSmCount();
auto const max_unroll = getMaxUnroll(max_input_dtype_size, n_tensor_inputs);
const int64_t min_warp_size = getL1L2WarpSize(
total_reduction_numel,
total_iteration_numel,
n_tensor_inputs,
max_input_dtype_size);
auto [target_threads_in_block, target_blocks, target_iterations] =
getTargetThreadsBlocksIterations(
total_reduction_numel,
total_iteration_numel,
n_tensor_inputs,
max_input_dtype_size,
sm_count,
threads_per_sm,
max_unroll,
min_warp_size);

// Parallelization strategy:
// [] indicates optional
// Reduction dim: Serial, [grdim], TIDx, Vect
// Iteration dim: Serial, godim, [TIDy]

// Max vectorization factor
int64_t vect_factor = std::min(
scheduler_utils::lastPow2(max_unroll), (int64_t)max_vectorize_factor);
int64_t after_vect = total_reduction_numel / vect_factor;

// Set bdimx and bdimy
// Prioritize set bdimx to max threads in block
// Put what is left to bdimy
int64_t bdimx =
std::min(std::max(after_vect, min_warp_size), target_threads_in_block);
// If we're not just barely covering the dimension, round to a more friendly
// number
if (bdimx * inner_reduction_unroll_factor != total_reduction_numel) {
if (bdimx * vect_factor != total_reduction_numel) {
// Round bdimx down to multiple of warp size or power 2
if (bdimx < min_warp_size) {
bdimx = scheduler_utils::lastPow2(bdimx);
} else {
bdimx = bdimx - bdimx % min_warp_size;
}
}

// Put everything else in bdimy for now
bdimy = std::max(target_threads_in_block / bdimx, (int64_t)1);

int64_t bdimy = std::max(target_threads_in_block / bdimx, (int64_t)1);
// If we don't have a full warp and have an unroll factor, move unroll into
// bdimx
if (bdimx * bdimy < min_warp_size && inner_reduction_unroll_factor > 1) {
if (bdimx * bdimy < min_warp_size && vect_factor > 1) {
bdimx = std::min(
std::max(total_reduction_numel, min_warp_size),
target_threads_in_block);
inner_reduction_unroll_factor =
std::min(ceilDiv(total_reduction_numel, bdimx), max_unroll);
vect_factor = std::min(ceilDiv(total_reduction_numel, bdimx), max_unroll);
bdimy = std::max(target_threads_in_block / bdimx, (int64_t)1);
}

godim = ceilDiv(total_iteration_numel, bdimy);

bool vectorize = false;

// Move unrolling factor into vectorization upto vectorization limit.
if (vectorize_factor > 1 && inner_reduction_unroll_factor > 1) {
vectorize = true;
inner_reduction_unroll_factor = std::min(
scheduler_utils::lastPow2(inner_reduction_unroll_factor),
(int64_t)vectorize_factor);
}

int64_t remainder_in_reduction = ceilDiv(
total_reduction_numel,
bdimx * inner_reduction_unroll_factor * target_iterations);

// set iteration blocks and iteration unroll
int64_t godim = ceilDiv(total_iteration_numel, bdimy);
int64_t remainder_in_reduction =
ceilDiv(total_reduction_numel, bdimx * vect_factor * target_iterations);
// If we haven't gotten to the max_unroll case, try to take it out of the
// iteration domain
if (inner_reduction_unroll_factor < max_unroll) {
int64_t iter_unroll_factor = 1;
if (vect_factor < max_unroll) {
// Don't go over a combined inner/outer unroll of max_unroll
auto unroll_available = ceilDiv(max_unroll, inner_reduction_unroll_factor);
auto unroll_available = ceilDiv(max_unroll, vect_factor);

if (unroll_available > 1 && godim > 2 * device_multiprocessor_count) {
unroll_available = std::min(
unroll_available, ceilDiv(godim, 2 * device_multiprocessor_count));
if (unroll_available > 1 && godim > 2 * sm_count) {
unroll_available =
std::min(unroll_available, ceilDiv(godim, 2 * sm_count));
iter_unroll_factor = unroll_available;
}
}

godim = ceilDiv(total_iteration_numel, bdimy * iter_unroll_factor);

// Clang tidy
// set reduction blocks, grdim
int64_t grdim = 1;
constexpr int64_t kEight = 8;
// Cross grid reduction if we haven't hit our target blocks, and we have manyr
// reduction elements.
if ((godim < target_blocks && remainder_in_reduction >= 0) ||
(remainder_in_reduction >= kEight)) {
gridim = remainder_in_reduction;
grdim = remainder_in_reduction;
}

// Try to do some cleanup of ragged waves on device
if ( // If we have less than 8 waves of blocks
gridim * godim < device_multiprocessor_count * kEight &&
grdim * godim < sm_count * kEight &&
// And we don't have an even divisible number of blocks
(gridim * godim) % device_multiprocessor_count != 0 &&
(grdim * godim) % sm_count != 0 &&
// And we have more than one wave
gridim * godim > device_multiprocessor_count) {
grdim * godim > sm_count) {
// round waves down
auto waves =
std::max((godim * gridim) / device_multiprocessor_count, (int64_t)1);
auto new_gridim =
std::max((waves * device_multiprocessor_count) / godim, (int64_t)1);
auto waves = std::max((godim * grdim) / sm_count, (int64_t)1);
auto new_grdim = std::max((waves * sm_count) / godim, (int64_t)1);
if (
// If difference is less than 25% of the original gridim
(new_gridim - gridim) * 4 < gridim &&
// If difference is less than 25% of the original grdim
(new_grdim - grdim) * 4 < grdim &&
// and difference is less than 25% of the original number of blocks
((new_gridim * godim) - (gridim * godim)) * 4 < gridim * godim) {
gridim = new_gridim;
((new_grdim * godim) - (grdim * godim)) * 4 < grdim * godim) {
grdim = new_grdim;
}
}

if (gridim > 1) {
if (grdim > 1) {
// Grid reductions do not support unrolling iteration dimension, revert if
// set. Recalculate godim.
if (iter_unroll_factor) {
Expand All @@ -332,11 +329,9 @@ std::unique_ptr<ReductionParams> inner2dReductionHeuristic(
rparams->fastest_dim = true;
rparams->cross_block_inner_reduction = true;
rparams->block_dim_inner_reduction = ParallelType::TIDx;
rparams->cross_grid_inner_reduction = gridim > 1;
rparams->cross_grid_inner_reduction = grdim > 1;
rparams->multiple_reds_per_blk = bdimy > 1;
bool pad_bdimx = bdimx > 16 &&
bdimx * bdimy <
(int64_t)at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
bool pad_bdimx = bdimx > 16 && bdimx * bdimy < threads_per_block;
rparams->pad_inner_reduction_to_warp = pad_bdimx;

if (rparams->pad_inner_reduction_to_warp) {
Expand All @@ -348,8 +343,8 @@ std::unique_ptr<ReductionParams> inner2dReductionHeuristic(
: bdimx + min_warp_size - bdimx % min_warp_size;
}

rparams->unroll_factor_inner_reduction = inner_reduction_unroll_factor;
rparams->vectorize_inner_reduction = vectorize;
rparams->unroll_factor_inner_reduction = vect_factor;
rparams->vectorize_inner_reduction = vect_factor > 1;

if (rparams->multiple_reds_per_blk) {
rparams->block_dim_iter_dom = ParallelType::TIDy;
Expand All @@ -367,7 +362,7 @@ std::unique_ptr<ReductionParams> inner2dReductionHeuristic(
if (rparams->cross_grid_inner_reduction) {
rparams->grid_dim_inner_reduction = ParallelType::BIDx;
rparams->split_grid_dim_inner_reduction = true;
gdimx = std::min(gridim, scheduler_utils::x_grid_limit);
gdimx = std::min(grdim, scheduler_utils::x_grid_limit);

rparams->grid_dim_iter_dom = ParallelType::BIDy;
if (godim > scheduler_utils::y_grid_limit) {
Expand Down Expand Up @@ -395,7 +390,7 @@ std::unique_ptr<ReductionParams> inner2dReductionHeuristic(
debug() << "\n===== Inner 2D Reduction Stats ========\n"
<< "total_reduction_numel: " << total_reduction_numel << "\n"
<< "total_iteration_numel: " << total_iteration_numel << "\n"
<< "vectorize_factor: " << vectorize_factor << "\n"
<< "vectorize_factor: " << vect_factor << "\n"
<< "n_tensor_inputs: " << n_tensor_inputs << "\n"
<< "max_input_dtype_size: " << max_input_dtype_size << "\n"
<< "block(" << bdimx << ", " << bdimy << ", " << 1 << ")"
Expand Down Expand Up @@ -432,38 +427,12 @@ std::unique_ptr<ReductionParams> inner3dReductionHeuristic(
scheduler_utils::lastPow2(
std::max((int64_t)n_tensor_inputs >> 2, (int64_t)1)));

// Conservative value, could be set to larger based on arch if necessary.
constexpr int64_t l1_cache = (int64_t)32 * 1024;
// Could change per generation, but for l1 we want to consider active threads,
// not resident
constexpr int64_t active_threads = 1024;

// if data fits in l2 and we need more parallelization in the reduction dim,
// we can use a smaller warp size. While thread local data fits in l1, and
// reduction dim is really small, we can use <32 threads per warp.
const bool fits_in_l2 = n_elems * max_input_dtype_size * n_tensor_inputs <
at::cuda::getCurrentDeviceProperties()->l2CacheSize;

// If it fits in l2, we just want to make sure each warp uses 32Bytes. Set
// minimum warp as 16 threads instead of 32 as if we have a small reduction
// dim going a bit smaller than 32 usually helps.
const int64_t warp_size_based_on_l2 =
fits_in_l2 ? (int64_t)32 / max_input_dtype_size : 16;

// Check how many elements it would take per thread to start thrashing l1
// set that to minimum number we want to reduce per thread.
const int64_t warp_size_based_on_l1 = std::min(
ceilDiv(
total_reduction_numel,
std::max(
l1_cache /
(n_tensor_inputs * max_input_dtype_size * active_threads),
(int64_t)1)),
(int64_t)16);

// Take the smaller
const int64_t min_warp_size =
std::min(warp_size_based_on_l1, warp_size_based_on_l2);
const int64_t min_warp_size = getL1L2WarpSize(
total_reduction_numel,
total_iteration_numel,
n_tensor_inputs,
max_input_dtype_size);

// Initialization
int64_t target_blocks = 1;
Expand Down

0 comments on commit 06f8a2a

Please sign in to comment.