diff --git a/pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.cpp b/pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.cpp index e198de8f3..f538eb9db 100644 --- a/pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.cpp @@ -20,7 +20,7 @@ std::tuple> merge_outputs( const std::vector& nodes, - const std::vector>& cumm_sampled_nbrs_per_node, + const std::vector>& cumsum_neighbors_per_node, const std::vector& partition_ids, const std::vector& partition_orders, const int64_t partitions_num, @@ -40,11 +40,11 @@ merge_outputs( at::parallel_for(0, partitions_num, 1, [&](size_t _s, size_t _e) { for (auto p_id = _s; p_id < _e; p_id++) { auto cummsum1 = - std::vector(cumm_sampled_nbrs_per_node[p_id].begin() + 1, - cumm_sampled_nbrs_per_node[p_id].end()); + std::vector(cumsum_neighbors_per_node[p_id].begin() + 1, + cumsum_neighbors_per_node[p_id].end()); auto cummsum2 = - std::vector(cumm_sampled_nbrs_per_node[p_id].begin(), - cumm_sampled_nbrs_per_node[p_id].end() - 1); + std::vector(cumsum_neighbors_per_node[p_id].begin(), + cumsum_neighbors_per_node[p_id].end() - 1); std::transform(cummsum1.begin(), cummsum1.end(), cummsum2.begin(), std::back_inserter(population), [](int64_t a, int64_t b) { return std::abs(a - b); }); @@ -90,11 +90,11 @@ merge_outputs( // When it comes to node and batch, we omit seed nodes. // In the case of edges, we take into account all sampled edge ids. - auto begin = cumm_sampled_nbrs_per_node[p_id][p_order]; - auto begin_edge = begin - cumm_sampled_nbrs_per_node[p_id][0]; + auto begin = cumsum_neighbors_per_node[p_id][p_order]; + auto begin_edge = begin - cumsum_neighbors_per_node[p_id][0]; - auto end = cumm_sampled_nbrs_per_node[p_id][p_order + 1]; - auto end_edge = end - cumm_sampled_nbrs_per_node[p_id][0]; + auto end = cumsum_neighbors_per_node[p_id][p_order + 1]; + auto end_edge = end - cumsum_neighbors_per_node[p_id][0]; std::copy(sampled_nodes_vec[p_id].begin() + begin, sampled_nodes_vec[p_id].begin() + end, @@ -155,7 +155,7 @@ std::tuple> merge_sampler_outputs_kernel( const std::vector& nodes, - const std::vector>& cumm_sampled_nbrs_per_node, + const std::vector>& cumsum_neighbors_per_node, const std::vector& partition_ids, const std::vector& partition_orders, const int64_t partitions_num, @@ -164,7 +164,7 @@ merge_sampler_outputs_kernel( const c10::optional& batch, bool disjoint, bool with_edge) { - DISPATCH_MERGE_OUTPUTS(disjoint, with_edge, nodes, cumm_sampled_nbrs_per_node, + DISPATCH_MERGE_OUTPUTS(disjoint, with_edge, nodes, cumsum_neighbors_per_node, partition_ids, partition_orders, partitions_num, one_hop_num, edge_ids, batch); } diff --git a/pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.h b/pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.h index 1936e23d5..f0bdf2b69 100644 --- a/pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.h +++ b/pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.h @@ -11,7 +11,7 @@ std::tuple> merge_sampler_outputs_kernel( const std::vector& nodes, - const std::vector>& cumm_sampled_nbrs_per_node, + const std::vector>& cumsum_neighbors_per_node, const std::vector& partition_ids, const std::vector& partition_orders, const int64_t partitions_num, diff --git a/pyg_lib/csrc/sampler/dist_merge_outputs.cpp b/pyg_lib/csrc/sampler/dist_merge_outputs.cpp index 984fdf617..2e885c8ea 100644 --- a/pyg_lib/csrc/sampler/dist_merge_outputs.cpp +++ b/pyg_lib/csrc/sampler/dist_merge_outputs.cpp @@ -14,7 +14,7 @@ std::tuple> merge_sampler_outputs( const std::vector& nodes, - const std::vector>& cumm_sampled_nbrs_per_node, + const std::vector>& cumsum_neighbors_per_node, const std::vector& partition_ids, const std::vector& partition_orders, const int64_t partitions_num, @@ -45,7 +45,7 @@ merge_sampler_outputs( static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("pyg::merge_sampler_outputs", "") .typed(); - return op.call(nodes, cumm_sampled_nbrs_per_node, partition_ids, + return op.call(nodes, cumsum_neighbors_per_node, partition_ids, partition_orders, partitions_num, one_hop_num, edge_ids, batch, disjoint, with_edge); } @@ -53,7 +53,7 @@ merge_sampler_outputs( TORCH_LIBRARY_FRAGMENT(pyg, m) { m.def(TORCH_SELECTIVE_SCHEMA( "pyg::merge_sampler_outputs(Tensor[] nodes, " - "int[][] cumm_sampled_nbrs_per_node, int[] partition_ids, int[] " + "int[][] cumsum_neighbors_per_node, int[] partition_ids, int[] " "partition_orders, int partitions_num, int one_hop_num, Tensor[]? " "edge_ids, Tensor? batch, bool disjoint, bool with_edge) -> (Tensor, " "Tensor?, Tensor?, int[])")); diff --git a/pyg_lib/csrc/sampler/dist_merge_outputs.h b/pyg_lib/csrc/sampler/dist_merge_outputs.h index 5770025a1..99680a679 100644 --- a/pyg_lib/csrc/sampler/dist_merge_outputs.h +++ b/pyg_lib/csrc/sampler/dist_merge_outputs.h @@ -10,7 +10,7 @@ namespace sampler { // For distributed training purpose. Merges samplers outputs from different // partitions, so that they are sorted according to the sampling order. // Removes seed nodes from sampled nodes and calculates how many neighbors -// were sampled by each src node based on the :obj:`cumm_sampled_nbrs_per_node`. +// were sampled by each src node based on the :obj:`cumsum_neighbors_per_node`. PYG_API std::tuple, @@ -18,7 +18,7 @@ std::tuple> merge_sampler_outputs( const std::vector& nodes, - const std::vector>& cumm_sampled_nbrs_per_node, + const std::vector>& cumsum_neighbors_per_node, const std::vector& partition_ids, const std::vector& partition_orders, const int64_t partitions_num, diff --git a/test/csrc/sampler/test_dist_merge_outputs.cpp b/test/csrc/sampler/test_dist_merge_outputs.cpp index a0d21ec0e..85501de23 100644 --- a/test/csrc/sampler/test_dist_merge_outputs.cpp +++ b/test/csrc/sampler/test_dist_merge_outputs.cpp @@ -20,13 +20,13 @@ TEST(DistMergeOutputsTest, BasicAssertions) { at::tensor({14, 15, 16}, options), at::tensor({19, 20}, options)}; - const std::vector> cumm_sampled_nbrs_per_node = { + const std::vector> cumsum_neighbors_per_node = { {1, 3}, {2, 4, 5}, {1, 3}}; const std::vector partition_ids = {1, 1, 0, 2}; const std::vector partition_orders = {0, 1, 0, 0}; auto out = pyg::sampler::merge_sampler_outputs( - nodes, cumm_sampled_nbrs_per_node, partition_ids, partition_orders, + nodes, cumsum_neighbors_per_node, partition_ids, partition_orders, partitions_num, one_hop_num, edge_ids, /*batch=*/c10::nullopt, disjoint, with_edge); @@ -56,13 +56,13 @@ TEST(DistMergeOutputsAllNeighborsTest, BasicAssertions) { at::tensor({14, 15, 16}, options), at::tensor({19, 20, 21}, options)}; - const std::vector> cumm_sampled_nbrs_per_node = { + const std::vector> cumsum_neighbors_per_node = { {1, 3}, {2, 4, 5}, {1, 4}}; const std::vector partition_ids = {1, 1, 0, 2}; const std::vector partition_orders = {0, 1, 0, 0}; auto out = pyg::sampler::merge_sampler_outputs( - nodes, cumm_sampled_nbrs_per_node, partition_ids, partition_orders, + nodes, cumsum_neighbors_per_node, partition_ids, partition_orders, partitions_num, one_hop_num, edge_ids, /*batch=*/c10::nullopt, disjoint, with_edge); @@ -90,13 +90,13 @@ TEST(DistDisjointMergeOutputsTest, BasicAssertions) { at::tensor({3, 9, 10}, options)}; const auto batch = at::tensor({0, 1, 2, 3}, options); - const std::vector> cumm_sampled_nbrs_per_node = { + const std::vector> cumsum_neighbors_per_node = { {1, 3}, {2, 4, 5}, {1, 3}}; const std::vector partition_ids = {1, 1, 0, 2}; const std::vector partition_orders = {0, 1, 0, 0}; auto out = pyg::sampler::merge_sampler_outputs( - nodes, cumm_sampled_nbrs_per_node, partition_ids, partition_orders, + nodes, cumsum_neighbors_per_node, partition_ids, partition_orders, partitions_num, one_hop_num, /*edge_ids=*/c10::nullopt, batch, disjoint, with_edge);