Skip to content

Commit

Permalink
Add support for distributed sampling (#246)
Browse files Browse the repository at this point in the history
This code belongs to the part of the whole distributed training for PyG.

## Description
Distributed training neighbor sampling differs from the sampling
currently implemented in pyg-lib. During distributed training nodes from
one batch can be sampled by different machines (and therefore different
samplers). The result of this is incorrect subtree/subgraph node
indexing.
To achieve correct results it is necessary to sample by one hop and then
synchronise outputs between machines.

Proposed algorithm:
1. First sample only global node ids (`sampled_nodes`) with duplicates
in `neighbor_sample`.
2. Do not sample rows and cols but save information of how many
neighbors were sampled by each node (`cumm_sum_sampled_nbrs_per_node`).
3. After each layer: synchronise and merge outputs from different
machines and take new seed nodes (without duplicates) from
sampled_nodes.
4. Sample next layer and continue 1-3 until all layers are sampled.
5. Perform global to local mappings using mapper and create (row, col)
based on a `sampled_nodes_with_duplicates` and `sampled_nbrs_per_node`.

Step 3. was implemented in pytorch_geometric.

## Added
- new argument `distributed` to the `neighbor_sample` function to enable
the algorithm described above.
- new argument `batch` to the `neighbor_sample` function that allows to
specify the initial subgraph indices for seed nodes (used with
disjoint).
- new return value `cumm_sum_sampled_nbrs_per_node` to the
`neighbor_sample` function to return cumulative sum of the sampled
neighbors per each node.

- new function `relabel_neighborhood` that is used after sampling all
layers and its purpose is to relabel global indices of the sampled nodes
to the local subtree/subgraph indices (row, col).
- new function `hetero_relabel_neighborhood` (same as
`relabel_neighborhood` but for heterogeneous graphs). Returns (row_dict
and col_dict).
- unit tests

---------

Co-authored-by: rusty1s <[email protected]>
  • Loading branch information
kgajdamo and rusty1s authored Sep 5, 2023
1 parent 888238c commit 6af62de
Show file tree
Hide file tree
Showing 6 changed files with 358 additions and 27 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [0.3.0] - 2023-MM-DD
### Added
- Added low-level support for distributed neighborhood sampling ([#246](https://github.com/pyg-team/pyg-lib/pull/246))
- Added support for homogeneous and heterogeneous biased neighborhood sampling ([#247](https://github.com/pyg-team/pyg-lib/pull/247), [#251](https://github.com/pyg-team/pyg-lib/pull/251))
- Added dispatch for XPU device in `index_sort` ([#243](https://github.com/pyg-team/pyg-lib/pull/243))
- Added `metis` partitioning ([#229](https://github.com/pyg-team/pyg-lib/pull/229))
Expand Down
183 changes: 157 additions & 26 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ template <typename node_t,
typename temporal_t,
bool replace,
bool save_edges,
bool save_edge_ids>
bool save_edge_ids,
bool distributed>
class NeighborSampler {
public:
NeighborSampler(const scalar_t* rowptr,
Expand Down Expand Up @@ -239,6 +240,16 @@ class NeighborSampler {
const auto global_dst_node_value = col_[edge_id];
const auto global_dst_node =
to_node_t(global_dst_node_value, global_src_node);

// In the distributed sampling case, we do not perform any mapping:
if constexpr (distributed) {
out_global_dst_nodes.push_back(global_dst_node);
if (save_edge_ids) {
sampled_edge_ids_.push_back(edge_id);
}
return;
}

const auto res = dst_mapper.insert(global_dst_node);
if (res.second) { // not yet sampled.
out_global_dst_nodes.push_back(global_dst_node);
Expand Down Expand Up @@ -266,12 +277,17 @@ class NeighborSampler {

// Homogeneous neighbor sampling ///////////////////////////////////////////////

template <bool replace, bool directed, bool disjoint, bool return_edge_id>
template <bool replace,
bool directed,
bool disjoint,
bool return_edge_id,
bool distributed>
std::tuple<at::Tensor,
at::Tensor,
at::Tensor,
c10::optional<at::Tensor>,
std::vector<int64_t>,
std::vector<int64_t>,
std::vector<int64_t>>
sample(const at::Tensor& rowptr,
const at::Tensor& col,
Expand Down Expand Up @@ -302,14 +318,17 @@ sample(const at::Tensor& rowptr,
c10::optional<at::Tensor> out_edge_id = c10::nullopt;
std::vector<int64_t> num_sampled_nodes_per_hop;
std::vector<int64_t> num_sampled_edges_per_hop;
std::vector<int64_t> cumsum_neighbors_per_node =
distributed ? std::vector<int64_t>(1, seed.size(0))
: std::vector<int64_t>();

AT_DISPATCH_INTEGRAL_TYPES(seed.scalar_type(), "sample_kernel", [&] {
typedef std::pair<scalar_t, scalar_t> pair_scalar_t;
typedef std::conditional_t<!disjoint, scalar_t, pair_scalar_t> node_t;
// TODO(zeyuan): Do not force int64_t for time type.
typedef int64_t temporal_t;
typedef NeighborSampler<node_t, scalar_t, temporal_t, replace, directed,
return_edge_id>
return_edge_id, distributed>
NeighborSamplerImpl;

pyg::random::RandintEngine<scalar_t> generator;
Expand Down Expand Up @@ -359,6 +378,8 @@ sample(const at::Tensor& rowptr,
/*dst_mapper=*/mapper,
/*generator=*/generator,
/*out_global_dst_nodes=*/sampled_nodes);
if constexpr (distributed)
cumsum_neighbors_per_node.push_back(sampled_nodes.size());
}
} else if (!time.has_value()) {
for (size_t i = begin; i < end; ++i) {
Expand All @@ -369,6 +390,8 @@ sample(const at::Tensor& rowptr,
/*dst_mapper=*/mapper,
/*generator=*/generator,
/*out_global_dst_nodes=*/sampled_nodes);
if constexpr (distributed)
cumsum_neighbors_per_node.push_back(sampled_nodes.size());
}
} else if constexpr (!std::is_scalar<node_t>::value) { // Temporal:
const auto time_data = time.value().data_ptr<temporal_t>();
Expand All @@ -382,6 +405,8 @@ sample(const at::Tensor& rowptr,
/*dst_mapper=*/mapper,
/*generator=*/generator,
/*out_global_dst_nodes=*/sampled_nodes);
if constexpr (distributed)
cumsum_neighbors_per_node.push_back(sampled_nodes.size());
}
}
begin = end, end = sampled_nodes.size();
Expand All @@ -400,12 +425,17 @@ sample(const at::Tensor& rowptr,
});

return std::make_tuple(out_row, out_col, out_node_id, out_edge_id,
num_sampled_nodes_per_hop, num_sampled_edges_per_hop);
num_sampled_nodes_per_hop, num_sampled_edges_per_hop,
cumsum_neighbors_per_node);
}

// Heterogeneous neighbor sampling /////////////////////////////////////////////

template <bool replace, bool directed, bool disjoint, bool return_edge_id>
template <bool replace,
bool directed,
bool disjoint,
bool return_edge_id,
bool distributed>
std::tuple<c10::Dict<rel_type, at::Tensor>,
c10::Dict<rel_type, at::Tensor>,
c10::Dict<node_type, at::Tensor>,
Expand Down Expand Up @@ -472,7 +502,7 @@ sample(const std::vector<node_type>& node_types,
typedef std::conditional_t<!disjoint, scalar_t, pair_scalar_t> node_t;
typedef int64_t temporal_t;
typedef NeighborSampler<node_t, scalar_t, temporal_t, replace, directed,
return_edge_id>
return_edge_id, distributed>
NeighborSamplerImpl;

pyg::random::RandintEngine<scalar_t> generator;
Expand Down Expand Up @@ -691,39 +721,73 @@ sample(const std::vector<node_type>& node_types,

// Dispatcher //////////////////////////////////////////////////////////////////

#define DISPATCH_SAMPLE(replace, directed, disjount, return_edge_id, ...) \
#define DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, ...) \
if (replace && directed && disjoint && return_edge_id) \
return sample<true, true, true, true>(__VA_ARGS__); \
return sample<true, true, true, true, false>(__VA_ARGS__); \
if (replace && directed && disjoint && !return_edge_id) \
return sample<true, true, true, false>(__VA_ARGS__); \
return sample<true, true, true, false, false>(__VA_ARGS__); \
if (replace && directed && !disjoint && return_edge_id) \
return sample<true, true, false, true>(__VA_ARGS__); \
return sample<true, true, false, true, false>(__VA_ARGS__); \
if (replace && directed && !disjoint && !return_edge_id) \
return sample<true, true, false, false>(__VA_ARGS__); \
return sample<true, true, false, false, false>(__VA_ARGS__); \
if (replace && !directed && disjoint && return_edge_id) \
return sample<true, false, true, true>(__VA_ARGS__); \
return sample<true, false, true, true, false>(__VA_ARGS__); \
if (replace && !directed && disjoint && !return_edge_id) \
return sample<true, false, true, false>(__VA_ARGS__); \
return sample<true, false, true, false, false>(__VA_ARGS__); \
if (replace && !directed && !disjoint && return_edge_id) \
return sample<true, false, false, true>(__VA_ARGS__); \
return sample<true, false, false, true, false>(__VA_ARGS__); \
if (replace && !directed && !disjoint && !return_edge_id) \
return sample<true, false, false, false>(__VA_ARGS__); \
return sample<true, false, false, false, false>(__VA_ARGS__); \
if (!replace && directed && disjoint && return_edge_id) \
return sample<false, true, true, true>(__VA_ARGS__); \
return sample<false, true, true, true, false>(__VA_ARGS__); \
if (!replace && directed && disjoint && !return_edge_id) \
return sample<false, true, true, false>(__VA_ARGS__); \
return sample<false, true, true, false, false>(__VA_ARGS__); \
if (!replace && directed && !disjoint && return_edge_id) \
return sample<false, true, false, true>(__VA_ARGS__); \
return sample<false, true, false, true, false>(__VA_ARGS__); \
if (!replace && directed && !disjoint && !return_edge_id) \
return sample<false, true, false, false>(__VA_ARGS__); \
return sample<false, true, false, false, false>(__VA_ARGS__); \
if (!replace && !directed && disjoint && return_edge_id) \
return sample<false, false, true, true>(__VA_ARGS__); \
return sample<false, false, true, true, false>(__VA_ARGS__); \
if (!replace && !directed && disjoint && !return_edge_id) \
return sample<false, false, true, false>(__VA_ARGS__); \
return sample<false, false, true, false, false>(__VA_ARGS__); \
if (!replace && !directed && !disjoint && return_edge_id) \
return sample<false, false, false, true>(__VA_ARGS__); \
return sample<false, false, false, true, false>(__VA_ARGS__); \
if (!replace && !directed && !disjoint && !return_edge_id) \
return sample<false, false, false, false>(__VA_ARGS__);
return sample<false, false, false, false, false>(__VA_ARGS__);

#define DISPATCH_DIST_SAMPLE(replace, directed, disjoint, return_edge_id, ...) \
if (replace && directed && disjoint && return_edge_id) \
return sample<true, true, true, true, true>(__VA_ARGS__); \
if (replace && directed && disjoint && !return_edge_id) \
return sample<true, true, true, false, true>(__VA_ARGS__); \
if (replace && directed && !disjoint && return_edge_id) \
return sample<true, true, false, true, true>(__VA_ARGS__); \
if (replace && directed && !disjoint && !return_edge_id) \
return sample<true, true, false, false, true>(__VA_ARGS__); \
if (replace && !directed && disjoint && return_edge_id) \
return sample<true, false, true, true, true>(__VA_ARGS__); \
if (replace && !directed && disjoint && !return_edge_id) \
return sample<true, false, true, false, true>(__VA_ARGS__); \
if (replace && !directed && !disjoint && return_edge_id) \
return sample<true, false, false, true, true>(__VA_ARGS__); \
if (replace && !directed && !disjoint && !return_edge_id) \
return sample<true, false, false, false, true>(__VA_ARGS__); \
if (!replace && directed && disjoint && return_edge_id) \
return sample<false, true, true, true, true>(__VA_ARGS__); \
if (!replace && directed && disjoint && !return_edge_id) \
return sample<false, true, true, false, true>(__VA_ARGS__); \
if (!replace && directed && !disjoint && return_edge_id) \
return sample<false, true, false, true, true>(__VA_ARGS__); \
if (!replace && directed && !disjoint && !return_edge_id) \
return sample<false, true, false, false, true>(__VA_ARGS__); \
if (!replace && !directed && disjoint && return_edge_id) \
return sample<false, false, true, true, true>(__VA_ARGS__); \
if (!replace && !directed && disjoint && !return_edge_id) \
return sample<false, false, true, false, true>(__VA_ARGS__); \
if (!replace && !directed && !disjoint && return_edge_id) \
return sample<false, false, false, true, true>(__VA_ARGS__); \
if (!replace && !directed && !disjoint && !return_edge_id) \
return sample<false, false, false, false, true>(__VA_ARGS__);

} // namespace

Expand All @@ -746,9 +810,13 @@ neighbor_sample_kernel(const at::Tensor& rowptr,
bool disjoint,
std::string temporal_strategy,
bool return_edge_id) {
DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, rowptr, col,
seed, num_neighbors, time, seed_time, edge_weight, csc,
temporal_strategy);
const auto out = [&] {
DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, rowptr, col,
seed, num_neighbors, time, seed_time, edge_weight, csc,
temporal_strategy);
}();
return std::make_tuple(std::get<0>(out), std::get<1>(out), std::get<2>(out),
std::get<3>(out), std::get<4>(out), std::get<5>(out));
}

std::tuple<c10::Dict<rel_type, at::Tensor>,
Expand Down Expand Up @@ -779,6 +847,59 @@ hetero_neighbor_sample_kernel(
edge_weight_dict, csc, temporal_strategy);
}

std::tuple<at::Tensor,
at::Tensor,
at::Tensor,
c10::optional<at::Tensor>,
std::vector<int64_t>,
std::vector<int64_t>,
std::vector<int64_t>>
dist_neighbor_sample_kernel(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
const std::vector<int64_t>& num_neighbors,
const c10::optional<at::Tensor>& time,
const c10::optional<at::Tensor>& seed_time,
const c10::optional<at::Tensor>& edge_weight,
bool csc,
bool replace,
bool directed,
bool disjoint,
std::string temporal_strategy,
bool return_edge_id) {
DISPATCH_DIST_SAMPLE(replace, directed, disjoint, return_edge_id, rowptr, col,
seed, num_neighbors, time, seed_time, edge_weight, csc,
temporal_strategy);
}

std::tuple<c10::Dict<rel_type, at::Tensor>,
c10::Dict<rel_type, at::Tensor>,
c10::Dict<node_type, at::Tensor>,
c10::optional<c10::Dict<rel_type, at::Tensor>>,
c10::Dict<node_type, std::vector<int64_t>>,
c10::Dict<rel_type, std::vector<int64_t>>>
dist_hetero_neighbor_sample_kernel(
const std::vector<node_type>& node_types,
const std::vector<edge_type>& edge_types,
const c10::Dict<rel_type, at::Tensor>& rowptr_dict,
const c10::Dict<rel_type, at::Tensor>& col_dict,
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<rel_type, std::vector<int64_t>>& num_neighbors_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& time_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& seed_time_dict,
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_weight_dict,
bool csc,
bool replace,
bool directed,
bool disjoint,
std::string temporal_strategy,
bool return_edge_id) {
DISPATCH_DIST_SAMPLE(replace, directed, disjoint, return_edge_id, node_types,
edge_types, rowptr_dict, col_dict, seed_dict,
num_neighbors_dict, time_dict, seed_time_dict,
edge_weight_dict, csc, temporal_strategy);
}

TORCH_LIBRARY_IMPL(pyg, CPU, m) {
m.impl(TORCH_SELECTIVE_NAME("pyg::neighbor_sample"),
TORCH_FN(neighbor_sample_kernel));
Expand All @@ -792,5 +913,15 @@ TORCH_LIBRARY_IMPL(pyg, BackendSelect, m) {
TORCH_FN(hetero_neighbor_sample_kernel));
}

TORCH_LIBRARY_IMPL(pyg, CPU, m) {
m.impl(TORCH_SELECTIVE_NAME("pyg::dist_neighbor_sample"),
TORCH_FN(dist_neighbor_sample_kernel));
}

TORCH_LIBRARY_IMPL(pyg, BackendSelect, m) {
m.impl(TORCH_SELECTIVE_NAME("pyg::dist_hetero_neighbor_sample"),
TORCH_FN(dist_hetero_neighbor_sample_kernel));
}

} // namespace sampler
} // namespace pyg
44 changes: 44 additions & 0 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,49 @@ hetero_neighbor_sample_kernel(
std::string temporal_strategy,
bool return_edge_id);

std::tuple<at::Tensor,
at::Tensor,
at::Tensor,
c10::optional<at::Tensor>,
std::vector<int64_t>,
std::vector<int64_t>,
std::vector<int64_t>>
dist_neighbor_sample_kernel(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
const std::vector<int64_t>& num_neighbors,
const c10::optional<at::Tensor>& time,
const c10::optional<at::Tensor>& seed_time,
const c10::optional<at::Tensor>& edge_weight,
bool csc,
bool replace,
bool directed,
bool disjoint,
std::string temporal_strategy,
bool return_edge_id);

std::tuple<c10::Dict<rel_type, at::Tensor>,
c10::Dict<rel_type, at::Tensor>,
c10::Dict<node_type, at::Tensor>,
c10::optional<c10::Dict<rel_type, at::Tensor>>,
c10::Dict<node_type, std::vector<int64_t>>,
c10::Dict<rel_type, std::vector<int64_t>>>
dist_hetero_neighbor_sample_kernel(
const std::vector<node_type>& node_types,
const std::vector<edge_type>& edge_types,
const c10::Dict<rel_type, at::Tensor>& rowptr_dict,
const c10::Dict<rel_type, at::Tensor>& col_dict,
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<rel_type, std::vector<int64_t>>& num_neighbors_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& time_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& seed_time_dict,
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_weight_dict,
bool csc,
bool replace,
bool directed,
bool disjoint,
std::string temporal_strategy,
bool return_edge_id);

} // namespace sampler
} // namespace pyg
Loading

0 comments on commit 6af62de

Please sign in to comment.