diff --git a/pyg_lib/csrc/sampler/cpu/dist_relabel_kernel.cpp b/pyg_lib/csrc/sampler/cpu/dist_relabel_kernel.cpp index 6277d613e..c09654e26 100644 --- a/pyg_lib/csrc/sampler/cpu/dist_relabel_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/dist_relabel_kernel.cpp @@ -4,7 +4,6 @@ #include "parallel_hashmap/phmap.h" -#include "pyg_lib/csrc/sampler/cpu/dist_relabel_kernel.h" #include "pyg_lib/csrc/sampler/cpu/mapper.h" #include "pyg_lib/csrc/utils/cpu/convert.h" #include "pyg_lib/csrc/utils/types.h" @@ -32,8 +31,8 @@ std::tuple get_sampled_edges( template std::tuple relabel( const at::Tensor& seed, - const at::Tensor& sampled_nodes_with_dupl, - const std::vector& num_sampled_nbrs_per_node, + const at::Tensor& sampled_nodes_with_duplicates, + const std::vector& num_sampled_neighbors_per_node, const int64_t num_nodes, const c10::optional& batch, const bool csc) { @@ -41,195 +40,190 @@ std::tuple relabel( TORCH_CHECK(batch.has_value(), "Batch needs to be specified to create disjoint subgraphs"); TORCH_CHECK(batch.value().is_contiguous(), "Non-contiguous 'batch'"); - TORCH_CHECK(batch.value().numel() == sampled_nodes_with_dupl.numel(), - "Each node must belong to a subgraph.'"); + TORCH_CHECK(batch.value().numel() == sampled_nodes_with_duplicates.numel(), + "Each node must belong to a subgraph"); } TORCH_CHECK(seed.is_contiguous(), "Non-contiguous 'seed'"); - TORCH_CHECK(sampled_nodes_with_dupl.is_contiguous(), - "Non-contiguous 'sampled_nodes_with_dupl'"); + TORCH_CHECK(sampled_nodes_with_duplicates.is_contiguous(), + "Non-contiguous 'sampled_nodes_with_duplicates'"); at::Tensor out_row, out_col; - AT_DISPATCH_INTEGRAL_TYPES( - seed.scalar_type(), "relabel_neighborhood_kernel", [&] { - typedef std::pair pair_scalar_t; - typedef std::conditional_t node_t; - - const auto sampled_nodes_data = - sampled_nodes_with_dupl.data_ptr(); - const auto batch_data = - !disjoint ? nullptr : batch.value().data_ptr(); - - std::vector sampled_rows; - std::vector sampled_cols; - auto mapper = Mapper(num_nodes); - - const auto seed_data = seed.data_ptr(); - if constexpr (!disjoint) { - mapper.fill(seed); - } else { - for (size_t i = 0; i < seed.numel(); ++i) { - mapper.insert({i, seed_data[i]}); - } - } - size_t begin = 0; - size_t end = 0; - for (auto i = 0; i < num_sampled_nbrs_per_node.size(); i++) { - end += num_sampled_nbrs_per_node[i]; - - for (auto j = begin; j < end; j++) { - std::pair res; - if constexpr (!disjoint) - res = mapper.insert(sampled_nodes_data[j]); - else - res = mapper.insert({batch_data[j], sampled_nodes_data[j]}); - sampled_rows.push_back(i); - sampled_cols.push_back(res.first); - } - - begin = end; - } - - std::tie(out_row, out_col) = - get_sampled_edges(sampled_rows, sampled_cols, csc); - }); + AT_DISPATCH_INTEGRAL_TYPES(seed.scalar_type(), "relabel_kernel", [&] { + typedef std::pair pair_scalar_t; + typedef std::conditional_t node_t; + + const auto sampled_nodes_data = + sampled_nodes_with_duplicates.data_ptr(); + const auto batch_data = + !disjoint ? nullptr : batch.value().data_ptr(); + + std::vector sampled_rows; + std::vector sampled_cols; + auto mapper = Mapper(num_nodes); + + const auto seed_data = seed.data_ptr(); + if constexpr (!disjoint) { + mapper.fill(seed); + } else { + for (size_t i = 0; i < seed.numel(); ++i) { + mapper.insert({i, seed_data[i]}); + } + } + size_t begin = 0, end = 0; + for (auto i = 0; i < num_sampled_neighbors_per_node.size(); i++) { + end += num_sampled_neighbors_per_node[i]; + + for (auto j = begin; j < end; j++) { + std::pair res; + if constexpr (!disjoint) + res = mapper.insert(sampled_nodes_data[j]); + else + res = mapper.insert({batch_data[j], sampled_nodes_data[j]}); + sampled_rows.push_back(i); + sampled_cols.push_back(res.first); + } + + begin = end; + } + + std::tie(out_row, out_col) = + get_sampled_edges(sampled_rows, sampled_cols, csc); + }); return std::make_tuple(out_row, out_col); } template std::tuple, c10::Dict> -relabel(const std::vector& node_types, - const std::vector& edge_types, - const c10::Dict& seed_dict, - const c10::Dict& sampled_nodes_with_dupl_dict, - const c10::Dict>& - num_sampled_nbrs_per_node_dict, - const c10::Dict& num_nodes_dict, - const c10::optional>& batch_dict, - const bool csc) { +relabel( + const std::vector& node_types, + const std::vector& edge_types, + const c10::Dict& seed_dict, + const c10::Dict& sampled_nodes_with_duplicates_dict, + const c10::Dict>& + num_sampled_neighbors_per_node_dict, + const c10::Dict& num_nodes_dict, + const c10::optional>& batch_dict, + const bool csc) { c10::Dict out_row_dict, out_col_dict; - AT_DISPATCH_INTEGRAL_TYPES( - seed_dict.begin()->value().scalar_type(), - "hetero_relabel_neighborhood_kernel", [&] { - typedef std::pair pair_scalar_t; - typedef std::conditional_t node_t; - - phmap::flat_hash_map sampled_nodes_data_dict; - phmap::flat_hash_map batch_data_dict; - phmap::flat_hash_map> - sampled_rows_dict; - phmap::flat_hash_map> - sampled_cols_dict; - - phmap::flat_hash_map> mapper_dict; - phmap::flat_hash_map> slice_dict; - - const bool parallel = - at::get_num_threads() > 1 && edge_types.size() > 1; - std::vector> threads_edge_types; - - for (const auto& k : edge_types) { - // Initialize empty vectors. - sampled_rows_dict[k]; - sampled_cols_dict[k]; - - if (parallel) { - // Each thread is assigned edge types that have the same dst node - // type. Thanks to this, each thread will operate on a separate - // mapper and separate sampler. - bool added = false; - const auto dst = !csc ? std::get<2>(k) : std::get<0>(k); - for (auto& e : threads_edge_types) { - if ((!csc ? std::get<2>(e[0]) : std::get<0>(e[0])) == dst) { - e.push_back(k); - added = true; - break; - } - } - if (!added) - threads_edge_types.push_back({k}); - } - } - if (!parallel) { - // If not parallel then one thread handles all edge types. - threads_edge_types.push_back({edge_types}); - } - - int64_t N = 0; - for (const auto& kv : num_nodes_dict) { - N += kv.value() > 0 ? kv.value() : 0; - } - - for (const auto& k : node_types) { - sampled_nodes_data_dict.insert( - {k, sampled_nodes_with_dupl_dict.at(k).data_ptr()}); - mapper_dict.insert({k, Mapper(N)}); - slice_dict[k] = {0, 0}; - if constexpr (disjoint) { - batch_data_dict.insert( - {k, batch_dict.value().at(k).data_ptr()}); + auto scalar_type = seed_dict.begin()->value().scalar_type(); + AT_DISPATCH_INTEGRAL_TYPES(scalar_type, "hetero_relabel_kernel", [&] { + typedef std::pair pair_scalar_t; + typedef std::conditional_t node_t; + + phmap::flat_hash_map sampled_nodes_data_dict; + phmap::flat_hash_map batch_data_dict; + phmap::flat_hash_map> sampled_rows_dict; + phmap::flat_hash_map> sampled_cols_dict; + + phmap::flat_hash_map> mapper_dict; + phmap::flat_hash_map> slice_dict; + + const bool parallel = at::get_num_threads() > 1 && edge_types.size() > 1; + std::vector> threads_edge_types; + + for (const auto& k : edge_types) { + // Initialize empty vectors. + sampled_rows_dict[k]; + sampled_cols_dict[k]; + + if (parallel) { + // Each thread is assigned edge types that have the same dst node + // type. Thanks to this, each thread will operate on a separate + // mapper and separate sampler. + bool added = false; + const auto dst = !csc ? std::get<2>(k) : std::get<0>(k); + for (auto& e : threads_edge_types) { + if ((!csc ? std::get<2>(e[0]) : std::get<0>(e[0])) == dst) { + e.push_back(k); + added = true; + break; } } - for (const auto& kv : seed_dict) { - const at::Tensor& seed = kv.value(); - if constexpr (!disjoint) { - mapper_dict.at(kv.key()).fill(seed); - } else { - auto& mapper = mapper_dict.at(kv.key()); - const auto seed_data = seed.data_ptr(); - for (size_t i = 0; i < seed.numel(); ++i) { - mapper.insert({i, seed_data[i]}); - } - } + if (!added) + threads_edge_types.push_back({k}); + } + } + if (!parallel) { + // If not parallel then one thread handles all edge types. + threads_edge_types.push_back({edge_types}); + } + + int64_t N = 0; + for (const auto& kv : num_nodes_dict) { + N += kv.value() > 0 ? kv.value() : 0; + } + + for (const auto& k : node_types) { + sampled_nodes_data_dict.insert( + {k, sampled_nodes_with_duplicates_dict.at(k).data_ptr()}); + mapper_dict.insert({k, Mapper(N)}); + slice_dict[k] = {0, 0}; + if constexpr (disjoint) { + batch_data_dict.insert( + {k, batch_dict.value().at(k).data_ptr()}); + } + } + for (const auto& kv : seed_dict) { + const at::Tensor& seed = kv.value(); + if constexpr (!disjoint) { + mapper_dict.at(kv.key()).fill(seed); + } else { + auto& mapper = mapper_dict.at(kv.key()); + const auto seed_data = seed.data_ptr(); + for (size_t i = 0; i < seed.numel(); ++i) { + mapper.insert({i, seed_data[i]}); } - at::parallel_for( - 0, threads_edge_types.size(), 1, [&](size_t _s, size_t _e) { - for (auto j = _s; j < _e; j++) { - for (const auto& k : threads_edge_types[j]) { - const auto src = !csc ? std::get<0>(k) : std::get<2>(k); - const auto dst = !csc ? std::get<2>(k) : std::get<0>(k); - - const auto num_sampled_nbrs_size = - num_sampled_nbrs_per_node_dict.at(to_rel_type(k)).size(); - if (num_sampled_nbrs_size == 0) { - continue; - } + } + } + at::parallel_for( + 0, threads_edge_types.size(), 1, [&](size_t _s, size_t _e) { + for (auto j = _s; j < _e; j++) { + for (const auto& k : threads_edge_types[j]) { + const auto src = !csc ? std::get<0>(k) : std::get<2>(k); + const auto dst = !csc ? std::get<2>(k) : std::get<0>(k); + + const auto num_sampled_neighbors_size = + num_sampled_neighbors_per_node_dict.at(to_rel_type(k)).size(); + + if (num_sampled_neighbors_size == 0) { + continue; + } - for (auto i = 0; i < num_sampled_nbrs_size; i++) { - auto& dst_mapper = mapper_dict.at(dst); - auto& dst_sampled_nodes_data = - sampled_nodes_data_dict.at(dst); - - slice_dict.at(dst).second += - num_sampled_nbrs_per_node_dict.at(to_rel_type(k))[i]; - auto [begin, end] = slice_dict.at(dst); - - for (auto j = begin; j < end; j++) { - std::pair res; - if constexpr (!disjoint) { - res = dst_mapper.insert(dst_sampled_nodes_data[j]); - } else { - res = dst_mapper.insert({batch_data_dict.at(dst)[j], - dst_sampled_nodes_data[j]}); - } - sampled_rows_dict.at(k).push_back(i); - sampled_cols_dict.at(k).push_back(res.first); - } - slice_dict.at(dst).first = end; + for (auto i = 0; i < num_sampled_neighbors_size; i++) { + auto& dst_mapper = mapper_dict.at(dst); + auto& dst_sampled_nodes_data = sampled_nodes_data_dict.at(dst); + + slice_dict.at(dst).second += + num_sampled_neighbors_per_node_dict.at(to_rel_type(k))[i]; + auto [begin, end] = slice_dict.at(dst); + + for (auto j = begin; j < end; j++) { + std::pair res; + if constexpr (!disjoint) { + res = dst_mapper.insert(dst_sampled_nodes_data[j]); + } else { + res = dst_mapper.insert({batch_data_dict.at(dst)[j], + dst_sampled_nodes_data[j]}); } + sampled_rows_dict.at(k).push_back(i); + sampled_cols_dict.at(k).push_back(res.first); } + slice_dict.at(dst).first = end; } - }); + } + } + }); - for (const auto& k : edge_types) { - const auto edges = get_sampled_edges( - sampled_rows_dict.at(k), sampled_cols_dict.at(k), csc); - out_row_dict.insert(to_rel_type(k), std::get<0>(edges)); - out_col_dict.insert(to_rel_type(k), std::get<1>(edges)); - } - }); + for (const auto& k : edge_types) { + const auto edges = get_sampled_edges( + sampled_rows_dict.at(k), sampled_cols_dict.at(k), csc); + out_row_dict.insert(to_rel_type(k), std::get<0>(edges)); + out_col_dict.insert(to_rel_type(k), std::get<1>(edges)); + } + }); return std::make_tuple(out_row_dict, out_col_dict); } @@ -244,14 +238,14 @@ relabel(const std::vector& node_types, std::tuple relabel_neighborhood_kernel( const at::Tensor& seed, - const at::Tensor& sampled_nodes_with_dupl, - const std::vector& num_sampled_nbrs_per_node, + const at::Tensor& sampled_nodes_with_duplicates, + const std::vector& num_sampled_neighbors_per_node, const int64_t num_nodes, const c10::optional& batch, bool csc, bool disjoint) { - DISPATCH_RELABEL(disjoint, seed, sampled_nodes_with_dupl, - num_sampled_nbrs_per_node, num_nodes, batch, csc); + DISPATCH_RELABEL(disjoint, seed, sampled_nodes_with_duplicates, + num_sampled_neighbors_per_node, num_nodes, batch, csc); } std::tuple, c10::Dict> @@ -259,17 +253,18 @@ hetero_relabel_neighborhood_kernel( const std::vector& node_types, const std::vector& edge_types, const c10::Dict& seed_dict, - const c10::Dict& sampled_nodes_with_dupl_dict, + const c10::Dict& sampled_nodes_with_duplicates_dict, const c10::Dict>& - num_sampled_nbrs_per_node_dict, + num_sampled_neighbors_per_node_dict, const c10::Dict& num_nodes_dict, const c10::optional>& batch_dict, bool csc, bool disjoint) { c10::Dict out_row_dict, out_col_dict; DISPATCH_RELABEL(disjoint, node_types, edge_types, seed_dict, - sampled_nodes_with_dupl_dict, num_sampled_nbrs_per_node_dict, - num_nodes_dict, batch_dict, csc); + sampled_nodes_with_duplicates_dict, + num_sampled_neighbors_per_node_dict, num_nodes_dict, + batch_dict, csc); } TORCH_LIBRARY_IMPL(pyg, CPU, m) { diff --git a/pyg_lib/csrc/sampler/cpu/dist_relabel_kernel.h b/pyg_lib/csrc/sampler/cpu/dist_relabel_kernel.h deleted file mode 100644 index b6d563d5c..000000000 --- a/pyg_lib/csrc/sampler/cpu/dist_relabel_kernel.h +++ /dev/null @@ -1,30 +0,0 @@ -#include -#include -#include "pyg_lib/csrc/utils/types.h" - -namespace pyg { -namespace sampler { - -std::tuple relabel_neighborhood_kernel( - const at::Tensor& seed, - const at::Tensor& sampled_nodes_with_dupl, - const std::vector& sampled_nbrs_per_node, - const int64_t num_nodes, - const c10::optional& batch, - bool csc, - bool disjoint); - -std::tuple, c10::Dict> -hetero_relabel_neighborhood_kernel( - const std::vector& node_types, - const std::vector& edge_types, - const c10::Dict& seed_dict, - const c10::Dict& sampled_nodes_with_dupl_dict, - const c10::Dict>& sampled_nbrs_per_node_dict, - const c10::Dict& num_nodes_dict, - const c10::optional>& batch_dict, - bool csc, - bool disjoint); - -} // namespace sampler -} // namespace pyg diff --git a/pyg_lib/csrc/sampler/dist_relabel.cpp b/pyg_lib/csrc/sampler/dist_relabel.cpp index 23fdd1f04..0ec6dbad9 100644 --- a/pyg_lib/csrc/sampler/dist_relabel.cpp +++ b/pyg_lib/csrc/sampler/dist_relabel.cpp @@ -10,25 +10,25 @@ namespace sampler { std::tuple relabel_neighborhood( const at::Tensor& seed, - const at::Tensor& sampled_nodes_with_dupl, - const std::vector& sampled_nbrs_per_node, + const at::Tensor& sampled_nodes_with_duplicates, + const std::vector& sampled_neighbors_per_node, const int64_t num_nodes, const c10::optional& batch, bool csc, bool disjoint) { at::TensorArg seed_t{seed, "seed", 1}; - at::TensorArg sampled_nodes_with_dupl_t{sampled_nodes_with_dupl, - "sampled_nodes_with_dupl", 1}; + at::TensorArg sampled_nodes_with_duplicates_t{ + sampled_nodes_with_duplicates, "sampled_nodes_with_duplicates", 1}; at::CheckedFrom c = "relabel_neighborhood"; - at::checkAllDefined(c, {sampled_nodes_with_dupl_t, seed_t}); - at::checkAllSameType(c, {sampled_nodes_with_dupl_t, seed_t}); + at::checkAllDefined(c, {sampled_nodes_with_duplicates_t, seed_t}); + at::checkAllSameType(c, {sampled_nodes_with_duplicates_t, seed_t}); static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("pyg::relabel_neighborhood", "") .typed(); - return op.call(seed, sampled_nodes_with_dupl, sampled_nbrs_per_node, - num_nodes, batch, csc, disjoint); + return op.call(seed, sampled_nodes_with_duplicates, + sampled_neighbors_per_node, num_nodes, batch, csc, disjoint); } std::tuple, c10::Dict> @@ -36,49 +36,49 @@ hetero_relabel_neighborhood( const std::vector& node_types, const std::vector& edge_types, const c10::Dict& seed_dict, - const c10::Dict& sampled_nodes_with_dupl_dict, - const c10::Dict>& sampled_nbrs_per_node_dict, + const c10::Dict& sampled_nodes_with_duplicates_dict, + const c10::Dict>& + sampled_neighbors_per_node_dict, const c10::Dict& num_nodes_dict, const c10::optional>& batch_dict, bool csc, bool disjoint) { std::vector seed_dict_args; - std::vector sampled_nodes_with_dupl_dict_args; + std::vector sampled_nodes_with_duplicates_dict_args; pyg::utils::fill_tensor_args(seed_dict_args, seed_dict, "seed_dict", 0); - pyg::utils::fill_tensor_args(sampled_nodes_with_dupl_dict_args, - sampled_nodes_with_dupl_dict, - "sampled_nodes_with_dupl_dict", 0); + pyg::utils::fill_tensor_args(sampled_nodes_with_duplicates_dict_args, + sampled_nodes_with_duplicates_dict, + "sampled_nodes_with_duplicates_dict", 0); at::CheckedFrom c{"hetero_relabel_neighborhood"}; at::checkAllDefined(c, seed_dict_args); - at::checkAllDefined(c, sampled_nodes_with_dupl_dict_args); - at::checkSameType(c, seed_dict_args[0], sampled_nodes_with_dupl_dict_args[0]); + at::checkAllDefined(c, sampled_nodes_with_duplicates_dict_args); + at::checkSameType(c, seed_dict_args[0], + sampled_nodes_with_duplicates_dict_args[0]); static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("pyg::hetero_relabel_neighborhood", "") .typed(); return op.call(node_types, edge_types, seed_dict, - sampled_nodes_with_dupl_dict, sampled_nbrs_per_node_dict, - num_nodes_dict, batch_dict, csc, disjoint); + sampled_nodes_with_duplicates_dict, + sampled_neighbors_per_node_dict, num_nodes_dict, batch_dict, + csc, disjoint); } TORCH_LIBRARY_FRAGMENT(pyg, m) { - m.def( - TORCH_SELECTIVE_SCHEMA("pyg::relabel_neighborhood(Tensor seed, Tensor " - "sampled_nodes_with_dupl, int[] " - "sampled_nbrs_per_node, int num_nodes, Tensor? " - "batch = None, bool csc = False, bool " - "disjoint = False) " - "-> (Tensor, Tensor)")); + m.def(TORCH_SELECTIVE_SCHEMA( + "pyg::relabel_neighborhood(Tensor seed, Tensor " + "sampled_nodes_with_duplicates, int[] sampled_neighbors_per_node, int " + "num_nodes, Tensor? batch = None, bool csc = False, bool disjoint = " + "False) -> (Tensor, Tensor)")); m.def(TORCH_SELECTIVE_SCHEMA( "pyg::hetero_relabel_neighborhood(str[] node_types, (str, str, str)[] " "edge_types, Dict(str, Tensor) seed_dict, Dict(str, Tensor) " - "sampled_nodes_with_dupl_dict, Dict(str, int[]) " - "sampled_nbrs_per_node_dict, Dict(str, int) num_nodes_dict, Dict(str, " - "Tensor)? batch_dict = None, bool csc = False, bool " - "disjoint = False) " - "-> (Dict(str, Tensor), Dict(str, Tensor))")); + "sampled_nodes_with_duplicates_dict, Dict(str, int[]) " + "sampled_neighbors_per_node_dict, Dict(str, int) num_nodes_dict, " + "Dict(str, Tensor)? batch_dict = None, bool csc = False, bool disjoint = " + "False) -> (Dict(str, Tensor), Dict(str, Tensor))")); } } // namespace sampler diff --git a/pyg_lib/csrc/sampler/dist_relabel.h b/pyg_lib/csrc/sampler/dist_relabel.h index 134fa96fa..780f5c99c 100644 --- a/pyg_lib/csrc/sampler/dist_relabel.h +++ b/pyg_lib/csrc/sampler/dist_relabel.h @@ -7,30 +7,33 @@ namespace pyg { namespace sampler { -// Relabel global indices of the `sampled_nodes_with_dupl` to the local -// subtree/subgraph indices. +// Relabels global indices from `sampled_nodes_with_duplicates` to the local +// subtree/subgraph indices in the homogeneous graph. +// Seed nodes should not be included. // Returns (row, col). PYG_API std::tuple relabel_neighborhood( const at::Tensor& seed, - const at::Tensor& sampled_nodes_with_dupl, - const std::vector& sampled_nbrs_per_node, + const at::Tensor& sampled_nodes_with_duplicates, + const std::vector& sampled_neighbors_per_node, const int64_t num_nodes, const c10::optional& batch = c10::nullopt, bool csc = false, bool disjoint = false); -// Relabel global indices of the `sampled_nodes_with_dupl` to the local +// Relabels global indices from `sampled_nodes_with_duplicates` to the local // subtree/subgraph indices in the heterogeneous graph. -// Returns src and dst indices for a given edge type as a (row_dict, col_dict). +// Seed nodes should not be included. +// Returns (row_dict, col_dict). PYG_API std::tuple, c10::Dict> hetero_relabel_neighborhood( const std::vector& node_types, const std::vector& edge_types, const c10::Dict& seed_dict, - const c10::Dict& sampled_nodes_with_dupl_dict, - const c10::Dict>& sampled_nbrs_per_node_dict, + const c10::Dict& sampled_nodes_with_duplicates_dict, + const c10::Dict>& + sampled_neighbors_per_node_dict, const c10::Dict& num_nodes_dict, const c10::optional>& batch_dict = c10::nullopt, diff --git a/pyg_lib/sampler/__init__.py b/pyg_lib/sampler/__init__.py index eab491fa0..dd0cc1a0f 100644 --- a/pyg_lib/sampler/__init__.py +++ b/pyg_lib/sampler/__init__.py @@ -215,100 +215,9 @@ def random_walk(rowptr: Tensor, col: Tensor, seed: Tensor, walk_length: int, return torch.ops.pyg.random_walk(rowptr, col, seed, walk_length, p, q) -def relabel_neighborhood( - seed: Tensor, - sampled_nodes_with_dupl: Tensor, - sampled_nbrs_per_node: List[int], - num_nodes: int, - batch: Optional[Tensor] = None, - csc: bool = False, - disjoint: bool = False, -) -> Tuple[Tensor, Tensor]: - r"""Relabel global indices of the :obj:`sampled_nodes_with_dupl` to the - local subtree/subgraph indices. - - .. note:: - - For :obj:`disjoint`, the :obj:`batch` needs to be specified - and each node from :obj:`sampled_nodes_with_dupl` must be assigned - to a subgraph. - - Args: - seed (torch.Tensor): The seed node indices. - sampled_nodes_with_dupl (torch.Tensor): Sampled nodes with duplicates. - Should not include seed nodes. - sampled_nbrs_per_node (List[int]): The number of neighbors sampled by - each node from :obj:`sampled_nodes_with_dupl`. - num_nodes (int): Number of all nodes in a graph. - batch (torch.Tensor, optional): Stores information about which subgraph - the node from :obj:`sampled_nodes_with_dupl` belongs to. - Must be specified when :obj:`disjoint`. (default: :obj:`None`) - csc (bool, optional): If set to :obj:`True`, assumes that the graph is - given in CSC format :obj:`(colptr, row)`. (default: :obj:`False`) - disjoint (bool, optional): If set to :obj:`True` , will create disjoint - subgraphs for every seed node. (default: :obj:`False`) - - Returns: - (torch.Tensor, torch.Tensor): - Row indices, col indices of the returned subtree/subgraph. - """ - return torch.ops.pyg.relabel_neighborhood(seed, sampled_nodes_with_dupl, - sampled_nbrs_per_node, num_nodes, - batch, csc, disjoint) - - -def hetero_relabel_neighborhood( - seed_dict: Dict[NodeType, - Tensor], sampled_nodes_with_dupl_dict: Dict[NodeType, - Tensor], - sampled_nbrs_per_node_dict: Dict[EdgeType, - List[int]], num_nodes_dict: Dict[NodeType, - int], - batch_dict: Optional[Dict[NodeType, Tensor]] = None, csc: bool = False, - disjoint: bool = False -) -> Tuple[Dict[EdgeType, Tensor], Dict[EdgeType, Tensor]]: - r"""Relabel global indices of the :obj:`sampled_nodes_with_dupl` to the - local subtree/subgraph indices in the heterogeneous graph. - - .. note :: - Similar to :meth:`relabel_neighborhood`, but expects a dictionary of - node types (:obj:`str`) and edge types (:obj:`Tuple[str, str, str]`) - for each non-boolean argument. - - Args: - kwargs: Arguments of :meth:`relabel_neighborhood`. - """ - - src_node_types = {k[0] for k in sampled_nodes_with_dupl_dict.keys()} - dst_node_types = {k[-1] for k in sampled_nodes_with_dupl_dict.keys()} - node_types = list(src_node_types | dst_node_types) - edge_types = list(sampled_nbrs_per_node_dict.keys()) - - TO_REL_TYPE = {key: '__'.join(key) for key in edge_types} - TO_EDGE_TYPE = {'__'.join(key): key for key in edge_types} - - sampled_nbrs_per_node_dict = { - TO_REL_TYPE[k]: v - for k, v in sampled_nbrs_per_node_dict.items() - } - - out = torch.ops.pyg.hetero_relabel_neighborhood( - node_types, edge_types, seed_dict, sampled_nodes_with_dupl_dict, - sampled_nbrs_per_node_dict, num_nodes_dict, batch_dict, csc, disjoint) - - (row_dict, col_dict) = out - - row_dict = {TO_EDGE_TYPE[k]: v for k, v in row_dict.items()} - col_dict = {TO_EDGE_TYPE[k]: v for k, v in col_dict.items()} - - return (row_dict, col_dict) - - __all__ = [ 'neighbor_sample', 'hetero_neighbor_sample', 'subgraph', 'random_walk', - 'relabel_neighborhood', - 'hetero_relabel_neighborhood', ] diff --git a/test/csrc/sampler/test_dist_relabel.cpp b/test/csrc/sampler/test_dist_relabel.cpp index 01bfdb3a1..6e5c58a4a 100644 --- a/test/csrc/sampler/test_dist_relabel.cpp +++ b/test/csrc/sampler/test_dist_relabel.cpp @@ -9,32 +9,28 @@ TEST(DistRelabelNeighborhoodTest, BasicAssertions) { auto options = at::TensorOptions().dtype(at::kLong); - int num_nodes = 6; - auto graph = cycle_graph(num_nodes, options); auto seed = at::arange(2, 4, options); - std::vector num_neighbors = {-1}; + auto sampled_nodes_with_duplicates = at::tensor({1, 3, 2, 4}, options); + std::vector sampled_neighbors_per_node = {2, 2}; - // nodes with duplicates - auto nodes = at::tensor({2, 3, 1, 3, 2, 4}, options); - auto edges = at::tensor({4, 5, 6, 7}, options); - - std::vector sampled_nbrs_per_node = {2, 2}; - // without seed nodes - auto sampled_nodes_with_dupl = at::tensor({1, 3, 2, 4}, options); - - // get rows and cols auto relabel_out = pyg::sampler::relabel_neighborhood( - seed, sampled_nodes_with_dupl, sampled_nbrs_per_node, num_nodes); + /*seed=*/seed, + /*sampled_nodes_with_duplicates=*/sampled_nodes_with_duplicates, + /*sampled_neighbors_per_node=*/sampled_neighbors_per_node, + /*num_nodes=*/6); auto expected_row = at::tensor({0, 0, 1, 1}, options); EXPECT_TRUE(at::equal(std::get<0>(relabel_out), expected_row)); auto expected_col = at::tensor({2, 1, 0, 3}, options); EXPECT_TRUE(at::equal(std::get<1>(relabel_out), expected_col)); - // check if rows and cols are correct + // Check if output is correct: + auto graph = cycle_graph(/*num_nodes=*/6, options); auto non_dist_out = pyg::sampler::neighbor_sample( - /*rowptr=*/std::get<0>(graph), /*col=*/std::get<1>(graph), seed, - num_neighbors); + /*rowptr=*/std::get<0>(graph), + /*col=*/std::get<1>(graph), + /*seed=*/seed, + /*num_neighbors=*/{-1}); EXPECT_TRUE(at::equal(std::get<0>(relabel_out), std::get<0>(non_dist_out))); EXPECT_TRUE(at::equal(std::get<1>(relabel_out), std::get<1>(non_dist_out))); @@ -43,36 +39,39 @@ TEST(DistRelabelNeighborhoodTest, BasicAssertions) { TEST(DistDisjointRelabelNeighborhoodTest, BasicAssertions) { auto options = at::TensorOptions().dtype(at::kLong); - int num_nodes = 6; - auto graph = cycle_graph(num_nodes, options); auto seed = at::arange(2, 4, options); - std::vector num_neighbors = {2}; + auto sampled_nodes_with_duplicates = at::tensor({1, 3, 2, 4}, options); + std::vector sampled_neighbors_per_node = {2, 2}; + auto batch = at::tensor({0, 0, 1, 1}, options); - // nodes with duplicates - auto nodes = at::tensor({0, 2, 1, 3, 0, 1, 0, 3, 1, 2, 1, 4}, options); - auto edges = at::tensor({4, 5, 6, 7}, options); - - std::vector sampled_nbrs_per_node = {2, 2}; - // without seed nodes - auto sampled_nodes_with_dupl = at::tensor({1, 3, 2, 4}, options); - auto sampled_batch = at::tensor({0, 0, 1, 1}, options); - - // get rows and cols auto relabel_out = pyg::sampler::relabel_neighborhood( - seed, sampled_nodes_with_dupl, sampled_nbrs_per_node, num_nodes, - sampled_batch, /*csc=*/false, /*disjoint=*/true); + /*seed=*/seed, + /*sampled_nodes_with_duplicates=*/sampled_nodes_with_duplicates, + /*sampled_neighbors_per_node=*/sampled_neighbors_per_node, + /*num_nodes=*/6, + /*batch=*/batch, + /*csc=*/false, + /*disjoint=*/true); auto expected_row = at::tensor({0, 0, 1, 1}, options); EXPECT_TRUE(at::equal(std::get<0>(relabel_out), expected_row)); auto expected_col = at::tensor({2, 3, 4, 5}, options); EXPECT_TRUE(at::equal(std::get<1>(relabel_out), expected_col)); - // check if rows and cols are correct + // Check if output is correct: + auto graph = cycle_graph(/*num_nodes=*/6, options); auto non_dist_out = pyg::sampler::neighbor_sample( - /*rowptr=*/std::get<0>(graph), /*col=*/std::get<1>(graph), seed, - num_neighbors, /*time=*/c10::nullopt, /*seed_time=*/c10::nullopt, - /*edge_weight=*/c10::nullopt, /*csc*/ false, /*replace=*/false, - /*directed=*/true, /*disjoint=*/true); + /*rowptr=*/std::get<0>(graph), + /*col=*/std::get<1>(graph), + /*seed=*/seed, + /*num_neighbors=*/{2}, + /*time=*/c10::nullopt, + /*seed_time=*/c10::nullopt, + /*edge_weight=*/c10::nullopt, + /*csc*/ false, + /*replace=*/false, + /*directed=*/true, + /*disjoint=*/true); EXPECT_TRUE(at::equal(std::get<0>(relabel_out), std::get<0>(non_dist_out))); EXPECT_TRUE(at::equal(std::get<1>(relabel_out), std::get<1>(non_dist_out))); @@ -81,8 +80,7 @@ TEST(DistDisjointRelabelNeighborhoodTest, BasicAssertions) { TEST(DistHeteroRelabelNeighborhoodTest, BasicAssertions) { auto options = at::TensorOptions().dtype(at::kLong); - int num_nodes = 6; - auto graph = cycle_graph(num_nodes, options); + auto graph = cycle_graph(/*num_nodes=*/6, options); const auto node_key = "paper"; const auto edge_key = std::make_tuple("paper", "to", "paper"); const auto rel_key = "paper__to__paper"; @@ -98,28 +96,35 @@ TEST(DistHeteroRelabelNeighborhoodTest, BasicAssertions) { c10::Dict> num_neighbors_dict; num_neighbors_dict.insert(rel_key, num_neighbors); c10::Dict num_nodes_dict; - num_nodes_dict.insert(node_key, num_nodes); - - c10::Dict sampled_nodes_with_dupl_dict; - c10::Dict> sampled_nbrs_per_node_dict; - sampled_nodes_with_dupl_dict.insert(node_key, - at::tensor({1, 3, 2, 4}, options)); - sampled_nbrs_per_node_dict.insert(rel_key, std::vector(2, 2)); - // get rows and cols + num_nodes_dict.insert(node_key, 6); + + c10::Dict sampled_nodes_with_duplicates_dict; + c10::Dict> sampled_neighbors_per_node_dict; + sampled_nodes_with_duplicates_dict.insert(node_key, + at::tensor({1, 3, 2, 4}, options)); + sampled_neighbors_per_node_dict.insert(rel_key, std::vector(2, 2)); + auto relabel_out = pyg::sampler::hetero_relabel_neighborhood( - node_types, edge_types, seed_dict, sampled_nodes_with_dupl_dict, - sampled_nbrs_per_node_dict, num_nodes_dict, - /*batch_dict=*/c10::nullopt, /*csc=*/false, /*disjoint=*/false); + /*node_types=*/node_types, + /*edge_types=*/edge_types, + /*seed_dict=*/seed_dict, + /*sampled_nodes_with_duplicates_dict=*/sampled_nodes_with_duplicates_dict, + /*sampled_neighbors_per_node=*/sampled_neighbors_per_node_dict, + /*num_nodes_dict=*/num_nodes_dict); auto expected_row = at::tensor({0, 0, 1, 1}, options); EXPECT_TRUE(at::equal(std::get<0>(relabel_out).at(rel_key), expected_row)); auto expected_col = at::tensor({2, 1, 0, 3}, options); EXPECT_TRUE(at::equal(std::get<1>(relabel_out).at(rel_key), expected_col)); - // check if rows and cols are correct + // Check if output is correct: auto non_dist_out = pyg::sampler::hetero_neighbor_sample( - node_types, edge_types, rowptr_dict, col_dict, seed_dict, - num_neighbors_dict); + /*node_types=*/node_types, + /*edge_types=*/edge_types, + /*rowptr_dict=*/rowptr_dict, + /*col_dict=*/col_dict, + /*seed_dict=*/seed_dict, + /*num_neighbors_dict=*/num_neighbors_dict); EXPECT_TRUE(at::equal(std::get<0>(relabel_out).at(rel_key), std::get<0>(non_dist_out).at(rel_key))); @@ -130,8 +135,7 @@ TEST(DistHeteroRelabelNeighborhoodTest, BasicAssertions) { TEST(DistHeteroRelabelNeighborhoodCscTest, BasicAssertions) { auto options = at::TensorOptions().dtype(at::kLong); - int num_nodes = 6; - auto graph = cycle_graph(num_nodes, options); + auto graph = cycle_graph(/*num_nodes=*/6, options); const auto node_key = "paper"; const auto edge_key = std::make_tuple("paper", "to", "paper"); const auto rel_key = "paper__to__paper"; @@ -147,29 +151,40 @@ TEST(DistHeteroRelabelNeighborhoodCscTest, BasicAssertions) { c10::Dict> num_neighbors_dict; num_neighbors_dict.insert(rel_key, num_neighbors); c10::Dict num_nodes_dict; - num_nodes_dict.insert(node_key, num_nodes); - - c10::Dict sampled_nodes_with_dupl_dict; - c10::Dict> sampled_nbrs_per_node_dict; - sampled_nodes_with_dupl_dict.insert(node_key, - at::tensor({1, 3, 2, 4}, options)); - sampled_nbrs_per_node_dict.insert(rel_key, std::vector(2, 2)); - // get rows and cols + num_nodes_dict.insert(node_key, 6); + + c10::Dict sampled_nodes_with_duplicates_dict; + c10::Dict> sampled_neighbors_per_node_dict; + sampled_nodes_with_duplicates_dict.insert(node_key, + at::tensor({1, 3, 2, 4}, options)); + sampled_neighbors_per_node_dict.insert(rel_key, std::vector(2, 2)); + auto relabel_out = pyg::sampler::hetero_relabel_neighborhood( - node_types, edge_types, seed_dict, sampled_nodes_with_dupl_dict, - sampled_nbrs_per_node_dict, num_nodes_dict, - /*batch_dict=*/c10::nullopt, /*csc=*/true, /*disjoint=*/false); + /*node_types=*/node_types, + /*edge_types=*/edge_types, + /*seed_dict=*/seed_dict, + /*sampled_nodes_with_duplicates_dict=*/sampled_nodes_with_duplicates_dict, + /*sampled_neighbors_per_node=*/sampled_neighbors_per_node_dict, + /*num_nodes_dict=*/num_nodes_dict, + /*batch_dict=*/c10::nullopt, + /*csc=*/true); auto expected_row = at::tensor({2, 1, 0, 3}, options); EXPECT_TRUE(at::equal(std::get<0>(relabel_out).at(rel_key), expected_row)); auto expected_col = at::tensor({0, 0, 1, 1}, options); EXPECT_TRUE(at::equal(std::get<1>(relabel_out).at(rel_key), expected_col)); - // check if rows and cols are correct + // Check if output is correct: auto non_dist_out = pyg::sampler::hetero_neighbor_sample( - node_types, edge_types, rowptr_dict, col_dict, seed_dict, - num_neighbors_dict, /*time_dict=*/c10::nullopt, - /*seed_time_dict=*/c10::nullopt, /*edge_weight_dict=*/c10::nullopt, + /*node_types=*/node_types, + /*edge_types=*/edge_types, + /*rowptr_dict=*/rowptr_dict, + /*col_dict=*/col_dict, + /*seed_dict=*/seed_dict, + /*num_neighbors_dict=*/num_neighbors_dict, + /*time_dict=*/c10::nullopt, + /*seed_time_dict=*/c10::nullopt, + /*edge_weight_dict=*/c10::nullopt, /*csc=*/true); EXPECT_TRUE(at::equal(std::get<0>(relabel_out).at(rel_key), @@ -181,8 +196,7 @@ TEST(DistHeteroRelabelNeighborhoodCscTest, BasicAssertions) { TEST(DistHeteroDisjointRelabelNeighborhoodTest, BasicAssertions) { auto options = at::TensorOptions().dtype(at::kLong); - int num_nodes = 6; - auto graph = cycle_graph(num_nodes, options); + auto graph = cycle_graph(/*num_nodes=*/6, options); const auto node_key = "paper"; const auto edge_key = std::make_tuple("paper", "to", "paper"); const auto rel_key = "paper__to__paper"; @@ -198,19 +212,24 @@ TEST(DistHeteroDisjointRelabelNeighborhoodTest, BasicAssertions) { c10::Dict> num_neighbors_dict; num_neighbors_dict.insert(rel_key, num_neighbors); c10::Dict num_nodes_dict; - num_nodes_dict.insert(node_key, num_nodes); + num_nodes_dict.insert(node_key, 6); - c10::Dict sampled_nodes_with_dupl_dict; - c10::Dict> sampled_nbrs_per_node_dict; + c10::Dict sampled_nodes_with_duplicates_dict; + c10::Dict> sampled_neighbors_per_node_dict; c10::Dict batch_dict; - sampled_nodes_with_dupl_dict.insert(node_key, - at::tensor({1, 3, 2, 4}, options)); - sampled_nbrs_per_node_dict.insert(rel_key, std::vector(2, 2)); + sampled_nodes_with_duplicates_dict.insert(node_key, + at::tensor({1, 3, 2, 4}, options)); + sampled_neighbors_per_node_dict.insert(rel_key, std::vector(2, 2)); batch_dict.insert(node_key, at::tensor({0, 0, 1, 1}, options)); - // get rows and cols + auto relabel_out = pyg::sampler::hetero_relabel_neighborhood( - node_types, edge_types, seed_dict, sampled_nodes_with_dupl_dict, - sampled_nbrs_per_node_dict, num_nodes_dict, batch_dict, + /*node_types=*/node_types, + /*edge_types=*/edge_types, + /*seed_dict=*/seed_dict, + /*sampled_nodes_with_duplicates_dict=*/sampled_nodes_with_duplicates_dict, + /*sampled_neighbors_per_node=*/sampled_neighbors_per_node_dict, + /*num_nodes_dict=*/num_nodes_dict, + /*batch_dict=*/batch_dict, /*csc=*/false, /*disjoint=*/true); auto expected_row = at::tensor({0, 0, 1, 1}, options); @@ -218,12 +237,21 @@ TEST(DistHeteroDisjointRelabelNeighborhoodTest, BasicAssertions) { auto expected_col = at::tensor({2, 3, 4, 5}, options); EXPECT_TRUE(at::equal(std::get<1>(relabel_out).at(rel_key), expected_col)); - // check if rows and cols are correct + // Check if output is correct: auto non_dist_out = pyg::sampler::hetero_neighbor_sample( - node_types, edge_types, rowptr_dict, col_dict, seed_dict, - num_neighbors_dict, /*time_dict=*/c10::nullopt, - /*seed_time_dict=*/c10::nullopt, /*edge_weight_dict=*/c10::nullopt, - /*csc=*/false, /*replace=*/false, /*directed=*/true, /*disjoint=*/true); + /*node_types=*/node_types, + /*edge_types=*/edge_types, + /*rowptr_dict=*/rowptr_dict, + /*col_dict=*/col_dict, + /*seed_dict=*/seed_dict, + /*num_neighbors_dict=*/num_neighbors_dict, + /*time_dict=*/c10::nullopt, + /*seed_time_dict=*/c10::nullopt, + /*edge_weight_dict=*/c10::nullopt, + /*csc=*/false, + /*replace=*/false, + /*directed=*/true, + /*disjoint=*/true); EXPECT_TRUE(at::equal(std::get<0>(relabel_out).at(rel_key), std::get<0>(non_dist_out).at(rel_key)));