Skip to content

Commit

Permalink
Merge branch 'master' into dist-sampler-relabel-1
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Sep 15, 2023
2 parents 0b891bc + 3bf8802 commit c37f09b
Show file tree
Hide file tree
Showing 7 changed files with 398 additions and 46 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [0.3.0] - 2023-MM-DD
### Added
- Added low-level support for distributed neighborhood sampling ([#246](https://github.com/pyg-team/pyg-lib/pull/246), [#253](https://github.com/pyg-team/pyg-lib/pull/253), [#254](https://github.com/pyg-team/pyg-lib/pull/254))
- Added low-level support for distributed neighborhood sampling ([#246](https://github.com/pyg-team/pyg-lib/pull/246), [#252](https://github.com/pyg-team/pyg-lib/pull/252), [#253](https://github.com/pyg-team/pyg-lib/pull/253), [#254](https://github.com/pyg-team/pyg-lib/pull/254))
- Added support for homogeneous and heterogeneous biased neighborhood sampling ([#247](https://github.com/pyg-team/pyg-lib/pull/247), [#251](https://github.com/pyg-team/pyg-lib/pull/251))
- Added dispatch for XPU device in `index_sort` ([#243](https://github.com/pyg-team/pyg-lib/pull/243))
- Added `metis` partitioning ([#229](https://github.com/pyg-team/pyg-lib/pull/229))
Expand Down
170 changes: 170 additions & 0 deletions pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <torch/library.h>

#include "parallel_hashmap/phmap.h"

#include "pyg_lib/csrc/sampler/cpu/mapper.h"
#include "pyg_lib/csrc/utils/cpu/convert.h"
#include "pyg_lib/csrc/utils/types.h"

namespace pyg {
namespace sampler {

namespace {

template <bool disjoint>
std::tuple<at::Tensor,
at::Tensor,
c10::optional<at::Tensor>,
std::vector<int64_t>>
merge_outputs(
const std::vector<at::Tensor>& node_ids,
const std::vector<at::Tensor>& edge_ids,
const std::vector<std::vector<int64_t>>& cumsum_neighbors_per_node,
const std::vector<int64_t>& partition_ids,
const std::vector<int64_t>& partition_orders,
const int64_t num_partitions,
const int64_t num_neighbors,
const c10::optional<at::Tensor>& batch) {
at::Tensor out_node_id;
at::Tensor out_edge_id;
c10::optional<at::Tensor> out_batch = c10::nullopt;

auto offset = num_neighbors;

if (num_neighbors < 0) {
// find maximum population
std::vector<std::vector<int64_t>> population(num_partitions);
std::vector<int64_t> max_populations(num_partitions);

at::parallel_for(0, num_partitions, 1, [&](size_t _s, size_t _e) {
for (auto p_id = _s; p_id < _e; p_id++) {
auto cummsum1 =
std::vector<int64_t>(cumsum_neighbors_per_node[p_id].begin() + 1,
cumsum_neighbors_per_node[p_id].end());
auto cummsum2 =
std::vector<int64_t>(cumsum_neighbors_per_node[p_id].begin(),
cumsum_neighbors_per_node[p_id].end() - 1);
std::transform(cummsum1.begin(), cummsum1.end(), cummsum2.begin(),
std::back_inserter(population[p_id]),
[](int64_t a, int64_t b) { return std::abs(a - b); });
auto max =
*max_element(population[p_id].begin(), population[p_id].end());
max_populations[p_id] = max;
}
});
offset = *max_element(max_populations.begin(), max_populations.end());
}

const auto p_size = partition_ids.size();
std::vector<int64_t> sampled_neighbors_per_node(p_size);

const auto scalar_type = node_ids[0].scalar_type();
AT_DISPATCH_INTEGRAL_TYPES(scalar_type, "merge_outputs_kernel", [&] {
std::vector<scalar_t> sampled_node_ids(p_size * offset, -1);
std::vector<scalar_t> sampled_edge_ids(p_size * offset, -1);
std::vector<std::vector<scalar_t>> sampled_node_ids_vec(p_size);
std::vector<std::vector<scalar_t>> sampled_edge_ids_vec(p_size);

std::vector<scalar_t> sampled_batch;
if constexpr (disjoint) {
sampled_batch = std::vector<scalar_t>(p_size * offset, -1);
}
const auto batch_data =
disjoint ? batch.value().data_ptr<scalar_t>() : nullptr;

for (auto p_id = 0; p_id < num_partitions; p_id++) {
sampled_node_ids_vec[p_id] =
pyg::utils::to_vector<scalar_t>(node_ids[p_id]);
sampled_edge_ids_vec[p_id] =
pyg::utils::to_vector<scalar_t>(edge_ids[p_id]);
}
at::parallel_for(0, p_size, 1, [&](size_t _s, size_t _e) {
for (auto j = _s; j < _e; j++) {
auto p_id = partition_ids[j];
auto p_order = partition_orders[j];

// When it comes to node and batch, we omit seed nodes.
// In the case of edges, we take into account all sampled edge ids.
auto begin_node = cumsum_neighbors_per_node[p_id][p_order];
auto begin_edge = begin_node - cumsum_neighbors_per_node[p_id][0];

auto end_node = cumsum_neighbors_per_node[p_id][p_order + 1];
auto end_edge = end_node - cumsum_neighbors_per_node[p_id][0];

std::copy(sampled_node_ids_vec[p_id].begin() + begin_node,
sampled_node_ids_vec[p_id].begin() + end_node,
sampled_node_ids.begin() + j * offset);
std::copy(sampled_edge_ids_vec[p_id].begin() + begin_edge,
sampled_edge_ids_vec[p_id].begin() + end_edge,
sampled_edge_ids.begin() + j * offset);

if constexpr (disjoint) {
std::fill(sampled_batch.begin() + j * offset,
sampled_batch.begin() + j * offset + end_node - begin_node,
batch_data[j]);
}

sampled_neighbors_per_node[j] = end_node - begin_node;
}
});

// Remove auxilary -1 numbers:
auto neg =
std::remove(sampled_node_ids.begin(), sampled_node_ids.end(), -1);
sampled_node_ids.erase(neg, sampled_node_ids.end());
out_node_id = pyg::utils::from_vector(sampled_node_ids);

neg = std::remove(sampled_edge_ids.begin(), sampled_edge_ids.end(), -1);
sampled_edge_ids.erase(neg, sampled_edge_ids.end());
out_edge_id = pyg::utils::from_vector(sampled_edge_ids);

if constexpr (disjoint) {
neg = std::remove(sampled_batch.begin(), sampled_batch.end(), -1);
sampled_batch.erase(neg, sampled_batch.end());
out_batch = pyg::utils::from_vector(sampled_batch);
}
});

return std::make_tuple(out_node_id, out_edge_id, out_batch,
sampled_neighbors_per_node);
}

#define DISPATCH_MERGE_OUTPUTS(disjoint, ...) \
if (disjoint) \
return merge_outputs<true>(__VA_ARGS__); \
if (!disjoint) \
return merge_outputs<false>(__VA_ARGS__);

} // namespace

std::tuple<at::Tensor,
at::Tensor,
c10::optional<at::Tensor>,
std::vector<int64_t>>
merge_sampler_outputs_kernel(
const std::vector<at::Tensor>& node_ids,
const std::vector<at::Tensor>& edge_ids,
const std::vector<std::vector<int64_t>>& cumsum_neighbors_per_node,
const std::vector<int64_t>& partition_ids,
const std::vector<int64_t>& partition_orders,
const int64_t num_partitions,
const int64_t num_neighbors,
const c10::optional<at::Tensor>& batch,
bool disjoint) {
DISPATCH_MERGE_OUTPUTS(
disjoint, node_ids, edge_ids, cumsum_neighbors_per_node, partition_ids,
partition_orders, num_partitions, num_neighbors, batch);
}

// We use `BackendSelect` as a fallback to the dispatcher logic as automatic
// dispatching of std::vector<at::Tensor> is not yet supported by PyTorch.
// See: pytorch/aten/src/ATen/templates/RegisterBackendSelect.cpp.
TORCH_LIBRARY_IMPL(pyg, BackendSelect, m) {
m.impl(TORCH_SELECTIVE_NAME("pyg::merge_sampler_outputs"),
TORCH_FN(merge_sampler_outputs_kernel));
}

} // namespace sampler
} // namespace pyg
59 changes: 59 additions & 0 deletions pyg_lib/csrc/sampler/dist_merge_outputs.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#include "dist_merge_outputs.h"

#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/library.h>

#include "pyg_lib/csrc/utils/check.h"

namespace pyg {
namespace sampler {

std::tuple<at::Tensor,
at::Tensor,
c10::optional<at::Tensor>,
std::vector<int64_t>>
merge_sampler_outputs(
const std::vector<at::Tensor>& node_ids,
const std::vector<at::Tensor>& edge_ids,
const std::vector<std::vector<int64_t>>& cumsum_neighbors_per_node,
const std::vector<int64_t>& partition_ids,
const std::vector<int64_t>& partition_orders,
const int64_t num_partitions,
const int64_t num_neighbors,
const c10::optional<at::Tensor>& batch,
bool disjoint) {
std::vector<at::TensorArg> node_ids_args;
std::vector<at::TensorArg> edge_ids_args;
pyg::utils::fill_tensor_args(node_ids_args, node_ids, "node_ids", 0);
pyg::utils::fill_tensor_args(edge_ids_args, edge_ids, "edge_ids", 0);

at::CheckedFrom c{"merge_sampler_outputs"};
at::checkAllDefined(c, {node_ids_args});
at::checkAllDefined(c, {edge_ids_args});

TORCH_CHECK(partition_ids.size() == partition_orders.size(),
"Every partition ID must be assigned a sampling order");

if (disjoint) {
TORCH_CHECK(batch.has_value(),
"Disjoint sampling requires 'batch' to be specified");
}

static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("pyg::merge_sampler_outputs", "")
.typed<decltype(merge_sampler_outputs)>();
return op.call(node_ids, edge_ids, cumsum_neighbors_per_node, partition_ids,
partition_orders, num_partitions, num_neighbors, batch,
disjoint);
}

TORCH_LIBRARY_FRAGMENT(pyg, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::merge_sampler_outputs(Tensor[] node_ids, Tensor[] edge_ids, "
"int[][] cumsum_neighbors_per_node, int[] partition_ids, int[] "
"partition_orders, int num_partitions, int num_neighbors, Tensor? "
"batch, bool disjoint) -> (Tensor, Tensor, Tensor?, int[])"));
}

} // namespace sampler
} // namespace pyg
34 changes: 34 additions & 0 deletions pyg_lib/csrc/sampler/dist_merge_outputs.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#pragma once

#include <ATen/ATen.h>
#include "pyg_lib/csrc/macros.h"
#include "pyg_lib/csrc/utils/types.h"

namespace pyg {
namespace sampler {

// For distributed training purposes. Merges sampler outputs from different
// partitions, so that they are sorted according to the sampling order.
// Removes seed nodes from sampled nodes and calculates how many neighbors
// were sampled by each source node based on the cummulative sum of sampled
// neighbors for each input node.
// Returns the unified node, edge and batch indices as well as the merged
// cummulative sum of sampled neighbors.
PYG_API
std::tuple<at::Tensor,
at::Tensor,
c10::optional<at::Tensor>,
std::vector<int64_t>>
merge_sampler_outputs(
const std::vector<at::Tensor>& node_ids,
const std::vector<at::Tensor>& edge_ids,
const std::vector<std::vector<int64_t>>& cumsum_neighbors_per_node,
const std::vector<int64_t>& partition_ids,
const std::vector<int64_t>& partition_orders,
const int64_t num_partitions,
const int64_t num_neighbors,
const c10::optional<at::Tensor>& batch = c10::nullopt,
bool disjoint = false);

} // namespace sampler
} // namespace pyg
5 changes: 5 additions & 0 deletions pyg_lib/csrc/sampler/neighbor.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ hetero_neighbor_sample(
std::string strategy = "uniform",
bool return_edge_id = true);

// For distributed sampling purposes. Leverages the `neighbor_sample` function
// internally. Samples one-hop neighborhoods with duplicates from all node
// indices in `seed` in the graph given by `(rowptr, col)`.
// Returns the original node and edge indices for all sampled nodes and edges.
// Lastly, returns the cummulative sum of sampled neighbors for each input node.
PYG_API
std::tuple<at::Tensor, at::Tensor, std::vector<int64_t>> dist_neighbor_sample(
const at::Tensor& rowptr,
Expand Down
45 changes: 0 additions & 45 deletions pyg_lib/sampler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,50 +165,6 @@ def hetero_neighbor_sample(
num_nodes_per_hop_dict, num_edges_per_hop_dict)


def dist_neighbor_sample(
rowptr: Tensor,
col: Tensor,
seed: Tensor,
num_neighbors: int,
time: Optional[Tensor] = None,
seed_time: Optional[Tensor] = None,
edge_weight: Optional[Tensor] = None,
csc: bool = False,
replace: bool = False,
directed: bool = True,
disjoint: bool = False,
temporal_strategy: str = 'uniform',
) -> Tuple[Tensor, Tensor, List[int]]:
r"""For distributed sampling purpose. Leverages the
:meth:`neighbor_sample`. Samples one hop neighborhood with duplicates from
all node indices in :obj:`seed` in the graph given by :obj:`(rowptr, col)`.
Args:
num_neighbors (int): Maximum number of neighbors to sample in the
current layer.
kwargs: Arguments of :meth:`neighbor_sample`.
Returns:
(torch.Tensor, torch.Tensor, List[int]): Returns the original node and
edge indices for all sampled nodes and edges. Lastly, returns the
cummulative sum of the amount of sampled neighbors for each input node.
"""
return torch.ops.pyg.dist_neighbor_sample(
rowptr,
col,
seed,
num_neighbors,
time,
seed_time,
edge_weight,
csc,
replace,
directed,
disjoint,
temporal_strategy,
)


def subgraph(
rowptr: Tensor,
col: Tensor,
Expand Down Expand Up @@ -351,7 +307,6 @@ def hetero_relabel_neighborhood(
__all__ = [
'neighbor_sample',
'hetero_neighbor_sample',
'dist_neighbor_sample',
'subgraph',
'random_walk',
'relabel_neighborhood',
Expand Down
Loading

0 comments on commit c37f09b

Please sign in to comment.