Skip to content

Commit

Permalink
Sort: Use cub::WarpMergeSort for small sorts (32 < n <= 128) (pytorch…
Browse files Browse the repository at this point in the history
…#96223)

We currently use `bitonicSortKVInplace` for sorts of size `n <= 32`
but use `radixSortKVInplace` for `32 < n <= 4096`. Bitonic sort is
also unstable, which forces stable sorts fall back to which is up to
4x slower in this small regime.

This PR adds a new kernel `warpMergeSortKVInplace` using
`cub::WarpMergeSort` to implement sorts with `32 < n <= 128` and all
stable sorts with `n < 128`. This results in up to a 2x speedup for
unstable sorts and up to 15x for stable sorts, depending on the input
geometry.

This also doesn't increase the total number of kernels since we are
replacing radix-sorts of size 32 and 128.
Pull Request resolved: pytorch#96223
Approved by: https://github.com/ngimel
  • Loading branch information
peterbell10 authored and pytorchmergebot committed Mar 23, 2023
1 parent 3b54592 commit 5d8c7e7
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 12 deletions.
6 changes: 6 additions & 0 deletions aten/src/ATen/cuda/DeviceUtils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ __device__ __forceinline__ unsigned int ACTIVE_MASK()
#endif
}

__device__ __forceinline__ void WARP_SYNC(unsigned mask = 0xffffffff) {
#if !defined(USE_ROCM)
return __syncwarp(mask);
#endif
}

