Skip to content

Commit

Permalink
update dist_neighbor_sample, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kgajdamo committed Sep 6, 2023
1 parent 6af62de commit 840e77d
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 184 deletions.
52 changes: 8 additions & 44 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -847,17 +847,11 @@ hetero_neighbor_sample_kernel(
edge_weight_dict, csc, temporal_strategy);
}

std::tuple<at::Tensor,
at::Tensor,
at::Tensor,
c10::optional<at::Tensor>,
std::vector<int64_t>,
std::vector<int64_t>,
std::vector<int64_t>>
std::tuple<at::Tensor, c10::optional<at::Tensor>, std::vector<int64_t>>
dist_neighbor_sample_kernel(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
const std::vector<int64_t>& num_neighbors,
const int64_t one_hop_num,
const c10::optional<at::Tensor>& time,
const c10::optional<at::Tensor>& seed_time,
const c10::optional<at::Tensor>& edge_weight,
Expand All @@ -867,37 +861,12 @@ dist_neighbor_sample_kernel(const at::Tensor& rowptr,
bool disjoint,
std::string temporal_strategy,
bool return_edge_id) {
DISPATCH_DIST_SAMPLE(replace, directed, disjoint, return_edge_id, rowptr, col,
seed, num_neighbors, time, seed_time, edge_weight, csc,
temporal_strategy);
}

std::tuple<c10::Dict<rel_type, at::Tensor>,
c10::Dict<rel_type, at::Tensor>,
c10::Dict<node_type, at::Tensor>,
c10::optional<c10::Dict<rel_type, at::Tensor>>,
c10::Dict<node_type, std::vector<int64_t>>,
c10::Dict<rel_type, std::vector<int64_t>>>
dist_hetero_neighbor_sample_kernel(
const std::vector<node_type>& node_types,
const std::vector<edge_type>& edge_types,
const c10::Dict<rel_type, at::Tensor>& rowptr_dict,
const c10::Dict<rel_type, at::Tensor>& col_dict,
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<rel_type, std::vector<int64_t>>& num_neighbors_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& time_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& seed_time_dict,
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_weight_dict,
bool csc,
bool replace,
bool directed,
bool disjoint,
std::string temporal_strategy,
bool return_edge_id) {
DISPATCH_DIST_SAMPLE(replace, directed, disjoint, return_edge_id, node_types,
edge_types, rowptr_dict, col_dict, seed_dict,
num_neighbors_dict, time_dict, seed_time_dict,
edge_weight_dict, csc, temporal_strategy);
const auto out = [&] {
DISPATCH_DIST_SAMPLE(replace, directed, disjoint, return_edge_id, rowptr,
col, seed, {one_hop_num}, time, seed_time, edge_weight,
csc, temporal_strategy);
}();
return std::make_tuple(std::get<2>(out), std::get<3>(out), std::get<6>(out));
}

TORCH_LIBRARY_IMPL(pyg, CPU, m) {
Expand All @@ -918,10 +887,5 @@ TORCH_LIBRARY_IMPL(pyg, CPU, m) {
TORCH_FN(dist_neighbor_sample_kernel));
}

TORCH_LIBRARY_IMPL(pyg, BackendSelect, m) {
m.impl(TORCH_SELECTIVE_NAME("pyg::dist_hetero_neighbor_sample"),
TORCH_FN(dist_hetero_neighbor_sample_kernel));
}

} // namespace sampler
} // namespace pyg
33 changes: 2 additions & 31 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,11 @@ hetero_neighbor_sample_kernel(
std::string temporal_strategy,
bool return_edge_id);

std::tuple<at::Tensor,
at::Tensor,
at::Tensor,
c10::optional<at::Tensor>,
std::vector<int64_t>,
std::vector<int64_t>,
std::vector<int64_t>>
std::tuple<at::Tensor, c10::optional<at::Tensor>, std::vector<int64_t>>
dist_neighbor_sample_kernel(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
const std::vector<int64_t>& num_neighbors,
const int64_t one_hop_num,
const c10::optional<at::Tensor>& time,
const c10::optional<at::Tensor>& seed_time,
const c10::optional<at::Tensor>& edge_weight,
Expand All @@ -69,28 +63,5 @@ dist_neighbor_sample_kernel(const at::Tensor& rowptr,
std::string temporal_strategy,
bool return_edge_id);

std::tuple<c10::Dict<rel_type, at::Tensor>,
c10::Dict<rel_type, at::Tensor>,
c10::Dict<node_type, at::Tensor>,
c10::optional<c10::Dict<rel_type, at::Tensor>>,
c10::Dict<node_type, std::vector<int64_t>>,
c10::Dict<rel_type, std::vector<int64_t>>>
dist_hetero_neighbor_sample_kernel(
const std::vector<node_type>& node_types,
const std::vector<edge_type>& edge_types,
const c10::Dict<rel_type, at::Tensor>& rowptr_dict,
const c10::Dict<rel_type, at::Tensor>& col_dict,
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<rel_type, std::vector<int64_t>>& num_neighbors_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& time_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& seed_time_dict,
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_weight_dict,
bool csc,
bool replace,
bool directed,
bool disjoint,
std::string temporal_strategy,
bool return_edge_id);

} // namespace sampler
} // namespace pyg
80 changes: 6 additions & 74 deletions pyg_lib/csrc/sampler/neighbor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,11 @@ hetero_neighbor_sample(
temporal_strategy, return_edge_id);
}

