diff --git a/pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.cpp b/pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.cpp index 9e030b76e..1c3884289 100644 --- a/pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.cpp @@ -30,13 +30,37 @@ merge_outputs( at::Tensor out_node; c10::optional out_edge_id = c10::nullopt; c10::optional out_batch = c10::nullopt; + int64_t offset = one_hop_num; + + if (one_hop_num < 0) { + // find maximum population + std::vector population; + std::vector max_populations(partitions_num); + + at::parallel_for(0, partitions_num, 1, [&](size_t _s, size_t _e) { + for (auto p_id = _s; p_id < _e; p_id++) { + auto cummsum1 = + std::vector(cumm_sampled_nbrs_per_node[p_id].begin() + 1, + cumm_sampled_nbrs_per_node[p_id].end()); + auto cummsum2 = + std::vector(cumm_sampled_nbrs_per_node[p_id].begin(), + cumm_sampled_nbrs_per_node[p_id].end() - 1); + std::transform(cummsum1.begin(), cummsum1.end(), cummsum2.begin(), + std::back_inserter(population), + [](int64_t a, int64_t b) { return std::abs(a - b); }); + auto max = *max_element(population.begin(), population.end()); + max_populations[p_id] = max; + } + }); + offset = *max_element(max_populations.begin(), max_populations.end()); + } const auto p_size = partition_ids.size(); std::vector sampled_nbrs_per_node(p_size); const auto scalar_type = nodes[0].scalar_type(); AT_DISPATCH_INTEGRAL_TYPES(scalar_type, "merge_outputs_kernel", [&] { - std::vector sampled_nodes(p_size * one_hop_num, -1); + std::vector sampled_nodes(p_size * offset, -1); std::vector sampled_edge_ids; std::vector sampled_batch; std::vector> sampled_nodes_vec(p_size); @@ -44,11 +68,11 @@ merge_outputs( std::vector> batch_vec(p_size); if constexpr (with_edge) { - sampled_edge_ids = std::vector(p_size * one_hop_num, -1); + sampled_edge_ids = std::vector(p_size * offset, -1); edge_ids_vec = std::vector>(p_size); } if constexpr (disjoint) { - sampled_batch = std::vector(p_size * one_hop_num, -1); + sampled_batch = std::vector(p_size * offset, -1); batch_vec = std::vector>(p_size); } @@ -77,15 +101,15 @@ merge_outputs( std::copy(sampled_nodes_vec[p_id].begin() + begin, sampled_nodes_vec[p_id].begin() + end, - sampled_nodes.begin() + j * one_hop_num); + sampled_nodes.begin() + j * offset); if constexpr (with_edge) std::copy(edge_ids_vec[p_id].begin() + begin_edge, edge_ids_vec[p_id].begin() + end_edge, - sampled_edge_ids.begin() + j * one_hop_num); + sampled_edge_ids.begin() + j * offset); if constexpr (disjoint) std::copy(batch_vec[p_id].begin() + begin, batch_vec[p_id].begin() + end, - sampled_batch.begin() + j * one_hop_num); + sampled_batch.begin() + j * offset); sampled_nbrs_per_node[j] = end - begin; } diff --git a/test/csrc/sampler/test_dist_merge_outputs.cpp b/test/csrc/sampler/test_dist_merge_outputs.cpp index 3136be4e5..8a40d7f84 100644 --- a/test/csrc/sampler/test_dist_merge_outputs.cpp +++ b/test/csrc/sampler/test_dist_merge_outputs.cpp @@ -40,6 +40,42 @@ TEST(DistMergeOutputsTest, BasicAssertions) { EXPECT_EQ(std::get<3>(out), expected_sampled_nbrs_per_node); } +TEST(DistMergeOutputsAllNeighborsTest, BasicAssertions) { + auto options = at::TensorOptions().dtype(at::kLong); + + auto partitions_num = 3; + auto one_hop_num = -1; + bool disjoint = false; + bool with_edge = true; + + // seed = {0, 1, 2, 3} + const std::vector nodes = {at::tensor({2, 7, 8}, options), + at::tensor({0, 1, 4, 5, 6}, options), + at::tensor({3, 9, 10, 11}, options)}; + const std::vector edge_ids = {at::tensor({17, 18}, options), + at::tensor({14, 15, 16}, options), + at::tensor({19, 20, 21}, options)}; + + const std::vector> cumm_sampled_nbrs_per_node = { + {1, 3}, {2, 4, 5}, {1, 4}}; + const std::vector partition_ids = {1, 1, 0, 2}; + const std::vector partition_orders = {0, 1, 0, 0}; + + auto out = pyg::sampler::merge_sampler_outputs( + nodes, cumm_sampled_nbrs_per_node, partition_ids, partition_orders, + partitions_num, one_hop_num, edge_ids, /*batch=*/c10::nullopt, disjoint, + with_edge); + + auto expected_nodes = at::tensor({4, 5, 6, 7, 8, 9, 10, 11}, options); + EXPECT_TRUE(at::equal(std::get<0>(out), expected_nodes)); + + auto expected_edges = at::tensor({14, 15, 16, 17, 18, 19, 20, 21}, options); + EXPECT_TRUE(at::equal(std::get<1>(out).value(), expected_edges)); + + const std::vector expected_sampled_nbrs_per_node = {2, 1, 2, 3}; + EXPECT_EQ(std::get<3>(out), expected_sampled_nbrs_per_node); +} + TEST(DistDisjointMergeOutputsTest, BasicAssertions) { auto options = at::TensorOptions().dtype(at::kLong);