Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Heterogeneous Distributed Sampling #4795

Merged
merged 73 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
e7862a1
h
alexbarghi-nv Nov 19, 2024
22155fc
remove file
alexbarghi-nv Nov 19, 2024
174f32c
Merge branch 'branch-24.12' of https://github.com/rapidsai/cugraph in…
alexbarghi-nv Nov 19, 2024
95c5ba2
homogeneous neighborhood sampling doesn't support edge renumber_map a…
jnke2016 Nov 19, 2024
af465f9
update docstrings
jnke2016 Nov 19, 2024
c3861e1
fix batch ids
alexbarghi-nv Nov 20, 2024
19161b1
Merge branch 'branch-24.12_remove-edge-renumber-map' of https://githu…
alexbarghi-nv Nov 20, 2024
61a0926
fix bug in heterogeneous renumbering
jnke2016 Nov 20, 2024
707cf3d
fix style
jnke2016 Nov 20, 2024
0b47bc1
Merge branch 'branch-24.12_remove-edge-renumber-map' of https://githu…
alexbarghi-nv Nov 21, 2024
4f42727
dist sampling
alexbarghi-nv Nov 21, 2024
ef848e9
handle case where there is no sampling result prior to renumbering
jnke2016 Nov 22, 2024
535df5d
expose 'label_type_hop_offsets'
jnke2016 Nov 22, 2024
3501a45
fix style
jnke2016 Nov 22, 2024
8e9d64f
Merge branch 'branch-24.12_remove-edge-renumber-map' of https://githu…
alexbarghi-nv Nov 22, 2024
288b8fc
remove debug print and properly compute the 'result_offsets'
jnke2016 Nov 22, 2024
9f459de
update docstrings
jnke2016 Nov 22, 2024
fc0426d
directly return the result
jnke2016 Nov 22, 2024
d997656
fix illegal memory access
jnke2016 Nov 22, 2024
b7405e0
Merge remote-tracking branch 'upstream/branch-24.12' into branch-24.1…
jnke2016 Nov 22, 2024
fc70811
rename variable for consistency
jnke2016 Nov 22, 2024
463f12b
fix style
jnke2016 Nov 22, 2024
3767b94
rename variable for consistency
jnke2016 Nov 22, 2024
e662bc2
Merge branch 'branch-24.12_remove-edge-renumber-map' of https://githu…
alexbarghi-nv Nov 22, 2024
6aeee28
add support for vertex type
jnke2016 Nov 22, 2024
223a73b
add support for vertex type at the plc layer
jnke2016 Nov 22, 2024
061c8cc
properly handle missing edge types
jnke2016 Nov 22, 2024
ed4c069
fix style
jnke2016 Nov 22, 2024
900a3fe
Merge remote-tracking branch 'upstream/branch-24.12' into branch-24.1…
jnke2016 Nov 23, 2024
61504a9
Merge remote-tracking branch 'upstream/branch-24.12' into branch-24.1…
jnke2016 Nov 25, 2024
b33d071
Merge branch 'branch-24.12_remove-edge-renumber-map' of https://githu…
alexbarghi-nv Nov 26, 2024
4d3f8c1
properly handle sampling with multiple edge types
jnke2016 Nov 26, 2024
d88eebd
remove debug print
jnke2016 Nov 27, 2024
77a28c2
undo changes to 'prepare_next_frontier'
jnke2016 Nov 27, 2024
b760a61
properly handle sampling with multiple edge types
jnke2016 Nov 27, 2024
580d3e9
properly compute the number of hops for heterogeneous renumbering
jnke2016 Nov 27, 2024
529955a
fix style
jnke2016 Nov 27, 2024
c3a58a0
Merge remote-tracking branch 'upstream/branch-24.12' into branch-24.1…
jnke2016 Nov 27, 2024
bf17251
update docstrings
jnke2016 Nov 27, 2024
88b35ba
simplify code and re-order statements
jnke2016 Nov 27, 2024
f8c576a
fix style
jnke2016 Nov 27, 2024
9648d69
update docstrings
jnke2016 Nov 27, 2024
d7a05a5
Merge branch 'branch-24.12_remove-edge-renumber-map' of https://githu…
alexbarghi-nv Nov 27, 2024
9d84558
fix bug when creating struct
jnke2016 Nov 30, 2024
a7a224c
fix style
jnke2016 Nov 30, 2024
d64dc66
add missing arguments
jnke2016 Dec 1, 2024
bcfc99c
update label list if some are missing from the result
jnke2016 Dec 1, 2024
19c37a8
fix style
jnke2016 Dec 1, 2024
7b7c648
x
alexbarghi-nv Dec 2, 2024
bbc32cd
c
alexbarghi-nv Dec 2, 2024
8e82263
M
alexbarghi-nv Dec 5, 2024
e3a45cc
pull in changes
alexbarghi-nv Dec 10, 2024
f71ad9c
update to 25.02
alexbarghi-nv Dec 10, 2024
c94c1cb
revert cpp change
alexbarghi-nv Dec 10, 2024
9754f9d
fix disk writing with new api
alexbarghi-nv Dec 10, 2024
f87578e
add tests for sg
alexbarghi-nv Dec 10, 2024
8c8891b
Merge branch 'branch-25.02' of https://github.com/rapidsai/cugraph in…
alexbarghi-nv Dec 10, 2024
580a7de
Merge branch 'branch-25.02' into hetero-dist-sampler
alexbarghi-nv Dec 10, 2024
4c87153
Merge branch 'hetero-dist-sampler' of https://github.com/alexbarghi-n…
alexbarghi-nv Dec 10, 2024
42a2d4a
fix style
jnke2016 Dec 10, 2024
a8cef9b
Merge branch 'branch-25.02_fix-bug-mg-nbr-sampling' of https://github…
alexbarghi-nv Dec 10, 2024
ba52838
fix rank issue in dist sampler
alexbarghi-nv Dec 10, 2024
d934b72
change empty input handling to work with new plc api
alexbarghi-nv Dec 11, 2024
b9437ea
fix bulk sampler tests, re-enable other tests
alexbarghi-nv Dec 11, 2024
6d82b03
Merge branch 'branch-25.02' into hetero-dist-sampler
nv-rliu Dec 12, 2024
a8f9bd7
Merge branch 'branch-25.02' of https://github.com/rapidsai/cugraph in…
alexbarghi-nv Dec 23, 2024
85c9cc7
Merge branch 'branch-25.02' into hetero-dist-sampler
alexbarghi-nv Jan 2, 2025
eff5025
update branch
alexbarghi-nv Jan 2, 2025
8aa7604
Merge branch 'hetero-dist-sampler' of https://github.com/alexbarghi-n…
alexbarghi-nv Jan 2, 2025
4659bd2
fix style
alexbarghi-nv Jan 2, 2025
e1e2d34
change copyright
alexbarghi-nv Jan 3, 2025
a21e673
update branch
alexbarghi-nv Jan 9, 2025
0ccf4f1
Merge branch 'branch-25.02' into hetero-dist-sampler
alexbarghi-nv Jan 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 49 additions & 30 deletions python/cugraph/cugraph/gnn/data_loading/dist_io/writer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Copyright (c) 2024-2025, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -79,9 +79,15 @@ def get_reader(
return DistSampleReader(self._directory, format=self._format, rank=rank)

def __write_minibatches_coo(self, minibatch_dict):
has_edge_ids = minibatch_dict["edge_id"] is not None
has_edge_types = minibatch_dict["edge_type"] is not None
has_weights = minibatch_dict["weight"] is not None
has_edge_ids = (
"edge_id" in minibatch_dict and minibatch_dict["edge_id"] is not None
)
has_edge_types = (
"edge_type" in minibatch_dict and minibatch_dict["edge_type"] is not None
)
has_weights = (
"weight" in minibatch_dict and minibatch_dict["weight"] is not None
)

if minibatch_dict["renumber_map"] is None:
raise ValueError(
Expand All @@ -92,22 +98,22 @@ def __write_minibatches_coo(self, minibatch_dict):
if len(minibatch_dict["batch_id"]) == 0:
return

fanout_length = (len(minibatch_dict["label_hop_offsets"]) - 1) // len(
minibatch_dict["batch_id"]
)
fanout_length = len(minibatch_dict["fanout"])
total_num_batches = (
len(minibatch_dict["label_hop_offsets"]) - 1
) / fanout_length

for p in range(
0, int(ceil(len(minibatch_dict["batch_id"]) / self.__batches_per_partition))
):
for p in range(0, int(ceil(total_num_batches / self.__batches_per_partition))):
partition_start = p * (self.__batches_per_partition)
partition_end = (p + 1) * (self.__batches_per_partition)

label_hop_offsets_array_p = minibatch_dict["label_hop_offsets"][
partition_start * fanout_length : partition_end * fanout_length + 1
]

batch_id_array_p = minibatch_dict["batch_id"][partition_start:partition_end]
start_batch_id = batch_id_array_p[0]
num_batches_p = len(label_hop_offsets_array_p) - 1

start_batch_id = minibatch_dict["batch_start"]

input_offsets_p = minibatch_dict["input_offsets"][
partition_start : (partition_end + 1)
Expand Down Expand Up @@ -171,7 +177,7 @@ def __write_minibatches_coo(self, minibatch_dict):
}
)

end_batch_id = start_batch_id + len(batch_id_array_p) - 1
end_batch_id = start_batch_id + num_batches_p - 1
rank = minibatch_dict["rank"] if "rank" in minibatch_dict else 0

full_output_path = os.path.join(
Expand All @@ -188,9 +194,15 @@ def __write_minibatches_coo(self, minibatch_dict):
)

def __write_minibatches_csr(self, minibatch_dict):
has_edge_ids = minibatch_dict["edge_id"] is not None
has_edge_types = minibatch_dict["edge_type"] is not None
has_weights = minibatch_dict["weight"] is not None
has_edge_ids = (
"edge_id" in minibatch_dict and minibatch_dict["edge_id"] is not None
)
has_edge_types = (
"edge_type" in minibatch_dict and minibatch_dict["edge_type"] is not None
)
has_weights = (
"weight" in minibatch_dict and minibatch_dict["weight"] is not None
)

if minibatch_dict["renumber_map"] is None:
raise ValueError(
Expand All @@ -201,22 +213,22 @@ def __write_minibatches_csr(self, minibatch_dict):
if len(minibatch_dict["batch_id"]) == 0:
return

fanout_length = (len(minibatch_dict["label_hop_offsets"]) - 1) // len(
minibatch_dict["batch_id"]
)
fanout_length = len(minibatch_dict["fanout"])
total_num_batches = (
len(minibatch_dict["label_hop_offsets"]) - 1
) / fanout_length

