Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Sep 15, 2023
1 parent 5d6d4ee commit d391ea7
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 148 deletions.
59 changes: 21 additions & 38 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -755,39 +755,23 @@ sample(const std::vector<node_type>& node_types,
if (!replace && !directed && !disjoint && !return_edge_id) \
return sample<false, false, false, false, false>(__VA_ARGS__);

#define DISPATCH_DIST_SAMPLE(replace, directed, disjoint, return_edge_id, ...) \
if (replace && directed && disjoint && return_edge_id) \
return sample<true, true, true, true, true>(__VA_ARGS__); \
if (replace && directed && disjoint && !return_edge_id) \
return sample<true, true, true, false, true>(__VA_ARGS__); \
if (replace && directed && !disjoint && return_edge_id) \
return sample<true, true, false, true, true>(__VA_ARGS__); \
if (replace && directed && !disjoint && !return_edge_id) \
return sample<true, true, false, false, true>(__VA_ARGS__); \
if (replace && !directed && disjoint && return_edge_id) \
return sample<true, false, true, true, true>(__VA_ARGS__); \
if (replace && !directed && disjoint && !return_edge_id) \
return sample<true, false, true, false, true>(__VA_ARGS__); \
if (replace && !directed && !disjoint && return_edge_id) \
return sample<true, false, false, true, true>(__VA_ARGS__); \
if (replace && !directed && !disjoint && !return_edge_id) \
return sample<true, false, false, false, true>(__VA_ARGS__); \
if (!replace && directed && disjoint && return_edge_id) \
return sample<false, true, true, true, true>(__VA_ARGS__); \
if (!replace && directed && disjoint && !return_edge_id) \
return sample<false, true, true, false, true>(__VA_ARGS__); \
if (!replace && directed && !disjoint && return_edge_id) \
return sample<false, true, false, true, true>(__VA_ARGS__); \
if (!replace && directed && !disjoint && !return_edge_id) \
return sample<false, true, false, false, true>(__VA_ARGS__); \
if (!replace && !directed && disjoint && return_edge_id) \
return sample<false, false, true, true, true>(__VA_ARGS__); \
if (!replace && !directed && disjoint && !return_edge_id) \
return sample<false, false, true, false, true>(__VA_ARGS__); \
if (!replace && !directed && !disjoint && return_edge_id) \
return sample<false, false, false, true, true>(__VA_ARGS__); \
if (!replace && !directed && !disjoint && !return_edge_id) \
return sample<false, false, false, false, true>(__VA_ARGS__);
#define DISPATCH_DIST_SAMPLE(replace, directed, disjoint, ...) \
if (replace && directed && disjoint) \
return sample<true, true, true, true, true>(__VA_ARGS__); \
if (replace && directed && !disjoint) \
return sample<true, true, false, true, true>(__VA_ARGS__); \
if (replace && !directed && disjoint) \
return sample<true, false, true, true, true>(__VA_ARGS__); \
if (replace && !directed && !disjoint) \
return sample<true, false, false, true, true>(__VA_ARGS__); \
if (!replace && directed && disjoint) \
return sample<false, true, true, true, true>(__VA_ARGS__); \
if (!replace && directed && !disjoint) \
return sample<false, true, false, true, true>(__VA_ARGS__); \
if (!replace && !directed && disjoint) \
return sample<false, false, true, true, true>(__VA_ARGS__); \
if (!replace && !directed && !disjoint) \
return sample<false, false, false, true, true>(__VA_ARGS__);

} // namespace

