forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Sort: Use cub::WarpMergeSort for small sorts (32 < n <= 128) (pytorch…
…#96223) We currently use `bitonicSortKVInplace` for sorts of size `n <= 32` but use `radixSortKVInplace` for `32 < n <= 4096`. Bitonic sort is also unstable, which forces stable sorts fall back to which is up to 4x slower in this small regime. This PR adds a new kernel `warpMergeSortKVInplace` using `cub::WarpMergeSort` to implement sorts with `32 < n <= 128` and all stable sorts with `n < 128`. This results in up to a 2x speedup for unstable sorts and up to 15x for stable sorts, depending on the input geometry. This also doesn't increase the total number of kernels since we are replacing radix-sorts of size 32 and 128. Pull Request resolved: pytorch#96223 Approved by: https://github.com/ngimel
- Loading branch information
1 parent
3b54592
commit 5d8c7e7
Showing
3 changed files
with
187 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters