Skip to content

Commit

Permalink
Clean up self-loop and multi-edge removal logic (#4032)
Browse files Browse the repository at this point in the history
There are mask utilities that perform some of the functions that were implemented to do this cleanup.

Use the mask utilities instead of replicating functionality.

Authors:
  - Chuck Hastings (https://github.com/ChuckHastings)

Approvers:
  - Seunghwa Kang (https://github.com/seunghwak)

URL: #4032
  • Loading branch information
ChuckHastings authored Dec 5, 2023
1 parent 20145b4 commit 32eaa5e
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 212 deletions.
33 changes: 10 additions & 23 deletions cpp/src/structure/detail/structure_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <cugraph/utilities/dataframe_buffer.hpp>
#include <cugraph/utilities/device_functors.cuh>
#include <cugraph/utilities/error.hpp>
#include <cugraph/utilities/mask_utils.cuh>
#include <cugraph/utilities/misc_utils.cuh>
#include <cugraph/utilities/packed_bool_utils.hpp>

Expand Down Expand Up @@ -524,35 +525,21 @@ std::tuple<size_t, rmm::device_uvector<uint32_t>> mark_entries(raft::handle_t co
return word;
});

// FIXME: use detail::count_set_bits
size_t bit_count = thrust::transform_reduce(
handle.get_thrust_policy(),
marked_entries.begin(),
marked_entries.end(),
[] __device__(auto word) { return __popc(word); },
size_t{0},
thrust::plus<size_t>());
size_t bit_count = detail::count_set_bits(handle, marked_entries.begin(), num_entries);

return std::make_tuple(bit_count, std::move(marked_entries));
}

template <typename T>
rmm::device_uvector<T> remove_flagged_elements(raft::handle_t const& handle,
rmm::device_uvector<T>&& vector,
raft::device_span<uint32_t const> remove_flags,
size_t remove_count)
rmm::device_uvector<T> keep_flagged_elements(raft::handle_t const& handle,
rmm::device_uvector<T>&& vector,
raft::device_span<uint32_t const> keep_flags,
size_t keep_count)
{
rmm::device_uvector<T> result(vector.size() - remove_count, handle.get_stream());

thrust::copy_if(
handle.get_thrust_policy(),
thrust::make_counting_iterator(size_t{0}),
thrust::make_counting_iterator(vector.size()),
thrust::make_transform_output_iterator(result.begin(),
indirection_t<size_t, T*>{vector.data()}),
[remove_flags] __device__(size_t i) {
return !(remove_flags[cugraph::packed_bool_offset(i)] & cugraph::packed_bool_mask(i));
});
rmm::device_uvector<T> result(keep_count, handle.get_stream());

detail::copy_if_mask_set(
handle, vector.begin(), vector.end(), keep_flags.begin(), result.begin());

return result;
}
Expand Down
53 changes: 25 additions & 28 deletions cpp/src/structure/remove_multi_edges_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -254,50 +254,47 @@ remove_multi_edges(raft::handle_t const& handle,
}
}

auto [multi_edge_count, multi_edges_to_delete] =
detail::mark_entries(handle,
edgelist_srcs.size(),
[d_edgelist_srcs = edgelist_srcs.data(),
d_edgelist_dsts = edgelist_dsts.data()] __device__(auto idx) {
return (idx > 0) && (d_edgelist_srcs[idx - 1] == d_edgelist_srcs[idx]) &&
(d_edgelist_dsts[idx - 1] == d_edgelist_dsts[idx]);
});

