diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index 71ea6de65..addb9146d 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -847,17 +847,11 @@ hetero_neighbor_sample_kernel( edge_weight_dict, csc, temporal_strategy); } -std::tuple, - std::vector, - std::vector, - std::vector> +std::tuple, std::vector> dist_neighbor_sample_kernel(const at::Tensor& rowptr, const at::Tensor& col, const at::Tensor& seed, - const std::vector& num_neighbors, + const int64_t one_hop_num, const c10::optional& time, const c10::optional& seed_time, const c10::optional& edge_weight, @@ -867,37 +861,12 @@ dist_neighbor_sample_kernel(const at::Tensor& rowptr, bool disjoint, std::string temporal_strategy, bool return_edge_id) { - DISPATCH_DIST_SAMPLE(replace, directed, disjoint, return_edge_id, rowptr, col, - seed, num_neighbors, time, seed_time, edge_weight, csc, - temporal_strategy); -} - -std::tuple, - c10::Dict, - c10::Dict, - c10::optional>, - c10::Dict>, - c10::Dict>> -dist_hetero_neighbor_sample_kernel( - const std::vector& node_types, - const std::vector& edge_types, - const c10::Dict& rowptr_dict, - const c10::Dict& col_dict, - const c10::Dict& seed_dict, - const c10::Dict>& num_neighbors_dict, - const c10::optional>& time_dict, - const c10::optional>& seed_time_dict, - const c10::optional>& edge_weight_dict, - bool csc, - bool replace, - bool directed, - bool disjoint, - std::string temporal_strategy, - bool return_edge_id) { - DISPATCH_DIST_SAMPLE(replace, directed, disjoint, return_edge_id, node_types, - edge_types, rowptr_dict, col_dict, seed_dict, - num_neighbors_dict, time_dict, seed_time_dict, - edge_weight_dict, csc, temporal_strategy); + const auto out = [&] { + DISPATCH_DIST_SAMPLE(replace, directed, disjoint, return_edge_id, rowptr, + col, seed, {one_hop_num}, time, seed_time, edge_weight, + csc, temporal_strategy); + }(); + return std::make_tuple(std::get<2>(out), std::get<3>(out), std::get<6>(out)); } TORCH_LIBRARY_IMPL(pyg, CPU, m) { @@ -918,10 +887,5 @@ TORCH_LIBRARY_IMPL(pyg, CPU, m) { TORCH_FN(dist_neighbor_sample_kernel)); } -TORCH_LIBRARY_IMPL(pyg, BackendSelect, m) { - m.impl(TORCH_SELECTIVE_NAME("pyg::dist_hetero_neighbor_sample"), - TORCH_FN(dist_hetero_neighbor_sample_kernel)); -} - } // namespace sampler } // namespace pyg diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h index 9c6e3aa51..9fd9b3656 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h @@ -48,17 +48,11 @@ hetero_neighbor_sample_kernel( std::string temporal_strategy, bool return_edge_id); -std::tuple, - std::vector, - std::vector, - std::vector> +std::tuple, std::vector> dist_neighbor_sample_kernel(const at::Tensor& rowptr, const at::Tensor& col, const at::Tensor& seed, - const std::vector& num_neighbors, + const int64_t one_hop_num, const c10::optional& time, const c10::optional& seed_time, const c10::optional& edge_weight, @@ -69,28 +63,5 @@ dist_neighbor_sample_kernel(const at::Tensor& rowptr, std::string temporal_strategy, bool return_edge_id); -std::tuple, - c10::Dict, - c10::Dict, - c10::optional>, - c10::Dict>, - c10::Dict>> -dist_hetero_neighbor_sample_kernel( - const std::vector& node_types, - const std::vector& edge_types, - const c10::Dict& rowptr_dict, - const c10::Dict& col_dict, - const c10::Dict& seed_dict, - const c10::Dict>& num_neighbors_dict, - const c10::optional>& time_dict, - const c10::optional>& seed_time_dict, - const c10::optional>& edge_weight_dict, - bool csc, - bool replace, - bool directed, - bool disjoint, - std::string temporal_strategy, - bool return_edge_id); - } // namespace sampler } // namespace pyg diff --git a/pyg_lib/csrc/sampler/neighbor.cpp b/pyg_lib/csrc/sampler/neighbor.cpp index 6907ab5d1..fde62bdec 100644 --- a/pyg_lib/csrc/sampler/neighbor.cpp +++ b/pyg_lib/csrc/sampler/neighbor.cpp @@ -94,17 +94,11 @@ hetero_neighbor_sample( temporal_strategy, return_edge_id); } -std::tuple, - std::vector, - std::vector, - std::vector> +std::tuple, std::vector> dist_neighbor_sample(const at::Tensor& rowptr, const at::Tensor& col, const at::Tensor& seed, - const std::vector& num_neighbors, + const int64_t one_hop_num, const c10::optional& time, const c10::optional& seed_time, const c10::optional& edge_weight, @@ -125,63 +119,11 @@ dist_neighbor_sample(const at::Tensor& rowptr, static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("pyg::dist_neighbor_sample", "") .typed(); - return op.call(rowptr, col, seed, num_neighbors, time, seed_time, edge_weight, + return op.call(rowptr, col, seed, one_hop_num, time, seed_time, edge_weight, csc, replace, directed, disjoint, temporal_strategy, return_edge_id); } -std::tuple, - c10::Dict, - c10::Dict, - c10::optional>, - c10::Dict>, - c10::Dict>> -dist_hetero_neighbor_sample( - const std::vector& node_types, - const std::vector& edge_types, - const c10::Dict& rowptr_dict, - const c10::Dict& col_dict, - const c10::Dict& seed_dict, - const c10::Dict>& num_neighbors_dict, - const c10::optional>& time_dict, - const c10::optional>& seed_time_dict, - const c10::optional>& edge_weight_dict, - bool csc, - bool replace, - bool directed, - bool disjoint, - std::string temporal_strategy, - bool return_edge_id) { - TORCH_CHECK(rowptr_dict.size() == col_dict.size(), - "Number of edge types in 'rowptr_dict' and 'col_dict' must match") - - std::vector rowptr_dict_args; - std::vector col_dict_args; - std::vector seed_dict_args; - pyg::utils::fill_tensor_args(rowptr_dict_args, rowptr_dict, "rowptr_dict", 0); - pyg::utils::fill_tensor_args(col_dict_args, col_dict, "col_dict", 0); - pyg::utils::fill_tensor_args(seed_dict_args, seed_dict, "seed_dict", 0); - at::CheckedFrom c{"dist_hetero_neighbor_sample"}; - - at::checkAllDefined(c, rowptr_dict_args); - at::checkAllDefined(c, col_dict_args); - at::checkAllDefined(c, seed_dict_args); - at::checkAllSameType(c, rowptr_dict_args); - at::checkAllSameType(c, col_dict_args); - at::checkAllSameType(c, seed_dict_args); - at::checkSameType(c, rowptr_dict_args[0], col_dict_args[0]); - at::checkSameType(c, rowptr_dict_args[0], seed_dict_args[0]); - - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("pyg::dist_hetero_neighbor_sample", "") - .typed(); - return op.call(node_types, edge_types, rowptr_dict, col_dict, seed_dict, - num_neighbors_dict, time_dict, seed_time_dict, - edge_weight_dict, csc, replace, directed, disjoint, - temporal_strategy, return_edge_id); -} - TORCH_LIBRARY_FRAGMENT(pyg, m) { m.def(TORCH_SELECTIVE_SCHEMA( "pyg::neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int[] " @@ -201,22 +143,12 @@ TORCH_LIBRARY_FRAGMENT(pyg, m) { "(Dict(str, Tensor), Dict(str, Tensor), Dict(str, Tensor), " "Dict(str, Tensor)?, Dict(str, int[]), Dict(str, int[]))")); m.def(TORCH_SELECTIVE_SCHEMA( - "pyg::dist_neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int[] " - "num_neighbors, Tensor? time = None, Tensor? seed_time = None, Tensor? " + "pyg::dist_neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int " + "one_hop_num, Tensor? time = None, Tensor? seed_time = None, Tensor? " "edge_weight = None, bool csc = False, bool replace = False, bool " "directed = True, bool disjoint = False, str temporal_strategy = " "'uniform', bool return_edge_id = True) " - "-> (Tensor, Tensor, Tensor, Tensor?, int[], int[], int[])")); - m.def(TORCH_SELECTIVE_SCHEMA( - "pyg::dist_hetero_neighbor_sample(str[] node_types, (str, str, str)[] " - "edge_types, Dict(str, Tensor) rowptr_dict, Dict(str, Tensor) col_dict, " - "Dict(str, Tensor) seed_dict, Dict(str, int[]) num_neighbors_dict, " - "Dict(str, Tensor)? time_dict = None, Dict(str, Tensor)? seed_time_dict " - "= None, Dict(str, Tensor)? edge_weight_dict = None, bool csc = False, " - "bool replace = False, bool directed = True, bool disjoint = False, " - "str temporal_strategy = 'uniform', bool return_edge_id = True) -> " - "(Dict(str, Tensor), Dict(str, Tensor), Dict(str, Tensor), " - "Dict(str, Tensor)?, Dict(str, int[]), Dict(str, int[]))")); + "-> (Tensor, Tensor?, int[])")); } } // namespace sampler diff --git a/pyg_lib/csrc/sampler/neighbor.h b/pyg_lib/csrc/sampler/neighbor.h index 8abcd42b2..045cccce0 100644 --- a/pyg_lib/csrc/sampler/neighbor.h +++ b/pyg_lib/csrc/sampler/neighbor.h @@ -62,18 +62,12 @@ hetero_neighbor_sample( bool return_edge_id = true); PYG_API -std::tuple, - std::vector, - std::vector, - std::vector> +std::tuple, std::vector> dist_neighbor_sample( const at::Tensor& rowptr, const at::Tensor& col, const at::Tensor& seed, - const std::vector& num_neighbors, + const int64_t one_hop_num, const c10::optional& time = c10::nullopt, const c10::optional& seed_time = c10::nullopt, const c10::optional& edge_weight = c10::nullopt, @@ -84,32 +78,5 @@ dist_neighbor_sample( std::string strategy = "uniform", bool return_edge_id = true); -PYG_API -std::tuple, - c10::Dict, - c10::Dict, - c10::optional>, - c10::Dict>, - c10::Dict>> -dist_hetero_neighbor_sample( - const std::vector& node_types, - const std::vector& edge_types, - const c10::Dict& rowptr_dict, - const c10::Dict& col_dict, - const c10::Dict& seed_dict, - const c10::Dict>& num_neighbors_dict, - const c10::optional>& time_dict = - c10::nullopt, - const c10::optional>& seed_time_dict = - c10::nullopt, - const c10::optional>& edge_weight_dict = - c10::nullopt, - bool csc = false, - bool replace = false, - bool directed = true, - bool disjoint = false, - std::string strategy = "uniform", - bool return_edge_id = true); - } // namespace sampler } // namespace pyg diff --git a/pyg_lib/sampler/__init__.py b/pyg_lib/sampler/__init__.py index dd0cc1a0f..980554834 100644 --- a/pyg_lib/sampler/__init__.py +++ b/pyg_lib/sampler/__init__.py @@ -165,6 +165,43 @@ def hetero_neighbor_sample( num_nodes_per_hop_dict, num_edges_per_hop_dict) +def dist_neighbor_sample( + rowptr: Tensor, + col: Tensor, + seed: Tensor, + one_hop_num: int, + time: Optional[Tensor] = None, + seed_time: Optional[Tensor] = None, + edge_weight: Optional[Tensor] = None, + csc: bool = False, + replace: bool = False, + directed: bool = True, + disjoint: bool = False, + temporal_strategy: str = 'uniform', + return_edge_id: bool = True, +) -> Tuple[Tensor, Optional[Tensor], List[int]]: + r"""For distributed sampling purpose. Leverages the + :meth:`neighbor_sample`. Samples one hop neighborhood with duplicates from + all node indices in :obj:`seed` in the graph given by :obj:`(rowptr, col)`. + + Args: + one_hop_num (int): Max number of neighbors to sample in the current + layer. + kwargs: Arguments of :meth:`neighbor_sample`. + + Returns: + (torch.Tensor, Optional[torch.Tensor], List[int]): + Returns original node indices for all sampled nodes and in addition, + the indices of edges of the original graph. Lastly, returns cummulative + sum of the amount of sampled neighbors by each node in the :obj:`seed`. + """ + return torch.ops.pyg.dist_neighbor_sample(rowptr, col, seed, one_hop_num, + time, seed_time, edge_weight, + csc, replace, directed, disjoint, + temporal_strategy, + return_edge_id) + + def subgraph( rowptr: Tensor, col: Tensor, @@ -218,6 +255,7 @@ def random_walk(rowptr: Tensor, col: Tensor, seed: Tensor, walk_length: int, __all__ = [ 'neighbor_sample', 'hetero_neighbor_sample', + 'dist_neighbor_sample', 'subgraph', 'random_walk', ] diff --git a/test/csrc/sampler/test_dist_neighbor.cpp b/test/csrc/sampler/test_dist_neighbor.cpp new file mode 100644 index 000000000..63b9b8027 --- /dev/null +++ b/test/csrc/sampler/test_dist_neighbor.cpp @@ -0,0 +1,139 @@ +#include +#include + +#include "pyg_lib/csrc/sampler/neighbor.h" +#include "pyg_lib/csrc/utils/types.h" +#include "test/csrc/graph.h" + +TEST(FullDistNeighborTest, BasicAssertions) { + auto options = at::TensorOptions().dtype(at::kLong); + + int num_nodes = 6; + auto graph = cycle_graph(num_nodes, options); + auto seed = at::arange(2, 4, options); + int64_t one_hop_num = -1; + + auto out = pyg::sampler::dist_neighbor_sample( + /*rowptr=*/std::get<0>(graph), /*col=*/std::get<1>(graph), seed, + one_hop_num); + + // sample nodes with duplicates + auto expected_nodes = at::tensor({2, 3, 1, 3, 2, 4}, options); + EXPECT_TRUE(at::equal(std::get<0>(out), expected_nodes)); + + auto expected_edges = at::tensor({4, 5, 6, 7}, options); + EXPECT_TRUE(at::equal(std::get<1>(out).value(), expected_edges)); + + std::vector expected_cumm_sum_nbrs_per_node = {2, 4, 6}; + EXPECT_EQ(std::get<2>(out), expected_cumm_sum_nbrs_per_node); +} + +TEST(WithoutReplacementNeighborTest, BasicAssertions) { + auto options = at::TensorOptions().dtype(at::kLong); + + int num_nodes = 6; + auto graph = cycle_graph(num_nodes, options); + auto seed = at::arange(2, 4, options); + int64_t one_hop_num = 1; + + at::manual_seed(123456); + auto out = pyg::sampler::dist_neighbor_sample( + /*rowptr=*/std::get<0>(graph), /*col=*/std::get<1>(graph), seed, + one_hop_num); + + // sample nodes with duplicates + auto expected_nodes = at::tensor({2, 3, 1, 4}, options); + EXPECT_TRUE(at::equal(std::get<0>(out), expected_nodes)); + + auto expected_edges = at::tensor({4, 7}, options); + EXPECT_TRUE(at::equal(std::get<1>(out).value(), expected_edges)); + + std::vector expected_cumm_sum_nbrs_per_node = {2, 3, 4}; + EXPECT_EQ(std::get<2>(out), expected_cumm_sum_nbrs_per_node); +} + +TEST(WithReplacementNeighborTest, BasicAssertions) { + auto options = at::TensorOptions().dtype(at::kLong); + + int num_nodes = 6; + auto graph = cycle_graph(num_nodes, options); + auto seed = at::arange(2, 4, options); + int64_t one_hop_num = 2; + + at::manual_seed(123456); + auto out = pyg::sampler::dist_neighbor_sample( + /*rowptr=*/std::get<0>(graph), /*col=*/std::get<1>(graph), seed, + one_hop_num, /*time=*/c10::nullopt, + /*seed_time=*/c10::nullopt, /*edge_weight=*/c10::nullopt, + /*csc*/ false, /*replace=*/true); + + // sample nodes with duplicates + auto expected_nodes = at::tensor({2, 3, 1, 3, 4, 4}, options); + EXPECT_TRUE(at::equal(std::get<0>(out), expected_nodes)); + + auto expected_edges = at::tensor({4, 5, 7, 7}, options); + EXPECT_TRUE(at::equal(std::get<1>(out).value(), expected_edges)); + + std::vector expected_cumm_sum_nbrs_per_node = {2, 4, 6}; + EXPECT_EQ(std::get<2>(out), expected_cumm_sum_nbrs_per_node); +} + +TEST(DistDisjointNeighborTest, BasicAssertions) { + auto options = at::TensorOptions().dtype(at::kLong); + + int num_nodes = 6; + auto graph = cycle_graph(num_nodes, options); + auto seed = at::arange(2, 4, options); + int64_t one_hop_num = 2; + // auto batch = at::tensor({0, 1}, options); + + auto out = pyg::sampler::dist_neighbor_sample( + /*rowptr=*/std::get<0>(graph), /*col=*/std::get<1>(graph), seed, + one_hop_num, /*time=*/c10::nullopt, + /*seed_time=*/c10::nullopt, /*edge_weight=*/c10::nullopt, /*csc*/ false, + /*replace=*/false, /*directed=*/true, /*disjoint=*/true); + + // sample nodes with duplicates + auto expected_nodes = + at::tensor({0, 2, 1, 3, 0, 1, 0, 3, 1, 2, 1, 4}, options); + EXPECT_TRUE(at::equal(std::get<0>(out), expected_nodes.view({-1, 2}))); + + auto expected_edges = at::tensor({4, 5, 6, 7}, options); + EXPECT_TRUE(at::equal(std::get<1>(out).value(), expected_edges)); + + std::vector expected_cumm_sum_nbrs_per_node = {2, 4, 6}; + EXPECT_EQ(std::get<2>(out), expected_cumm_sum_nbrs_per_node); +} + +TEST(DistTemporalNeighborTest, BasicAssertions) { + auto options = at::TensorOptions().dtype(at::kLong); + + int num_nodes = 6; + auto graph = cycle_graph(num_nodes, options); + auto rowptr = std::get<0>(graph); + auto col = std::get<1>(graph); + + auto seed = at::arange(2, 4, options); + int64_t one_hop_num = 2; + + // Time is equal to node ID ... + auto time = at::arange(6, options); + // ... so we need to sort the column vector by time/node ID: + col = std::get<0>(at::sort(col.view({-1, 2}), /*dim=*/1)).flatten(); + + auto out = pyg::sampler::dist_neighbor_sample( + rowptr, col, seed, one_hop_num, time, + /*seed_time=*/c10::nullopt, /*edge_weight=*/c10::nullopt, + /*csc*/ false, /*replace=*/false, /*directed=*/true, + /*disjoint=*/true, /*temporal_strategy=*/"uniform"); + + // sample nodes with duplicates + auto expected_nodes = at::tensor({0, 2, 1, 3, 0, 1, 1, 2}, options); + EXPECT_TRUE(at::equal(std::get<0>(out), expected_nodes.view({-1, 2}))); + + auto expected_edges = at::tensor({4, 6}, options); + EXPECT_TRUE(at::equal(std::get<1>(out).value(), expected_edges)); + + std::vector expected_cumm_sum_nbrs_per_node = {2, 3, 4}; + EXPECT_EQ(std::get<2>(out), expected_cumm_sum_nbrs_per_node); +}