-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
pyg::subgraph
CUDA implementation
#42
Open
rusty1s
wants to merge
11
commits into
master
Choose a base branch
from
subgraph_cuda
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
07f6a85
Update
rusty1s 058c0bc
update
rusty1s f85ff1e
update
rusty1s 4aa86b9
initial commit
rusty1s 4bb8087
update
rusty1s ef68b2b
update
rusty1s 377e261
update
rusty1s 9b5d9f7
typo
rusty1s 9b4e6e6
changelog
rusty1s bf23274
update
rusty1s 9113818
Merge branch 'master' into subgraph_cuda
ZenoTan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
#include <ATen/ATen.h> | ||
#include <torch/library.h> | ||
|
||
#include "pyg_lib/csrc/utils/cuda/helpers.h" | ||
|
||
namespace pyg { | ||
namespace sampler { | ||
|
||
namespace { | ||
|
||
template <typename scalar_t> | ||
__global__ void subgraph_deg_kernel_impl( | ||
const scalar_t* __restrict__ rowptr_data, | ||
const scalar_t* __restrict__ col_data, | ||
const scalar_t* __restrict__ nodes_data, | ||
const scalar_t* __restrict__ to_local_node_data, | ||
scalar_t* __restrict__ out_data, | ||
int64_t num_nodes) { | ||
CUDA_1D_KERNEL_LOOP(scalar_t, thread_idx, WARP * num_nodes) { | ||
scalar_t i = thread_idx / WARP; | ||
scalar_t lane = thread_idx % WARP; | ||
|
||
auto v = nodes_data[i]; | ||
|
||
scalar_t deg = 0; | ||
for (size_t j = rowptr_data[v] + lane; j < rowptr_data[v + 1]; j += WARP) { | ||
if (to_local_node_data[col_data[j]] >= 0) // contiguous access | ||
deg++; | ||
} | ||
|
||
for (size_t offset = 16; offset > 0; offset /= 2) // warp-level reduction | ||
deg += __shfl_down_sync(FULL_WARP_MASK, deg, offset); | ||
|
||
if (lane == 0) | ||
out_data[i] = deg; | ||
} | ||
} | ||
|
||
template <typename scalar_t, bool return_edge_id> | ||
__global__ void subgraph_deg_kernel_impl( | ||
const scalar_t* __restrict__ rowptr_data, | ||
const scalar_t* __restrict__ col_data, | ||
const scalar_t* __restrict__ nodes_data, | ||
const scalar_t* __restrict__ to_local_node_data, | ||
const scalar_t* __restrict__ out_rowptr_data, | ||
scalar_t* __restrict__ out_col_data, | ||
scalar_t* __restrict__ out_edge_id_data, | ||
int64_t num_nodes) { | ||
CUDA_1D_KERNEL_LOOP(scalar_t, thread_idx, WARP * num_nodes) { | ||
scalar_t i = thread_idx / WARP; | ||
scalar_t lane = thread_idx % WARP; | ||
} | ||
} | ||
|
||
std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>> subgraph_kernel( | ||
const at::Tensor& rowptr, | ||
const at::Tensor& col, | ||
const at::Tensor& nodes, | ||
const bool return_edge_id) { | ||
TORCH_CHECK(rowptr.is_cuda(), "'rowptr' must be a CUDA tensor"); | ||
TORCH_CHECK(col.is_cuda(), "'col' must be a CUDA tensor"); | ||
TORCH_CHECK(nodes.is_cuda(), "'nodes' must be a CUDA tensor"); | ||
|
||
const auto stream = at::cuda::getCurrentCUDAStream(); | ||
|
||
// We maintain a O(num_nodes) vector to map global node indices to local ones. | ||
// TODO Can we do this without O(num_nodes) storage requirement? | ||
// TODO Consider caching this tensor across consecutive runs? | ||
const auto to_local_node = nodes.new_full({rowptr.size(0) - 1}, -1); | ||
const auto arange = at::arange(nodes.size(0), nodes.options()); | ||
to_local_node.index_copy_(/*dim=*/0, nodes, arange); | ||
|
||
const auto deg = nodes.new_empty({nodes.size(0)}); | ||
const auto out_rowptr = rowptr.new_zeros({nodes.size(0) + 1}); | ||
at::Tensor out_col; | ||
c10::optional<at::Tensor> out_edge_id = c10::nullopt; | ||
|
||
AT_DISPATCH_INTEGRAL_TYPES(nodes.scalar_type(), "subgraph_kernel", [&] { | ||
const auto rowptr_data = rowptr.data_ptr<scalar_t>(); | ||
const auto col_data = col.data_ptr<scalar_t>(); | ||
const auto nodes_data = nodes.data_ptr<scalar_t>(); | ||
const auto to_local_node_data = to_local_node.data_ptr<scalar_t>(); | ||
auto deg_data = deg.data_ptr<scalar_t>(); | ||
|
||
// Compute induced subgraph degree, parallelize with 32 threads per node: | ||
subgraph_deg_kernel_impl<<<pyg::utils::blocks(WARP * nodes.size(0)), | ||
pyg::utils::threads(), 0, stream>>>( | ||
rowptr_data, col_data, nodes_data, to_local_node_data, deg_data, | ||
nodes.size(0)); | ||
|
||
auto tmp = out_rowptr.narrow(0, 1, nodes.size(0)); | ||
at::cumsum_out(tmp, deg, /*dim=*/0); | ||
|
||
subgraph_kernel_imp<<<pyg::utils::blocks(WARP * nodes.size(0)), | ||
pyg::utils::threads(), 0, stream>>>(); | ||
}); | ||
|
||
return std::make_tuple(out_rowptr, deg, deg); | ||
} | ||
|
||
} // namespace | ||
|
||
TORCH_LIBRARY_IMPL(pyg, CUDA, m) { | ||
m.impl(TORCH_SELECTIVE_NAME("pyg::subgraph"), TORCH_FN(subgraph_kernel)); | ||
} | ||
|
||
} // namespace sampler | ||
} // namespace pyg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
#pragma once | ||
|
||
#include <ATen/ATen.h> | ||
#include <ATen/cuda/CUDAContext.h> | ||
|
||
namespace pyg { | ||
namespace utils { | ||
|
||
__host__ inline int threads() { | ||
const auto props = at::cuda::getCurrentDeviceProperties(); | ||
return std::min(props->maxThreadsPerBlock, 1024); | ||
} | ||
|
||
template <typename scalar_t> | ||
__host__ inline scalar_t blocks(scalar_t numel) { | ||
const auto props = at::cuda::getCurrentDeviceProperties(); | ||
const auto blocks_per_sm = props->maxThreadsPerMultiProcessor / 256; | ||
const auto max_blocks = props->multiProcessorCount * blocks_per_sm; | ||
const auto max_threads = threads(); | ||
return std::min<scalar_t>(max_blocks, | ||
(numel + max_threads - 1) / max_threads); | ||
} | ||
|
||
#define WARP 32 | ||
#define FULL_WARP_MASK 0xffffffff | ||
|
||
#define CUDA_1D_KERNEL_LOOP(scalar_t, i, n) \ | ||
for (scalar_t i = (blockIdx.x * blockDim.x) + threadIdx.x; i < (n); \ | ||
rusty1s marked this conversation as resolved.
Show resolved
Hide resolved
|
||
i += (blockDim.x * gridDim.x)) | ||
|
||
} // namespace utils | ||
} // namespace pyg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm actually not sure if it is necessary to parallelize with 32 threads per nodes. Most of the time we are dealing with sparse data and a lot of threads will not go into for loop.
If you are looking for extreme performance, you can bundle
to_local_node_data
andcol_data
into one iterator structure and use this function. I haven't seen any better performance than it in the past.https://nvlabs.github.io/cub/structcub_1_1_device_segmented_reduce.html#a4854a13561cb66d46aa617aab16b8825
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have an example of bundling
to_local_node_data
andcol_data
into one iterator structure? This looks really interesting.I am okay with dropping the warp-level parallelism for now, but we will lose the contiguous access to
col_data
, and probably under-utilize the number of threads available on modern GPUs.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On a second look, this doesn't seem possible since
col_data
refers to edges, whileto_local_node_data
refers to nodes, while we actually want do the compute across the number of nodes in the induced subgraph.