From 225785d9d406812de0d38ad8cddf0728436ef60c Mon Sep 17 00:00:00 2001 From: OlhaBabicheva Date: Tue, 31 Oct 2023 07:20:48 +0000 Subject: [PATCH 1/7] Added optimization of biased sampling --- pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp | 22 +++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index cfc679d01..83a850aa3 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -223,11 +223,23 @@ class NeighborSampler { // Case 2: Multinomial sampling: else { - const auto index = at::multinomial(weight, count, replace); - const auto index_data = index.data_ptr(); - for (size_t i = 0; i < index.numel(); ++i) { - add(row_start + index_data[i], global_src_node, local_src_node, - dst_mapper, out_global_dst_nodes); + if (replace) { + const auto index = at::multinomial(weight, count, replace); + const auto index_data = index.data_ptr(); + for (size_t i = 0; i < index.numel(); ++i) { + add(row_start + index_data[i], global_src_node, local_src_node, + dst_mapper, out_global_dst_nodes); + } + } + else { + const auto rand = at::empty_like(weight).uniform_(); + const auto a = (rand.log() / weight); + const auto index = std::get<1>(a.topk(count)); + const auto index_data = index.data_ptr(); + for (size_t i = 0; i < index.numel(); ++i) { + add(row_start + index_data[i], global_src_node, local_src_node, + dst_mapper, out_global_dst_nodes); + } } } } From 80d58346e3c0bba00f0e59c3a18483825e744c0a Mon Sep 17 00:00:00 2001 From: OlhaBabicheva Date: Tue, 31 Oct 2023 07:50:17 +0000 Subject: [PATCH 2/7] Small changes --- pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index 83a850aa3..8719ec980 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -233,8 +233,8 @@ class NeighborSampler { } else { const auto rand = at::empty_like(weight).uniform_(); - const auto a = (rand.log() / weight); - const auto index = std::get<1>(a.topk(count)); + const auto key = (rand.log() / weight); + const auto index = std::get<1>(key.topk(count)); const auto index_data = index.data_ptr(); for (size_t i = 0; i < index.numel(); ++i) { add(row_start + index_data[i], global_src_node, local_src_node, From 2722c17856ef211ad3625f2944d8135488e22645 Mon Sep 17 00:00:00 2001 From: OlhaBabicheva Date: Tue, 31 Oct 2023 07:53:08 +0000 Subject: [PATCH 3/7] Updated CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f8ca4155..f4d8b7fc4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added ### Changed - Added `--biased` parameter to run benchmarks for biased sampling ([#267](https://github.com/pyg-team/pyg-lib/pull/267)) +- Improved speed of biased sampling ([#270](https://github.com/pyg-team/pyg-lib/pull/270)) ### Removed ## [0.3.0] - 2023-10-11 From 2b7c2fb09cfef8bb76db6d8ddf0dfc2613202271 Mon Sep 17 00:00:00 2001 From: OlhaBabicheva Date: Tue, 31 Oct 2023 08:38:15 +0000 Subject: [PATCH 4/7] Updated neighbor_kernel.cpp --- pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp | 22 ++++++++------------ 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index 8719ec980..df51be63f 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -223,23 +223,19 @@ class NeighborSampler { // Case 2: Multinomial sampling: else { + at::Tensor index; if (replace) { - const auto index = at::multinomial(weight, count, replace); - const auto index_data = index.data_ptr(); - for (size_t i = 0; i < index.numel(); ++i) { - add(row_start + index_data[i], global_src_node, local_src_node, - dst_mapper, out_global_dst_nodes); - } + index = at::multinomial(weight, count, replace); } - else { + else { // An Efficient Algorithm for Biased Sampling: https://utopia.duth.gr/~pefraimi/research/data/2007EncOfAlg.pdf const auto rand = at::empty_like(weight).uniform_(); const auto key = (rand.log() / weight); - const auto index = std::get<1>(key.topk(count)); - const auto index_data = index.data_ptr(); - for (size_t i = 0; i < index.numel(); ++i) { - add(row_start + index_data[i], global_src_node, local_src_node, - dst_mapper, out_global_dst_nodes); - } + index = std::get<1>(key.topk(count)); + } + const auto index_data = index.data_ptr(); + for (size_t i = 0; i < index.numel(); ++i) { + add(row_start + index_data[i], global_src_node, local_src_node, + dst_mapper, out_global_dst_nodes); } } } From cf440e89ce0c6f19b1043a9e83f91d2ea3795b7d Mon Sep 17 00:00:00 2001 From: OlhaBabicheva Date: Tue, 31 Oct 2023 08:44:47 +0000 Subject: [PATCH 5/7] Added comments --- pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index df51be63f..202bf906e 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -224,7 +224,7 @@ class NeighborSampler { // Case 2: Multinomial sampling: else { at::Tensor index; - if (replace) { + if (replace) { // at::multinomial has good perfomance only when replace=true, e.g. https://github.com/pytorch/pytorch/issues/11931#top index = at::multinomial(weight, count, replace); } else { // An Efficient Algorithm for Biased Sampling: https://utopia.duth.gr/~pefraimi/research/data/2007EncOfAlg.pdf From 3fb34e93c490be09a4d9fe02c0b3e5c8fced3de3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 31 Oct 2023 08:45:24 +0000 Subject: [PATCH 6/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index 202bf906e..0bc4cace9 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -224,10 +224,12 @@ class NeighborSampler { // Case 2: Multinomial sampling: else { at::Tensor index; - if (replace) { // at::multinomial has good perfomance only when replace=true, e.g. https://github.com/pytorch/pytorch/issues/11931#top + if (replace) { // at::multinomial has good perfomance only when + // replace=true, e.g. + // https://github.com/pytorch/pytorch/issues/11931#top index = at::multinomial(weight, count, replace); - } - else { // An Efficient Algorithm for Biased Sampling: https://utopia.duth.gr/~pefraimi/research/data/2007EncOfAlg.pdf + } else { // An Efficient Algorithm for Biased Sampling: + // https://utopia.duth.gr/~pefraimi/research/data/2007EncOfAlg.pdf const auto rand = at::empty_like(weight).uniform_(); const auto key = (rand.log() / weight); index = std::get<1>(key.topk(count)); From 633db5b756b6bcaf088080c030024c8357c856bf Mon Sep 17 00:00:00 2001 From: rusty1s Date: Tue, 31 Oct 2023 08:59:09 +0000 Subject: [PATCH 7/7] update --- pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index 0bc4cace9..f26ee13af 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -224,12 +224,14 @@ class NeighborSampler { // Case 2: Multinomial sampling: else { at::Tensor index; - if (replace) { // at::multinomial has good perfomance only when - // replace=true, e.g. - // https://github.com/pytorch/pytorch/issues/11931#top + if (replace) { + // at::multinomial only has good perfomance for `replace=true`, see: + // https://github.com/pytorch/pytorch/issues/11931 index = at::multinomial(weight, count, replace); - } else { // An Efficient Algorithm for Biased Sampling: - // https://utopia.duth.gr/~pefraimi/research/data/2007EncOfAlg.pdf + } else { + // For `replace=false`, we make use of the implementation of the + // "Weighted Random Sampling" paper: + // https://utopia.duth.gr/~pefraimi/research/data/2007EncOfAlg.pdf const auto rand = at::empty_like(weight).uniform_(); const auto key = (rand.log() / weight); index = std::get<1>(key.topk(count));