From 5d8c7e7ea47cb6e1faf333430889a804de87536e Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Thu, 23 Mar 2023 00:38:40 +0000 Subject: [PATCH] Sort: Use cub::WarpMergeSort for small sorts (32 < n <= 128) (#96223) We currently use `bitonicSortKVInplace` for sorts of size `n <= 32` but use `radixSortKVInplace` for `32 < n <= 4096`. Bitonic sort is also unstable, which forces stable sorts fall back to which is up to 4x slower in this small regime. This PR adds a new kernel `warpMergeSortKVInplace` using `cub::WarpMergeSort` to implement sorts with `32 < n <= 128` and all stable sorts with `n < 128`. This results in up to a 2x speedup for unstable sorts and up to 15x for stable sorts, depending on the input geometry. This also doesn't increase the total number of kernels since we are replacing radix-sorts of size 32 and 128. Pull Request resolved: https://github.com/pytorch/pytorch/pull/96223 Approved by: https://github.com/ngimel --- aten/src/ATen/cuda/DeviceUtils.cuh | 6 ++ aten/src/ATen/native/cuda/Sort.cu | 109 +++++++++++++++++++++--- aten/src/ATen/native/cuda/SortUtils.cuh | 84 ++++++++++++++++++ 3 files changed, 187 insertions(+), 12 deletions(-) diff --git a/aten/src/ATen/cuda/DeviceUtils.cuh b/aten/src/ATen/cuda/DeviceUtils.cuh index dc17aa80ca84b1..c0a2fc47c0069b 100644 --- a/aten/src/ATen/cuda/DeviceUtils.cuh +++ b/aten/src/ATen/cuda/DeviceUtils.cuh @@ -14,6 +14,12 @@ __device__ __forceinline__ unsigned int ACTIVE_MASK() #endif } +__device__ __forceinline__ void WARP_SYNC(unsigned mask = 0xffffffff) { +#if !defined(USE_ROCM) + return __syncwarp(mask); +#endif +} + #if defined(USE_ROCM) __device__ __forceinline__ unsigned long long int WARP_BALLOT(int predicate) { diff --git a/aten/src/ATen/native/cuda/Sort.cu b/aten/src/ATen/native/cuda/Sort.cu index cb66b6571c5638..30a61490cb692c 100644 --- a/aten/src/ATen/native/cuda/Sort.cu +++ b/aten/src/ATen/native/cuda/Sort.cu @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -28,12 +29,21 @@ static int minimum_grid_for_occupancy(T kernel, int max_block_size) { return minGridSize; } -// For very small sorts, use bitonicSortKVInPlace which performs -// better because it can sort multiple arrays within the same block of -// threads, improving occupancy. -// -// TODO: cub in CUDA 11.6 has a WarpMergeSort primitive that could -// replace the bitonic sort here. +template +constexpr bool has_nan() { + if constexpr (std::numeric_limits::is_specialized) { + return std::numeric_limits::has_quiet_NaN; + } else if constexpr ( + c10::is_complex::value || + std::is_same_v || + std::is_same_v) { + return true; + } +} + +// For very small unstable sorts (n <= 32), use bitonicSortKVInPlace +// which can sort multiple arrays within the same block of threads, +// improving occupancy. struct SmallBitonicSort { template void sort( @@ -94,8 +104,79 @@ struct SmallBitonicSort { } }; -// For medium sizes (32 < n <= 4096) use radixSortKVInplace for better -// performance than the bitonic sort kernel. +// For small sorts (n <= 128) we use warpMergeSortKVInPlace which +// sorts one slice per warp and potentially multiple slices in the +// same block for improved occupancy with large batch sizes. +template +struct WarpMergeSort { + + template + void sort( + at::cuda::detail::TensorInfo keyInfo, + IndexType keySlices, + IndexType keySliceSize, + IndexType keySliceStride, + at::cuda::detail::TensorInfo valueInfo, + IndexType valueSliceStride, + bool descending) { + constexpr int max_block_y = 16; + const int block_x = at::cuda::warp_size(); + + TORCH_INTERNAL_ASSERT(keySliceSize <= sort_size); + + // Scale batch size down if the grid would be too small + const auto min_grid = minimum_grid_for_occupancy( + warpMergeSortKVInPlace< + A, -1, sort_size, max_block_y, + K, V, LTOp, IndexType>, + block_x * max_block_y); + const auto max_batch = std::max(IndexType{1}, keySlices / min_grid); + const int block_y = std::min(IndexType(max_block_y), max_batch); + dim3 block(block_x, block_y); + + dim3 grid; + const int grid_count = (keySlices + block_y - 1) / block_y; + TORCH_INTERNAL_ASSERT(getGridFromTiles(grid_count, grid), + "Too many slices to sort"); + const auto stream = at::cuda::getCurrentCUDAStream(); + + if (descending) { + const K invalid_key = at::numeric_limits::lower_bound(); + warpMergeSortKVInPlace + <<>>( + keyInfo, + keySlices, + keySliceSize, + keySliceStride, + valueInfo, + valueSliceStride, + GTOp(), + invalid_key); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + const K invalid_key = []{ + // NAN is sorted after inf + if constexpr(has_nan()) { + return K(NAN); + } + return at::numeric_limits::upper_bound(); + }(); + warpMergeSortKVInPlace + <<>>( + keyInfo, + keySlices, + keySliceSize, + keySliceStride, + valueInfo, + valueSliceStride, + LTOp(), + invalid_key); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + } +}; + +// For medium sizes (128 < n <= 4096) use radixSortKVInplace. struct MediumRadixSort { template @@ -134,14 +215,13 @@ struct MediumRadixSort { break; case 128: case 64: - HANDLE_CASE(128, 4); - break; case 32: case 16: case 8: case 4: case 2: - HANDLE_CASE(32, 2); + TORCH_INTERNAL_ASSERT( + false, "Expected size <= 128 to be handled by a different algorithm"); break; case 1: /* Nothing to do, data already sorted */ @@ -272,9 +352,14 @@ void sortKeyValueInplace( int dim, bool descending, bool stable) { - if (!stable && key.size(dim) <= 32) { + const auto sort_size = key.size(dim); + if (sort_size <= 1) { + return; // Already sorted + } else if (!stable && sort_size <= 32) { // NOTE: Bitonic sort is unstable sortCommon(SmallBitonicSort{}, key, value, dim, descending); + } else if (sort_size <= 128) { + sortCommon(WarpMergeSort<128>{}, key, value, dim, descending); } else { sortCommon(MediumRadixSort{}, key, value, dim, descending); } diff --git a/aten/src/ATen/native/cuda/SortUtils.cuh b/aten/src/ATen/native/cuda/SortUtils.cuh index a1d309ce709e22..172a260da96714 100644 --- a/aten/src/ATen/native/cuda/SortUtils.cuh +++ b/aten/src/ATen/native/cuda/SortUtils.cuh @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -153,6 +154,89 @@ bitonicSortKVInPlace(at::cuda::detail::TensorInfo keys, } } +template +C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE * max_block_dim_y) +__global__ void +warpMergeSortKVInPlace( + at::cuda::detail::TensorInfo keys, + IndexType keySlices, + IndexType keySliceSize, + IndexType keySliceStride, + at::cuda::detail::TensorInfo values, + IndexType valueSliceStride, + Comparator comp, + K invalid_key) { + // Find the slice of the tensor that we are sorting + // NOTE: blockDim.y may be less max_block_dim_y + const IndexType blockIndex = getLinearBlockId(); + const IndexType linearIndex = blockIndex * blockDim.y + threadIdx.y; + + // If this row is out of bounds exit early + if (linearIndex >= keySlices) { + return; + } + + const IndexType keyStartOffset = + at::cuda::detail::IndexToOffset::get(linearIndex, keys); + const IndexType valueStartOffset = + at::cuda::detail::IndexToOffset::get(linearIndex, values); + + K *keys_slice = &keys.data[keyStartOffset]; + V *values_slice = &values.data[valueStartOffset]; + + StridedRandomAccessor keys_iter(keys_slice, keySliceStride); + StridedRandomAccessor values_iter(values_slice, valueSliceStride); + + namespace cub = ROCM_HIPCUB(at_cuda_detail::cub); + + assert(blockDim.x == C10_WARP_SIZE); + assert(blockDim.y <= max_block_dim_y); + constexpr int items_per_thread = sort_size / C10_WARP_SIZE; + static_assert( + items_per_thread * C10_WARP_SIZE == sort_size, + "sort_size must be a multiple of C10_WARP_SIZE"); + + + using LoadKeys = cub::WarpLoad; + using LoadValues = cub::WarpLoad; + using Sort = cub::WarpMergeSort; + using StoreKeys = cub::WarpStore; + using StoreValues = cub::WarpStore; + + __shared__ union { + typename LoadKeys::TempStorage load_keys; + typename LoadValues::TempStorage load_values; + typename Sort::TempStorage sort; + typename StoreKeys::TempStorage store_keys; + typename StoreValues::TempStorage store_values; + } tmp_storage[max_block_dim_y]; + + auto& warp_storage = tmp_storage[threadIdx.y]; + + // Load inputs + K local_keys[items_per_thread]; + V local_values[items_per_thread]; + + const auto invalid_value = V{}; + LoadKeys(warp_storage.load_keys).Load(keys_iter, local_keys, keySliceSize, invalid_key); + WARP_SYNC(); + LoadValues(warp_storage.load_values).Load(values_iter, local_values, keySliceSize, invalid_value); + WARP_SYNC(); + + // Sort! We use stable sort to ensure that invalid values are never + // sorted before valid values. In testing it performed the same as + // .Sort, so there is no down-side. + Sort(warp_storage.sort).StableSort( + local_keys, local_values, comp, keySliceSize, invalid_key); + WARP_SYNC(); + + // Store outputs + StoreKeys(warp_storage.store_keys).Store(keys_iter, local_keys, keySliceSize); + WARP_SYNC(); + StoreValues(warp_storage.store_values).Store(values_iter, local_values, keySliceSize); +} + template