Skip to content

Commit

Permalink
Update dist_neighbor_sample (#253)
Browse files Browse the repository at this point in the history
This code belongs to the part of the whole distributed training for PyG.

This PR is complementary to the
[#246](#246) and introduces some
updates.

What has been changed:
* Removed not needed `dist_hetero_neighbor_sample` function (due to the
fact, that distributed sampling have a loop over the layers in python,
in case of hetero at the moment when we call `neighbor_sample` we have
only one edge type. So it becomes actually homo and we don't need the
`dist_hetero_neighbor_sample` and can use `dist_neighbor_sample`
instead.)
* Removed all not used outputs and left only the following: `node`,
`edge_ids`, `cummsum_sampled_nbrs_per_node`.
* Changed `std::vector<int64_t> num_neighbors` input list into `int64_t
one_hop_num`.

Added:
* Unit tests

---------

Co-authored-by: rusty1s <[email protected]>
  • Loading branch information
kgajdamo and rusty1s authored Sep 6, 2023
1 parent 6af62de commit 3a4d436
Show file tree
Hide file tree
Showing 8 changed files with 248 additions and 261 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [0.3.0] - 2023-MM-DD
### Added
- Added low-level support for distributed neighborhood sampling ([#246](https://github.com/pyg-team/pyg-lib/pull/246))
- Added low-level support for distributed neighborhood sampling ([#246](https://github.com/pyg-team/pyg-lib/pull/246), [#253](https://github.com/pyg-team/pyg-lib/pull/253))
- Added support for homogeneous and heterogeneous biased neighborhood sampling ([#247](https://github.com/pyg-team/pyg-lib/pull/247), [#251](https://github.com/pyg-team/pyg-lib/pull/251))
- Added dispatch for XPU device in `index_sort` ([#243](https://github.com/pyg-team/pyg-lib/pull/243))
- Added `metis` partitioning ([#229](https://github.com/pyg-team/pyg-lib/pull/229))
Expand Down
106 changes: 27 additions & 79 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -755,39 +755,23 @@ sample(const std::vector<node_type>& node_types,
if (!replace && !directed && !disjoint && !return_edge_id) \
return sample<false, false, false, false, false>(__VA_ARGS__);

#define DISPATCH_DIST_SAMPLE(replace, directed, disjoint, return_edge_id, ...) \
if (replace && directed && disjoint && return_edge_id) \
return sample<true, true, true, true, true>(__VA_ARGS__); \
if (replace && directed && disjoint && !return_edge_id) \
return sample<true, true, true, false, true>(__VA_ARGS__); \
if (replace && directed && !disjoint && return_edge_id) \
return sample<true, true, false, true, true>(__VA_ARGS__); \
if (replace && directed && !disjoint && !return_edge_id) \
return sample<true, true, false, false, true>(__VA_ARGS__); \
if (replace && !directed && disjoint && return_edge_id) \
return sample<true, false, true, true, true>(__VA_ARGS__); \
if (replace && !directed && disjoint && !return_edge_id) \
return sample<true, false, true, false, true>(__VA_ARGS__); \
if (replace && !directed && !disjoint && return_edge_id) \
return sample<true, false, false, true, true>(__VA_ARGS__); \
if (replace && !directed && !disjoint && !return_edge_id) \
return sample<true, false, false, false, true>(__VA_ARGS__); \
if (!replace && directed && disjoint && return_edge_id) \
return sample<false, true, true, true, true>(__VA_ARGS__); \
if (!replace && directed && disjoint && !return_edge_id) \
return sample<false, true, true, false, true>(__VA_ARGS__); \
if (!replace && directed && !disjoint && return_edge_id) \
return sample<false, true, false, true, true>(__VA_ARGS__); \
if (!replace && directed && !disjoint && !return_edge_id) \
return sample<false, true, false, false, true>(__VA_ARGS__); \
if (!replace && !directed && disjoint && return_edge_id) \
return sample<false, false, true, true, true>(__VA_ARGS__); \
if (!replace && !directed && disjoint && !return_edge_id) \
return sample<false, false, true, false, true>(__VA_ARGS__); \
if (!replace && !directed && !disjoint && return_edge_id) \
return sample<false, false, false, true, true>(__VA_ARGS__); \
if (!replace && !directed && !disjoint && !return_edge_id) \
return sample<false, false, false, false, true>(__VA_ARGS__);
#define DISPATCH_DIST_SAMPLE(replace, directed, disjoint, ...) \
if (replace && directed && disjoint) \
return sample<true, true, true, true, true>(__VA_ARGS__); \
if (replace && directed && !disjoint) \
return sample<true, true, false, true, true>(__VA_ARGS__); \
if (replace && !directed && disjoint) \
return sample<true, false, true, true, true>(__VA_ARGS__); \
if (replace && !directed && !disjoint) \
return sample<true, false, false, true, true>(__VA_ARGS__); \
if (!replace && directed && disjoint) \
return sample<false, true, true, true, true>(__VA_ARGS__); \
if (!replace && directed && !disjoint) \
return sample<false, true, false, true, true>(__VA_ARGS__); \
if (!replace && !directed && disjoint) \
return sample<false, false, true, true, true>(__VA_ARGS__); \
if (!replace && !directed && !disjoint) \
return sample<false, false, false, true, true>(__VA_ARGS__);

} // namespace

Expand Down Expand Up @@ -847,57 +831,26 @@ hetero_neighbor_sample_kernel(
edge_weight_dict, csc, temporal_strategy);
}

std::tuple<at::Tensor,
at::Tensor,
at::Tensor,
c10::optional<at::Tensor>,
std::vector<int64_t>,
std::vector<int64_t>,
std::vector<int64_t>>
std::tuple<at::Tensor, at::Tensor, std::vector<int64_t>>
dist_neighbor_sample_kernel(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
const std::vector<int64_t>& num_neighbors,
const int64_t num_neighbors,
const c10::optional<at::Tensor>& time,
const c10::optional<at::Tensor>& seed_time,
const c10::optional<at::Tensor>& edge_weight,
bool csc,
bool replace,
bool directed,
bool disjoint,
std::string temporal_strategy,
bool return_edge_id) {
DISPATCH_DIST_SAMPLE(replace, directed, disjoint, return_edge_id, rowptr, col,
seed, num_neighbors, time, seed_time, edge_weight, csc,
temporal_strategy);
}

std::tuple<c10::Dict<rel_type, at::Tensor>,
c10::Dict<rel_type, at::Tensor>,
c10::Dict<node_type, at::Tensor>,
c10::optional<c10::Dict<rel_type, at::Tensor>>,
c10::Dict<node_type, std::vector<int64_t>>,
c10::Dict<rel_type, std::vector<int64_t>>>
dist_hetero_neighbor_sample_kernel(
const std::vector<node_type>& node_types,
const std::vector<edge_type>& edge_types,
const c10::Dict<rel_type, at::Tensor>& rowptr_dict,
const c10::Dict<rel_type, at::Tensor>& col_dict,
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<rel_type, std::vector<int64_t>>& num_neighbors_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& time_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& seed_time_dict,
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_weight_dict,
bool csc,
bool replace,
bool directed,
bool disjoint,
std::string temporal_strategy,
bool return_edge_id) {
DISPATCH_DIST_SAMPLE(replace, directed, disjoint, return_edge_id, node_types,
edge_types, rowptr_dict, col_dict, seed_dict,
num_neighbors_dict, time_dict, seed_time_dict,
edge_weight_dict, csc, temporal_strategy);
std::string temporal_strategy) {
const auto out = [&] {
DISPATCH_DIST_SAMPLE(replace, directed, disjoint, rowptr, col, seed,
{num_neighbors}, time, seed_time, edge_weight, csc,
temporal_strategy);
}();
return std::make_tuple(std::get<2>(out), std::get<3>(out).value(),
std::get<6>(out));
}

TORCH_LIBRARY_IMPL(pyg, CPU, m) {
Expand All @@ -918,10 +871,5 @@ TORCH_LIBRARY_IMPL(pyg, CPU, m) {
TORCH_FN(dist_neighbor_sample_kernel));
}

TORCH_LIBRARY_IMPL(pyg, BackendSelect, m) {
m.impl(TORCH_SELECTIVE_NAME("pyg::dist_hetero_neighbor_sample"),
TORCH_FN(dist_hetero_neighbor_sample_kernel));
}

} // namespace sampler
} // namespace pyg
36 changes: 3 additions & 33 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,49 +48,19 @@ hetero_neighbor_sample_kernel(
std::string temporal_strategy,
bool return_edge_id);

std::tuple<at::Tensor,
at::Tensor,
at::Tensor,
c10::optional<at::Tensor>,
std::vector<int64_t>,
std::vector<int64_t>,
std::vector<int64_t>>
std::tuple<at::Tensor, at::Tensor, std::vector<int64_t>>
dist_neighbor_sample_kernel(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
const std::vector<int64_t>& num_neighbors,
const int64_t num_neighbors,
const c10::optional<at::Tensor>& time,
const c10::optional<at::Tensor>& seed_time,
const c10::optional<at::Tensor>& edge_weight,
bool csc,
bool replace,
bool directed,
bool disjoint,
std::string temporal_strategy,
bool return_edge_id);

std::tuple<c10::Dict<rel_type, at::Tensor>,
c10::Dict<rel_type, at::Tensor>,
c10::Dict<node_type, at::Tensor>,
c10::optional<c10::Dict<rel_type, at::Tensor>>,
c10::Dict<node_type, std::vector<int64_t>>,
c10::Dict<rel_type, std::vector<int64_t>>>
dist_hetero_neighbor_sample_kernel(
const std::vector<node_type>& node_types,
const std::vector<edge_type>& edge_types,
const c10::Dict<rel_type, at::Tensor>& rowptr_dict,
const c10::Dict<rel_type, at::Tensor>& col_dict,
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<rel_type, std::vector<int64_t>>& num_neighbors_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& time_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& seed_time_dict,
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_weight_dict,
bool csc,
bool replace,
bool directed,
bool disjoint,
std::string temporal_strategy,
bool return_edge_id);
std::string temporal_strategy);

} // namespace sampler
} // namespace pyg
103 changes: 16 additions & 87 deletions pyg_lib/csrc/sampler/neighbor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,26 +94,19 @@ hetero_neighbor_sample(
temporal_strategy, return_edge_id);
}

std::tuple<at::Tensor,
at::Tensor,
at::Tensor,
c10::optional<at::Tensor>,
std::vector<int64_t>,
std::vector<int64_t>,
std::vector<int64_t>>
dist_neighbor_sample(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
const std::vector<int64_t>& num_neighbors,
const c10::optional<at::Tensor>& time,
const c10::optional<at::Tensor>& seed_time,
const c10::optional<at::Tensor>& edge_weight,
bool csc,
bool replace,
bool directed,
bool disjoint,
std::string temporal_strategy,
bool return_edge_id) {
std::tuple<at::Tensor, at::Tensor, std::vector<int64_t>> dist_neighbor_sample(
const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
const int64_t num_neighbors,
const c10::optional<at::Tensor>& time,
const c10::optional<at::Tensor>& seed_time,
const c10::optional<at::Tensor>& edge_weight,
bool csc,
bool replace,
bool directed,
bool disjoint,
std::string temporal_strategy) {
at::TensorArg rowptr_t{rowptr, "rowtpr", 1};
at::TensorArg col_t{col, "col", 1};
at::TensorArg seed_t{seed, "seed", 1};
Expand All @@ -126,60 +119,7 @@ dist_neighbor_sample(const at::Tensor& rowptr,
.findSchemaOrThrow("pyg::dist_neighbor_sample", "")
.typed<decltype(dist_neighbor_sample)>();
return op.call(rowptr, col, seed, num_neighbors, time, seed_time, edge_weight,
csc, replace, directed, disjoint, temporal_strategy,
return_edge_id);
}

std::tuple<c10::Dict<rel_type, at::Tensor>,
c10::Dict<rel_type, at::Tensor>,
c10::Dict<node_type, at::Tensor>,
c10::optional<c10::Dict<rel_type, at::Tensor>>,
c10::Dict<node_type, std::vector<int64_t>>,
c10::Dict<rel_type, std::vector<int64_t>>>
dist_hetero_neighbor_sample(
const std::vector<node_type>& node_types,
const std::vector<edge_type>& edge_types,
const c10::Dict<rel_type, at::Tensor>& rowptr_dict,
const c10::Dict<rel_type, at::Tensor>& col_dict,
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<rel_type, std::vector<int64_t>>& num_neighbors_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& time_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& seed_time_dict,
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_weight_dict,
bool csc,
bool replace,
bool directed,
bool disjoint,
std::string temporal_strategy,
bool return_edge_id) {
TORCH_CHECK(rowptr_dict.size() == col_dict.size(),
"Number of edge types in 'rowptr_dict' and 'col_dict' must match")

std::vector<at::TensorArg> rowptr_dict_args;
std::vector<at::TensorArg> col_dict_args;
std::vector<at::TensorArg> seed_dict_args;
pyg::utils::fill_tensor_args(rowptr_dict_args, rowptr_dict, "rowptr_dict", 0);
pyg::utils::fill_tensor_args(col_dict_args, col_dict, "col_dict", 0);
pyg::utils::fill_tensor_args(seed_dict_args, seed_dict, "seed_dict", 0);
at::CheckedFrom c{"dist_hetero_neighbor_sample"};

at::checkAllDefined(c, rowptr_dict_args);
at::checkAllDefined(c, col_dict_args);
at::checkAllDefined(c, seed_dict_args);
at::checkAllSameType(c, rowptr_dict_args);
at::checkAllSameType(c, col_dict_args);
at::checkAllSameType(c, seed_dict_args);
at::checkSameType(c, rowptr_dict_args[0], col_dict_args[0]);
at::checkSameType(c, rowptr_dict_args[0], seed_dict_args[0]);

static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("pyg::dist_hetero_neighbor_sample", "")
.typed<decltype(dist_hetero_neighbor_sample)>();
return op.call(node_types, edge_types, rowptr_dict, col_dict, seed_dict,
num_neighbors_dict, time_dict, seed_time_dict,
edge_weight_dict, csc, replace, directed, disjoint,
temporal_strategy, return_edge_id);
csc, replace, directed, disjoint, temporal_strategy);
}

TORCH_LIBRARY_FRAGMENT(pyg, m) {
Expand All @@ -201,22 +141,11 @@ TORCH_LIBRARY_FRAGMENT(pyg, m) {
"(Dict(str, Tensor), Dict(str, Tensor), Dict(str, Tensor), "
"Dict(str, Tensor)?, Dict(str, int[]), Dict(str, int[]))"));
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::dist_neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int[] "
"pyg::dist_neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int "
"num_neighbors, Tensor? time = None, Tensor? seed_time = None, Tensor? "
"edge_weight = None, bool csc = False, bool replace = False, bool "
"directed = True, bool disjoint = False, str temporal_strategy = "
"'uniform', bool return_edge_id = True) "
"-> (Tensor, Tensor, Tensor, Tensor?, int[], int[], int[])"));
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::dist_hetero_neighbor_sample(str[] node_types, (str, str, str)[] "
"edge_types, Dict(str, Tensor) rowptr_dict, Dict(str, Tensor) col_dict, "
"Dict(str, Tensor) seed_dict, Dict(str, int[]) num_neighbors_dict, "
"Dict(str, Tensor)? time_dict = None, Dict(str, Tensor)? seed_time_dict "
"= None, Dict(str, Tensor)? edge_weight_dict = None, bool csc = False, "
"bool replace = False, bool directed = True, bool disjoint = False, "
"str temporal_strategy = 'uniform', bool return_edge_id = True) -> "
"(Dict(str, Tensor), Dict(str, Tensor), Dict(str, Tensor), "
"Dict(str, Tensor)?, Dict(str, int[]), Dict(str, int[]))"));
"'uniform') -> (Tensor, Tensor, int[])"));
}

} // namespace sampler
Expand Down
41 changes: 3 additions & 38 deletions pyg_lib/csrc/sampler/neighbor.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,54 +62,19 @@ hetero_neighbor_sample(
bool return_edge_id = true);