Expand Down Expand Up @@ -859,12 +843,11 @@ dist_neighbor_sample_kernel(const at::Tensor& rowptr,
bool replace,
bool directed,
bool disjoint,
std::string temporal_strategy,
bool return_edge_id) {
std::string temporal_strategy) {
const auto out = [&] {
DISPATCH_DIST_SAMPLE(replace, directed, disjoint, return_edge_id, rowptr,
col, seed, {num_neighbors}, time, seed_time,
edge_weight, csc, temporal_strategy);
DISPATCH_DIST_SAMPLE(replace, directed, disjoint, rowptr, col, seed,
{num_neighbors}, time, seed_time, edge_weight, csc,
temporal_strategy);
}();
return std::make_tuple(std::get<2>(out), std::get<3>(out).value(),
std::get<6>(out));
Expand Down
3 changes: 1 addition & 2 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ dist_neighbor_sample_kernel(const at::Tensor& rowptr,
bool replace,
bool directed,
bool disjoint,
std::string temporal_strategy,
bool return_edge_id);
std::string temporal_strategy);

} // namespace sampler
} // namespace pyg
7 changes: 5 additions & 2 deletions pyg_lib/csrc/sampler/dist_merge_outputs.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
namespace pyg {
namespace sampler {

// For distributed training purpose. Merges samplers outputs from different
// For distributed training purposes. Merges sampler outputs from different
// partitions, so that they are sorted according to the sampling order.
// Removes seed nodes from sampled nodes and calculates how many neighbors
// were sampled by each src node based on the :obj:`cumsum_neighbors_per_node`.
// were sampled by each source node based on the cummulative sum of sampled
// neighbors for each input node.
// Returns the unified node, edge and batch indices as well as the merged
// cummulative sum of sampled neighbors.
PYG_API
std::tuple<at::Tensor,
c10::optional<at::Tensor>,
Expand Down
8 changes: 3 additions & 5 deletions pyg_lib/csrc/sampler/neighbor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ std::tuple<at::Tensor, at::Tensor, std::vector<int64_t>> dist_neighbor_sample(
bool replace,
bool directed,
bool disjoint,
std::string temporal_strategy,
bool return_edge_id) {
std::string temporal_strategy) {
at::TensorArg rowptr_t{rowptr, "rowtpr", 1};
at::TensorArg col_t{col, "col", 1};
at::TensorArg seed_t{seed, "seed", 1};
Expand All @@ -120,8 +119,7 @@ std::tuple<at::Tensor, at::Tensor, std::vector<int64_t>> dist_neighbor_sample(
.findSchemaOrThrow("pyg::dist_neighbor_sample", "")
.typed<decltype(dist_neighbor_sample)>();
return op.call(rowptr, col, seed, num_neighbors, time, seed_time, edge_weight,
csc, replace, directed, disjoint, temporal_strategy,
return_edge_id);
csc, replace, directed, disjoint, temporal_strategy);
}

TORCH_LIBRARY_FRAGMENT(pyg, m) {
Expand All @@ -147,7 +145,7 @@ TORCH_LIBRARY_FRAGMENT(pyg, m) {
"num_neighbors, Tensor? time = None, Tensor? seed_time = None, Tensor? "
"edge_weight = None, bool csc = False, bool replace = False, bool "
"directed = True, bool disjoint = False, str temporal_strategy = "
"'uniform', bool return_edge_id = True) -> (Tensor, Tensor, int[])"));
"'uniform') -> (Tensor, Tensor, int[])"));
}

} // namespace sampler
Expand Down
8 changes: 6 additions & 2 deletions pyg_lib/csrc/sampler/neighbor.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ hetero_neighbor_sample(
std::string strategy = "uniform",
bool return_edge_id = true);

// For distributed sampling purposes. Leverages the `neighbor_sample` function
// internally. Samples one-hop neighborhoods with duplicates from all node
// indices in `seed` in the graph given by `(rowptr, col)`.
// Returns the original node and edge indices for all sampled nodes and edges.
// Lastly, returns the cummulative sum of sampled neighbors for each input node.
PYG_API
std::tuple<at::Tensor, at::Tensor, std::vector<int64_t>> dist_neighbor_sample(
const at::Tensor& rowptr,
Expand All @@ -74,8 +79,7 @@ std::tuple<at::Tensor, at::Tensor, std::vector<int64_t>> dist_neighbor_sample(
bool replace = false,
bool directed = true,
bool disjoint = false,
std::string strategy = "uniform",
bool return_edge_id = true);
std::string strategy = "uniform");

} // namespace sampler
} // namespace pyg
98 changes: 0 additions & 98 deletions pyg_lib/sampler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,52 +165,6 @@ def hetero_neighbor_sample(
num_nodes_per_hop_dict, num_edges_per_hop_dict)