std::tuple<at::Tensor,
at::Tensor,
at::Tensor,
c10::optional<at::Tensor>,
std::vector<int64_t>,
std::vector<int64_t>,
std::vector<int64_t>>
std::tuple<at::Tensor, c10::optional<at::Tensor>, std::vector<int64_t>>
dist_neighbor_sample(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
const std::vector<int64_t>& num_neighbors,
const int64_t one_hop_num,
const c10::optional<at::Tensor>& time,
const c10::optional<at::Tensor>& seed_time,
const c10::optional<at::Tensor>& edge_weight,
Expand All @@ -125,63 +119,11 @@ dist_neighbor_sample(const at::Tensor& rowptr,
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("pyg::dist_neighbor_sample", "")
.typed<decltype(dist_neighbor_sample)>();
return op.call(rowptr, col, seed, num_neighbors, time, seed_time, edge_weight,
return op.call(rowptr, col, seed, one_hop_num, time, seed_time, edge_weight,
csc, replace, directed, disjoint, temporal_strategy,
return_edge_id);
}

std::tuple<c10::Dict<rel_type, at::Tensor>,
c10::Dict<rel_type, at::Tensor>,
c10::Dict<node_type, at::Tensor>,
c10::optional<c10::Dict<rel_type, at::Tensor>>,
c10::Dict<node_type, std::vector<int64_t>>,
c10::Dict<rel_type, std::vector<int64_t>>>
dist_hetero_neighbor_sample(
const std::vector<node_type>& node_types,
const std::vector<edge_type>& edge_types,
const c10::Dict<rel_type, at::Tensor>& rowptr_dict,
const c10::Dict<rel_type, at::Tensor>& col_dict,
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<rel_type, std::vector<int64_t>>& num_neighbors_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& time_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& seed_time_dict,
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_weight_dict,
bool csc,
bool replace,
bool directed,
bool disjoint,
std::string temporal_strategy,
bool return_edge_id) {
TORCH_CHECK(rowptr_dict.size() == col_dict.size(),
"Number of edge types in 'rowptr_dict' and 'col_dict' must match")

std::vector<at::TensorArg> rowptr_dict_args;
std::vector<at::TensorArg> col_dict_args;
std::vector<at::TensorArg> seed_dict_args;
pyg::utils::fill_tensor_args(rowptr_dict_args, rowptr_dict, "rowptr_dict", 0);
pyg::utils::fill_tensor_args(col_dict_args, col_dict, "col_dict", 0);
pyg::utils::fill_tensor_args(seed_dict_args, seed_dict, "seed_dict", 0);
at::CheckedFrom c{"dist_hetero_neighbor_sample"};

at::checkAllDefined(c, rowptr_dict_args);
at::checkAllDefined(c, col_dict_args);
at::checkAllDefined(c, seed_dict_args);
at::checkAllSameType(c, rowptr_dict_args);
at::checkAllSameType(c, col_dict_args);
at::checkAllSameType(c, seed_dict_args);
at::checkSameType(c, rowptr_dict_args[0], col_dict_args[0]);
at::checkSameType(c, rowptr_dict_args[0], seed_dict_args[0]);

static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("pyg::dist_hetero_neighbor_sample", "")
.typed<decltype(dist_hetero_neighbor_sample)>();
return op.call(node_types, edge_types, rowptr_dict, col_dict, seed_dict,
num_neighbors_dict, time_dict, seed_time_dict,
edge_weight_dict, csc, replace, directed, disjoint,
temporal_strategy, return_edge_id);
}

TORCH_LIBRARY_FRAGMENT(pyg, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int[] "
Expand All @@ -201,22 +143,12 @@ TORCH_LIBRARY_FRAGMENT(pyg, m) {
"(Dict(str, Tensor), Dict(str, Tensor), Dict(str, Tensor), "
"Dict(str, Tensor)?, Dict(str, int[]), Dict(str, int[]))"));
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::dist_neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int[] "
"num_neighbors, Tensor? time = None, Tensor? seed_time = None, Tensor? "
"pyg::dist_neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int "
"one_hop_num, Tensor? time = None, Tensor? seed_time = None, Tensor? "
"edge_weight = None, bool csc = False, bool replace = False, bool "
"directed = True, bool disjoint = False, str temporal_strategy = "
"'uniform', bool return_edge_id = True) "
"-> (Tensor, Tensor, Tensor, Tensor?, int[], int[], int[])"));
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::dist_hetero_neighbor_sample(str[] node_types, (str, str, str)[] "
"edge_types, Dict(str, Tensor) rowptr_dict, Dict(str, Tensor) col_dict, "
"Dict(str, Tensor) seed_dict, Dict(str, int[]) num_neighbors_dict, "
"Dict(str, Tensor)? time_dict = None, Dict(str, Tensor)? seed_time_dict "
"= None, Dict(str, Tensor)? edge_weight_dict = None, bool csc = False, "
"bool replace = False, bool directed = True, bool disjoint = False, "
"str temporal_strategy = 'uniform', bool return_edge_id = True) -> "
"(Dict(str, Tensor), Dict(str, Tensor), Dict(str, Tensor), "
"Dict(str, Tensor)?, Dict(str, int[]), Dict(str, int[]))"));
"-> (Tensor, Tensor?, int[])"));
}

} // namespace sampler
Expand Down
37 changes: 2 additions & 35 deletions pyg_lib/csrc/sampler/neighbor.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,12 @@ hetero_neighbor_sample(
bool return_edge_id = true);