PYG_API
std::tuple<at::Tensor,
at::Tensor,
at::Tensor,
c10::optional<at::Tensor>,
std::vector<int64_t>,
std::vector<int64_t>,
std::vector<int64_t>>
dist_neighbor_sample(
std::tuple<at::Tensor, at::Tensor, std::vector<int64_t>> dist_neighbor_sample(
const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
const std::vector<int64_t>& num_neighbors,
const int64_t num_neighbors,
const c10::optional<at::Tensor>& time = c10::nullopt,
const c10::optional<at::Tensor>& seed_time = c10::nullopt,
const c10::optional<at::Tensor>& edge_weight = c10::nullopt,
bool csc = false,
bool replace = false,
bool directed = true,
bool disjoint = false,
std::string strategy = "uniform",
bool return_edge_id = true);

PYG_API
std::tuple<c10::Dict<rel_type, at::Tensor>,
c10::Dict<rel_type, at::Tensor>,
c10::Dict<node_type, at::Tensor>,
c10::optional<c10::Dict<rel_type, at::Tensor>>,
c10::Dict<node_type, std::vector<int64_t>>,
c10::Dict<rel_type, std::vector<int64_t>>>
dist_hetero_neighbor_sample(
const std::vector<node_type>& node_types,
const std::vector<edge_type>& edge_types,
const c10::Dict<rel_type, at::Tensor>& rowptr_dict,
const c10::Dict<rel_type, at::Tensor>& col_dict,
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<rel_type, std::vector<int64_t>>& num_neighbors_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& time_dict =
c10::nullopt,
const c10::optional<c10::Dict<node_type, at::Tensor>>& seed_time_dict =
c10::nullopt,
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_weight_dict =
c10::nullopt,
bool csc = false,
bool replace = false,
bool directed = true,
bool disjoint = false,
std::string strategy = "uniform",
bool return_edge_id = true);
std::string strategy = "uniform");

} // namespace sampler
} // namespace pyg
Loading

0 comments on commit 3a4d436

Please sign in to comment.