Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update hetero dist relabel #284

Merged
merged 1 commit into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ merge_outputs(
}

const auto p_size = partition_ids.size();
std::vector<int64_t> sampled_neighbors_per_node(p_size);
std::vector<int64_t> num_sampled_neighbors_per_node(p_size);

const auto scalar_type = node_ids[0].scalar_type();
AT_DISPATCH_INTEGRAL_TYPES(scalar_type, "merge_outputs_kernel", [&] {
Expand Down Expand Up @@ -106,7 +106,7 @@ merge_outputs(
batch_data[j]);
}

sampled_neighbors_per_node[j] = end_node - begin_node;
num_sampled_neighbors_per_node[j] = end_node - begin_node;
}
});

Expand All @@ -128,7 +128,7 @@ merge_outputs(
});

return std::make_tuple(out_node_id, out_edge_id, out_batch,
sampled_neighbors_per_node);
num_sampled_neighbors_per_node);
}

#define DISPATCH_MERGE_OUTPUTS(disjoint, ...) \
Expand Down
111 changes: 77 additions & 34 deletions pyg_lib/csrc/sampler/cpu/dist_relabel_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ relabel(
const std::vector<edge_type>& edge_types,
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<node_type, at::Tensor>& sampled_nodes_with_duplicates_dict,
const c10::Dict<rel_type, std::vector<int64_t>>&
const c10::Dict<rel_type, std::vector<std::vector<int64_t>>>&
num_sampled_neighbors_per_node_dict,
const c10::Dict<node_type, int64_t>& num_nodes_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& batch_dict,
Expand All @@ -117,9 +117,16 @@ relabel(
phmap::flat_hash_map<node_type, scalar_t*> batch_data_dict;
phmap::flat_hash_map<edge_type, std::vector<scalar_t>> sampled_rows_dict;
phmap::flat_hash_map<edge_type, std::vector<scalar_t>> sampled_cols_dict;
// `srcs_slice_dict` defines the number of src nodes for each edge type in
// a given layer in the form of a range. Local src nodes (`sampled_rows`)
// will be created on its basis, so for a given edge type the ranges will
// not be repeated, and the starting value of the next layer will be the
// end value from the previous layer.
phmap::flat_hash_map<edge_type, std::pair<size_t, size_t>> srcs_slice_dict;

phmap::flat_hash_map<node_type, Mapper<node_t, scalar_t>> mapper_dict;
phmap::flat_hash_map<node_type, std::pair<size_t, size_t>> slice_dict;
phmap::flat_hash_map<node_type, int64_t> srcs_offset_dict;

const bool parallel = at::get_num_threads() > 1 && edge_types.size() > 1;
std::vector<std::vector<edge_type>> threads_edge_types;
Expand All @@ -129,6 +136,14 @@ relabel(
sampled_rows_dict[k];
sampled_cols_dict[k];

// `num_sampled_neighbors_per_node_dict` is a dictionary where for
// each edge type it contains information about how many neighbors every
// src node has sampled. These values are saved in a separate vector for
// each layer.
size_t num_src_nodes =
num_sampled_neighbors_per_node_dict.at(to_rel_type(k))[0].size();
srcs_slice_dict[k] = {0, num_src_nodes};

if (parallel) {
// Each thread is assigned edge types that have the same dst node
// type. Thanks to this, each thread will operate on a separate
Expand Down Expand Up @@ -161,6 +176,7 @@ relabel(
{k, sampled_nodes_with_duplicates_dict.at(k).data_ptr<scalar_t>()});
mapper_dict.insert({k, Mapper<node_t, scalar_t>(N)});
slice_dict[k] = {0, 0};
srcs_offset_dict[k] = 0;
if constexpr (disjoint) {
batch_data_dict.insert(
{k, batch_dict.value().at(k).data_ptr<scalar_t>()});
Expand All @@ -178,44 +194,71 @@ relabel(
}
}
}
at::parallel_for(
0, threads_edge_types.size(), 1, [&](size_t _s, size_t _e) {
for (auto j = _s; j < _e; j++) {
for (const auto& k : threads_edge_types[j]) {
const auto src = !csc ? std::get<0>(k) : std::get<2>(k);
const auto dst = !csc ? std::get<2>(k) : std::get<0>(k);

const auto num_sampled_neighbors_size =
num_sampled_neighbors_per_node_dict.at(to_rel_type(k)).size();

if (num_sampled_neighbors_size == 0) {
continue;
}

for (auto i = 0; i < num_sampled_neighbors_size; i++) {
auto& dst_mapper = mapper_dict.at(dst);
auto& dst_sampled_nodes_data = sampled_nodes_data_dict.at(dst);

slice_dict.at(dst).second +=
num_sampled_neighbors_per_node_dict.at(to_rel_type(k))[i];
auto [begin, end] = slice_dict.at(dst);

for (auto j = begin; j < end; j++) {
std::pair<scalar_t, bool> res;
if constexpr (!disjoint) {
res = dst_mapper.insert(dst_sampled_nodes_data[j]);
} else {
res = dst_mapper.insert({batch_data_dict.at(dst)[j],
dst_sampled_nodes_data[j]});
size_t num_layers =
num_sampled_neighbors_per_node_dict.at(to_rel_type(edge_types[0]))
.size();
// Iterate over the layers
for (auto ell = 0; ell < num_layers; ++ell) {
at::parallel_for(
0, threads_edge_types.size(), 1, [&](size_t _s, size_t _e) {
for (auto t = _s; t < _e; t++) {
for (const auto& k : threads_edge_types[t]) {
const auto dst = !csc ? std::get<2>(k) : std::get<0>(k);

auto [src_begin, src_end] = srcs_slice_dict.at(k);

for (auto i = src_begin; i < src_end; i++) {
auto& dst_mapper = mapper_dict.at(dst);
auto& dst_sampled_nodes_data =
sampled_nodes_data_dict.at(dst);

// For each edge type `slice_dict` defines the number of
// nodes sampled by a src node `i` in the form of a range.
// The indices in the given range point to global dst nodes
// from `dst_sampled_nodes_data`.
slice_dict.at(dst).second +=
num_sampled_neighbors_per_node_dict.at(
to_rel_type(k))[ell][i - src_begin];
auto [begin, end] = slice_dict.at(dst);

for (auto j = begin; j < end; j++) {
std::pair<scalar_t, bool> res;
if constexpr (!disjoint) {
res = dst_mapper.insert(dst_sampled_nodes_data[j]);
} else {
res = dst_mapper.insert({batch_data_dict.at(dst)[j],
dst_sampled_nodes_data[j]});
}
sampled_rows_dict.at(k).push_back(i);
sampled_cols_dict.at(k).push_back(res.first);
}
sampled_rows_dict.at(k).push_back(i);
sampled_cols_dict.at(k).push_back(res.first);
slice_dict.at(dst).first = end;
}
slice_dict.at(dst).first = end;
}
}
});

// Get local src nodes ranges for the next layer
if (ell < num_layers - 1) {
for (const auto& k : edge_types) {
// Edges with the same src node types will have the same src node
// offsets.
const auto src = !csc ? std::get<0>(k) : std::get<2>(k);
if (srcs_offset_dict[src] < srcs_slice_dict.at(k).second) {
srcs_offset_dict[src] = srcs_slice_dict.at(k).second;
}
});
}
for (const auto& k : edge_types) {
const auto src = !csc ? std::get<0>(k) : std::get<2>(k);
srcs_slice_dict[k] = {
srcs_offset_dict.at(src),
srcs_offset_dict.at(src) + num_sampled_neighbors_per_node_dict
.at(to_rel_type(k))[ell + 1]
.size()};
}
}
}

for (const auto& k : edge_types) {
const auto edges = get_sampled_edges<scalar_t>(
Expand Down Expand Up @@ -254,7 +297,7 @@ hetero_relabel_neighborhood_kernel(
const std::vector<edge_type>& edge_types,
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<node_type, at::Tensor>& sampled_nodes_with_duplicates_dict,
const c10::Dict<rel_type, std::vector<int64_t>>&
const c10::Dict<rel_type, std::vector<std::vector<int64_t>>>&
num_sampled_neighbors_per_node_dict,
const c10::Dict<node_type, int64_t>& num_nodes_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& batch_dict,
Expand Down
20 changes: 11 additions & 9 deletions pyg_lib/csrc/sampler/dist_relabel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace sampler {
std::tuple<at::Tensor, at::Tensor> relabel_neighborhood(
const at::Tensor& seed,
const at::Tensor& sampled_nodes_with_duplicates,
const std::vector<int64_t>& sampled_neighbors_per_node,
const std::vector<int64_t>& num_sampled_neighbors_per_node,
const int64_t num_nodes,
const c10::optional<at::Tensor>& batch,
bool csc,
Expand All @@ -28,7 +28,8 @@ std::tuple<at::Tensor, at::Tensor> relabel_neighborhood(
.findSchemaOrThrow("pyg::relabel_neighborhood", "")
.typed<decltype(relabel_neighborhood)>();
return op.call(seed, sampled_nodes_with_duplicates,
sampled_neighbors_per_node, num_nodes, batch, csc, disjoint);
num_sampled_neighbors_per_node, num_nodes, batch, csc,
disjoint);
}

std::tuple<c10::Dict<rel_type, at::Tensor>, c10::Dict<rel_type, at::Tensor>>
Expand All @@ -37,8 +38,8 @@ hetero_relabel_neighborhood(
const std::vector<edge_type>& edge_types,
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<node_type, at::Tensor>& sampled_nodes_with_duplicates_dict,
const c10::Dict<rel_type, std::vector<int64_t>>&
sampled_neighbors_per_node_dict,
const c10::Dict<rel_type, std::vector<std::vector<int64_t>>>&
num_sampled_neighbors_per_node_dict,
const c10::Dict<node_type, int64_t>& num_nodes_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& batch_dict,
bool csc,
Expand All @@ -62,21 +63,22 @@ hetero_relabel_neighborhood(
.typed<decltype(hetero_relabel_neighborhood)>();
return op.call(node_types, edge_types, seed_dict,
sampled_nodes_with_duplicates_dict,
sampled_neighbors_per_node_dict, num_nodes_dict, batch_dict,
csc, disjoint);
num_sampled_neighbors_per_node_dict, num_nodes_dict,
batch_dict, csc, disjoint);
}

TORCH_LIBRARY_FRAGMENT(pyg, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::relabel_neighborhood(Tensor seed, Tensor "
"sampled_nodes_with_duplicates, int[] sampled_neighbors_per_node, int "
"sampled_nodes_with_duplicates, int[] num_sampled_neighbors_per_node, "
"int "
"num_nodes, Tensor? batch = None, bool csc = False, bool disjoint = "
"False) -> (Tensor, Tensor)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::hetero_relabel_neighborhood(str[] node_types, (str, str, str)[] "
"edge_types, Dict(str, Tensor) seed_dict, Dict(str, Tensor) "
"sampled_nodes_with_duplicates_dict, Dict(str, int[]) "
"sampled_neighbors_per_node_dict, Dict(str, int) num_nodes_dict, "
"sampled_nodes_with_duplicates_dict, Dict(str, int[][]) "
"num_sampled_neighbors_per_node_dict, Dict(str, int) num_nodes_dict, "
"Dict(str, Tensor)? batch_dict = None, bool csc = False, bool disjoint = "
"False) -> (Dict(str, Tensor), Dict(str, Tensor))"));
}
Expand Down
6 changes: 3 additions & 3 deletions pyg_lib/csrc/sampler/dist_relabel.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ PYG_API
std::tuple<at::Tensor, at::Tensor> relabel_neighborhood(
const at::Tensor& seed,
const at::Tensor& sampled_nodes_with_duplicates,
const std::vector<int64_t>& sampled_neighbors_per_node,
const std::vector<int64_t>& num_sampled_neighbors_per_node,
const int64_t num_nodes,
const c10::optional<at::Tensor>& batch = c10::nullopt,
bool csc = false,
Expand All @@ -32,8 +32,8 @@ hetero_relabel_neighborhood(
const std::vector<edge_type>& edge_types,
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<node_type, at::Tensor>& sampled_nodes_with_duplicates_dict,
const c10::Dict<rel_type, std::vector<int64_t>>&
sampled_neighbors_per_node_dict,
const c10::Dict<rel_type, std::vector<std::vector<int64_t>>>&
num_sampled_neighbors_per_node_dict,
const c10::Dict<node_type, int64_t>& num_nodes_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& batch_dict =
c10::nullopt,
Expand Down
15 changes: 9 additions & 6 deletions test/csrc/sampler/test_dist_merge_outputs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ TEST(DistMergeOutputsTest, BasicAssertions) {
auto expected_edges = at::tensor({14, 15, 16, 17, 18, 19, 20}, options);
EXPECT_TRUE(at::equal(std::get<1>(out), expected_edges));

const std::vector<int64_t> expected_sampled_neighbors_per_node = {2, 1, 2, 2};
EXPECT_EQ(std::get<3>(out), expected_sampled_neighbors_per_node);
const std::vector<int64_t> expected_num_sampled_neighbors_per_node = {2, 1, 2,
2};
EXPECT_EQ(std::get<3>(out), expected_num_sampled_neighbors_per_node);
}

TEST(DistMergeOutputsAllNeighborsTest, BasicAssertions) {
Expand Down Expand Up @@ -82,8 +83,9 @@ TEST(DistMergeOutputsAllNeighborsTest, BasicAssertions) {
auto expected_edges = at::tensor({14, 15, 16, 17, 18, 19, 20, 21}, options);
EXPECT_TRUE(at::equal(std::get<1>(out), expected_edges));

const std::vector<int64_t> expected_sampled_neighbors_per_node = {2, 1, 2, 3};
EXPECT_EQ(std::get<3>(out), expected_sampled_neighbors_per_node);
const std::vector<int64_t> expected_num_sampled_neighbors_per_node = {2, 1, 2,
3};
EXPECT_EQ(std::get<3>(out), expected_num_sampled_neighbors_per_node);
}

TEST(DistDisjointMergeOutputsTest, BasicAssertions) {
Expand Down Expand Up @@ -124,6 +126,7 @@ TEST(DistDisjointMergeOutputsTest, BasicAssertions) {
auto expected_batch = at::tensor({0, 0, 1, 2, 2, 3, 3}, options);
EXPECT_TRUE(at::equal(std::get<2>(out).value(), expected_batch));

const std::vector<int64_t> expected_sampled_neighbors_per_node = {2, 1, 2, 2};
EXPECT_EQ(std::get<3>(out), expected_sampled_neighbors_per_node);
const std::vector<int64_t> expected_num_sampled_neighbors_per_node = {2, 1, 2,
2};
EXPECT_EQ(std::get<3>(out), expected_num_sampled_neighbors_per_node);
}
Loading
Loading