if (multi_edge_count > 0) {
edgelist_srcs = detail::remove_flagged_elements(
auto [keep_count, keep_flags] = detail::mark_entries(
handle,
edgelist_srcs.size(),
[d_edgelist_srcs = edgelist_srcs.data(),
d_edgelist_dsts = edgelist_dsts.data()] __device__(auto idx) {
return !((idx > 0) && (d_edgelist_srcs[idx - 1] == d_edgelist_srcs[idx]) &&
(d_edgelist_dsts[idx - 1] == d_edgelist_dsts[idx]));
});

if (keep_count < edgelist_srcs.size()) {
edgelist_srcs = detail::keep_flagged_elements(
handle,
std::move(edgelist_srcs),
raft::device_span<uint32_t const>{multi_edges_to_delete.data(), multi_edges_to_delete.size()},
multi_edge_count);
edgelist_dsts = detail::remove_flagged_elements(
raft::device_span<uint32_t const>{keep_flags.data(), keep_flags.size()},
keep_count);
edgelist_dsts = detail::keep_flagged_elements(
handle,
std::move(edgelist_dsts),
raft::device_span<uint32_t const>{multi_edges_to_delete.data(), multi_edges_to_delete.size()},
multi_edge_count);
raft::device_span<uint32_t const>{keep_flags.data(), keep_flags.size()},
keep_count);

if (edgelist_weights)
edgelist_weights = detail::remove_flagged_elements(
edgelist_weights = detail::keep_flagged_elements(
handle,
std::move(*edgelist_weights),
raft::device_span<uint32_t const>{multi_edges_to_delete.data(),
multi_edges_to_delete.size()},
multi_edge_count);
raft::device_span<uint32_t const>{keep_flags.data(), keep_flags.size()},
keep_count);

if (edgelist_edge_ids)
edgelist_edge_ids = detail::remove_flagged_elements(
edgelist_edge_ids = detail::keep_flagged_elements(
handle,
std::move(*edgelist_edge_ids),
raft::device_span<uint32_t const>{multi_edges_to_delete.data(),
multi_edges_to_delete.size()},
multi_edge_count);
raft::device_span<uint32_t const>{keep_flags.data(), keep_flags.size()},
keep_count);

if (edgelist_edge_types)
edgelist_edge_types = detail::remove_flagged_elements(
edgelist_edge_types = detail::keep_flagged_elements(
handle,
std::move(*edgelist_edge_types),
raft::device_span<uint32_t const>{multi_edges_to_delete.data(),
multi_edges_to_delete.size()},
multi_edge_count);
raft::device_span<uint32_t const>{keep_flags.data(), keep_flags.size()},
keep_count);
}

return std::make_tuple(std::move(edgelist_srcs),
Expand Down
36 changes: 18 additions & 18 deletions cpp/src/structure/remove_self_loops_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,44 +44,44 @@ remove_self_loops(raft::handle_t const& handle,
std::optional<rmm::device_uvector<edge_t>>&& edgelist_edge_ids,
std::optional<rmm::device_uvector<edge_type_t>>&& edgelist_edge_types)
{
auto [self_loop_count, self_loops_to_delete] =
auto [keep_count, keep_flags] =
detail::mark_entries(handle,
edgelist_srcs.size(),
[d_srcs = edgelist_srcs.data(), d_dsts = edgelist_dsts.data()] __device__(
size_t i) { return d_srcs[i] == d_dsts[i]; });
size_t i) { return d_srcs[i] != d_dsts[i]; });

if (self_loop_count > 0) {
edgelist_srcs = detail::remove_flagged_elements(
if (keep_count < edgelist_srcs.size()) {
edgelist_srcs = detail::keep_flagged_elements(
handle,
std::move(edgelist_srcs),
raft::device_span<uint32_t const>{self_loops_to_delete.data(), self_loops_to_delete.size()},
self_loop_count);
edgelist_dsts = detail::remove_flagged_elements(
raft::device_span<uint32_t const>{keep_flags.data(), keep_flags.size()},
keep_count);
edgelist_dsts = detail::keep_flagged_elements(
handle,
std::move(edgelist_dsts),
raft::device_span<uint32_t const>{self_loops_to_delete.data(), self_loops_to_delete.size()},
self_loop_count);
raft::device_span<uint32_t const>{keep_flags.data(), keep_flags.size()},
keep_count);

if (edgelist_weights)
edgelist_weights = detail::remove_flagged_elements(
edgelist_weights = detail::keep_flagged_elements(
handle,
std::move(*edgelist_weights),
raft::device_span<uint32_t const>{self_loops_to_delete.data(), self_loops_to_delete.size()},
self_loop_count);
raft::device_span<uint32_t const>{keep_flags.data(), keep_flags.size()},
keep_count);

if (edgelist_edge_ids)
edgelist_edge_ids = detail::remove_flagged_elements(
edgelist_edge_ids = detail::keep_flagged_elements(
handle,
std::move(*edgelist_edge_ids),
raft::device_span<uint32_t const>{self_loops_to_delete.data(), self_loops_to_delete.size()},
self_loop_count);
raft::device_span<uint32_t const>{keep_flags.data(), keep_flags.size()},
keep_count);

if (edgelist_edge_types)
edgelist_edge_types = detail::remove_flagged_elements(
edgelist_edge_types = detail::keep_flagged_elements(
handle,
std::move(*edgelist_edge_types),
raft::device_span<uint32_t const>{self_loops_to_delete.data(), self_loops_to_delete.size()},
self_loop_count);
raft::device_span<uint32_t const>{keep_flags.data(), keep_flags.size()},
keep_count);
}

return std::make_tuple(std::move(edgelist_srcs),
Expand Down
4 changes: 2 additions & 2 deletions cpp/tests/community/triangle_count_test.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -232,7 +232,7 @@ class Tests_TriangleCount
for (size_t i = 0; i < h_cugraph_vertices.size(); ++i) {
auto v = h_cugraph_vertices[i];
auto count = h_cugraph_triangle_counts[i];
ASSERT_TRUE(count == h_reference_triangle_counts[v])
ASSERT_EQ(count, h_reference_triangle_counts[v])
<< "Triangle count values do not match with the reference values.";
}
}
Expand Down
20 changes: 18 additions & 2 deletions cpp/tests/utilities/test_graphs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -621,9 +621,25 @@ construct_graph(raft::handle_t const& handle,

CUGRAPH_EXPECTS(d_src_v.size() <= static_cast<size_t>(std::numeric_limits<edge_t>::max()),
"Invalid template parameter: edge_t overflow.");
if (drop_self_loops) { remove_self_loops(handle, d_src_v, d_dst_v, d_weights_v); }
if (drop_self_loops) {
std::tie(d_src_v, d_dst_v, d_weights_v, std::ignore, std::ignore) =
cugraph::remove_self_loops<vertex_t, edge_t, weight_t, int32_t>(handle,
std::move(d_src_v),
std::move(d_dst_v),
std::move(d_weights_v),
std::nullopt,
std::nullopt);
}

if (drop_multi_edges) { sort_and_remove_multi_edges(handle, d_src_v, d_dst_v, d_weights_v); }
if (drop_multi_edges) {
std::tie(d_src_v, d_dst_v, d_weights_v, std::ignore, std::ignore) =
cugraph::remove_multi_edges<vertex_t, edge_t, weight_t, int32_t>(handle,
std::move(d_src_v),
std::move(d_dst_v),
std::move(d_weights_v),
std::nullopt,
std::nullopt);
}

graph_t<vertex_t, edge_t, store_transposed, multi_gpu> graph(handle);
std::optional<
Expand Down
126 changes: 0 additions & 126 deletions cpp/tests/utilities/thrust_wrapper.cu
Original file line number Diff line number Diff line change
Expand Up @@ -206,131 +206,5 @@ template void populate_vertex_ids(raft::handle_t const& handle,
rmm::device_uvector<int64_t>& d_vertices_v,
int64_t vertex_id_offset);

template <typename vertex_t, typename weight_t>
void remove_self_loops(raft::handle_t const& handle,
rmm::device_uvector<vertex_t>& d_src_v /* [INOUT] */,
rmm::device_uvector<vertex_t>& d_dst_v /* [INOUT] */,
std::optional<rmm::device_uvector<weight_t>>& d_weight_v /* [INOUT] */)
{
if (d_weight_v) {
auto edge_first = thrust::make_zip_iterator(
thrust::make_tuple(d_src_v.begin(), d_dst_v.begin(), (*d_weight_v).begin()));
d_src_v.resize(
thrust::distance(edge_first,
thrust::remove_if(
handle.get_thrust_policy(),
edge_first,
edge_first + d_src_v.size(),
[] __device__(auto e) { return thrust::get<0>(e) == thrust::get<1>(e); })),
handle.get_stream());
d_dst_v.resize(d_src_v.size(), handle.get_stream());
(*d_weight_v).resize(d_src_v.size(), handle.get_stream());
} else {
auto edge_first =
thrust::make_zip_iterator(thrust::make_tuple(d_src_v.begin(), d_dst_v.begin()));
d_src_v.resize(
thrust::distance(edge_first,
thrust::remove_if(
handle.get_thrust_policy(),
edge_first,
edge_first + d_src_v.size(),
[] __device__(auto e) { return thrust::get<0>(e) == thrust::get<1>(e); })),
handle.get_stream());
d_dst_v.resize(d_src_v.size(), handle.get_stream());
}

d_src_v.shrink_to_fit(handle.get_stream());
d_dst_v.shrink_to_fit(handle.get_stream());
if (d_weight_v) { (*d_weight_v).shrink_to_fit(handle.get_stream()); }
}

template void remove_self_loops(
raft::handle_t const& handle,
rmm::device_uvector<int32_t>& d_src_v /* [INOUT] */,
rmm::device_uvector<int32_t>& d_dst_v /* [INOUT] */,
std::optional<rmm::device_uvector<float>>& d_weight_v /* [INOUT] */);

template void remove_self_loops(
raft::handle_t const& handle,
rmm::device_uvector<int32_t>& d_src_v /* [INOUT] */,
rmm::device_uvector<int32_t>& d_dst_v /* [INOUT] */,
std::optional<rmm::device_uvector<double>>& d_weight_v /* [INOUT] */);

template void remove_self_loops(
raft::handle_t const& handle,
rmm::device_uvector<int64_t>& d_src_v /* [INOUT] */,
rmm::device_uvector<int64_t>& d_dst_v /* [INOUT] */,
std::optional<rmm::device_uvector<float>>& d_weight_v /* [INOUT] */);

template void remove_self_loops(
raft::handle_t const& handle,
rmm::device_uvector<int64_t>& d_src_v /* [INOUT] */,
rmm::device_uvector<int64_t>& d_dst_v /* [INOUT] */,
std::optional<rmm::device_uvector<double>>& d_weight_v /* [INOUT] */);

template <typename vertex_t, typename weight_t>
void sort_and_remove_multi_edges(
raft::handle_t const& handle,
rmm::device_uvector<vertex_t>& d_src_v /* [INOUT] */,
rmm::device_uvector<vertex_t>& d_dst_v /* [INOUT] */,
std::optional<rmm::device_uvector<weight_t>>& d_weight_v /* [INOUT] */)
{
if (d_weight_v) {
auto edge_first = thrust::make_zip_iterator(
thrust::make_tuple(d_src_v.begin(), d_dst_v.begin(), (*d_weight_v).begin()));
thrust::sort(handle.get_thrust_policy(), edge_first, edge_first + d_src_v.size());
d_src_v.resize(
thrust::distance(edge_first,
thrust::unique(handle.get_thrust_policy(),
edge_first,
edge_first + d_src_v.size(),
[] __device__(auto lhs, auto rhs) {
return (thrust::get<0>(lhs) == thrust::get<0>(rhs)) &&
(thrust::get<1>(lhs) == thrust::get<1>(rhs));
})),
handle.get_stream());
d_dst_v.resize(d_src_v.size(), handle.get_stream());
(*d_weight_v).resize(d_src_v.size(), handle.get_stream());
} else {
auto edge_first =
thrust::make_zip_iterator(thrust::make_tuple(d_src_v.begin(), d_dst_v.begin()));
thrust::sort(handle.get_thrust_policy(), edge_first, edge_first + d_src_v.size());
d_src_v.resize(
thrust::distance(
edge_first,
thrust::unique(handle.get_thrust_policy(), edge_first, edge_first + d_src_v.size())),
handle.get_stream());
d_dst_v.resize(d_src_v.size(), handle.get_stream());
}

d_src_v.shrink_to_fit(handle.get_stream());
d_dst_v.shrink_to_fit(handle.get_stream());
if (d_weight_v) { (*d_weight_v).shrink_to_fit(handle.get_stream()); }
}

template void sort_and_remove_multi_edges(
raft::handle_t const& handle,
rmm::device_uvector<int32_t>& d_src_v /* [INOUT] */,
rmm::device_uvector<int32_t>& d_dst_v /* [INOUT] */,
std::optional<rmm::device_uvector<float>>& d_weight_v /* [INOUT] */);

template void sort_and_remove_multi_edges(
raft::handle_t const& handle,
rmm::device_uvector<int32_t>& d_src_v /* [INOUT] */,
rmm::device_uvector<int32_t>& d_dst_v /* [INOUT] */,
std::optional<rmm::device_uvector<double>>& d_weight_v /* [INOUT] */);

template void sort_and_remove_multi_edges(
raft::handle_t const& handle,
rmm::device_uvector<int64_t>& d_src_v /* [INOUT] */,
rmm::device_uvector<int64_t>& d_dst_v /* [INOUT] */,
std::optional<rmm::device_uvector<float>>& d_weight_v /* [INOUT] */);

template void sort_and_remove_multi_edges(
raft::handle_t const& handle,
rmm::device_uvector<int64_t>& d_src_v /* [INOUT] */,
rmm::device_uvector<int64_t>& d_dst_v /* [INOUT] */,
std::optional<rmm::device_uvector<double>>& d_weight_v /* [INOUT] */);

} // namespace test
} // namespace cugraph
13 changes: 0 additions & 13 deletions cpp/tests/utilities/thrust_wrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,5 @@ void populate_vertex_ids(raft::handle_t const& handle,
rmm::device_uvector<vertex_t>& d_vertices_v /* [INOUT] */,
vertex_t vertex_id_offset);

template <typename vertex_t, typename weight_t>
void remove_self_loops(raft::handle_t const& handle,
rmm::device_uvector<vertex_t>& d_src_v /* [INOUT] */,
rmm::device_uvector<vertex_t>& d_dst_v /* [INOUT] */,
std::optional<rmm::device_uvector<weight_t>>& d_weight_v /* [INOUT] */);

template <typename vertex_t, typename weight_t>
void sort_and_remove_multi_edges(
raft::handle_t const& handle,
rmm::device_uvector<vertex_t>& d_src_v /* [INOUT] */,
rmm::device_uvector<vertex_t>& d_dst_v /* [INOUT] */,
std::optional<rmm::device_uvector<weight_t>>& d_weight_v /* [INOUT] */);

} // namespace test
} // namespace cugraph

0 comments on commit 32eaa5e

Please sign in to comment.