diff --git a/pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.cpp b/pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.cpp index 856bfb17c..b4fd8acc6 100644 --- a/pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.cpp @@ -58,7 +58,7 @@ merge_outputs( } const auto p_size = partition_ids.size(); - std::vector sampled_neighbors_per_node(p_size); + std::vector num_sampled_neighbors_per_node(p_size); const auto scalar_type = node_ids[0].scalar_type(); AT_DISPATCH_INTEGRAL_TYPES(scalar_type, "merge_outputs_kernel", [&] { @@ -106,7 +106,7 @@ merge_outputs( batch_data[j]); } - sampled_neighbors_per_node[j] = end_node - begin_node; + num_sampled_neighbors_per_node[j] = end_node - begin_node; } }); @@ -128,7 +128,7 @@ merge_outputs( }); return std::make_tuple(out_node_id, out_edge_id, out_batch, - sampled_neighbors_per_node); + num_sampled_neighbors_per_node); } #define DISPATCH_MERGE_OUTPUTS(disjoint, ...) \ diff --git a/pyg_lib/csrc/sampler/cpu/dist_relabel_kernel.cpp b/pyg_lib/csrc/sampler/cpu/dist_relabel_kernel.cpp index c09654e26..29ebfecf0 100644 --- a/pyg_lib/csrc/sampler/cpu/dist_relabel_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/dist_relabel_kernel.cpp @@ -101,7 +101,7 @@ relabel( const std::vector& edge_types, const c10::Dict& seed_dict, const c10::Dict& sampled_nodes_with_duplicates_dict, - const c10::Dict>& + const c10::Dict>>& num_sampled_neighbors_per_node_dict, const c10::Dict& num_nodes_dict, const c10::optional>& batch_dict, @@ -117,9 +117,16 @@ relabel( phmap::flat_hash_map batch_data_dict; phmap::flat_hash_map> sampled_rows_dict; phmap::flat_hash_map> sampled_cols_dict; + // `srcs_slice_dict` defines the number of src nodes for each edge type in + // a given layer in the form of a range. Local src nodes (`sampled_rows`) + // will be created on its basis, so for a given edge type the ranges will + // not be repeated, and the starting value of the next layer will be the + // end value from the previous layer. + phmap::flat_hash_map> srcs_slice_dict; phmap::flat_hash_map> mapper_dict; phmap::flat_hash_map> slice_dict; + phmap::flat_hash_map srcs_offset_dict; const bool parallel = at::get_num_threads() > 1 && edge_types.size() > 1; std::vector> threads_edge_types; @@ -129,6 +136,14 @@ relabel( sampled_rows_dict[k]; sampled_cols_dict[k]; + // `num_sampled_neighbors_per_node_dict` is a dictionary where for + // each edge type it contains information about how many neighbors every + // src node has sampled. These values are saved in a separate vector for + // each layer. + size_t num_src_nodes = + num_sampled_neighbors_per_node_dict.at(to_rel_type(k))[0].size(); + srcs_slice_dict[k] = {0, num_src_nodes}; + if (parallel) { // Each thread is assigned edge types that have the same dst node // type. Thanks to this, each thread will operate on a separate @@ -161,6 +176,7 @@ relabel( {k, sampled_nodes_with_duplicates_dict.at(k).data_ptr()}); mapper_dict.insert({k, Mapper(N)}); slice_dict[k] = {0, 0}; + srcs_offset_dict[k] = 0; if constexpr (disjoint) { batch_data_dict.insert( {k, batch_dict.value().at(k).data_ptr()}); @@ -178,44 +194,71 @@ relabel( } } } - at::parallel_for( - 0, threads_edge_types.size(), 1, [&](size_t _s, size_t _e) { - for (auto j = _s; j < _e; j++) { - for (const auto& k : threads_edge_types[j]) { - const auto src = !csc ? std::get<0>(k) : std::get<2>(k); - const auto dst = !csc ? std::get<2>(k) : std::get<0>(k); - - const auto num_sampled_neighbors_size = - num_sampled_neighbors_per_node_dict.at(to_rel_type(k)).size(); - - if (num_sampled_neighbors_size == 0) { - continue; - } - for (auto i = 0; i < num_sampled_neighbors_size; i++) { - auto& dst_mapper = mapper_dict.at(dst); - auto& dst_sampled_nodes_data = sampled_nodes_data_dict.at(dst); - - slice_dict.at(dst).second += - num_sampled_neighbors_per_node_dict.at(to_rel_type(k))[i]; - auto [begin, end] = slice_dict.at(dst); - - for (auto j = begin; j < end; j++) { - std::pair res; - if constexpr (!disjoint) { - res = dst_mapper.insert(dst_sampled_nodes_data[j]); - } else { - res = dst_mapper.insert({batch_data_dict.at(dst)[j], - dst_sampled_nodes_data[j]}); + size_t num_layers = + num_sampled_neighbors_per_node_dict.at(to_rel_type(edge_types[0])) + .size(); + // Iterate over the layers + for (auto ell = 0; ell < num_layers; ++ell) { + at::parallel_for( + 0, threads_edge_types.size(), 1, [&](size_t _s, size_t _e) { + for (auto t = _s; t < _e; t++) { + for (const auto& k : threads_edge_types[t]) { + const auto dst = !csc ? std::get<2>(k) : std::get<0>(k); + + auto [src_begin, src_end] = srcs_slice_dict.at(k); + + for (auto i = src_begin; i < src_end; i++) { + auto& dst_mapper = mapper_dict.at(dst); + auto& dst_sampled_nodes_data = + sampled_nodes_data_dict.at(dst); + + // For each edge type `slice_dict` defines the number of + // nodes sampled by a src node `i` in the form of a range. + // The indices in the given range point to global dst nodes + // from `dst_sampled_nodes_data`. + slice_dict.at(dst).second += + num_sampled_neighbors_per_node_dict.at( + to_rel_type(k))[ell][i - src_begin]; + auto [begin, end] = slice_dict.at(dst); + + for (auto j = begin; j < end; j++) { + std::pair res; + if constexpr (!disjoint) { + res = dst_mapper.insert(dst_sampled_nodes_data[j]); + } else { + res = dst_mapper.insert({batch_data_dict.at(dst)[j], + dst_sampled_nodes_data[j]}); + } + sampled_rows_dict.at(k).push_back(i); + sampled_cols_dict.at(k).push_back(res.first); } - sampled_rows_dict.at(k).push_back(i); - sampled_cols_dict.at(k).push_back(res.first); + slice_dict.at(dst).first = end; } - slice_dict.at(dst).first = end; } } + }); + + // Get local src nodes ranges for the next layer + if (ell < num_layers - 1) { + for (const auto& k : edge_types) { + // Edges with the same src node types will have the same src node + // offsets. + const auto src = !csc ? std::get<0>(k) : std::get<2>(k); + if (srcs_offset_dict[src] < srcs_slice_dict.at(k).second) { + srcs_offset_dict[src] = srcs_slice_dict.at(k).second; } - }); + } + for (const auto& k : edge_types) { + const auto src = !csc ? std::get<0>(k) : std::get<2>(k); + srcs_slice_dict[k] = { + srcs_offset_dict.at(src), + srcs_offset_dict.at(src) + num_sampled_neighbors_per_node_dict + .at(to_rel_type(k))[ell + 1] + .size()}; + } + } + } for (const auto& k : edge_types) { const auto edges = get_sampled_edges( @@ -254,7 +297,7 @@ hetero_relabel_neighborhood_kernel( const std::vector& edge_types, const c10::Dict& seed_dict, const c10::Dict& sampled_nodes_with_duplicates_dict, - const c10::Dict>& + const c10::Dict>>& num_sampled_neighbors_per_node_dict, const c10::Dict& num_nodes_dict, const c10::optional>& batch_dict, diff --git a/pyg_lib/csrc/sampler/dist_relabel.cpp b/pyg_lib/csrc/sampler/dist_relabel.cpp index 0ec6dbad9..78533336e 100644 --- a/pyg_lib/csrc/sampler/dist_relabel.cpp +++ b/pyg_lib/csrc/sampler/dist_relabel.cpp @@ -11,7 +11,7 @@ namespace sampler { std::tuple relabel_neighborhood( const at::Tensor& seed, const at::Tensor& sampled_nodes_with_duplicates, - const std::vector& sampled_neighbors_per_node, + const std::vector& num_sampled_neighbors_per_node, const int64_t num_nodes, const c10::optional& batch, bool csc, @@ -28,7 +28,8 @@ std::tuple relabel_neighborhood( .findSchemaOrThrow("pyg::relabel_neighborhood", "") .typed(); return op.call(seed, sampled_nodes_with_duplicates, - sampled_neighbors_per_node, num_nodes, batch, csc, disjoint); + num_sampled_neighbors_per_node, num_nodes, batch, csc, + disjoint); } std::tuple, c10::Dict> @@ -37,8 +38,8 @@ hetero_relabel_neighborhood( const std::vector& edge_types, const c10::Dict& seed_dict, const c10::Dict& sampled_nodes_with_duplicates_dict, - const c10::Dict>& - sampled_neighbors_per_node_dict, + const c10::Dict>>& + num_sampled_neighbors_per_node_dict, const c10::Dict& num_nodes_dict, const c10::optional>& batch_dict, bool csc, @@ -62,21 +63,22 @@ hetero_relabel_neighborhood( .typed(); return op.call(node_types, edge_types, seed_dict, sampled_nodes_with_duplicates_dict, - sampled_neighbors_per_node_dict, num_nodes_dict, batch_dict, - csc, disjoint); + num_sampled_neighbors_per_node_dict, num_nodes_dict, + batch_dict, csc, disjoint); } TORCH_LIBRARY_FRAGMENT(pyg, m) { m.def(TORCH_SELECTIVE_SCHEMA( "pyg::relabel_neighborhood(Tensor seed, Tensor " - "sampled_nodes_with_duplicates, int[] sampled_neighbors_per_node, int " + "sampled_nodes_with_duplicates, int[] num_sampled_neighbors_per_node, " + "int " "num_nodes, Tensor? batch = None, bool csc = False, bool disjoint = " "False) -> (Tensor, Tensor)")); m.def(TORCH_SELECTIVE_SCHEMA( "pyg::hetero_relabel_neighborhood(str[] node_types, (str, str, str)[] " "edge_types, Dict(str, Tensor) seed_dict, Dict(str, Tensor) " - "sampled_nodes_with_duplicates_dict, Dict(str, int[]) " - "sampled_neighbors_per_node_dict, Dict(str, int) num_nodes_dict, " + "sampled_nodes_with_duplicates_dict, Dict(str, int[][]) " + "num_sampled_neighbors_per_node_dict, Dict(str, int) num_nodes_dict, " "Dict(str, Tensor)? batch_dict = None, bool csc = False, bool disjoint = " "False) -> (Dict(str, Tensor), Dict(str, Tensor))")); } diff --git a/pyg_lib/csrc/sampler/dist_relabel.h b/pyg_lib/csrc/sampler/dist_relabel.h index 780f5c99c..426339dcd 100644 --- a/pyg_lib/csrc/sampler/dist_relabel.h +++ b/pyg_lib/csrc/sampler/dist_relabel.h @@ -15,7 +15,7 @@ PYG_API std::tuple relabel_neighborhood( const at::Tensor& seed, const at::Tensor& sampled_nodes_with_duplicates, - const std::vector& sampled_neighbors_per_node, + const std::vector& num_sampled_neighbors_per_node, const int64_t num_nodes, const c10::optional& batch = c10::nullopt, bool csc = false, @@ -32,8 +32,8 @@ hetero_relabel_neighborhood( const std::vector& edge_types, const c10::Dict& seed_dict, const c10::Dict& sampled_nodes_with_duplicates_dict, - const c10::Dict>& - sampled_neighbors_per_node_dict, + const c10::Dict>>& + num_sampled_neighbors_per_node_dict, const c10::Dict& num_nodes_dict, const c10::optional>& batch_dict = c10::nullopt, diff --git a/test/csrc/sampler/test_dist_merge_outputs.cpp b/test/csrc/sampler/test_dist_merge_outputs.cpp index 33c093d6f..e5f7d65b3 100644 --- a/test/csrc/sampler/test_dist_merge_outputs.cpp +++ b/test/csrc/sampler/test_dist_merge_outputs.cpp @@ -41,8 +41,9 @@ TEST(DistMergeOutputsTest, BasicAssertions) { auto expected_edges = at::tensor({14, 15, 16, 17, 18, 19, 20}, options); EXPECT_TRUE(at::equal(std::get<1>(out), expected_edges)); - const std::vector expected_sampled_neighbors_per_node = {2, 1, 2, 2}; - EXPECT_EQ(std::get<3>(out), expected_sampled_neighbors_per_node); + const std::vector expected_num_sampled_neighbors_per_node = {2, 1, 2, + 2}; + EXPECT_EQ(std::get<3>(out), expected_num_sampled_neighbors_per_node); } TEST(DistMergeOutputsAllNeighborsTest, BasicAssertions) { @@ -82,8 +83,9 @@ TEST(DistMergeOutputsAllNeighborsTest, BasicAssertions) { auto expected_edges = at::tensor({14, 15, 16, 17, 18, 19, 20, 21}, options); EXPECT_TRUE(at::equal(std::get<1>(out), expected_edges)); - const std::vector expected_sampled_neighbors_per_node = {2, 1, 2, 3}; - EXPECT_EQ(std::get<3>(out), expected_sampled_neighbors_per_node); + const std::vector expected_num_sampled_neighbors_per_node = {2, 1, 2, + 3}; + EXPECT_EQ(std::get<3>(out), expected_num_sampled_neighbors_per_node); } TEST(DistDisjointMergeOutputsTest, BasicAssertions) { @@ -124,6 +126,7 @@ TEST(DistDisjointMergeOutputsTest, BasicAssertions) { auto expected_batch = at::tensor({0, 0, 1, 2, 2, 3, 3}, options); EXPECT_TRUE(at::equal(std::get<2>(out).value(), expected_batch)); - const std::vector expected_sampled_neighbors_per_node = {2, 1, 2, 2}; - EXPECT_EQ(std::get<3>(out), expected_sampled_neighbors_per_node); + const std::vector expected_num_sampled_neighbors_per_node = {2, 1, 2, + 2}; + EXPECT_EQ(std::get<3>(out), expected_num_sampled_neighbors_per_node); } diff --git a/test/csrc/sampler/test_dist_relabel.cpp b/test/csrc/sampler/test_dist_relabel.cpp index 6c8389363..7bc60db3c 100644 --- a/test/csrc/sampler/test_dist_relabel.cpp +++ b/test/csrc/sampler/test_dist_relabel.cpp @@ -11,12 +11,12 @@ TEST(DistRelabelNeighborhoodTest, BasicAssertions) { auto seed = at::arange(2, 4, options); auto sampled_nodes_with_duplicates = at::tensor({1, 3, 2, 4}, options); - std::vector sampled_neighbors_per_node = {2, 2}; + std::vector num_sampled_neighbors_per_node = {2, 2}; auto relabel_out = pyg::sampler::relabel_neighborhood( /*seed=*/seed, /*sampled_nodes_with_duplicates=*/sampled_nodes_with_duplicates, - /*sampled_neighbors_per_node=*/sampled_neighbors_per_node, + /*num_sampled_neighbors_per_node=*/num_sampled_neighbors_per_node, /*num_nodes=*/6); auto expected_row = at::tensor({0, 0, 1, 1}, options); @@ -41,13 +41,13 @@ TEST(DistDisjointRelabelNeighborhoodTest, BasicAssertions) { auto seed = at::arange(2, 4, options); auto sampled_nodes_with_duplicates = at::tensor({1, 3, 2, 4}, options); - std::vector sampled_neighbors_per_node = {2, 2}; + std::vector num_sampled_neighbors_per_node = {2, 2}; auto batch = at::tensor({0, 0, 1, 1}, options); auto relabel_out = pyg::sampler::relabel_neighborhood( /*seed=*/seed, /*sampled_nodes_with_duplicates=*/sampled_nodes_with_duplicates, - /*sampled_neighbors_per_node=*/sampled_neighbors_per_node, + /*num_sampled_neighbors_per_node=*/num_sampled_neighbors_per_node, /*num_nodes=*/6, /*batch=*/batch, /*csc=*/false, @@ -100,17 +100,21 @@ TEST(DistHeteroRelabelNeighborhoodTest, BasicAssertions) { num_nodes_dict.insert(node_key, 6); c10::Dict sampled_nodes_with_duplicates_dict; - c10::Dict> sampled_neighbors_per_node_dict; + c10::Dict>> + num_sampled_neighbors_per_node_dict; sampled_nodes_with_duplicates_dict.insert(node_key, at::tensor({1, 3, 2, 4}, options)); - sampled_neighbors_per_node_dict.insert(rel_key, std::vector(2, 2)); + std::vector> num_sampled_neighbors_per_node_vec( + 2, std::vector(1, 2)); + num_sampled_neighbors_per_node_dict.insert( + rel_key, num_sampled_neighbors_per_node_vec); auto relabel_out = pyg::sampler::hetero_relabel_neighborhood( /*node_types=*/node_types, /*edge_types=*/edge_types, /*seed_dict=*/seed_dict, /*sampled_nodes_with_duplicates_dict=*/sampled_nodes_with_duplicates_dict, - /*sampled_neighbors_per_node=*/sampled_neighbors_per_node_dict, + /*num_sampled_neighbors_per_node=*/num_sampled_neighbors_per_node_dict, /*num_nodes_dict=*/num_nodes_dict); auto expected_row = at::tensor({0, 0, 1, 1}, options); @@ -155,17 +159,21 @@ TEST(DistHeteroRelabelNeighborhoodCscTest, BasicAssertions) { num_nodes_dict.insert(node_key, 6); c10::Dict sampled_nodes_with_duplicates_dict; - c10::Dict> sampled_neighbors_per_node_dict; + c10::Dict>> + num_sampled_neighbors_per_node_dict; sampled_nodes_with_duplicates_dict.insert(node_key, at::tensor({1, 3, 2, 4}, options)); - sampled_neighbors_per_node_dict.insert(rel_key, std::vector(2, 2)); + std::vector> num_sampled_neighbors_per_node_vec( + 2, std::vector(1, 2)); + num_sampled_neighbors_per_node_dict.insert( + rel_key, num_sampled_neighbors_per_node_vec); auto relabel_out = pyg::sampler::hetero_relabel_neighborhood( /*node_types=*/node_types, /*edge_types=*/edge_types, /*seed_dict=*/seed_dict, /*sampled_nodes_with_duplicates_dict=*/sampled_nodes_with_duplicates_dict, - /*sampled_neighbors_per_node=*/sampled_neighbors_per_node_dict, + /*num_sampled_neighbors_per_node=*/num_sampled_neighbors_per_node_dict, /*num_nodes_dict=*/num_nodes_dict, /*batch_dict=*/c10::nullopt, /*csc=*/true); @@ -217,11 +225,15 @@ TEST(DistHeteroDisjointRelabelNeighborhoodTest, BasicAssertions) { num_nodes_dict.insert(node_key, 6); c10::Dict sampled_nodes_with_duplicates_dict; - c10::Dict> sampled_neighbors_per_node_dict; + c10::Dict>> + num_sampled_neighbors_per_node_dict; c10::Dict batch_dict; sampled_nodes_with_duplicates_dict.insert(node_key, at::tensor({1, 3, 2, 4}, options)); - sampled_neighbors_per_node_dict.insert(rel_key, std::vector(2, 2)); + std::vector> num_sampled_neighbors_per_node_vec( + 2, std::vector(1, 2)); + num_sampled_neighbors_per_node_dict.insert( + rel_key, num_sampled_neighbors_per_node_vec); batch_dict.insert(node_key, at::tensor({0, 0, 1, 1}, options)); auto relabel_out = pyg::sampler::hetero_relabel_neighborhood( @@ -229,7 +241,7 @@ TEST(DistHeteroDisjointRelabelNeighborhoodTest, BasicAssertions) { /*edge_types=*/edge_types, /*seed_dict=*/seed_dict, /*sampled_nodes_with_duplicates_dict=*/sampled_nodes_with_duplicates_dict, - /*sampled_neighbors_per_node=*/sampled_neighbors_per_node_dict, + /*num_sampled_neighbors_per_node=*/num_sampled_neighbors_per_node_dict, /*num_nodes_dict=*/num_nodes_dict, /*batch_dict=*/batch_dict, /*csc=*/false,