Skip to content

Commit

Permalink
EmbeddingBackward exclusive_scan thrust->cub (pytorch#66566)
Browse files Browse the repository at this point in the history
Summary:
Fixes #{issue number}

Pull Request resolved: pytorch#66566

Reviewed By: H-Huang

Differential Revision: D31637660

Pulled By: ngimel

fbshipit-source-id: 8093432bb9a9b902bb6bab7da221f0bcd7e9fb34
  • Loading branch information
zasdfgbnm authored and facebook-github-bot committed Oct 15, 2021
1 parent bd25f92 commit b5b7d6a
Show file tree
Hide file tree
Showing 9 changed files with 46 additions and 37 deletions.
4 changes: 2 additions & 2 deletions aten/src/ATen/cuda/cub.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ inline int get_num_bits(uint64_t max_key) {
}

template<typename key_t>
static inline void sort_keys(
static inline void radix_sort_keys(
const key_t *keys_in, key_t *keys_out,
int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8
) {
Expand All @@ -126,7 +126,7 @@ static inline void sort_keys(
}

template<typename key_t, typename value_t>
static inline void sort_pairs(
static inline void radix_sort_pairs(
const key_t *keys_in, key_t *keys_out,
const value_t *values_in, value_t *values_out,
int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/Embedding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () {
auto range = at::arange(num_indices, indices.options());
int64_t nbits = cuda::cub::get_num_bits(num_weights);
cuda::cub::sort_pairs(
cuda::cub::radix_sort_pairs(
indices.data_ptr<index_t>(), sorted_indices.data_ptr<index_t>(),
range.data_ptr<index_t>(), orig_indices.data_ptr<index_t>(),
num_indices, false/*, 0, nbits*/);
Expand Down
37 changes: 11 additions & 26 deletions aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
#include <ATen/ATen.h>
#include <ATen/cuda/Atomic.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/ThrustAllocator.h>
#include <ATen/cuda/cub.cuh>
#include <ATen/TensorUtils.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/cuda/SortingCommon.cuh>

#include <ATen/AccumulateType.h>

#include <thrust/device_ptr.h>
#include <thrust/execution_policy.h>
#include <thrust/unique.h>

#include <c10/macros/Macros.h>

namespace at {
Expand Down Expand Up @@ -179,6 +175,9 @@ __global__ void sum_and_scatter(

} // anon namespace

template<typename index_t>
int64_t embedding_backward_cuda_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets);

Tensor embedding_backward_cuda_kernel(
const Tensor &grad,
const Tensor &orig_indices,
Expand All @@ -192,8 +191,6 @@ Tensor embedding_backward_cuda_kernel(
const Tensor &per_sample_weights) {

auto stream = at::cuda::getCurrentCUDAStream();
at::cuda::ThrustAllocator allocator;
auto policy = thrust::cuda::par(allocator).on(stream);
const ptrdiff_t numel = sorted_indices.numel();

auto grad_weight = at::zeros({num_weights, grad.size(-1)}, grad.options());
Expand All @@ -205,20 +202,7 @@ Tensor embedding_backward_cuda_kernel(
// Unit: index in `sorted_indices` and `orig_indices`
AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () {
auto segment_offsets = at::empty({numel}, orig_indices.options());
int64_t num_of_segments;
{
auto sorted_indices_dev = thrust::device_ptr<index_t>(sorted_indices.data_ptr<index_t>());
auto dummy = at::empty_like(sorted_indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto dummy_dev = thrust::device_ptr<index_t>(dummy.data_ptr<index_t>());
auto ends = thrust::unique_by_key_copy(
policy,
sorted_indices_dev,
sorted_indices_dev + numel,
thrust::make_counting_iterator(0),
dummy_dev,
thrust::device_ptr<index_t>(segment_offsets.data_ptr<index_t>()));
num_of_segments = thrust::get<0>(ends) - dummy_dev;
}
int64_t num_of_segments = embedding_backward_cuda_kernel_unique_by_key<index_t>(sorted_indices, segment_offsets);

// We split the segments up into sizes of `NROWS_PER_THREAD`
// Compute the number partial-segments per segment (some partial-segments
Expand All @@ -238,11 +222,12 @@ Tensor embedding_backward_cuda_kernel(
// start position of each _segment_ in `partial_segment_offset`.
// Unit: index in `partial_segment_offset`
auto partials_per_segment_offset = at::empty({num_of_segments}, orig_indices.options());
thrust::exclusive_scan(
policy,
thrust::device_ptr<index_t>(partials_per_segment.data_ptr<index_t>()),
thrust::device_ptr<index_t>(partials_per_segment.data_ptr<index_t>()+num_of_segments),
thrust::device_ptr<index_t>(partials_per_segment_offset.data_ptr<index_t>()));
cuda::cub::exclusive_scan(
partials_per_segment.data_ptr<index_t>(),
partials_per_segment_offset.data_ptr<index_t>(),
cub::Sum(),
index_t(0),
num_of_segments);

// The total number of partial-segments is the sum of `partials_per_segment_offset`
const int num_of_partial_segments = partials_per_segment[num_of_segments-1].item<index_t>() +
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/EmbeddingBag.cu
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ Tensor embedding_bag_backward_cuda_sum_avg(
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () {
auto range = at::arange(num_indices, indices.options());
int64_t nbits = cuda::cub::get_num_bits(num_weights);
cuda::cub::sort_pairs(
cuda::cub::radix_sort_pairs(
indices.data_ptr<index_t>(), sorted_indices.data_ptr<index_t>(),
range.data_ptr<index_t>(), orig_indices.data_ptr<index_t>(),
num_indices, false/*, 0, nbits*/);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/Indexing.cu
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<c10::optional<Ten
// linearIndex can not be negative, and we take advantage of this
// fact to sort on less bits for better performance.
int64_t nbits = cuda::cub::get_num_bits(largestIndex(self) / sliceSize);
cuda::cub::sort_pairs(
cuda::cub::radix_sort_pairs(
linearIndex.data_ptr<int64_t>(), sorted_indices.data_ptr<int64_t>(),
range.data_ptr<int64_t>(), orig_indices.data_ptr<int64_t>(),
num_indices, false, 0, nbits);
Expand Down
24 changes: 24 additions & 0 deletions aten/src/ATen/native/cuda/LegacyThrustHelpers.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,28 @@ void embedding_dense_backward_cuda_scan<int>(Tensor &sorted_indices, Tensor &cou
template
void embedding_dense_backward_cuda_scan<int64_t>(Tensor &sorted_indices, Tensor &count);

template<typename index_t>
int64_t embedding_backward_cuda_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets) {
auto stream = at::cuda::getCurrentCUDAStream();
at::cuda::ThrustAllocator allocator;
auto policy = thrust::cuda::par(allocator).on(stream);
const ptrdiff_t numel = sorted_indices.numel();
auto sorted_indices_dev = thrust::device_ptr<index_t>(sorted_indices.data_ptr<index_t>());
auto dummy = at::empty_like(sorted_indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto dummy_dev = thrust::device_ptr<index_t>(dummy.data_ptr<index_t>());
auto ends = thrust::unique_by_key_copy(
policy,
sorted_indices_dev,
sorted_indices_dev + numel,
thrust::make_counting_iterator(0),
dummy_dev,
thrust::device_ptr<index_t>(segment_offsets.data_ptr<index_t>()));
return thrust::get<0>(ends) - dummy_dev;
}

template
int64_t embedding_backward_cuda_kernel_unique_by_key<int>(const Tensor &sorted_indices, Tensor &segment_offsets);
template
int64_t embedding_backward_cuda_kernel_unique_by_key<int64_t>(const Tensor &sorted_indices, Tensor &segment_offsets);

}}
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cuda/Randperm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ Tensor& randperm_out_cuda(int64_t n, c10::optional<Generator> generator, Tensor&
using dtype = OpaqueType<sizeof(scalar_t)>;
auto shuffled_data_ = reinterpret_cast<dtype*>(shuffled_data);
dtype* range_data = reinterpret_cast<dtype*>(range.data_ptr());
at::cuda::cub::sort_pairs<int, dtype>(
at::cuda::cub::radix_sort_pairs<int, dtype>(
keys.data_ptr<int>(), keys_out,
range_data, shuffled_data_,
n, false, 0, bits);
Expand All @@ -103,7 +103,7 @@ Tensor& randperm_out_cuda(int64_t n, c10::optional<Generator> generator, Tensor&
using dtype = OpaqueType<sizeof(scalar_t)>;
auto shuffled_data_ = reinterpret_cast<dtype*>(shuffled_data);
dtype* range_data = reinterpret_cast<dtype*>(range.data_ptr());
at::cuda::cub::sort_pairs<int64_t, dtype>(
at::cuda::cub::radix_sort_pairs<int64_t, dtype>(
keys.data_ptr<int64_t>(), keys_out,
range_data, shuffled_data_,
n, false, 0, bits);
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cuda/Sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -271,14 +271,14 @@ inline void segmented_sort_pairs_by_full_sort(
auto indices_and_segment2 = at::empty_like(indices_and_segment);
auto i_s_ptr2 = reinterpret_cast<int2 *>(indices_and_segment2.data_ptr<int>());

at::cuda::cub::sort_pairs<scalar_t, int2>(
at::cuda::cub::radix_sort_pairs<scalar_t, int2>(
self_ptr, nullptr, i_s_ptr, i_s_ptr2,
n, descending);

TORCH_INTERNAL_ASSERT(segment_bits <= 32);

// sort on lower 32bits, i.e. segment index
at::cuda::cub::sort_keys<int64_t>(
at::cuda::cub::radix_sort_keys<int64_t>(
reinterpret_cast<int64_t *>(i_s_ptr2), reinterpret_cast<int64_t *>(i_s_ptr),
n, false, 0, segment_bits);

Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cuda/UniqueCub.cu
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,13 @@ std::tuple<Tensor, Tensor, Tensor> unique_cuda_template(
Tensor sorted_indices;
if (!return_inverse) {
if (!consecutive) {
cuda::cub::sort_keys(self_c.data_ptr<scalar_t>(), sorted_data, num_inp);
cuda::cub::radix_sort_keys(self_c.data_ptr<scalar_t>(), sorted_data, num_inp);
}
} else {
if (!consecutive) {
Tensor range = at::arange(0, num_inp, options);
sorted_indices = at::empty({num_inp}, options);
cuda::cub::sort_pairs(
cuda::cub::radix_sort_pairs(
self_c.data_ptr<scalar_t>(),
sorted_data,
range.data_ptr<int64_t>(),
Expand Down

0 comments on commit b5b7d6a

Please sign in to comment.