def dist_neighbor_sample(
rowptr: Tensor,
col: Tensor,
seed: Tensor,
num_neighbors: int,
time: Optional[Tensor] = None,
seed_time: Optional[Tensor] = None,
edge_weight: Optional[Tensor] = None,
csc: bool = False,
replace: bool = False,
directed: bool = True,
disjoint: bool = False,
temporal_strategy: str = 'uniform',
return_edge_id: bool = True,
) -> Tuple[Tensor, Tensor, List[int]]:
r"""For distributed sampling purpose. Leverages the
:meth:`neighbor_sample`. Samples one hop neighborhood with duplicates from
all node indices in :obj:`seed` in the graph given by :obj:`(rowptr, col)`.
Args:
num_neighbors (int): Maximum number of neighbors to sample in the
current layer.
kwargs: Arguments of :meth:`neighbor_sample`.
Returns:
(torch.Tensor, torch.Tensor, List[int]): Returns the original node and
edge indices for all sampled nodes and edges. Lastly, returns the
cummulative sum of the amount of sampled neighbors for each input node.
"""
return torch.ops.pyg.dist_neighbor_sample(
rowptr,
col,
seed,
num_neighbors,
time,
seed_time,
edge_weight,
csc,
replace,
directed,
disjoint,
temporal_strategy,
return_edge_id,
)


def subgraph(
rowptr: Tensor,
col: Tensor,
Expand Down Expand Up @@ -261,61 +215,9 @@ def random_walk(rowptr: Tensor, col: Tensor, seed: Tensor, walk_length: int,
return torch.ops.pyg.random_walk(rowptr, col, seed, walk_length, p, q)


def merge_sampler_outputs(
nodes: List[Tensor],
cumm_sampled_nbrs_per_node: List[List[int]],
partition_ids: List[int],
partition_orders: List[int],
partitions_num: int,
one_hop_num: int,
edge_ids: Optional[List[Tensor]] = None,
batch: Optional[List[Tensor]] = None,
disjoint: bool = False,
with_edge: bool = True,
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor], List[int]]:
r""" For distributed training purpose. Merges samplers outputs from
different partitions, so that they are sorted according to the sampling
order. Removes seed nodes from sampled nodes and calculates how many
neighbors were sampled by each src node based on the
:obj:`cumm_sampled_nbrs_per_node`.
Args:
nodes (List[torch.Tensor]): A list of nodes sampled by all machines.
cumm_sampled_nbrs_per_node (List[List[int]]): For each sampled node,
it contains information of how many neighbors it has sampled.
Represented as a cumulative sum for the nodes in a given partition.
partition_ids (torch.Tensor): Contains information on which
partition src nodes are located on.
partition_orders (torch.Tensor): Contains information about the
order of src nodes in each partition.
one_hop_num (int): Max number of neighbors sampled in the current
layer.
edge_ids (List[Tensor], optional): A list of edge_ids sampled by all
machines. (default: :obj:`None`)
batch (List[Tensor], optional): A list of subgraph ids that the sampled
:obj:`nodes` belong to. (default: :obj:`None`)
disjoint (bool, optional): Informs whether it is a disjoint sampling.
(default: :obj:`False`)
with_edge (bool, optional): Informs whether it is a sampling with edge.
(default: :obj:`True`)
Returns:
(torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
List[int]):
Returns sorted and merged nodes, edge_ids and subgraph ids (batch),
as well as number of sampled neighbors per each :obj:`node`.
"""
return torch.ops.pyg.merge_sampler_outputs(nodes, edge_ids, batch,
cumm_sampled_nbrs_per_node,
partition_ids, partition_orders,
partitions_num, one_hop_num,
disjoint, with_edge)


__all__ = [
'neighbor_sample',
'hetero_neighbor_sample',
'dist_neighbor_sample',
'subgraph',
'random_walk',
'merge_sampler_outputs',
]
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def build_extension(self, ext):
WITH_CUDA = bool(int(os.getenv('FORCE_CUDA', WITH_CUDA)))

cmake_args = [
'-DBUILD_TEST=ON',
'-DBUILD_TEST=OFF',
'-DBUILD_BENCHMARK=OFF',
'-DUSE_PYTHON=ON',
f'-DWITH_CUDA={"ON" if WITH_CUDA else "OFF"}',
Expand Down

0 comments on commit d391ea7

Please sign in to comment.