From 3a4d43621774c2c109951fd2f0f76df2369137ad Mon Sep 17 00:00:00 2001 From: kgajdamo Date: Wed, 6 Sep 2023 16:07:57 +0200 Subject: [PATCH] Update `dist_neighbor_sample` (#253) This code belongs to the part of the whole distributed training for PyG. This PR is complementary to the [#246](https://github.com/pyg-team/pyg-lib/pull/246) and introduces some updates. What has been changed: * Removed not needed `dist_hetero_neighbor_sample` function (due to the fact, that distributed sampling have a loop over the layers in python, in case of hetero at the moment when we call `neighbor_sample` we have only one edge type. So it becomes actually homo and we don't need the `dist_hetero_neighbor_sample` and can use `dist_neighbor_sample` instead.) * Removed all not used outputs and left only the following: `node`, `edge_ids`, `cummsum_sampled_nbrs_per_node`. * Changed `std::vector num_neighbors` input list into `int64_t one_hop_num`. Added: * Unit tests --------- Co-authored-by: rusty1s --- CHANGELOG.md | 2 +- pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp | 106 ++++---------- pyg_lib/csrc/sampler/cpu/neighbor_kernel.h | 36 +---- pyg_lib/csrc/sampler/neighbor.cpp | 103 +++----------- pyg_lib/csrc/sampler/neighbor.h | 41 +----- pyg_lib/sampler/__init__.py | 45 ++++++ test/csrc/sampler/test_dist_neighbor.cpp | 141 +++++++++++++++++++ test/csrc/sampler/test_neighbor.cpp | 35 ++--- 8 files changed, 248 insertions(+), 261 deletions(-) create mode 100644 test/csrc/sampler/test_dist_neighbor.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index cb0660b18..2307c9b10 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [0.3.0] - 2023-MM-DD ### Added -- Added low-level support for distributed neighborhood sampling ([#246](https://github.com/pyg-team/pyg-lib/pull/246)) +- Added low-level support for distributed neighborhood sampling ([#246](https://github.com/pyg-team/pyg-lib/pull/246), [#253](https://github.com/pyg-team/pyg-lib/pull/253)) - Added support for homogeneous and heterogeneous biased neighborhood sampling ([#247](https://github.com/pyg-team/pyg-lib/pull/247), [#251](https://github.com/pyg-team/pyg-lib/pull/251)) - Added dispatch for XPU device in `index_sort` ([#243](https://github.com/pyg-team/pyg-lib/pull/243)) - Added `metis` partitioning ([#229](https://github.com/pyg-team/pyg-lib/pull/229)) diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index 71ea6de65..cfc679d01 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -755,39 +755,23 @@ sample(const std::vector& node_types, if (!replace && !directed && !disjoint && !return_edge_id) \ return sample(__VA_ARGS__); -#define DISPATCH_DIST_SAMPLE(replace, directed, disjoint, return_edge_id, ...) \ - if (replace && directed && disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ - if (replace && directed && disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ - if (replace && directed && !disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ - if (replace && directed && !disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && !disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && !disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && !disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && !disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && !disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && !disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); +#define DISPATCH_DIST_SAMPLE(replace, directed, disjoint, ...) \ + if (replace && directed && disjoint) \ + return sample(__VA_ARGS__); \ + if (replace && directed && !disjoint) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && disjoint) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && !disjoint) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && disjoint) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && !disjoint) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && disjoint) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && !disjoint) \ + return sample(__VA_ARGS__); } // namespace @@ -847,17 +831,11 @@ hetero_neighbor_sample_kernel( edge_weight_dict, csc, temporal_strategy); } -std::tuple, - std::vector, - std::vector, - std::vector> +std::tuple> dist_neighbor_sample_kernel(const at::Tensor& rowptr, const at::Tensor& col, const at::Tensor& seed, - const std::vector& num_neighbors, + const int64_t num_neighbors, const c10::optional& time, const c10::optional& seed_time, const c10::optional& edge_weight, @@ -865,39 +843,14 @@ dist_neighbor_sample_kernel(const at::Tensor& rowptr, bool replace, bool directed, bool disjoint, - std::string temporal_strategy, - bool return_edge_id) { - DISPATCH_DIST_SAMPLE(replace, directed, disjoint, return_edge_id, rowptr, col, - seed, num_neighbors, time, seed_time, edge_weight, csc, - temporal_strategy); -} - -std::tuple, - c10::Dict, - c10::Dict, - c10::optional>, - c10::Dict>, - c10::Dict>> -dist_hetero_neighbor_sample_kernel( - const std::vector& node_types, - const std::vector& edge_types, - const c10::Dict& rowptr_dict, - const c10::Dict& col_dict, - const c10::Dict& seed_dict, - const c10::Dict>& num_neighbors_dict, - const c10::optional>& time_dict, - const c10::optional>& seed_time_dict, - const c10::optional>& edge_weight_dict, - bool csc, - bool replace, - bool directed, - bool disjoint, - std::string temporal_strategy, - bool return_edge_id) { - DISPATCH_DIST_SAMPLE(replace, directed, disjoint, return_edge_id, node_types, - edge_types, rowptr_dict, col_dict, seed_dict, - num_neighbors_dict, time_dict, seed_time_dict, - edge_weight_dict, csc, temporal_strategy); + std::string temporal_strategy) { + const auto out = [&] { + DISPATCH_DIST_SAMPLE(replace, directed, disjoint, rowptr, col, seed, + {num_neighbors}, time, seed_time, edge_weight, csc, + temporal_strategy); + }(); + return std::make_tuple(std::get<2>(out), std::get<3>(out).value(), + std::get<6>(out)); } TORCH_LIBRARY_IMPL(pyg, CPU, m) { @@ -918,10 +871,5 @@ TORCH_LIBRARY_IMPL(pyg, CPU, m) { TORCH_FN(dist_neighbor_sample_kernel)); } -TORCH_LIBRARY_IMPL(pyg, BackendSelect, m) { - m.impl(TORCH_SELECTIVE_NAME("pyg::dist_hetero_neighbor_sample"), - TORCH_FN(dist_hetero_neighbor_sample_kernel)); -} - } // namespace sampler } // namespace pyg diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h index 9c6e3aa51..2d1dfa831 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h @@ -48,17 +48,11 @@ hetero_neighbor_sample_kernel( std::string temporal_strategy, bool return_edge_id); -std::tuple, - std::vector, - std::vector, - std::vector> +std::tuple> dist_neighbor_sample_kernel(const at::Tensor& rowptr, const at::Tensor& col, const at::Tensor& seed, - const std::vector& num_neighbors, + const int64_t num_neighbors, const c10::optional& time, const c10::optional& seed_time, const c10::optional& edge_weight, @@ -66,31 +60,7 @@ dist_neighbor_sample_kernel(const at::Tensor& rowptr, bool replace, bool directed, bool disjoint, - std::string temporal_strategy, - bool return_edge_id); - -std::tuple, - c10::Dict, - c10::Dict, - c10::optional>, - c10::Dict>, - c10::Dict>> -dist_hetero_neighbor_sample_kernel( - const std::vector& node_types, - const std::vector& edge_types, - const c10::Dict& rowptr_dict, - const c10::Dict& col_dict, - const c10::Dict& seed_dict, - const c10::Dict>& num_neighbors_dict, - const c10::optional>& time_dict, - const c10::optional>& seed_time_dict, - const c10::optional>& edge_weight_dict, - bool csc, - bool replace, - bool directed, - bool disjoint, - std::string temporal_strategy, - bool return_edge_id); + std::string temporal_strategy); } // namespace sampler } // namespace pyg diff --git a/pyg_lib/csrc/sampler/neighbor.cpp b/pyg_lib/csrc/sampler/neighbor.cpp index 6907ab5d1..38c0cc3cb 100644 --- a/pyg_lib/csrc/sampler/neighbor.cpp +++ b/pyg_lib/csrc/sampler/neighbor.cpp @@ -94,26 +94,19 @@ hetero_neighbor_sample( temporal_strategy, return_edge_id); } -std::tuple, - std::vector, - std::vector, - std::vector> -dist_neighbor_sample(const at::Tensor& rowptr, - const at::Tensor& col, - const at::Tensor& seed, - const std::vector& num_neighbors, - const c10::optional& time, - const c10::optional& seed_time, - const c10::optional& edge_weight, - bool csc, - bool replace, - bool directed, - bool disjoint, - std::string temporal_strategy, - bool return_edge_id) { +std::tuple> dist_neighbor_sample( + const at::Tensor& rowptr, + const at::Tensor& col, + const at::Tensor& seed, + const int64_t num_neighbors, + const c10::optional& time, + const c10::optional& seed_time, + const c10::optional& edge_weight, + bool csc, + bool replace, + bool directed, + bool disjoint, + std::string temporal_strategy) { at::TensorArg rowptr_t{rowptr, "rowtpr", 1}; at::TensorArg col_t{col, "col", 1}; at::TensorArg seed_t{seed, "seed", 1}; @@ -126,60 +119,7 @@ dist_neighbor_sample(const at::Tensor& rowptr, .findSchemaOrThrow("pyg::dist_neighbor_sample", "") .typed(); return op.call(rowptr, col, seed, num_neighbors, time, seed_time, edge_weight, - csc, replace, directed, disjoint, temporal_strategy, - return_edge_id); -} - -std::tuple, - c10::Dict, - c10::Dict, - c10::optional>, - c10::Dict>, - c10::Dict>> -dist_hetero_neighbor_sample( - const std::vector& node_types, - const std::vector& edge_types, - const c10::Dict& rowptr_dict, - const c10::Dict& col_dict, - const c10::Dict& seed_dict, - const c10::Dict>& num_neighbors_dict, - const c10::optional>& time_dict, - const c10::optional>& seed_time_dict, - const c10::optional>& edge_weight_dict, - bool csc, - bool replace, - bool directed, - bool disjoint, - std::string temporal_strategy, - bool return_edge_id) { - TORCH_CHECK(rowptr_dict.size() == col_dict.size(), - "Number of edge types in 'rowptr_dict' and 'col_dict' must match") - - std::vector rowptr_dict_args; - std::vector col_dict_args; - std::vector seed_dict_args; - pyg::utils::fill_tensor_args(rowptr_dict_args, rowptr_dict, "rowptr_dict", 0); - pyg::utils::fill_tensor_args(col_dict_args, col_dict, "col_dict", 0); - pyg::utils::fill_tensor_args(seed_dict_args, seed_dict, "seed_dict", 0); - at::CheckedFrom c{"dist_hetero_neighbor_sample"}; - - at::checkAllDefined(c, rowptr_dict_args); - at::checkAllDefined(c, col_dict_args); - at::checkAllDefined(c, seed_dict_args); - at::checkAllSameType(c, rowptr_dict_args); - at::checkAllSameType(c, col_dict_args); - at::checkAllSameType(c, seed_dict_args); - at::checkSameType(c, rowptr_dict_args[0], col_dict_args[0]); - at::checkSameType(c, rowptr_dict_args[0], seed_dict_args[0]); - - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("pyg::dist_hetero_neighbor_sample", "") - .typed(); - return op.call(node_types, edge_types, rowptr_dict, col_dict, seed_dict, - num_neighbors_dict, time_dict, seed_time_dict, - edge_weight_dict, csc, replace, directed, disjoint, - temporal_strategy, return_edge_id); + csc, replace, directed, disjoint, temporal_strategy); } TORCH_LIBRARY_FRAGMENT(pyg, m) { @@ -201,22 +141,11 @@ TORCH_LIBRARY_FRAGMENT(pyg, m) { "(Dict(str, Tensor), Dict(str, Tensor), Dict(str, Tensor), " "Dict(str, Tensor)?, Dict(str, int[]), Dict(str, int[]))")); m.def(TORCH_SELECTIVE_SCHEMA( - "pyg::dist_neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int[] " + "pyg::dist_neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int " "num_neighbors, Tensor? time = None, Tensor? seed_time = None, Tensor? " "edge_weight = None, bool csc = False, bool replace = False, bool " "directed = True, bool disjoint = False, str temporal_strategy = " - "'uniform', bool return_edge_id = True) " - "-> (Tensor, Tensor, Tensor, Tensor?, int[], int[], int[])")); - m.def(TORCH_SELECTIVE_SCHEMA( - "pyg::dist_hetero_neighbor_sample(str[] node_types, (str, str, str)[] " - "edge_types, Dict(str, Tensor) rowptr_dict, Dict(str, Tensor) col_dict, " - "Dict(str, Tensor) seed_dict, Dict(str, int[]) num_neighbors_dict, " - "Dict(str, Tensor)? time_dict = None, Dict(str, Tensor)? seed_time_dict " - "= None, Dict(str, Tensor)? edge_weight_dict = None, bool csc = False, " - "bool replace = False, bool directed = True, bool disjoint = False, " - "str temporal_strategy = 'uniform', bool return_edge_id = True) -> " - "(Dict(str, Tensor), Dict(str, Tensor), Dict(str, Tensor), " - "Dict(str, Tensor)?, Dict(str, int[]), Dict(str, int[]))")); + "'uniform') -> (Tensor, Tensor, int[])")); } } // namespace sampler diff --git a/pyg_lib/csrc/sampler/neighbor.h b/pyg_lib/csrc/sampler/neighbor.h index 8abcd42b2..b589a0a1c 100644 --- a/pyg_lib/csrc/sampler/neighbor.h +++ b/pyg_lib/csrc/sampler/neighbor.h @@ -62,18 +62,11 @@ hetero_neighbor_sample( bool return_edge_id = true); PYG_API -std::tuple, - std::vector, - std::vector, - std::vector> -dist_neighbor_sample( +std::tuple> dist_neighbor_sample( const at::Tensor& rowptr, const at::Tensor& col, const at::Tensor& seed, - const std::vector& num_neighbors, + const int64_t num_neighbors, const c10::optional& time = c10::nullopt, const c10::optional& seed_time = c10::nullopt, const c10::optional& edge_weight = c10::nullopt, @@ -81,35 +74,7 @@ dist_neighbor_sample( bool replace = false, bool directed = true, bool disjoint = false, - std::string strategy = "uniform", - bool return_edge_id = true); - -PYG_API -std::tuple, - c10::Dict, - c10::Dict, - c10::optional>, - c10::Dict>, - c10::Dict>> -dist_hetero_neighbor_sample( - const std::vector& node_types, - const std::vector& edge_types, - const c10::Dict& rowptr_dict, - const c10::Dict& col_dict, - const c10::Dict& seed_dict, - const c10::Dict>& num_neighbors_dict, - const c10::optional>& time_dict = - c10::nullopt, - const c10::optional>& seed_time_dict = - c10::nullopt, - const c10::optional>& edge_weight_dict = - c10::nullopt, - bool csc = false, - bool replace = false, - bool directed = true, - bool disjoint = false, - std::string strategy = "uniform", - bool return_edge_id = true); + std::string strategy = "uniform"); } // namespace sampler } // namespace pyg diff --git a/pyg_lib/sampler/__init__.py b/pyg_lib/sampler/__init__.py index dd0cc1a0f..2087628e3 100644 --- a/pyg_lib/sampler/__init__.py +++ b/pyg_lib/sampler/__init__.py @@ -165,6 +165,50 @@ def hetero_neighbor_sample( num_nodes_per_hop_dict, num_edges_per_hop_dict) +def dist_neighbor_sample( + rowptr: Tensor, + col: Tensor, + seed: Tensor, + num_neighbors: int, + time: Optional[Tensor] = None, + seed_time: Optional[Tensor] = None, + edge_weight: Optional[Tensor] = None, + csc: bool = False, + replace: bool = False, + directed: bool = True, + disjoint: bool = False, + temporal_strategy: str = 'uniform', +) -> Tuple[Tensor, Tensor, List[int]]: + r"""For distributed sampling purpose. Leverages the + :meth:`neighbor_sample`. Samples one hop neighborhood with duplicates from + all node indices in :obj:`seed` in the graph given by :obj:`(rowptr, col)`. + + Args: + num_neighbors (int): Maximum number of neighbors to sample in the + current layer. + kwargs: Arguments of :meth:`neighbor_sample`. + + Returns: + (torch.Tensor, torch.Tensor, List[int]): Returns the original node and + edge indices for all sampled nodes and edges. Lastly, returns the + cummulative sum of the amount of sampled neighbors for each input node. + """ + return torch.ops.pyg.dist_neighbor_sample( + rowptr, + col, + seed, + num_neighbors, + time, + seed_time, + edge_weight, + csc, + replace, + directed, + disjoint, + temporal_strategy, + ) + + def subgraph( rowptr: Tensor, col: Tensor, @@ -218,6 +262,7 @@ def random_walk(rowptr: Tensor, col: Tensor, seed: Tensor, walk_length: int, __all__ = [ 'neighbor_sample', 'hetero_neighbor_sample', + 'dist_neighbor_sample', 'subgraph', 'random_walk', ] diff --git a/test/csrc/sampler/test_dist_neighbor.cpp b/test/csrc/sampler/test_dist_neighbor.cpp new file mode 100644 index 000000000..f13d1c668 --- /dev/null +++ b/test/csrc/sampler/test_dist_neighbor.cpp @@ -0,0 +1,141 @@ +#include +#include + +#include "pyg_lib/csrc/sampler/neighbor.h" +#include "pyg_lib/csrc/utils/types.h" +#include "test/csrc/graph.h" + +TEST(BasicDistNeighborTest, BasicAssertions) { + auto options = at::TensorOptions().dtype(at::kLong); + + auto graph = cycle_graph(/*num_nodes=*/6, options); + + auto out = pyg::sampler::dist_neighbor_sample( + /*rowptr=*/std::get<0>(graph), + /*col=*/std::get<1>(graph), + /*seed=*/at::arange(2, 4, options), + /*num_neighbors=*/-1); + + auto expected_nodes = at::tensor({2, 3, 1, 3, 2, 4}, options); + EXPECT_TRUE(at::equal(std::get<0>(out), expected_nodes)); + + auto expected_edges = at::tensor({4, 5, 6, 7}, options); + EXPECT_TRUE(at::equal(std::get<1>(out), expected_edges)); + + std::vector expected_cumsum_neighbors_per_node = {2, 4, 6}; + EXPECT_EQ(std::get<2>(out), expected_cumsum_neighbors_per_node); +} + +TEST(WithoutReplacementNeighborTest, BasicAssertions) { + auto options = at::TensorOptions().dtype(at::kLong); + + auto graph = cycle_graph(/*num_nodes=*/6, options); + + at::manual_seed(123456); + auto out = pyg::sampler::dist_neighbor_sample( + /*rowptr=*/std::get<0>(graph), + /*col=*/std::get<1>(graph), + /*seed=*/at::arange(2, 4, options), + /*num_neighbors=*/1); + + auto expected_nodes = at::tensor({2, 3, 1, 4}, options); + EXPECT_TRUE(at::equal(std::get<0>(out), expected_nodes)); + + auto expected_edges = at::tensor({4, 7}, options); + EXPECT_TRUE(at::equal(std::get<1>(out), expected_edges)); + + std::vector expected_cumsum_neighbors_per_node = {2, 3, 4}; + EXPECT_EQ(std::get<2>(out), expected_cumsum_neighbors_per_node); +} + +TEST(WithReplacementNeighborTest, BasicAssertions) { + auto options = at::TensorOptions().dtype(at::kLong); + + auto graph = cycle_graph(/*num_nodes=*/6, options); + + at::manual_seed(123456); + auto out = pyg::sampler::dist_neighbor_sample( + /*rowptr=*/std::get<0>(graph), + /*col=*/std::get<1>(graph), + /*seed=*/at::arange(2, 4, options), + /*num_neighbors=*/2, + /*time=*/c10::nullopt, + /*seed_time=*/c10::nullopt, + /*edge_weight=*/c10::nullopt, + /*csc*/ false, + /*replace=*/true); + + auto expected_nodes = at::tensor({2, 3, 1, 3, 4, 4}, options); + EXPECT_TRUE(at::equal(std::get<0>(out), expected_nodes)); + + auto expected_edges = at::tensor({4, 5, 7, 7}, options); + EXPECT_TRUE(at::equal(std::get<1>(out), expected_edges)); + + std::vector expected_cumsum_neighbors_per_node = {2, 4, 6}; + EXPECT_EQ(std::get<2>(out), expected_cumsum_neighbors_per_node); +} + +TEST(DistDisjointNeighborTest, BasicAssertions) { + auto options = at::TensorOptions().dtype(at::kLong); + + auto graph = cycle_graph(/*num_nodes=*/6, options); + + auto out = pyg::sampler::dist_neighbor_sample( + /*rowptr=*/std::get<0>(graph), + /*col=*/std::get<1>(graph), + /*seed=*/at::arange(2, 4, options), + /*num_neighbors=*/2, + /*time=*/c10::nullopt, + /*seed_time=*/c10::nullopt, + /*edge_weight=*/c10::nullopt, + /*csc*/ false, + /*replace=*/false, + /*directed=*/true, + /*disjoint=*/true); + + auto expected_nodes = + at::tensor({0, 2, 1, 3, 0, 1, 0, 3, 1, 2, 1, 4}, options); + EXPECT_TRUE(at::equal(std::get<0>(out), expected_nodes.view({-1, 2}))); + + auto expected_edges = at::tensor({4, 5, 6, 7}, options); + EXPECT_TRUE(at::equal(std::get<1>(out), expected_edges)); + + std::vector expected_cumsum_neighbors_per_node = {2, 4, 6}; + EXPECT_EQ(std::get<2>(out), expected_cumsum_neighbors_per_node); +} + +TEST(DistTemporalNeighborTest, BasicAssertions) { + auto options = at::TensorOptions().dtype(at::kLong); + + auto graph = cycle_graph(/*num_nodes=*/6, options); + auto rowptr = std::get<0>(graph); + auto col = std::get<1>(graph); + + // Time is equal to node ID ... + auto time = at::arange(6, options); + // ... so we need to sort the column vector by time/node ID: + col = std::get<0>(at::sort(col.view({-1, 2}), /*dim=*/1)).flatten(); + + auto out = pyg::sampler::dist_neighbor_sample( + /*rowptr=*/rowptr, + /*col=*/col, + /*seed=*/at::arange(2, 4, options), + /*num_neighbors=*/2, + /*time=*/time, + /*seed_time=*/c10::nullopt, + /*edge_weight=*/c10::nullopt, + /*csc*/ false, + /*replace=*/false, + /*directed=*/true, + /*disjoint=*/true, + /*temporal_strategy=*/"uniform"); + + auto expected_nodes = at::tensor({0, 2, 1, 3, 0, 1, 1, 2}, options); + EXPECT_TRUE(at::equal(std::get<0>(out), expected_nodes.view({-1, 2}))); + + auto expected_edges = at::tensor({4, 6}, options); + EXPECT_TRUE(at::equal(std::get<1>(out), expected_edges)); + + std::vector expected_cumsum_neighbors_per_node = {2, 3, 4}; + EXPECT_EQ(std::get<2>(out), expected_cumsum_neighbors_per_node); +} diff --git a/test/csrc/sampler/test_neighbor.cpp b/test/csrc/sampler/test_neighbor.cpp index e8fb50165..20c94d18b 100644 --- a/test/csrc/sampler/test_neighbor.cpp +++ b/test/csrc/sampler/test_neighbor.cpp @@ -9,14 +9,12 @@ TEST(BasicNeighborTest, BasicAssertions) { auto options = at::TensorOptions().dtype(at::kLong); auto graph = cycle_graph(/*num_nodes=*/6, options); - auto seed = at::arange(2, 4, options); - std::vector num_neighbors = {-1, -1}; auto out = pyg::sampler::neighbor_sample( /*rowptr=*/std::get<0>(graph), /*col=*/std::get<1>(graph), - /*seed=*/seed, - /*num_neighbors=*/num_neighbors); + /*seed=*/at::arange(2, 4, options), + /*num_neighbors=*/{-1, -1}); auto expected_row = at::tensor({0, 0, 1, 1, 2, 2, 3, 3}, options); EXPECT_TRUE(at::equal(std::get<0>(out), expected_row)); @@ -36,15 +34,13 @@ TEST(WithoutReplacementNeighborTest, BasicAssertions) { auto options = at::TensorOptions().dtype(at::kLong); auto graph = cycle_graph(/*num_nodes=*/6, options); - auto seed = at::arange(2, 4, options); - std::vector num_neighbors = {1, 1}; at::manual_seed(123456); auto out = pyg::sampler::neighbor_sample( /*rowptr=*/std::get<0>(graph), /*col=*/std::get<1>(graph), - /*seed=*/seed, - /*num_neighbors=*/num_neighbors, + /*seed=*/at ::arange(2, 4, options), + /*num_neighbors=*/{1, 1}, /*time=*/c10::nullopt, /*seed_time=*/c10::nullopt, /*edge_weight=*/c10::nullopt, @@ -65,15 +61,13 @@ TEST(WithReplacementNeighborTest, BasicAssertions) { auto options = at::TensorOptions().dtype(at::kLong); auto graph = cycle_graph(/*num_nodes=*/6, options); - auto seed = at::arange(2, 4, options); - std::vector num_neighbors = {1, 1}; at::manual_seed(123456); auto out = pyg::sampler::neighbor_sample( /*rowptr=*/std::get<0>(graph), /*col=*/std::get<1>(graph), - /*seed=*/seed, - /*num_neighbors=*/num_neighbors, + /*seed=*/at::arange(2, 4, options), + /*num_neighbors=*/{1, 1}, /*time=*/c10::nullopt, /*seed_time=*/c10::nullopt, /*edge_weight=*/c10::nullopt, @@ -94,14 +88,12 @@ TEST(DisjointNeighborTest, BasicAssertions) { auto options = at::TensorOptions().dtype(at::kLong); auto graph = cycle_graph(/*num_nodes=*/6, options); - auto seed = at::arange(2, 4, options); - std::vector num_neighbors = {2, 2}; auto out = pyg::sampler::neighbor_sample( /*rowptr=*/std::get<0>(graph), /*col=*/std::get<1>(graph), - /*seed=*/seed, - /*num_neighbors=*/num_neighbors, + /*seed=*/at::arange(2, 4, options), + /*num_neighbors=*/{2, 2}, /*time=*/c10::nullopt, /*seed_time=*/c10::nullopt, /*edge_weight=*/c10::nullopt, @@ -128,7 +120,6 @@ TEST(TemporalNeighborTest, BasicAssertions) { auto graph = cycle_graph(/*num_nodes=*/6, options); auto rowptr = std::get<0>(graph); auto col = std::get<1>(graph); - auto seed = at::arange(2, 4, options); // Time is equal to node ID ... auto time = at::arange(6, options); @@ -138,7 +129,7 @@ TEST(TemporalNeighborTest, BasicAssertions) { auto out1 = pyg::sampler::neighbor_sample( /*rowptr=*/rowptr, /*col=*/col, - /*seed=*/seed, + /*seed=*/at::arange(2, 4, options), /*num_neighbors=*/{2, 2}, /*time=*/time, /*seed_time=*/c10::nullopt, @@ -162,7 +153,7 @@ TEST(TemporalNeighborTest, BasicAssertions) { auto out2 = pyg::sampler::neighbor_sample( /*rowptr=*/rowptr, /*col=*/col, - /*seed=*/seed, + /*seed=*/at::arange(2, 4, options), /*num_neighbors=*/{1, 2}, /*time=*/time, /*seed_time=*/c10::nullopt, @@ -224,8 +215,6 @@ TEST(BiasedNeighborTest, BasicAssertions) { auto options = at::TensorOptions().dtype(at::kLong); auto graph = cycle_graph(/*num_nodes=*/6, options); - auto seed = at::arange(0, 2, options); - std::vector num_neighbors = {1}; auto ones = at::ones(6).view({-1, 1}); auto zeros = at::zeros(6).view({-1, 1}); @@ -235,8 +224,8 @@ TEST(BiasedNeighborTest, BasicAssertions) { auto out = pyg::sampler::neighbor_sample( /*rowptr=*/std::get<0>(graph), /*col=*/std::get<1>(graph), - /*seed=*/seed, - /*num_neighbors=*/num_neighbors, + /*seed=*/at::arange(0, 2, options), + /*num_neighbors=*/{1}, /*time=*/c10::nullopt, /*seed_time=*/c10::nullopt, /*edge_weight=*/edge_weight);