Skip to content

Commit

Permalink
change variables names
Browse files Browse the repository at this point in the history
  • Loading branch information
kgajdamo committed Sep 15, 2023
1 parent 597bcd7 commit ecd8e9b
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 23 deletions.
22 changes: 11 additions & 11 deletions pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ std::tuple<at::Tensor,
std::vector<int64_t>>
merge_outputs(
const std::vector<at::Tensor>& nodes,
const std::vector<std::vector<int64_t>>& cumm_sampled_nbrs_per_node,
const std::vector<std::vector<int64_t>>& cumsum_neighbors_per_node,
const std::vector<int64_t>& partition_ids,
const std::vector<int64_t>& partition_orders,
const int64_t partitions_num,
Expand All @@ -40,11 +40,11 @@ merge_outputs(
at::parallel_for(0, partitions_num, 1, [&](size_t _s, size_t _e) {
for (auto p_id = _s; p_id < _e; p_id++) {
auto cummsum1 =
std::vector<int64_t>(cumm_sampled_nbrs_per_node[p_id].begin() + 1,
cumm_sampled_nbrs_per_node[p_id].end());
std::vector<int64_t>(cumsum_neighbors_per_node[p_id].begin() + 1,
cumsum_neighbors_per_node[p_id].end());
auto cummsum2 =
std::vector<int64_t>(cumm_sampled_nbrs_per_node[p_id].begin(),
cumm_sampled_nbrs_per_node[p_id].end() - 1);
std::vector<int64_t>(cumsum_neighbors_per_node[p_id].begin(),
cumsum_neighbors_per_node[p_id].end() - 1);
std::transform(cummsum1.begin(), cummsum1.end(), cummsum2.begin(),
std::back_inserter(population),
[](int64_t a, int64_t b) { return std::abs(a - b); });
Expand Down Expand Up @@ -90,11 +90,11 @@ merge_outputs(

// When it comes to node and batch, we omit seed nodes.
// In the case of edges, we take into account all sampled edge ids.
auto begin = cumm_sampled_nbrs_per_node[p_id][p_order];
auto begin_edge = begin - cumm_sampled_nbrs_per_node[p_id][0];
auto begin = cumsum_neighbors_per_node[p_id][p_order];
auto begin_edge = begin - cumsum_neighbors_per_node[p_id][0];

auto end = cumm_sampled_nbrs_per_node[p_id][p_order + 1];
auto end_edge = end - cumm_sampled_nbrs_per_node[p_id][0];
auto end = cumsum_neighbors_per_node[p_id][p_order + 1];
auto end_edge = end - cumsum_neighbors_per_node[p_id][0];

std::copy(sampled_nodes_vec[p_id].begin() + begin,
sampled_nodes_vec[p_id].begin() + end,
Expand Down Expand Up @@ -155,7 +155,7 @@ std::tuple<at::Tensor,
std::vector<int64_t>>
merge_sampler_outputs_kernel(
const std::vector<at::Tensor>& nodes,
const std::vector<std::vector<int64_t>>& cumm_sampled_nbrs_per_node,
const std::vector<std::vector<int64_t>>& cumsum_neighbors_per_node,
const std::vector<int64_t>& partition_ids,
const std::vector<int64_t>& partition_orders,
const int64_t partitions_num,
Expand All @@ -164,7 +164,7 @@ merge_sampler_outputs_kernel(
const c10::optional<at::Tensor>& batch,
bool disjoint,
bool with_edge) {
DISPATCH_MERGE_OUTPUTS(disjoint, with_edge, nodes, cumm_sampled_nbrs_per_node,
DISPATCH_MERGE_OUTPUTS(disjoint, with_edge, nodes, cumsum_neighbors_per_node,
partition_ids, partition_orders, partitions_num,
one_hop_num, edge_ids, batch);
}
Expand Down
2 changes: 1 addition & 1 deletion pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ std::tuple<at::Tensor,
std::vector<int64_t>>
merge_sampler_outputs_kernel(
const std::vector<at::Tensor>& nodes,
const std::vector<std::vector<int64_t>>& cumm_sampled_nbrs_per_node,
const std::vector<std::vector<int64_t>>& cumsum_neighbors_per_node,
const std::vector<int64_t>& partition_ids,
const std::vector<int64_t>& partition_orders,
const int64_t partitions_num,
Expand Down
6 changes: 3 additions & 3 deletions pyg_lib/csrc/sampler/dist_merge_outputs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ std::tuple<at::Tensor,
std::vector<int64_t>>
merge_sampler_outputs(
const std::vector<at::Tensor>& nodes,
const std::vector<std::vector<int64_t>>& cumm_sampled_nbrs_per_node,
const std::vector<std::vector<int64_t>>& cumsum_neighbors_per_node,
const std::vector<int64_t>& partition_ids,
const std::vector<int64_t>& partition_orders,
const int64_t partitions_num,
Expand Down Expand Up @@ -45,15 +45,15 @@ merge_sampler_outputs(
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("pyg::merge_sampler_outputs", "")
.typed<decltype(merge_sampler_outputs)>();
return op.call(nodes, cumm_sampled_nbrs_per_node, partition_ids,
return op.call(nodes, cumsum_neighbors_per_node, partition_ids,
partition_orders, partitions_num, one_hop_num, edge_ids, batch,
disjoint, with_edge);
}

TORCH_LIBRARY_FRAGMENT(pyg, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::merge_sampler_outputs(Tensor[] nodes, "
"int[][] cumm_sampled_nbrs_per_node, int[] partition_ids, int[] "
"int[][] cumsum_neighbors_per_node, int[] partition_ids, int[] "
"partition_orders, int partitions_num, int one_hop_num, Tensor[]? "
"edge_ids, Tensor? batch, bool disjoint, bool with_edge) -> (Tensor, "
"Tensor?, Tensor?, int[])"));
Expand Down
4 changes: 2 additions & 2 deletions pyg_lib/csrc/sampler/dist_merge_outputs.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ namespace sampler {
// For distributed training purpose. Merges samplers outputs from different
// partitions, so that they are sorted according to the sampling order.
// Removes seed nodes from sampled nodes and calculates how many neighbors
// were sampled by each src node based on the :obj:`cumm_sampled_nbrs_per_node`.
// were sampled by each src node based on the :obj:`cumsum_neighbors_per_node`.
PYG_API
std::tuple<at::Tensor,
c10::optional<at::Tensor>,
c10::optional<at::Tensor>,
std::vector<int64_t>>
merge_sampler_outputs(
const std::vector<at::Tensor>& nodes,
const std::vector<std::vector<int64_t>>& cumm_sampled_nbrs_per_node,
const std::vector<std::vector<int64_t>>& cumsum_neighbors_per_node,
const std::vector<int64_t>& partition_ids,
const std::vector<int64_t>& partition_orders,
const int64_t partitions_num,
Expand Down
12 changes: 6 additions & 6 deletions test/csrc/sampler/test_dist_merge_outputs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ TEST(DistMergeOutputsTest, BasicAssertions) {
at::tensor({14, 15, 16}, options),
at::tensor({19, 20}, options)};

const std::vector<std::vector<int64_t>> cumm_sampled_nbrs_per_node = {
const std::vector<std::vector<int64_t>> cumsum_neighbors_per_node = {
{1, 3}, {2, 4, 5}, {1, 3}};
const std::vector<int64_t> partition_ids = {1, 1, 0, 2};
const std::vector<int64_t> partition_orders = {0, 1, 0, 0};

auto out = pyg::sampler::merge_sampler_outputs(
nodes, cumm_sampled_nbrs_per_node, partition_ids, partition_orders,
nodes, cumsum_neighbors_per_node, partition_ids, partition_orders,
partitions_num, one_hop_num, edge_ids, /*batch=*/c10::nullopt, disjoint,
with_edge);

Expand Down Expand Up @@ -56,13 +56,13 @@ TEST(DistMergeOutputsAllNeighborsTest, BasicAssertions) {
at::tensor({14, 15, 16}, options),
at::tensor({19, 20, 21}, options)};

const std::vector<std::vector<int64_t>> cumm_sampled_nbrs_per_node = {
const std::vector<std::vector<int64_t>> cumsum_neighbors_per_node = {
{1, 3}, {2, 4, 5}, {1, 4}};
const std::vector<int64_t> partition_ids = {1, 1, 0, 2};
const std::vector<int64_t> partition_orders = {0, 1, 0, 0};

auto out = pyg::sampler::merge_sampler_outputs(
nodes, cumm_sampled_nbrs_per_node, partition_ids, partition_orders,
nodes, cumsum_neighbors_per_node, partition_ids, partition_orders,
partitions_num, one_hop_num, edge_ids, /*batch=*/c10::nullopt, disjoint,
with_edge);

Expand Down Expand Up @@ -90,13 +90,13 @@ TEST(DistDisjointMergeOutputsTest, BasicAssertions) {
at::tensor({3, 9, 10}, options)};
const auto batch = at::tensor({0, 1, 2, 3}, options);

const std::vector<std::vector<int64_t>> cumm_sampled_nbrs_per_node = {
const std::vector<std::vector<int64_t>> cumsum_neighbors_per_node = {
{1, 3}, {2, 4, 5}, {1, 3}};
const std::vector<int64_t> partition_ids = {1, 1, 0, 2};
const std::vector<int64_t> partition_orders = {0, 1, 0, 0};

auto out = pyg::sampler::merge_sampler_outputs(
nodes, cumm_sampled_nbrs_per_node, partition_ids, partition_orders,
nodes, cumsum_neighbors_per_node, partition_ids, partition_orders,
partitions_num, one_hop_num, /*edge_ids=*/c10::nullopt, batch, disjoint,
with_edge);

Expand Down

0 comments on commit ecd8e9b

Please sign in to comment.