#if defined(USE_ROCM)
__device__ __forceinline__ unsigned long long int WARP_BALLOT(int predicate)
{
Expand Down
109 changes: 97 additions & 12 deletions aten/src/ATen/native/cuda/Sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/cuda/NumericLimits.cuh>
#include <ATen/native/cuda/SortUtils.cuh>
#include <ATen/native/cuda/SortingCommon.cuh>

Expand All @@ -28,12 +29,21 @@ static int minimum_grid_for_occupancy(T kernel, int max_block_size) {
return minGridSize;
}

// For very small sorts, use bitonicSortKVInPlace which performs
// better because it can sort multiple arrays within the same block of
// threads, improving occupancy.
//
// TODO: cub in CUDA 11.6 has a WarpMergeSort primitive that could
// replace the bitonic sort here.
template <typename T>
constexpr bool has_nan() {
if constexpr (std::numeric_limits<T>::is_specialized) {
return std::numeric_limits<T>::has_quiet_NaN;
} else if constexpr (
c10::is_complex<T>::value ||
std::is_same_v<T, c10::BFloat16> ||
std::is_same_v<T, c10::Half>) {
return true;
}
}

// For very small unstable sorts (n <= 32), use bitonicSortKVInPlace
// which can sort multiple arrays within the same block of threads,
// improving occupancy.
struct SmallBitonicSort {
template <int A, typename K, typename V, typename IndexType>
void sort(
Expand Down Expand Up @@ -94,8 +104,79 @@ struct SmallBitonicSort {
}
};

// For medium sizes (32 < n <= 4096) use radixSortKVInplace for better
// performance than the bitonic sort kernel.
// For small sorts (n <= 128) we use warpMergeSortKVInPlace which
// sorts one slice per warp and potentially multiple slices in the
// same block for improved occupancy with large batch sizes.
template <int sort_size>
struct WarpMergeSort {

template <int A, typename K, typename V, typename IndexType>
void sort(
at::cuda::detail::TensorInfo<K, IndexType> keyInfo,
IndexType keySlices,
IndexType keySliceSize,
IndexType keySliceStride,
at::cuda::detail::TensorInfo<V, IndexType> valueInfo,
IndexType valueSliceStride,
bool descending) {
constexpr int max_block_y = 16;
const int block_x = at::cuda::warp_size();

TORCH_INTERNAL_ASSERT(keySliceSize <= sort_size);

// Scale batch size down if the grid would be too small
const auto min_grid = minimum_grid_for_occupancy(
warpMergeSortKVInPlace<
A, -1, sort_size, max_block_y,
K, V, LTOp<K, true>, IndexType>,
block_x * max_block_y);
const auto max_batch = std::max(IndexType{1}, keySlices / min_grid);
const int block_y = std::min(IndexType(max_block_y), max_batch);
dim3 block(block_x, block_y);

dim3 grid;
const int grid_count = (keySlices + block_y - 1) / block_y;
TORCH_INTERNAL_ASSERT(getGridFromTiles(grid_count, grid),
"Too many slices to sort");
const auto stream = at::cuda::getCurrentCUDAStream();

if (descending) {
const K invalid_key = at::numeric_limits<K>::lower_bound();
warpMergeSortKVInPlace<A, -1, sort_size, max_block_y>
<<<grid, block, 0, stream>>>(
keyInfo,
keySlices,
keySliceSize,
keySliceStride,
valueInfo,
valueSliceStride,
GTOp<K, true>(),
invalid_key);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
const K invalid_key = []{
// NAN is sorted after inf
if constexpr(has_nan<K>()) {
return K(NAN);
}
return at::numeric_limits<K>::upper_bound();
}();
warpMergeSortKVInPlace<A, -1, sort_size, max_block_y>
<<<grid, block, 0, stream>>>(
keyInfo,
keySlices,
keySliceSize,
keySliceStride,
valueInfo,
valueSliceStride,
LTOp<K, true>(),
invalid_key);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
};

// For medium sizes (128 < n <= 4096) use radixSortKVInplace.
struct MediumRadixSort {

template <int A, typename K, typename V, typename IndexType>
Expand Down Expand Up @@ -134,14 +215,13 @@ struct MediumRadixSort {
break;
case 128:
case 64:
HANDLE_CASE(128, 4);
break;
case 32:
case 16:
case 8:
case 4:
case 2:
HANDLE_CASE(32, 2);
TORCH_INTERNAL_ASSERT(
false, "Expected size <= 128 to be handled by a different algorithm");
break;
case 1:
/* Nothing to do, data already sorted */
Expand Down Expand Up @@ -272,9 +352,14 @@ void sortKeyValueInplace(
int dim,
bool descending,
bool stable) {
if (!stable && key.size(dim) <= 32) {
const auto sort_size = key.size(dim);
if (sort_size <= 1) {
return; // Already sorted
} else if (!stable && sort_size <= 32) {
// NOTE: Bitonic sort is unstable
sortCommon(SmallBitonicSort{}, key, value, dim, descending);
} else if (sort_size <= 128) {
sortCommon(WarpMergeSort<128>{}, key, value, dim, descending);
} else {
sortCommon(MediumRadixSort{}, key, value, dim, descending);
}
Expand Down
84 changes: 84 additions & 0 deletions aten/src/ATen/native/cuda/SortUtils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <ATen/cuda/cub.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/DeviceUtils.cuh>
#include <ATen/native/cuda/SortingCommon.cuh>
#include <ATen/native/cuda/Sort.h>
#include <ATen/native/StridedRandomAccessor.h>
Expand Down Expand Up @@ -153,6 +154,89 @@ bitonicSortKVInPlace(at::cuda::detail::TensorInfo<K, IndexType> keys,
}
}

template <int KeyDims, int ValueDims, int sort_size, int max_block_dim_y,
typename K, typename V, typename Comparator, typename IndexType>
C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE * max_block_dim_y)
__global__ void
warpMergeSortKVInPlace(
at::cuda::detail::TensorInfo<K, IndexType> keys,
IndexType keySlices,
IndexType keySliceSize,
IndexType keySliceStride,
at::cuda::detail::TensorInfo<V, IndexType> values,
IndexType valueSliceStride,
Comparator comp,
K invalid_key) {
// Find the slice of the tensor that we are sorting
// NOTE: blockDim.y may be less max_block_dim_y
const IndexType blockIndex = getLinearBlockId<IndexType>();
const IndexType linearIndex = blockIndex * blockDim.y + threadIdx.y;

// If this row is out of bounds exit early
if (linearIndex >= keySlices) {
return;
}

const IndexType keyStartOffset =
at::cuda::detail::IndexToOffset<K, IndexType, KeyDims>::get(linearIndex, keys);
const IndexType valueStartOffset =
at::cuda::detail::IndexToOffset<V, IndexType, ValueDims>::get(linearIndex, values);

K *keys_slice = &keys.data[keyStartOffset];
V *values_slice = &values.data[valueStartOffset];

StridedRandomAccessor<K, IndexType> keys_iter(keys_slice, keySliceStride);
StridedRandomAccessor<V, IndexType> values_iter(values_slice, valueSliceStride);

namespace cub = ROCM_HIPCUB(at_cuda_detail::cub);

assert(blockDim.x == C10_WARP_SIZE);
assert(blockDim.y <= max_block_dim_y);
constexpr int items_per_thread = sort_size / C10_WARP_SIZE;
static_assert(
items_per_thread * C10_WARP_SIZE == sort_size,
"sort_size must be a multiple of C10_WARP_SIZE");


using LoadKeys = cub::WarpLoad<K, items_per_thread, cub::WARP_LOAD_TRANSPOSE>;
using LoadValues = cub::WarpLoad<V, items_per_thread, cub::WARP_LOAD_TRANSPOSE>;
using Sort = cub::WarpMergeSort<K, items_per_thread, C10_WARP_SIZE, V>;
using StoreKeys = cub::WarpStore<K, items_per_thread, cub::WARP_STORE_TRANSPOSE>;
using StoreValues = cub::WarpStore<V, items_per_thread, cub::WARP_STORE_TRANSPOSE>;

__shared__ union {
typename LoadKeys::TempStorage load_keys;
typename LoadValues::TempStorage load_values;
typename Sort::TempStorage sort;
typename StoreKeys::TempStorage store_keys;
typename StoreValues::TempStorage store_values;
} tmp_storage[max_block_dim_y];

auto& warp_storage = tmp_storage[threadIdx.y];

// Load inputs
K local_keys[items_per_thread];
V local_values[items_per_thread];

const auto invalid_value = V{};
LoadKeys(warp_storage.load_keys).Load(keys_iter, local_keys, keySliceSize, invalid_key);
WARP_SYNC();
LoadValues(warp_storage.load_values).Load(values_iter, local_values, keySliceSize, invalid_value);
WARP_SYNC();

// Sort! We use stable sort to ensure that invalid values are never
// sorted before valid values. In testing it performed the same as
// .Sort, so there is no down-side.
Sort(warp_storage.sort).StableSort(
local_keys, local_values, comp, keySliceSize, invalid_key);
WARP_SYNC();

// Store outputs
StoreKeys(warp_storage.store_keys).Store(keys_iter, local_keys, keySliceSize);
WARP_SYNC();
StoreValues(warp_storage.store_values).Store(values_iter, local_values, keySliceSize);
}

template <int KeyDims, int ValueDims,
int block_size, int items_per_thread,
typename K, typename V, typename IndexType>
Expand Down

0 comments on commit 5d8c7e7

Please sign in to comment.