PYG_API
std::tuple<at::Tensor,
at::Tensor,
at::Tensor,
c10::optional<at::Tensor>,
std::vector<int64_t>,
std::vector<int64_t>,
std::vector<int64_t>>
std::tuple<at::Tensor, c10::optional<at::Tensor>, std::vector<int64_t>>
dist_neighbor_sample(
const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
const std::vector<int64_t>& num_neighbors,
const int64_t one_hop_num,
const c10::optional<at::Tensor>& time = c10::nullopt,
const c10::optional<at::Tensor>& seed_time = c10::nullopt,
const c10::optional<at::Tensor>& edge_weight = c10::nullopt,
Expand All @@ -84,32 +78,5 @@ dist_neighbor_sample(
std::string strategy = "uniform",
bool return_edge_id = true);

PYG_API
std::tuple<c10::Dict<rel_type, at::Tensor>,
c10::Dict<rel_type, at::Tensor>,
c10::Dict<node_type, at::Tensor>,
c10::optional<c10::Dict<rel_type, at::Tensor>>,
c10::Dict<node_type, std::vector<int64_t>>,
c10::Dict<rel_type, std::vector<int64_t>>>
dist_hetero_neighbor_sample(
const std::vector<node_type>& node_types,
const std::vector<edge_type>& edge_types,
const c10::Dict<rel_type, at::Tensor>& rowptr_dict,
const c10::Dict<rel_type, at::Tensor>& col_dict,
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<rel_type, std::vector<int64_t>>& num_neighbors_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& time_dict =
c10::nullopt,
const c10::optional<c10::Dict<node_type, at::Tensor>>& seed_time_dict =
c10::nullopt,
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_weight_dict =
c10::nullopt,
bool csc = false,
bool replace = false,
bool directed = true,
bool disjoint = false,
std::string strategy = "uniform",
bool return_edge_id = true);

} // namespace sampler
} // namespace pyg
38 changes: 38 additions & 0 deletions pyg_lib/sampler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,43 @@ def hetero_neighbor_sample(
num_nodes_per_hop_dict, num_edges_per_hop_dict)


def dist_neighbor_sample(
rowptr: Tensor,
col: Tensor,
seed: Tensor,
one_hop_num: int,
time: Optional[Tensor] = None,
seed_time: Optional[Tensor] = None,
edge_weight: Optional[Tensor] = None,
csc: bool = False,
replace: bool = False,
directed: bool = True,
disjoint: bool = False,
temporal_strategy: str = 'uniform',
return_edge_id: bool = True,
) -> Tuple[Tensor, Optional[Tensor], List[int]]:
r"""For distributed sampling purpose. Leverages the
:meth:`neighbor_sample`. Samples one hop neighborhood with duplicates from
all node indices in :obj:`seed` in the graph given by :obj:`(rowptr, col)`.
Args:
one_hop_num (int): Max number of neighbors to sample in the current
layer.
kwargs: Arguments of :meth:`neighbor_sample`.
Returns:
(torch.Tensor, Optional[torch.Tensor], List[int]):
Returns original node indices for all sampled nodes and in addition,
the indices of edges of the original graph. Lastly, returns cummulative
sum of the amount of sampled neighbors by each node in the :obj:`seed`.
"""
return torch.ops.pyg.dist_neighbor_sample(rowptr, col, seed, one_hop_num,
time, seed_time, edge_weight,
csc, replace, directed, disjoint,
temporal_strategy,
return_edge_id)


def subgraph(
rowptr: Tensor,
col: Tensor,
Expand Down Expand Up @@ -218,6 +255,7 @@ def random_walk(rowptr: Tensor, col: Tensor, seed: Tensor, walk_length: int,
__all__ = [
'neighbor_sample',
'hetero_neighbor_sample',
'dist_neighbor_sample',
'subgraph',
'random_walk',
]
Loading

0 comments on commit 840e77d

Please sign in to comment.