Skip to content

Commit

Permalink
merge outputs when num neighbors = -1
Browse files Browse the repository at this point in the history
  • Loading branch information
kgajdamo committed Sep 8, 2023
1 parent eb6bcaf commit b6b55e8
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 6 deletions.
36 changes: 30 additions & 6 deletions pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,49 @@ merge_outputs(
at::Tensor out_node;
c10::optional<at::Tensor> out_edge_id = c10::nullopt;
c10::optional<at::Tensor> out_batch = c10::nullopt;
int64_t offset = one_hop_num;

if (one_hop_num < 0) {
// find maximum population
std::vector<int64_t> population;
std::vector<int64_t> max_populations(partitions_num);

at::parallel_for(0, partitions_num, 1, [&](size_t _s, size_t _e) {
for (auto p_id = _s; p_id < _e; p_id++) {
auto cummsum1 =
std::vector<int64_t>(cumm_sampled_nbrs_per_node[p_id].begin() + 1,
cumm_sampled_nbrs_per_node[p_id].end());
auto cummsum2 =
std::vector<int64_t>(cumm_sampled_nbrs_per_node[p_id].begin(),
cumm_sampled_nbrs_per_node[p_id].end() - 1);
std::transform(cummsum1.begin(), cummsum1.end(), cummsum2.begin(),
std::back_inserter(population),
[](int64_t a, int64_t b) { return std::abs(a - b); });
auto max = *max_element(population.begin(), population.end());
max_populations[p_id] = max;
}
});
offset = *max_element(max_populations.begin(), max_populations.end());
}

const auto p_size = partition_ids.size();
std::vector<int64_t> sampled_nbrs_per_node(p_size);

const auto scalar_type = nodes[0].scalar_type();
AT_DISPATCH_INTEGRAL_TYPES(scalar_type, "merge_outputs_kernel", [&] {
std::vector<scalar_t> sampled_nodes(p_size * one_hop_num, -1);
std::vector<scalar_t> sampled_nodes(p_size * offset, -1);
std::vector<scalar_t> sampled_edge_ids;
std::vector<scalar_t> sampled_batch;
std::vector<std::vector<scalar_t>> sampled_nodes_vec(p_size);
std::vector<std::vector<scalar_t>> edge_ids_vec;
std::vector<std::vector<scalar_t>> batch_vec(p_size);

if constexpr (with_edge) {
sampled_edge_ids = std::vector<scalar_t>(p_size * one_hop_num, -1);
sampled_edge_ids = std::vector<scalar_t>(p_size * offset, -1);
edge_ids_vec = std::vector<std::vector<scalar_t>>(p_size);
}
if constexpr (disjoint) {
sampled_batch = std::vector<scalar_t>(p_size * one_hop_num, -1);
sampled_batch = std::vector<scalar_t>(p_size * offset, -1);
batch_vec = std::vector<std::vector<scalar_t>>(p_size);
}

Expand Down Expand Up @@ -77,15 +101,15 @@ merge_outputs(

std::copy(sampled_nodes_vec[p_id].begin() + begin,
sampled_nodes_vec[p_id].begin() + end,
sampled_nodes.begin() + j * one_hop_num);
sampled_nodes.begin() + j * offset);
if constexpr (with_edge)
std::copy(edge_ids_vec[p_id].begin() + begin_edge,
edge_ids_vec[p_id].begin() + end_edge,
sampled_edge_ids.begin() + j * one_hop_num);
sampled_edge_ids.begin() + j * offset);
if constexpr (disjoint)
std::copy(batch_vec[p_id].begin() + begin,
batch_vec[p_id].begin() + end,
sampled_batch.begin() + j * one_hop_num);
sampled_batch.begin() + j * offset);

sampled_nbrs_per_node[j] = end - begin;
}
Expand Down
36 changes: 36 additions & 0 deletions test/csrc/sampler/test_dist_merge_outputs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,42 @@ TEST(DistMergeOutputsTest, BasicAssertions) {
EXPECT_EQ(std::get<3>(out), expected_sampled_nbrs_per_node);
}

TEST(DistMergeOutputsAllNeighborsTest, BasicAssertions) {
auto options = at::TensorOptions().dtype(at::kLong);

auto partitions_num = 3;
auto one_hop_num = -1;
bool disjoint = false;
bool with_edge = true;

// seed = {0, 1, 2, 3}
const std::vector<at::Tensor> nodes = {at::tensor({2, 7, 8}, options),
at::tensor({0, 1, 4, 5, 6}, options),
at::tensor({3, 9, 10, 11}, options)};
const std::vector<at::Tensor> edge_ids = {at::tensor({17, 18}, options),
at::tensor({14, 15, 16}, options),
at::tensor({19, 20, 21}, options)};

const std::vector<std::vector<int64_t>> cumm_sampled_nbrs_per_node = {
{1, 3}, {2, 4, 5}, {1, 4}};
const std::vector<int64_t> partition_ids = {1, 1, 0, 2};
const std::vector<int64_t> partition_orders = {0, 1, 0, 0};

auto out = pyg::sampler::merge_sampler_outputs(
nodes, cumm_sampled_nbrs_per_node, partition_ids, partition_orders,
partitions_num, one_hop_num, edge_ids, /*batch=*/c10::nullopt, disjoint,
with_edge);

auto expected_nodes = at::tensor({4, 5, 6, 7, 8, 9, 10, 11}, options);
EXPECT_TRUE(at::equal(std::get<0>(out), expected_nodes));

auto expected_edges = at::tensor({14, 15, 16, 17, 18, 19, 20, 21}, options);
EXPECT_TRUE(at::equal(std::get<1>(out).value(), expected_edges));

const std::vector<int64_t> expected_sampled_nbrs_per_node = {2, 1, 2, 3};
EXPECT_EQ(std::get<3>(out), expected_sampled_nbrs_per_node);
}

TEST(DistDisjointMergeOutputsTest, BasicAssertions) {
auto options = at::TensorOptions().dtype(at::kLong);

Expand Down

0 comments on commit b6b55e8

Please sign in to comment.