for p in range(
0, int(ceil(len(minibatch_dict["batch_id"]) / self.__batches_per_partition))
):
for p in range(0, int(ceil(total_num_batches / self.__batches_per_partition))):
partition_start = p * (self.__batches_per_partition)
partition_end = (p + 1) * (self.__batches_per_partition)

label_hop_offsets_array_p = minibatch_dict["label_hop_offsets"][
partition_start * fanout_length : partition_end * fanout_length + 1
]

batch_id_array_p = minibatch_dict["batch_id"][partition_start:partition_end]
start_batch_id = batch_id_array_p[0]
num_batches_p = len(label_hop_offsets_array_p) - 1

start_batch_id = minibatch_dict["batch_start"]

input_offsets_p = minibatch_dict["input_offsets"][
partition_start : (partition_end + 1)
Expand Down Expand Up @@ -292,7 +304,7 @@ def __write_minibatches_csr(self, minibatch_dict):
}
)

end_batch_id = start_batch_id + len(batch_id_array_p) - 1
end_batch_id = start_batch_id + num_batches_p - 1
rank = minibatch_dict["rank"] if "rank" in minibatch_dict else 0

full_output_path = os.path.join(
Expand All @@ -309,12 +321,19 @@ def __write_minibatches_csr(self, minibatch_dict):
)

def write_minibatches(self, minibatch_dict):
if (minibatch_dict["majors"] is not None) and (
minibatch_dict["minors"] is not None
):
if "minors" not in minibatch_dict:
raise ValueError("invalid columns")

# PLC API specifies this behavior for empty input
# This needs to be handled here to avoid causing a hang
if len(minibatch_dict["minors"]) == 0:
return

if "majors" in minibatch_dict and minibatch_dict["majors"] is not None:
self.__write_minibatches_coo(minibatch_dict)
elif (minibatch_dict["major_offsets"] is not None) and (
minibatch_dict["minors"] is not None
elif (
"major_offsets" in minibatch_dict
and minibatch_dict["major_offsets"] is not None
):
self.__write_minibatches_csr(minibatch_dict)
else:
Expand Down
Loading
Loading