diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index 97135e204..cfc679d01 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -755,39 +755,23 @@ sample(const std::vector& node_types, if (!replace && !directed && !disjoint && !return_edge_id) \ return sample(__VA_ARGS__); -#define DISPATCH_DIST_SAMPLE(replace, directed, disjoint, return_edge_id, ...) \ - if (replace && directed && disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ - if (replace && directed && disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ - if (replace && directed && !disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ - if (replace && directed && !disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && !disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && !disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && !disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && !disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && !disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && !disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); +#define DISPATCH_DIST_SAMPLE(replace, directed, disjoint, ...) \ + if (replace && directed && disjoint) \ + return sample(__VA_ARGS__); \ + if (replace && directed && !disjoint) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && disjoint) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && !disjoint) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && disjoint) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && !disjoint) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && disjoint) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && !disjoint) \ + return sample(__VA_ARGS__); } // namespace @@ -859,12 +843,11 @@ dist_neighbor_sample_kernel(const at::Tensor& rowptr, bool replace, bool directed, bool disjoint, - std::string temporal_strategy, - bool return_edge_id) { + std::string temporal_strategy) { const auto out = [&] { - DISPATCH_DIST_SAMPLE(replace, directed, disjoint, return_edge_id, rowptr, - col, seed, {num_neighbors}, time, seed_time, - edge_weight, csc, temporal_strategy); + DISPATCH_DIST_SAMPLE(replace, directed, disjoint, rowptr, col, seed, + {num_neighbors}, time, seed_time, edge_weight, csc, + temporal_strategy); }(); return std::make_tuple(std::get<2>(out), std::get<3>(out).value(), std::get<6>(out)); diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h index dbc1860fd..2d1dfa831 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h @@ -60,8 +60,7 @@ dist_neighbor_sample_kernel(const at::Tensor& rowptr, bool replace, bool directed, bool disjoint, - std::string temporal_strategy, - bool return_edge_id); + std::string temporal_strategy); } // namespace sampler } // namespace pyg diff --git a/pyg_lib/csrc/sampler/dist_merge_outputs.h b/pyg_lib/csrc/sampler/dist_merge_outputs.h index 99680a679..5a6c54c09 100644 --- a/pyg_lib/csrc/sampler/dist_merge_outputs.h +++ b/pyg_lib/csrc/sampler/dist_merge_outputs.h @@ -7,10 +7,13 @@ namespace pyg { namespace sampler { -// For distributed training purpose. Merges samplers outputs from different +// For distributed training purposes. Merges sampler outputs from different // partitions, so that they are sorted according to the sampling order. // Removes seed nodes from sampled nodes and calculates how many neighbors -// were sampled by each src node based on the :obj:`cumsum_neighbors_per_node`. +// were sampled by each source node based on the cummulative sum of sampled +// neighbors for each input node. +// Returns the unified node, edge and batch indices as well as the merged +// cummulative sum of sampled neighbors. PYG_API std::tuple, diff --git a/pyg_lib/csrc/sampler/neighbor.cpp b/pyg_lib/csrc/sampler/neighbor.cpp index 6de1dcedf..38c0cc3cb 100644 --- a/pyg_lib/csrc/sampler/neighbor.cpp +++ b/pyg_lib/csrc/sampler/neighbor.cpp @@ -106,8 +106,7 @@ std::tuple> dist_neighbor_sample( bool replace, bool directed, bool disjoint, - std::string temporal_strategy, - bool return_edge_id) { + std::string temporal_strategy) { at::TensorArg rowptr_t{rowptr, "rowtpr", 1}; at::TensorArg col_t{col, "col", 1}; at::TensorArg seed_t{seed, "seed", 1}; @@ -120,8 +119,7 @@ std::tuple> dist_neighbor_sample( .findSchemaOrThrow("pyg::dist_neighbor_sample", "") .typed(); return op.call(rowptr, col, seed, num_neighbors, time, seed_time, edge_weight, - csc, replace, directed, disjoint, temporal_strategy, - return_edge_id); + csc, replace, directed, disjoint, temporal_strategy); } TORCH_LIBRARY_FRAGMENT(pyg, m) { @@ -147,7 +145,7 @@ TORCH_LIBRARY_FRAGMENT(pyg, m) { "num_neighbors, Tensor? time = None, Tensor? seed_time = None, Tensor? " "edge_weight = None, bool csc = False, bool replace = False, bool " "directed = True, bool disjoint = False, str temporal_strategy = " - "'uniform', bool return_edge_id = True) -> (Tensor, Tensor, int[])")); + "'uniform') -> (Tensor, Tensor, int[])")); } } // namespace sampler diff --git a/pyg_lib/csrc/sampler/neighbor.h b/pyg_lib/csrc/sampler/neighbor.h index 07de4fa36..66282e38d 100644 --- a/pyg_lib/csrc/sampler/neighbor.h +++ b/pyg_lib/csrc/sampler/neighbor.h @@ -61,6 +61,11 @@ hetero_neighbor_sample( std::string strategy = "uniform", bool return_edge_id = true); +// For distributed sampling purposes. Leverages the `neighbor_sample` function +// internally. Samples one-hop neighborhoods with duplicates from all node +// indices in `seed` in the graph given by `(rowptr, col)`. +// Returns the original node and edge indices for all sampled nodes and edges. +// Lastly, returns the cummulative sum of sampled neighbors for each input node. PYG_API std::tuple> dist_neighbor_sample( const at::Tensor& rowptr, @@ -74,8 +79,7 @@ std::tuple> dist_neighbor_sample( bool replace = false, bool directed = true, bool disjoint = false, - std::string strategy = "uniform", - bool return_edge_id = true); + std::string strategy = "uniform"); } // namespace sampler } // namespace pyg diff --git a/pyg_lib/sampler/__init__.py b/pyg_lib/sampler/__init__.py index 3937f49c0..dd0cc1a0f 100644 --- a/pyg_lib/sampler/__init__.py +++ b/pyg_lib/sampler/__init__.py @@ -165,52 +165,6 @@ def hetero_neighbor_sample( num_nodes_per_hop_dict, num_edges_per_hop_dict) -def dist_neighbor_sample( - rowptr: Tensor, - col: Tensor, - seed: Tensor, - num_neighbors: int, - time: Optional[Tensor] = None, - seed_time: Optional[Tensor] = None, - edge_weight: Optional[Tensor] = None, - csc: bool = False, - replace: bool = False, - directed: bool = True, - disjoint: bool = False, - temporal_strategy: str = 'uniform', - return_edge_id: bool = True, -) -> Tuple[Tensor, Tensor, List[int]]: - r"""For distributed sampling purpose. Leverages the - :meth:`neighbor_sample`. Samples one hop neighborhood with duplicates from - all node indices in :obj:`seed` in the graph given by :obj:`(rowptr, col)`. - - Args: - num_neighbors (int): Maximum number of neighbors to sample in the - current layer. - kwargs: Arguments of :meth:`neighbor_sample`. - - Returns: - (torch.Tensor, torch.Tensor, List[int]): Returns the original node and - edge indices for all sampled nodes and edges. Lastly, returns the - cummulative sum of the amount of sampled neighbors for each input node. - """ - return torch.ops.pyg.dist_neighbor_sample( - rowptr, - col, - seed, - num_neighbors, - time, - seed_time, - edge_weight, - csc, - replace, - directed, - disjoint, - temporal_strategy, - return_edge_id, - ) - - def subgraph( rowptr: Tensor, col: Tensor, @@ -261,61 +215,9 @@ def random_walk(rowptr: Tensor, col: Tensor, seed: Tensor, walk_length: int, return torch.ops.pyg.random_walk(rowptr, col, seed, walk_length, p, q) -def merge_sampler_outputs( - nodes: List[Tensor], - cumm_sampled_nbrs_per_node: List[List[int]], - partition_ids: List[int], - partition_orders: List[int], - partitions_num: int, - one_hop_num: int, - edge_ids: Optional[List[Tensor]] = None, - batch: Optional[List[Tensor]] = None, - disjoint: bool = False, - with_edge: bool = True, -) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor], List[int]]: - r""" For distributed training purpose. Merges samplers outputs from - different partitions, so that they are sorted according to the sampling - order. Removes seed nodes from sampled nodes and calculates how many - neighbors were sampled by each src node based on the - :obj:`cumm_sampled_nbrs_per_node`. - - Args: - nodes (List[torch.Tensor]): A list of nodes sampled by all machines. - cumm_sampled_nbrs_per_node (List[List[int]]): For each sampled node, - it contains information of how many neighbors it has sampled. - Represented as a cumulative sum for the nodes in a given partition. - partition_ids (torch.Tensor): Contains information on which - partition src nodes are located on. - partition_orders (torch.Tensor): Contains information about the - order of src nodes in each partition. - one_hop_num (int): Max number of neighbors sampled in the current - layer. - edge_ids (List[Tensor], optional): A list of edge_ids sampled by all - machines. (default: :obj:`None`) - batch (List[Tensor], optional): A list of subgraph ids that the sampled - :obj:`nodes` belong to. (default: :obj:`None`) - disjoint (bool, optional): Informs whether it is a disjoint sampling. - (default: :obj:`False`) - with_edge (bool, optional): Informs whether it is a sampling with edge. - (default: :obj:`True`) - Returns: - (torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], - List[int]): - Returns sorted and merged nodes, edge_ids and subgraph ids (batch), - as well as number of sampled neighbors per each :obj:`node`. - """ - return torch.ops.pyg.merge_sampler_outputs(nodes, edge_ids, batch, - cumm_sampled_nbrs_per_node, - partition_ids, partition_orders, - partitions_num, one_hop_num, - disjoint, with_edge) - - __all__ = [ 'neighbor_sample', 'hetero_neighbor_sample', - 'dist_neighbor_sample', 'subgraph', 'random_walk', - 'merge_sampler_outputs', ] diff --git a/setup.py b/setup.py index 7e601d57b..43b898020 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,7 @@ def build_extension(self, ext): WITH_CUDA = bool(int(os.getenv('FORCE_CUDA', WITH_CUDA))) cmake_args = [ - '-DBUILD_TEST=ON', + '-DBUILD_TEST=OFF', '-DBUILD_BENCHMARK=OFF', '-DUSE_PYTHON=ON', f'-DWITH_CUDA={"ON" if WITH_CUDA else "OFF"}',