diff --git a/CHANGELOG.md b/CHANGELOG.md index 40e84e35f..cb0660b18 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [0.3.0] - 2023-MM-DD ### Added +- Added low-level support for distributed neighborhood sampling ([#246](https://github.com/pyg-team/pyg-lib/pull/246)) - Added support for homogeneous and heterogeneous biased neighborhood sampling ([#247](https://github.com/pyg-team/pyg-lib/pull/247), [#251](https://github.com/pyg-team/pyg-lib/pull/251)) - Added dispatch for XPU device in `index_sort` ([#243](https://github.com/pyg-team/pyg-lib/pull/243)) - Added `metis` partitioning ([#229](https://github.com/pyg-team/pyg-lib/pull/229)) diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index e039ea7f2..71ea6de65 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -25,7 +25,8 @@ template + bool save_edge_ids, + bool distributed> class NeighborSampler { public: NeighborSampler(const scalar_t* rowptr, @@ -239,6 +240,16 @@ class NeighborSampler { const auto global_dst_node_value = col_[edge_id]; const auto global_dst_node = to_node_t(global_dst_node_value, global_src_node); + + // In the distributed sampling case, we do not perform any mapping: + if constexpr (distributed) { + out_global_dst_nodes.push_back(global_dst_node); + if (save_edge_ids) { + sampled_edge_ids_.push_back(edge_id); + } + return; + } + const auto res = dst_mapper.insert(global_dst_node); if (res.second) { // not yet sampled. out_global_dst_nodes.push_back(global_dst_node); @@ -266,12 +277,17 @@ class NeighborSampler { // Homogeneous neighbor sampling /////////////////////////////////////////////// -template +template std::tuple, std::vector, + std::vector, std::vector> sample(const at::Tensor& rowptr, const at::Tensor& col, @@ -302,6 +318,9 @@ sample(const at::Tensor& rowptr, c10::optional out_edge_id = c10::nullopt; std::vector num_sampled_nodes_per_hop; std::vector num_sampled_edges_per_hop; + std::vector cumsum_neighbors_per_node = + distributed ? std::vector(1, seed.size(0)) + : std::vector(); AT_DISPATCH_INTEGRAL_TYPES(seed.scalar_type(), "sample_kernel", [&] { typedef std::pair pair_scalar_t; @@ -309,7 +328,7 @@ sample(const at::Tensor& rowptr, // TODO(zeyuan): Do not force int64_t for time type. typedef int64_t temporal_t; typedef NeighborSampler + return_edge_id, distributed> NeighborSamplerImpl; pyg::random::RandintEngine generator; @@ -359,6 +378,8 @@ sample(const at::Tensor& rowptr, /*dst_mapper=*/mapper, /*generator=*/generator, /*out_global_dst_nodes=*/sampled_nodes); + if constexpr (distributed) + cumsum_neighbors_per_node.push_back(sampled_nodes.size()); } } else if (!time.has_value()) { for (size_t i = begin; i < end; ++i) { @@ -369,6 +390,8 @@ sample(const at::Tensor& rowptr, /*dst_mapper=*/mapper, /*generator=*/generator, /*out_global_dst_nodes=*/sampled_nodes); + if constexpr (distributed) + cumsum_neighbors_per_node.push_back(sampled_nodes.size()); } } else if constexpr (!std::is_scalar::value) { // Temporal: const auto time_data = time.value().data_ptr(); @@ -382,6 +405,8 @@ sample(const at::Tensor& rowptr, /*dst_mapper=*/mapper, /*generator=*/generator, /*out_global_dst_nodes=*/sampled_nodes); + if constexpr (distributed) + cumsum_neighbors_per_node.push_back(sampled_nodes.size()); } } begin = end, end = sampled_nodes.size(); @@ -400,12 +425,17 @@ sample(const at::Tensor& rowptr, }); return std::make_tuple(out_row, out_col, out_node_id, out_edge_id, - num_sampled_nodes_per_hop, num_sampled_edges_per_hop); + num_sampled_nodes_per_hop, num_sampled_edges_per_hop, + cumsum_neighbors_per_node); } // Heterogeneous neighbor sampling ///////////////////////////////////////////// -template +template std::tuple, c10::Dict, c10::Dict, @@ -472,7 +502,7 @@ sample(const std::vector& node_types, typedef std::conditional_t node_t; typedef int64_t temporal_t; typedef NeighborSampler + return_edge_id, distributed> NeighborSamplerImpl; pyg::random::RandintEngine generator; @@ -691,39 +721,73 @@ sample(const std::vector& node_types, // Dispatcher ////////////////////////////////////////////////////////////////// -#define DISPATCH_SAMPLE(replace, directed, disjount, return_edge_id, ...) \ +#define DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, ...) \ if (replace && directed && disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ + return sample(__VA_ARGS__); \ if (replace && directed && disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ + return sample(__VA_ARGS__); \ if (replace && directed && !disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ + return sample(__VA_ARGS__); \ if (replace && directed && !disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ + return sample(__VA_ARGS__); \ if (replace && !directed && disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ + return sample(__VA_ARGS__); \ if (replace && !directed && disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ + return sample(__VA_ARGS__); \ if (replace && !directed && !disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ + return sample(__VA_ARGS__); \ if (replace && !directed && !disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ + return sample(__VA_ARGS__); \ if (!replace && directed && disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ + return sample(__VA_ARGS__); \ if (!replace && directed && disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ + return sample(__VA_ARGS__); \ if (!replace && directed && !disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ + return sample(__VA_ARGS__); \ if (!replace && directed && !disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ + return sample(__VA_ARGS__); \ if (!replace && !directed && disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ + return sample(__VA_ARGS__); \ if (!replace && !directed && disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ + return sample(__VA_ARGS__); \ if (!replace && !directed && !disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ + return sample(__VA_ARGS__); \ if (!replace && !directed && !disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); + return sample(__VA_ARGS__); + +#define DISPATCH_DIST_SAMPLE(replace, directed, disjoint, return_edge_id, ...) \ + if (replace && directed && disjoint && return_edge_id) \ + return sample(__VA_ARGS__); \ + if (replace && directed && disjoint && !return_edge_id) \ + return sample(__VA_ARGS__); \ + if (replace && directed && !disjoint && return_edge_id) \ + return sample(__VA_ARGS__); \ + if (replace && directed && !disjoint && !return_edge_id) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && disjoint && return_edge_id) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && disjoint && !return_edge_id) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && !disjoint && return_edge_id) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && !disjoint && !return_edge_id) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && disjoint && return_edge_id) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && disjoint && !return_edge_id) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && !disjoint && return_edge_id) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && !disjoint && !return_edge_id) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && disjoint && return_edge_id) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && disjoint && !return_edge_id) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && !disjoint && return_edge_id) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && !disjoint && !return_edge_id) \ + return sample(__VA_ARGS__); } // namespace @@ -746,9 +810,13 @@ neighbor_sample_kernel(const at::Tensor& rowptr, bool disjoint, std::string temporal_strategy, bool return_edge_id) { - DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, rowptr, col, - seed, num_neighbors, time, seed_time, edge_weight, csc, - temporal_strategy); + const auto out = [&] { + DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, rowptr, col, + seed, num_neighbors, time, seed_time, edge_weight, csc, + temporal_strategy); + }(); + return std::make_tuple(std::get<0>(out), std::get<1>(out), std::get<2>(out), + std::get<3>(out), std::get<4>(out), std::get<5>(out)); } std::tuple, @@ -779,6 +847,59 @@ hetero_neighbor_sample_kernel( edge_weight_dict, csc, temporal_strategy); } +std::tuple, + std::vector, + std::vector, + std::vector> +dist_neighbor_sample_kernel(const at::Tensor& rowptr, + const at::Tensor& col, + const at::Tensor& seed, + const std::vector& num_neighbors, + const c10::optional& time, + const c10::optional& seed_time, + const c10::optional& edge_weight, + bool csc, + bool replace, + bool directed, + bool disjoint, + std::string temporal_strategy, + bool return_edge_id) { + DISPATCH_DIST_SAMPLE(replace, directed, disjoint, return_edge_id, rowptr, col, + seed, num_neighbors, time, seed_time, edge_weight, csc, + temporal_strategy); +} + +std::tuple, + c10::Dict, + c10::Dict, + c10::optional>, + c10::Dict>, + c10::Dict>> +dist_hetero_neighbor_sample_kernel( + const std::vector& node_types, + const std::vector& edge_types, + const c10::Dict& rowptr_dict, + const c10::Dict& col_dict, + const c10::Dict& seed_dict, + const c10::Dict>& num_neighbors_dict, + const c10::optional>& time_dict, + const c10::optional>& seed_time_dict, + const c10::optional>& edge_weight_dict, + bool csc, + bool replace, + bool directed, + bool disjoint, + std::string temporal_strategy, + bool return_edge_id) { + DISPATCH_DIST_SAMPLE(replace, directed, disjoint, return_edge_id, node_types, + edge_types, rowptr_dict, col_dict, seed_dict, + num_neighbors_dict, time_dict, seed_time_dict, + edge_weight_dict, csc, temporal_strategy); +} + TORCH_LIBRARY_IMPL(pyg, CPU, m) { m.impl(TORCH_SELECTIVE_NAME("pyg::neighbor_sample"), TORCH_FN(neighbor_sample_kernel)); @@ -792,5 +913,15 @@ TORCH_LIBRARY_IMPL(pyg, BackendSelect, m) { TORCH_FN(hetero_neighbor_sample_kernel)); } +TORCH_LIBRARY_IMPL(pyg, CPU, m) { + m.impl(TORCH_SELECTIVE_NAME("pyg::dist_neighbor_sample"), + TORCH_FN(dist_neighbor_sample_kernel)); +} + +TORCH_LIBRARY_IMPL(pyg, BackendSelect, m) { + m.impl(TORCH_SELECTIVE_NAME("pyg::dist_hetero_neighbor_sample"), + TORCH_FN(dist_hetero_neighbor_sample_kernel)); +} + } // namespace sampler } // namespace pyg diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h index 629e98001..9c6e3aa51 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h @@ -48,5 +48,49 @@ hetero_neighbor_sample_kernel( std::string temporal_strategy, bool return_edge_id); +std::tuple, + std::vector, + std::vector, + std::vector> +dist_neighbor_sample_kernel(const at::Tensor& rowptr, + const at::Tensor& col, + const at::Tensor& seed, + const std::vector& num_neighbors, + const c10::optional& time, + const c10::optional& seed_time, + const c10::optional& edge_weight, + bool csc, + bool replace, + bool directed, + bool disjoint, + std::string temporal_strategy, + bool return_edge_id); + +std::tuple, + c10::Dict, + c10::Dict, + c10::optional>, + c10::Dict>, + c10::Dict>> +dist_hetero_neighbor_sample_kernel( + const std::vector& node_types, + const std::vector& edge_types, + const c10::Dict& rowptr_dict, + const c10::Dict& col_dict, + const c10::Dict& seed_dict, + const c10::Dict>& num_neighbors_dict, + const c10::optional>& time_dict, + const c10::optional>& seed_time_dict, + const c10::optional>& edge_weight_dict, + bool csc, + bool replace, + bool directed, + bool disjoint, + std::string temporal_strategy, + bool return_edge_id); + } // namespace sampler } // namespace pyg diff --git a/pyg_lib/csrc/sampler/neighbor.cpp b/pyg_lib/csrc/sampler/neighbor.cpp index 1fdb6a354..6907ab5d1 100644 --- a/pyg_lib/csrc/sampler/neighbor.cpp +++ b/pyg_lib/csrc/sampler/neighbor.cpp @@ -94,6 +94,94 @@ hetero_neighbor_sample( temporal_strategy, return_edge_id); } +std::tuple, + std::vector, + std::vector, + std::vector> +dist_neighbor_sample(const at::Tensor& rowptr, + const at::Tensor& col, + const at::Tensor& seed, + const std::vector& num_neighbors, + const c10::optional& time, + const c10::optional& seed_time, + const c10::optional& edge_weight, + bool csc, + bool replace, + bool directed, + bool disjoint, + std::string temporal_strategy, + bool return_edge_id) { + at::TensorArg rowptr_t{rowptr, "rowtpr", 1}; + at::TensorArg col_t{col, "col", 1}; + at::TensorArg seed_t{seed, "seed", 1}; + + at::CheckedFrom c = "dist_neighbor_sample"; + at::checkAllDefined(c, {rowptr_t, col_t, seed_t}); + at::checkAllSameType(c, {rowptr_t, col_t, seed_t}); + + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("pyg::dist_neighbor_sample", "") + .typed(); + return op.call(rowptr, col, seed, num_neighbors, time, seed_time, edge_weight, + csc, replace, directed, disjoint, temporal_strategy, + return_edge_id); +} + +std::tuple, + c10::Dict, + c10::Dict, + c10::optional>, + c10::Dict>, + c10::Dict>> +dist_hetero_neighbor_sample( + const std::vector& node_types, + const std::vector& edge_types, + const c10::Dict& rowptr_dict, + const c10::Dict& col_dict, + const c10::Dict& seed_dict, + const c10::Dict>& num_neighbors_dict, + const c10::optional>& time_dict, + const c10::optional>& seed_time_dict, + const c10::optional>& edge_weight_dict, + bool csc, + bool replace, + bool directed, + bool disjoint, + std::string temporal_strategy, + bool return_edge_id) { + TORCH_CHECK(rowptr_dict.size() == col_dict.size(), + "Number of edge types in 'rowptr_dict' and 'col_dict' must match") + + std::vector rowptr_dict_args; + std::vector col_dict_args; + std::vector seed_dict_args; + pyg::utils::fill_tensor_args(rowptr_dict_args, rowptr_dict, "rowptr_dict", 0); + pyg::utils::fill_tensor_args(col_dict_args, col_dict, "col_dict", 0); + pyg::utils::fill_tensor_args(seed_dict_args, seed_dict, "seed_dict", 0); + at::CheckedFrom c{"dist_hetero_neighbor_sample"}; + + at::checkAllDefined(c, rowptr_dict_args); + at::checkAllDefined(c, col_dict_args); + at::checkAllDefined(c, seed_dict_args); + at::checkAllSameType(c, rowptr_dict_args); + at::checkAllSameType(c, col_dict_args); + at::checkAllSameType(c, seed_dict_args); + at::checkSameType(c, rowptr_dict_args[0], col_dict_args[0]); + at::checkSameType(c, rowptr_dict_args[0], seed_dict_args[0]); + + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("pyg::dist_hetero_neighbor_sample", "") + .typed(); + return op.call(node_types, edge_types, rowptr_dict, col_dict, seed_dict, + num_neighbors_dict, time_dict, seed_time_dict, + edge_weight_dict, csc, replace, directed, disjoint, + temporal_strategy, return_edge_id); +} + TORCH_LIBRARY_FRAGMENT(pyg, m) { m.def(TORCH_SELECTIVE_SCHEMA( "pyg::neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int[] " @@ -112,6 +200,23 @@ TORCH_LIBRARY_FRAGMENT(pyg, m) { "str temporal_strategy = 'uniform', bool return_edge_id = True) -> " "(Dict(str, Tensor), Dict(str, Tensor), Dict(str, Tensor), " "Dict(str, Tensor)?, Dict(str, int[]), Dict(str, int[]))")); + m.def(TORCH_SELECTIVE_SCHEMA( + "pyg::dist_neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int[] " + "num_neighbors, Tensor? time = None, Tensor? seed_time = None, Tensor? " + "edge_weight = None, bool csc = False, bool replace = False, bool " + "directed = True, bool disjoint = False, str temporal_strategy = " + "'uniform', bool return_edge_id = True) " + "-> (Tensor, Tensor, Tensor, Tensor?, int[], int[], int[])")); + m.def(TORCH_SELECTIVE_SCHEMA( + "pyg::dist_hetero_neighbor_sample(str[] node_types, (str, str, str)[] " + "edge_types, Dict(str, Tensor) rowptr_dict, Dict(str, Tensor) col_dict, " + "Dict(str, Tensor) seed_dict, Dict(str, int[]) num_neighbors_dict, " + "Dict(str, Tensor)? time_dict = None, Dict(str, Tensor)? seed_time_dict " + "= None, Dict(str, Tensor)? edge_weight_dict = None, bool csc = False, " + "bool replace = False, bool directed = True, bool disjoint = False, " + "str temporal_strategy = 'uniform', bool return_edge_id = True) -> " + "(Dict(str, Tensor), Dict(str, Tensor), Dict(str, Tensor), " + "Dict(str, Tensor)?, Dict(str, int[]), Dict(str, int[]))")); } } // namespace sampler diff --git a/pyg_lib/csrc/sampler/neighbor.h b/pyg_lib/csrc/sampler/neighbor.h index 3091624fb..8abcd42b2 100644 --- a/pyg_lib/csrc/sampler/neighbor.h +++ b/pyg_lib/csrc/sampler/neighbor.h @@ -61,5 +61,55 @@ hetero_neighbor_sample( std::string strategy = "uniform", bool return_edge_id = true); +PYG_API +std::tuple, + std::vector, + std::vector, + std::vector> +dist_neighbor_sample( + const at::Tensor& rowptr, + const at::Tensor& col, + const at::Tensor& seed, + const std::vector& num_neighbors, + const c10::optional& time = c10::nullopt, + const c10::optional& seed_time = c10::nullopt, + const c10::optional& edge_weight = c10::nullopt, + bool csc = false, + bool replace = false, + bool directed = true, + bool disjoint = false, + std::string strategy = "uniform", + bool return_edge_id = true); + +PYG_API +std::tuple, + c10::Dict, + c10::Dict, + c10::optional>, + c10::Dict>, + c10::Dict>> +dist_hetero_neighbor_sample( + const std::vector& node_types, + const std::vector& edge_types, + const c10::Dict& rowptr_dict, + const c10::Dict& col_dict, + const c10::Dict& seed_dict, + const c10::Dict>& num_neighbors_dict, + const c10::optional>& time_dict = + c10::nullopt, + const c10::optional>& seed_time_dict = + c10::nullopt, + const c10::optional>& edge_weight_dict = + c10::nullopt, + bool csc = false, + bool replace = false, + bool directed = true, + bool disjoint = false, + std::string strategy = "uniform", + bool return_edge_id = true); + } // namespace sampler } // namespace pyg diff --git a/pyg_lib/sampler/__init__.py b/pyg_lib/sampler/__init__.py index 5cf08a5df..dd0cc1a0f 100644 --- a/pyg_lib/sampler/__init__.py +++ b/pyg_lib/sampler/__init__.py @@ -97,7 +97,7 @@ def hetero_neighbor_sample( return_edge_id: bool = True, ) -> Tuple[Dict[EdgeType, Tensor], Dict[EdgeType, Tensor], Dict[ NodeType, Tensor], Optional[Dict[EdgeType, Tensor]], Dict[ - NodeType, List[int]], Dict[NodeType, List[int]]]: + NodeType, List[int]], Dict[EdgeType, List[int]]]: r"""Recursively samples neighbors from all node indices in :obj:`seed_dict` in the heterogeneous graph given by :obj:`(rowptr_dict, col_dict)`.