diff --git a/csrc/scheduler/reduction.cpp b/csrc/scheduler/reduction.cpp index 4032c9d159e..c7021706865 100644 --- a/csrc/scheduler/reduction.cpp +++ b/csrc/scheduler/reduction.cpp @@ -63,7 +63,449 @@ void reduceProductTo(int64_t& z, int64_t& y, int64_t& x, const int64_t max) { } } -std::unique_ptr innerReductionHeuristic( +std::unique_ptr inner2dReductionHeuristic( + const int64_t total_reduction_numel, + const int64_t total_iteration_numel, + const int64_t inner_most_dimension_numel, + const int64_t n_tensor_inputs, + const int64_t max_input_dtype_size, + const size_t vectorize_factor) { + // Set some targets for parallelization + + const int64_t n_elems = total_reduction_numel * total_iteration_numel; + + // WARNING: At some point we may want to generate heuristics for another + // device that is not the current device. + const int64_t max_threads_per_sm = + (int64_t)at::cuda::getCurrentDeviceProperties() + ->maxThreadsPerMultiProcessor; + + const int64_t device_multiprocessor_count = + (int64_t)at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + + auto const max_unroll = ceilDiv( + // Available unrolling based on size of data type + (int64_t)16 / (int64_t)max_input_dtype_size, + // Reduce unrolling if we have many inputs, start reduction at 4 inputs + scheduler_utils::lastPow2( + std::max((int64_t)n_tensor_inputs >> 2, (int64_t)1))); + + // Conservative value, could be set to larger based on arch if necessary. + constexpr int64_t l1_cache = (int64_t)32 * 1024; + // Could change per generation, but for l1 we want to consider active threads, + // not resident + constexpr int64_t active_threads = 1024; + + // if data fits in l2 and we need more parallelization in the reduction dim, + // we can use a smaller warp size. While thread local data fits in l1, and + // reduction dim is really small, we can use <32 threads per warp. + const bool fits_in_l2 = n_elems * max_input_dtype_size * n_tensor_inputs < + at::cuda::getCurrentDeviceProperties()->l2CacheSize; + + // If it fits in l2, we just want to make sure each warp uses 32Bytes. Set + // minimum warp as 16 threads instead of 32 as if we have a small reduction + // dim going a bit smaller than 32 usually helps. + const int64_t warp_size_based_on_l2 = + fits_in_l2 ? (int64_t)32 / max_input_dtype_size : 16; + + // Check how many elements it would take per thread to start thrashing l1 + // set that to minimum number we want to reduce per thread. + const int64_t warp_size_based_on_l1 = std::min( + ceilDiv( + total_reduction_numel, + std::max( + l1_cache / + (n_tensor_inputs * max_input_dtype_size * active_threads), + (int64_t)1)), + (int64_t)16); + + // Take the smaller + const int64_t min_warp_size = + std::min(warp_size_based_on_l1, warp_size_based_on_l2); + + // Initialization + int64_t target_blocks = 1; + int64_t target_unroll = 1; + int64_t target_iterations = 1; + + // Try to set a minmum amount of work for each thread, as cross thread + // communication is slow so it shouldn't be done for every element in the + // reduction. + int64_t min_target_iterations = + std::max((int64_t)32 / (int64_t)max_input_dtype_size, (int64_t)1); + + // Start trying to break parallelization up across threads, + // unrolling/iterations, and blocks. + + // target_threads_in_block is the cap on a thread block, the minimum is based + // on min_warp_size + int64_t target_threads_in_block = std::max( + min_warp_size, ceilDiv(total_reduction_numel, min_target_iterations)); + + // If we have one warp per block, check if that's enough to saturate the SMs + target_blocks = ceilDiv(n_elems, min_warp_size); + + // If we have more than a wave of blocks, put parallelism into unrolling and + // target iterations + if (target_blocks > device_multiprocessor_count) { + auto available_unroll = std::max( + n_elems / (min_warp_size * device_multiprocessor_count), (int64_t)1); + + // Spread across unrolling and iterations, want a balance of the two so flip + // back and forth to alternate adding to them. + bool flip = true; + + while (available_unroll > 1 && + (target_unroll < max_unroll || + // Prefer unrolling + target_iterations < max_unroll)) { + if (target_unroll * 2 <= max_unroll && flip) { + target_unroll *= 2; + } + + if (target_iterations * 2 <= max_unroll && !flip) { + target_iterations *= 2; + } + + available_unroll = std::max( + n_elems / + (min_warp_size * device_multiprocessor_count * target_unroll * + target_iterations), + (int64_t)1); + + flip = !flip; + } + + // Recompute target blocks + target_blocks = + ceilDiv(n_elems, min_warp_size * target_unroll * target_iterations); + } + + // Cap target blocks to 4 waves + target_blocks = std::min(target_blocks, device_multiprocessor_count * 4); + + if (target_blocks * target_unroll * target_iterations < n_elems) { + // targetting 4 waves, so try to use a quarter of available threads + target_threads_in_block = std::min( + ceilDiv(n_elems, target_blocks * target_unroll), + ceilDiv(max_threads_per_sm, (int64_t)4)); + } + + // Round up to nearest warp. + if (target_threads_in_block % min_warp_size != 0) { + target_threads_in_block += + min_warp_size - target_threads_in_block % min_warp_size; + } + + // To get to target threads: + // Prioritize + // (1) x dim in reduction + // (2) unrolling in reduction + // (3) y in output + // To get target blocks: + // Prioritize + // (1) x dim in multiple outputs + // (2) y dim in multiple reductions + + // Cross grid inner reduction, number of blocks to cross-grid on + int64_t gridim = 1; + // Cross grid outer reduction, number of blocks to cross-grid on + int64_t grodim = 1; + // Blocks for outputs + int64_t godim = 1; + + // Threads for reduction + int64_t bdimx = 1; + // Threads for outputs + int64_t bdimy = 1; + // Threads for outer reduction dimension + int64_t bdimz = 1; + + // Unroll amount + int64_t inner_reduction_unroll_factor = 1; + int64_t outer_reduction_unroll_factor = 1; + int64_t iter_unroll_factor = 1; + + inner_reduction_unroll_factor = + vectorize_factor > 1 ? (int64_t)vectorize_factor : 1; + + // Grab what we can out of reduction domain, but don't go over a warp size yet + bdimx = std::min( + std::max( + ceilDiv(inner_most_dimension_numel, inner_reduction_unroll_factor), + (int64_t)min_warp_size), + target_threads_in_block); + + // If we're not just barely covering the dimension, round to a more friendly + // number + if (bdimx * inner_reduction_unroll_factor != inner_most_dimension_numel) { + // Round bdimx down to multiple of warp size or power 2 + if (bdimx < min_warp_size) { + bdimx = scheduler_utils::lastPow2(bdimx); + } else { + bdimx = bdimx - bdimx % min_warp_size; + } + } + + // Put everything else in bdimy for now + bdimy = std::max(min_warp_size / bdimx, (int64_t)1); + + // If 3D fill the rest of the threads into bdimz + bdimz = std::min( + std::min( + std::max(target_threads_in_block / (bdimx * bdimy), (int64_t)1), + ceilDiv(total_reduction_numel, inner_most_dimension_numel)), + scheduler_utils::z_block_limit); + + // If 3D doesn't fill out the threads, adjust to add to bdimy + bdimy = std::max(target_threads_in_block / (bdimx * bdimz), (int64_t)1); + + // If we don't have a full warp and have an unroll factor, move unroll into + // bdimx + if (bdimx * bdimy * bdimz < min_warp_size && + inner_reduction_unroll_factor > 1) { + bdimx = std::min( + std::max(inner_most_dimension_numel, min_warp_size), + target_threads_in_block); + + inner_reduction_unroll_factor = + std::min(ceilDiv(inner_most_dimension_numel, bdimx), max_unroll); + + // Readjust bdimy and bdimz + bdimy = std::max(min_warp_size / bdimx, (int64_t)1); + + bdimz = std::min( + std::max(target_threads_in_block / (bdimx * bdimy), (int64_t)1), + ceilDiv(total_reduction_numel, inner_most_dimension_numel)); + + bdimy = std::max(target_threads_in_block / (bdimx * bdimz), (int64_t)1); + } + + godim = ceilDiv(total_iteration_numel, bdimy); + + bool vectorize = false; + + // Move unrolling factor into vectorization upto vectorization limit. + if (vectorize_factor > 1 && inner_reduction_unroll_factor > 1) { + vectorize = true; + inner_reduction_unroll_factor = std::min( + scheduler_utils::lastPow2(inner_reduction_unroll_factor), + (int64_t)vectorize_factor); + } + + // Attempt to put some unrolling into the outer reduction if inner hasn't + // taken the max unrolling + if (inner_reduction_unroll_factor < max_unroll) { + outer_reduction_unroll_factor = std::min( + ceilDiv(max_unroll, inner_reduction_unroll_factor), + ceilDiv( + ceilDiv(total_reduction_numel, inner_most_dimension_numel), bdimz)); + } + + int64_t remainder_in_reduction = ceilDiv( + total_reduction_numel, + bdimx * inner_reduction_unroll_factor * bdimz * + outer_reduction_unroll_factor * target_iterations); + + int64_t remainder_in_inner_dim = ceilDiv( + inner_most_dimension_numel, + bdimx * inner_reduction_unroll_factor * target_iterations); + + // If we haven't gotten to the max_unroll case, try to take it out of the + // iteration domain + if (inner_reduction_unroll_factor * outer_reduction_unroll_factor < + max_unroll) { + // Don't go over a combined inner/outer unroll of max_unroll + auto unroll_available = ceilDiv( + max_unroll, + inner_reduction_unroll_factor * outer_reduction_unroll_factor); + + if (unroll_available > 1 && godim > 2 * device_multiprocessor_count) { + unroll_available = std::min( + unroll_available, ceilDiv(godim, 2 * device_multiprocessor_count)); + iter_unroll_factor = unroll_available; + } + } + + godim = ceilDiv(total_iteration_numel, bdimy * iter_unroll_factor); + + // Clang tidy + constexpr int64_t kEight = 8; + // Cross grid reduction if we haven't hit our target blocks, and we have manyr + // reduction elements. + if ((godim < target_blocks && remainder_in_reduction >= 0) || + (remainder_in_reduction >= kEight)) { + auto grdim = std::min(remainder_in_reduction, bdimx * bdimy * kEight); + + gridim = remainder_in_inner_dim; + grodim = std::max(grdim / gridim, (int64_t)1); + grodim = std::max( + std::min(remainder_in_reduction / remainder_in_inner_dim, grodim), + (int64_t)1); + } + + // Try to do some cleanup of ragged waves on device, don't do this if we're + // trying to do a 3D schedule. godim is a remainder of a split, so can only + // control gridim + if (grodim == 1 && + // If we have less than 8 waves of blocks + gridim * godim < device_multiprocessor_count * kEight && + // And we don't have an even divisible number of blocks + (gridim * godim) % device_multiprocessor_count != 0 && + // And we have more than one wave + gridim * godim > device_multiprocessor_count) { + // round waves down + auto waves = + std::max((godim * gridim) / device_multiprocessor_count, (int64_t)1); + auto new_gridim = + std::max((waves * device_multiprocessor_count) / godim, (int64_t)1); + if ( + // If difference is less than 25% of the original gridim + (new_gridim - gridim) * 4 < gridim && + // and difference is less than 25% of the original number of blocks + ((new_gridim * godim) - (gridim * godim)) * 4 < gridim * godim) { + gridim = new_gridim; + } + } + + if (grodim > 1 || gridim > 1) { + // Grid reductions do not support unrolling iteration dimension, revert if + // set. Recalculate godim. + if (iter_unroll_factor) { + iter_unroll_factor = 1; + godim = ceilDiv(total_iteration_numel, bdimy * iter_unroll_factor); + } + // This could mess up parallelization which could be redone, but that would + // require iterating over this entire function. + } + + auto rparams = std::make_unique(); + rparams->fastest_dim = true; + rparams->cross_block_inner_reduction = true; + rparams->block_dim_inner_reduction = ParallelType::TIDx; + rparams->cross_grid_inner_reduction = gridim > 1; + rparams->multiple_reds_per_blk = bdimy > 1; + bool pad_bdimx = bdimx > 16 && + bdimx * bdimy < + (int64_t)at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock; + rparams->pad_inner_reduction_to_warp = pad_bdimx; + + if (rparams->pad_inner_reduction_to_warp) { + // Adjust bdimx based on padding + auto min_warp_size = + (int64_t)at::cuda::getCurrentDeviceProperties()->warpSize; + bdimx = bdimx % min_warp_size == 0 + ? bdimx + : bdimx + min_warp_size - bdimx % min_warp_size; + } + + rparams->unroll_factor_inner_reduction = inner_reduction_unroll_factor; + rparams->vectorize_inner_reduction = vectorize; + + if (rparams->multiple_reds_per_blk) { + rparams->block_dim_iter_dom = ParallelType::TIDy; + } + + rparams->unroll_factor_iter_dom = iter_unroll_factor; + + rparams->schedule_3D = total_reduction_numel != inner_most_dimension_numel; + // Outer reduction domain + if (rparams->schedule_3D) { + rparams->cross_grid_outer_reduction = grodim > 1; + if (bdimz > 1) { + rparams->block_dim_outer_reduction = ParallelType::TIDz; + rparams->cross_block_outer_reduction = true; + } + rparams->unroll_factor_outer_reduction = outer_reduction_unroll_factor; + } + + int64_t gdimx = LaunchParams::UNINITIALIZED_VAL; + int64_t gdimy = LaunchParams::UNINITIALIZED_VAL; + int64_t gdimz = LaunchParams::UNINITIALIZED_VAL; + + // If we have a cross grid case we want to have gdimy assigned to godim and + // gdimx assigned to grdim. Otherwise it's helpful to pull godim into gdimx in + // case it's larger than gdimy can hold, as not doing so can thrash the cache. + + if (rparams->cross_grid_inner_reduction) { + rparams->grid_dim_inner_reduction = ParallelType::BIDx; + rparams->split_grid_dim_inner_reduction = true; + gdimx = std::min(gridim, scheduler_utils::x_grid_limit); + + rparams->grid_dim_iter_dom = ParallelType::BIDy; + if (godim > scheduler_utils::y_grid_limit) { + rparams->split_grid_dim_iter_dom_outer = true; + gdimy = std::min(godim, scheduler_utils::y_grid_limit); + } + + } else { + rparams->grid_dim_iter_dom = ParallelType::BIDx; + if (gdimx > scheduler_utils::x_grid_limit) { + rparams->split_grid_dim_iter_dom_outer = true; + gdimx = godim; + } + } + + if (rparams->cross_grid_outer_reduction) { + if (rparams->cross_block_inner_reduction) { + rparams->grid_dim_outer_reduction = ParallelType::BIDz; + gdimz = std::min(grodim, scheduler_utils::z_grid_limit); + rparams->split_grid_dim_outer_reduction = true; + } else { + rparams->grid_dim_outer_reduction = ParallelType::BIDy; + gdimy = std::min(grodim, scheduler_utils::y_grid_limit); + rparams->split_grid_dim_outer_reduction = true; + } + } + + rparams->lparams = LaunchParams( + gdimx, + gdimy, + gdimz, + bdimx, + bdimy > 1 ? bdimy : LaunchParams::UNINITIALIZED_VAL, + bdimz > 1 ? bdimz : LaunchParams::UNINITIALIZED_VAL); + + if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { + debug() << "\n===== Reduction Stats ========\n" + << "total_reduction_numel: " + << total_reduction_numel / inner_most_dimension_numel << " * " + << inner_most_dimension_numel << "\n" + << "total_iteration_numel: " << total_iteration_numel << "\n" + << "vectorize_factor: " << vectorize_factor << "\n" + << "n_tensor_inputs: " << n_tensor_inputs << "\n" + << "max_input_dtype_size: " << max_input_dtype_size << "\n" + << "block(" << bdimx << ", " << bdimy << ", " << bdimz << ")" + << std::endl; + debug() << rparams->toString() << std::endl; + } + + // If 3d, check if it's supported by the scheduler, otherwise force 2D + // schedule + if (rparams->schedule_3D) { + if (rparams->multiple_reds_per_blk && + (rparams->cross_grid_inner_reduction || + rparams->cross_grid_outer_reduction)) { + if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { + debug() << "\n===== UNSUPPORTED REDUCTION HEURISTIC ========\n"; + debug() << rparams->multiple_reds_per_blk << ", " + << (rparams->unroll_factor_inner_reduction > 1) << ", " + << rparams->cross_grid_inner_reduction << std::endl; + } + return inner2dReductionHeuristic( + total_reduction_numel, + total_iteration_numel, + total_reduction_numel, + n_tensor_inputs, + max_input_dtype_size, + vectorize_factor); + } + } + + return rparams; +} + +std::unique_ptr inner3dReductionHeuristic( const int64_t total_reduction_numel, const int64_t total_iteration_numel, const int64_t inner_most_dimension_numel, @@ -492,7 +934,7 @@ std::unique_ptr innerReductionHeuristic( << (rparams->unroll_factor_inner_reduction > 1) << ", " << rparams->cross_grid_inner_reduction << std::endl; } - return innerReductionHeuristic( + return inner2dReductionHeuristic( total_reduction_numel, total_iteration_numel, total_reduction_numel, @@ -1052,13 +1494,24 @@ std::unique_ptr reductionHeuristic( const int64_t max_input_dtype_size, const size_t vectorize_factor) { if (fastest_dim_reduction) { - return innerReductionHeuristic( - total_reduction_numel, - total_iteration_numel, - inner_most_dimension_numel, - (int64_t)n_tensor_inputs, - (int64_t)max_input_dtype_size, - vectorize_factor); + if (total_reduction_numel == inner_most_dimension_numel) { + return inner2dReductionHeuristic( + total_reduction_numel, + total_iteration_numel, + inner_most_dimension_numel, + (int64_t)n_tensor_inputs, + (int64_t)max_input_dtype_size, + vectorize_factor); + } else { + return inner3dReductionHeuristic( + total_reduction_numel, + total_iteration_numel, + inner_most_dimension_numel, + (int64_t)n_tensor_inputs, + (int64_t)max_input_dtype_size, + vectorize_factor); + } + } else { // 3D schedules not enabled for outer reductions return outerReductionHeuristic(