Skip to content

Commit

Permalink
[GraphBolt][CUDA] Reduce and hide unique_and_compact synchronizatio…
Browse files Browse the repository at this point in the history
…ns. (dmlc#6841)
  • Loading branch information
mfbalin authored Dec 28, 2023
1 parent 0856913 commit 22a2513
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 54 deletions.
27 changes: 21 additions & 6 deletions graphbolt/src/cuda/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,26 +104,41 @@ inline bool is_zero<dim3>(dim3 size) {
*/
template <typename scalar_t>
struct CopyScalar {
CopyScalar(const scalar_t* device_ptr) : is_ready_(false) {
pinned_scalar_ = torch::empty(
sizeof(scalar_t),
c10::TensorOptions().dtype(torch::kBool).pinned_memory(true));
CopyScalar() : is_ready_(true) { init_pinned_storage(); }

void record(at::cuda::CUDAStream stream = GetCurrentStream()) {
copy_event_.record(stream);
is_ready_ = false;
}

scalar_t* get() {
return reinterpret_cast<scalar_t*>(pinned_scalar_.data_ptr());
}

CopyScalar(const scalar_t* device_ptr) {
init_pinned_storage();
auto stream = GetCurrentStream();
CUDA_CALL(cudaMemcpyAsync(
reinterpret_cast<scalar_t*>(pinned_scalar_.data_ptr()), device_ptr,
sizeof(scalar_t), cudaMemcpyDeviceToHost, stream));
copy_event_.record(stream);
record(stream);
}

operator scalar_t() {
if (!is_ready_) {
copy_event_.synchronize();
is_ready_ = true;
}
return reinterpret_cast<scalar_t*>(pinned_scalar_.data_ptr())[0];
return *get();
}

private:
void init_pinned_storage() {
pinned_scalar_ = torch::empty(
sizeof(scalar_t),
c10::TensorOptions().dtype(torch::kBool).pinned_memory(true));
}

torch::Tensor pinned_scalar_;
at::cuda::CUDAEvent copy_event_;
bool is_ready_;
Expand Down
129 changes: 81 additions & 48 deletions graphbolt/src/cuda/unique_and_compact_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <thrust/remove.h>

#include <cub/cub.cuh>
#include <type_traits>

#include "./common.h"
#include "./utils.h"
Expand All @@ -32,6 +33,24 @@ struct EqualityFunc {
}
};

#define DefineReductionFunction(reduce_fn, name) \
template <typename scalar_iterator_t> \
auto name(const scalar_iterator_t input, int64_t size) { \
auto allocator = cuda::GetAllocator(); \
auto stream = cuda::GetCurrentStream(); \
using scalar_t = std::remove_reference_t<decltype(input[0])>; \
cuda::CopyScalar<scalar_t> result; \
size_t workspace_size = 0; \
reduce_fn(nullptr, workspace_size, input, result.get(), size, stream); \
auto tmp_storage = allocator.AllocateStorage<char>(workspace_size); \
reduce_fn( \
tmp_storage.get(), workspace_size, input, result.get(), size, stream); \
return result; \
}

DefineReductionFunction(cub::DeviceReduce::Max, Max);
DefineReductionFunction(cub::DeviceReduce::Min, Min);

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
const torch::Tensor src_ids, const torch::Tensor dst_ids,
const torch::Tensor unique_dst_ids, int num_bits) {
Expand All @@ -48,17 +67,13 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
auto dst_ids_ptr = dst_ids.data_ptr<scalar_t>();
auto unique_dst_ids_ptr = unique_dst_ids.data_ptr<scalar_t>();

// If the given num_bits argument is not in the reasonable range,
// we recompute it to speedup the expensive sort operations.
if (num_bits <= 0 || num_bits > sizeof(scalar_t) * 8) {
auto max_id = thrust::reduce(
exec_policy, src_ids_ptr, src_ids_ptr + src_ids.size(0),
static_cast<scalar_t>(0), thrust::maximum<scalar_t>{});
max_id = thrust::reduce(
exec_policy, unique_dst_ids_ptr,
unique_dst_ids_ptr + unique_dst_ids.size(0), max_id,
thrust::maximum<scalar_t>{});
num_bits = cuda::NumberOfBits(max_id + 1);
// If num_bits is not given, compute maximum vertex ids to compute
// num_bits later to speedup the expensive sort operations.
cuda::CopyScalar<scalar_t> max_id_src;
cuda::CopyScalar<scalar_t> max_id_dst;
if (num_bits == 0) {
max_id_src = Max(src_ids_ptr, src_ids.size(0));
max_id_dst = Max(unique_dst_ids_ptr, unique_dst_ids.size(0));
}

// Sort the unique_dst_ids tensor.
Expand All @@ -78,41 +93,55 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
auto only_src =
torch::empty(src_ids.size(0), sorted_unique_dst_ids.options());
{
auto only_src_size =
thrust::remove_copy_if(
exec_policy, src_ids_ptr, src_ids_ptr + src_ids.size(0),
is_dst.get(), only_src.data_ptr<scalar_t>(),
thrust::identity<bool>{}) -
only_src.data_ptr<scalar_t>();
only_src = only_src.slice(0, 0, only_src_size);
auto is_src = thrust::make_transform_iterator(
is_dst.get(), thrust::logical_not<bool>{});
cuda::CopyScalar<int64_t> only_src_size;
size_t workspace_size = 0;
cub::DeviceSelect::Flagged(
nullptr, workspace_size, src_ids_ptr, is_src,
only_src.data_ptr<scalar_t>(), only_src_size.get(),
src_ids.size(0), stream);
auto tmp_storage = allocator.AllocateStorage<char>(workspace_size);
cub::DeviceSelect::Flagged(
tmp_storage.get(), workspace_size, src_ids_ptr, is_src,
only_src.data_ptr<scalar_t>(), only_src_size.get(),
src_ids.size(0), stream);
stream.synchronize();
only_src = only_src.slice(0, 0, static_cast<int64_t>(only_src_size));
}

// Sort the only_src tensor so that we can unique it with Encode
// operation later.
// The code block above synchronizes, ensuring safe access to max_id_src
// and max_id_dst.
if (num_bits == 0) {
num_bits = cuda::NumberOfBits(
1 + std::max(
static_cast<scalar_t>(max_id_src),
static_cast<scalar_t>(max_id_dst)));
}

// Sort the only_src tensor so that we can unique it later.
auto sorted_only_src = Sort<false>(
only_src.data_ptr<scalar_t>(), only_src.size(0), num_bits);

auto unique_only_src =
torch::empty(only_src.size(0), src_ids.options());
auto unique_only_src_ptr = unique_only_src.data_ptr<scalar_t>();
auto unique_only_src_cnt = allocator.AllocateStorage<scalar_t>(1);

{ // Compute the unique operation on the only_src tensor.
size_t workspace_size;
CUDA_CALL(cub::DeviceRunLengthEncode::Encode(
cuda::CopyScalar<int64_t> unique_only_src_size;
size_t workspace_size = 0;
CUDA_CALL(cub::DeviceSelect::Unique(
nullptr, workspace_size, sorted_only_src.data_ptr<scalar_t>(),
unique_only_src_ptr, cub::DiscardOutputIterator{},
unique_only_src_cnt.get(), only_src.size(0), stream));
auto temp = allocator.AllocateStorage<char>(workspace_size);
CUDA_CALL(cub::DeviceRunLengthEncode::Encode(
temp.get(), workspace_size, sorted_only_src.data_ptr<scalar_t>(),
unique_only_src_ptr, cub::DiscardOutputIterator{},
unique_only_src_cnt.get(), only_src.size(0), stream));

auto unique_only_src_size =
cuda::CopyScalar(unique_only_src_cnt.get());
unique_only_src_ptr, unique_only_src_size.get(), only_src.size(0),
stream));
auto tmp_storage = allocator.AllocateStorage<char>(workspace_size);
CUDA_CALL(cub::DeviceSelect::Unique(
tmp_storage.get(), workspace_size,
sorted_only_src.data_ptr<scalar_t>(), unique_only_src_ptr,
unique_only_src_size.get(), only_src.size(0), stream));
stream.synchronize();
unique_only_src = unique_only_src.slice(
0, 0, static_cast<scalar_t>(unique_only_src_size));
0, 0, static_cast<int64_t>(unique_only_src_size));
}

auto real_order = torch::cat({unique_dst_ids, unique_only_src});
Expand All @@ -123,39 +152,43 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
// Holds the found locations of the src and dst ids in the sorted_order.
// Later is used to lookup the new ids of the src_ids and dst_ids
// tensors.
auto new_src_ids_loc =
allocator.AllocateStorage<scalar_t>(src_ids.size(0));
auto new_dst_ids_loc =
allocator.AllocateStorage<scalar_t>(dst_ids.size(0));
thrust::lower_bound(
exec_policy, sorted_order_ptr,
sorted_order_ptr + sorted_order.size(0), src_ids_ptr,
src_ids_ptr + src_ids.size(0), new_src_ids_loc.get());
thrust::lower_bound(
exec_policy, sorted_order_ptr,
sorted_order_ptr + sorted_order.size(0), dst_ids_ptr,
dst_ids_ptr + dst_ids.size(0), new_dst_ids_loc.get());
{ // Check if unique_dst_ids includes all dst_ids.

cuda::CopyScalar<bool> all_exist;
// Check if unique_dst_ids includes all dst_ids.
if (dst_ids.size(0) > 0) {
thrust::counting_iterator<int64_t> iota(0);
auto equal_it = thrust::make_transform_iterator(
iota, EqualityFunc<scalar_t>{
sorted_order_ptr, new_dst_ids_loc.get(), dst_ids_ptr});
auto all_exist = thrust::all_of(
exec_policy, equal_it, equal_it + dst_ids.size(0),
thrust::identity<bool>());
if (!all_exist) {
throw std::out_of_range("Some ids not found.");
}
all_exist = Min(equal_it, dst_ids.size(0));
all_exist.record();
}

auto new_src_ids_loc =
allocator.AllocateStorage<scalar_t>(src_ids.size(0));
thrust::lower_bound(
exec_policy, sorted_order_ptr,
sorted_order_ptr + sorted_order.size(0), src_ids_ptr,
src_ids_ptr + src_ids.size(0), new_src_ids_loc.get());

// Finally, lookup the new compact ids of the src and dst tensors via
// gather operations.
auto new_src_ids = torch::empty_like(src_ids);
auto new_dst_ids = torch::empty_like(dst_ids);
thrust::gather(
exec_policy, new_src_ids_loc.get(),
new_src_ids_loc.get() + src_ids.size(0),
new_ids.data_ptr<int64_t>(), new_src_ids.data_ptr<scalar_t>());
// Perform check before we gather for the dst indices.
if (dst_ids.size(0) > 0 && !static_cast<bool>(all_exist)) {
throw std::out_of_range("Some ids not found.");
}
auto new_dst_ids = torch::empty_like(dst_ids);
thrust::gather(
exec_policy, new_dst_ids_loc.get(),
new_dst_ids_loc.get() + dst_ids.size(0),
Expand Down

0 comments on commit 22a2513

Please sign in to comment.