Skip to content

Commit

Permalink
[GraphBolt][CUDA] Get world_size=1 somewhat for cooperative samplin…
Browse files Browse the repository at this point in the history
…g. (dmlc#7796)
  • Loading branch information
mfbalin authored Sep 12, 2024
1 parent 165e250 commit 189b83c
Show file tree
Hide file tree
Showing 3 changed files with 214 additions and 41 deletions.
134 changes: 125 additions & 9 deletions python/dgl/graphbolt/impl/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from functools import partial

import torch
import torch.distributed as thd
from torch.utils.data import functional_datapipe
from torch.utils.data.datapipes.iter import Mapper

Expand All @@ -12,10 +13,14 @@
index_select,
ORIGINAL_EDGE_ID,
)
from ..internal import compact_csc_format, unique_and_compact_csc_formats
from ..internal import (
compact_csc_format,
unique_and_compact,
unique_and_compact_csc_formats,
)
from ..minibatch_transformer import MiniBatchTransformer

from ..subgraph_sampler import SubgraphSampler
from ..subgraph_sampler import all_to_all, revert_to_homo, SubgraphSampler
from .fused_csc_sampling_graph import fused_csc_sampling_graph
from .sampled_subgraph_impl import SampledSubgraphImpl

Expand Down Expand Up @@ -455,12 +460,32 @@ def _subtract_hetero_indices_offset(
class CompactPerLayer(MiniBatchTransformer):
"""Compact the sampled edges for a single layer."""

def __init__(self, datapipe, deduplicate, asynchronous=False):
def __init__(
self, datapipe, deduplicate, cooperative=False, asynchronous=False
):
self.deduplicate = deduplicate
self.cooperative = cooperative
if asynchronous and deduplicate:
datapipe = datapipe.transform(self._compact_per_layer_async)
datapipe = datapipe.buffer()
super().__init__(datapipe, self._compact_per_layer_wait_future)
datapipe = datapipe.transform(self._compact_per_layer_wait_future)
if cooperative:
datapipe = datapipe.transform(
self._seeds_cooperative_exchange_1
)
datapipe = datapipe.buffer()
datapipe = datapipe.transform(
self._seeds_cooperative_exchange_2
)
datapipe = datapipe.buffer()
datapipe = datapipe.transform(
self._seeds_cooperative_exchange_3
)
datapipe = datapipe.buffer()
datapipe = datapipe.transform(
self._seeds_cooperative_exchange_4
)
super().__init__(datapipe)
else:
super().__init__(datapipe, self._compact_per_layer)

Expand Down Expand Up @@ -498,19 +523,20 @@ def _compact_per_layer_async(self, minibatch):
subgraph = minibatch.sampled_subgraphs[0]
seeds = minibatch._seed_nodes
assert self.deduplicate
rank = thd.get_rank() if self.cooperative else 0
world_size = thd.get_world_size() if self.cooperative else 1
minibatch._future = unique_and_compact_csc_formats(
subgraph.sampled_csc, seeds, async_op=True
subgraph.sampled_csc, seeds, rank, world_size, async_op=True
)
return minibatch

@staticmethod
def _compact_per_layer_wait_future(minibatch):
def _compact_per_layer_wait_future(self, minibatch):
subgraph = minibatch.sampled_subgraphs[0]
seeds = minibatch._seed_nodes
(
original_row_node_ids,
compacted_csc_format,
_,
seeds_offsets,
) = minibatch._future.wait()
delattr(minibatch, "_future")
subgraph = SampledSubgraphImpl(
Expand All @@ -521,6 +547,87 @@ def _compact_per_layer_wait_future(minibatch):
)
minibatch._seed_nodes = original_row_node_ids
minibatch.sampled_subgraphs[0] = subgraph
if self.cooperative:
subgraph._seeds_offsets = seeds_offsets
return minibatch

@staticmethod
def _seeds_cooperative_exchange_1(minibatch):
world_size = thd.get_world_size()
subgraph = minibatch.sampled_subgraphs[0]
seeds_offsets = subgraph._seeds_offsets
is_homogeneous = not isinstance(seeds_offsets, dict)
if is_homogeneous:
seeds_offsets = {"_N": seeds_offsets}
num_ntypes = len(seeds_offsets)
counts_sent = torch.empty(world_size * num_ntypes, dtype=torch.int64)
for i, offsets in enumerate(seeds_offsets.values()):
counts_sent[
torch.arange(i, world_size * num_ntypes, num_ntypes)
] = offsets.diff()
counts_received = torch.empty_like(counts_sent)
subgraph._counts_future = all_to_all(
counts_received.split(num_ntypes),
counts_sent.split(num_ntypes),
async_op=True,
)
subgraph._counts_sent = counts_sent
subgraph._counts_received = counts_received
return minibatch

@staticmethod
def _seeds_cooperative_exchange_2(minibatch):
world_size = thd.get_world_size()
seeds = minibatch._seed_nodes
is_homogenous = not isinstance(seeds, dict)
if is_homogenous:
seeds = {"_N": seeds}
subgraph = minibatch.sampled_subgraphs[0]
subgraph._counts_future.wait()
delattr(subgraph, "_counts_future")
num_ntypes = len(seeds.keys())
seeds_received = {}
counts_sent = {}
counts_received = {}
for i, (ntype, typed_seeds) in enumerate(seeds.items()):
idx = torch.arange(i, world_size * num_ntypes, num_ntypes)
typed_counts_sent = subgraph._counts_sent[idx].tolist()
typed_counts_received = subgraph._counts_received[idx].tolist()
typed_seeds_received = typed_seeds.new_empty(
sum(typed_counts_received)
)
all_to_all(
typed_seeds_received.split(typed_counts_received),
typed_seeds.split(typed_counts_sent),
)
seeds_received[ntype] = typed_seeds_received
subgraph._seeds_received = seeds_received
subgraph._counts_sent = revert_to_homo(counts_sent)
subgraph._counts_received = revert_to_homo(counts_received)
return minibatch

@staticmethod
def _seeds_cooperative_exchange_3(minibatch):
subgraph = minibatch.sampled_subgraphs[0]
nodes = {
ntype: [typed_seeds]
for ntype, typed_seeds in subgraph._seeds_received.items()
}
minibatch._unique_future = unique_and_compact(
nodes, 0, 1, async_op=True
)
return minibatch

@staticmethod
def _seeds_cooperative_exchange_4(minibatch):
unique_seeds, inverse_seeds, _ = minibatch._unique_future.wait()
delattr(minibatch, "_unique_future")
inverse_seeds = {
ntype: typed_inv[0] for ntype, typed_inv in inverse_seeds.items()
}
minibatch._seed_nodes = revert_to_homo(unique_seeds)
subgraph = minibatch.sampled_subgraphs[0]
subgraph._seed_inverse_ids = revert_to_homo(inverse_seeds)
return minibatch


Expand All @@ -541,6 +648,7 @@ def __init__(
overlap_fetch,
num_gpu_cached_edges,
gpu_cache_threshold,
cooperative,
asynchronous,
layer_dependency=None,
batch_dependency=None,
Expand All @@ -561,6 +669,7 @@ def __init__(
deduplicate,
sampler,
overlap_fetch,
cooperative=cooperative,
asynchronous=asynchronous,
layer_dependency=layer_dependency,
)
Expand Down Expand Up @@ -637,6 +746,7 @@ def sampling_stages(
deduplicate,
sampler,
overlap_fetch,
cooperative,
asynchronous,
layer_dependency,
):
Expand All @@ -653,7 +763,9 @@ def sampling_stages(
datapipe = datapipe.sample_per_layer(
sampler, fanout, replace, prob_name, overlap_fetch, asynchronous
)
datapipe = datapipe.compact_per_layer(deduplicate, asynchronous)
datapipe = datapipe.compact_per_layer(
deduplicate, cooperative, asynchronous
)
if is_labor and not layer_dependency:
datapipe = datapipe.transform(self._increment_seed)
if is_labor:
Expand Down Expand Up @@ -775,6 +887,7 @@ def __init__(
overlap_fetch=False,
num_gpu_cached_edges=0,
gpu_cache_threshold=1,
cooperative=False,
asynchronous=False,
):
super().__init__(
Expand All @@ -788,6 +901,7 @@ def __init__(
overlap_fetch,
num_gpu_cached_edges,
gpu_cache_threshold,
cooperative,
asynchronous,
)

Expand Down Expand Up @@ -937,6 +1051,7 @@ def __init__(
overlap_fetch=False,
num_gpu_cached_edges=0,
gpu_cache_threshold=1,
cooperative=False,
asynchronous=False,
):
super().__init__(
Expand All @@ -950,6 +1065,7 @@ def __init__(
overlap_fetch,
num_gpu_cached_edges,
gpu_cache_threshold,
cooperative,
asynchronous,
layer_dependency,
batch_dependency,
Expand Down
94 changes: 63 additions & 31 deletions python/dgl/graphbolt/subgraph_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

__all__ = [
"SubgraphSampler",
"all_to_all",
"revert_to_homo",
]


Expand All @@ -41,10 +43,48 @@ def all_to_all(outputs, inputs, group=None, async_op=False):
`rank, ..., world_size - 1, 0, ..., rank - 1` and we make it
`0, world_size - 1` before calling `thd.all_to_all`."""
shift_fn = partial(_shift, group=group)
return thd.all_to_all(shift_fn(outputs), shift_fn(inputs), group, async_op)


def _revert_to_homo(d: dict):
outputs = shift_fn(list(outputs))
inputs = shift_fn(list(inputs))
if outputs[0].is_cuda:
return thd.all_to_all(outputs, inputs, group, async_op)
# gloo backend will be used.
outputs_single = torch.cat(outputs)
output_split_sizes = [o.size(0) for o in outputs]
handle = thd.all_to_all_single(
outputs_single,
torch.cat(inputs),
output_split_sizes,
[i.size(0) for i in inputs],
group,
async_op,
)
temp_outputs = outputs_single.split(output_split_sizes)

class _Waiter:
def __init__(self, handle, outputs, temp_outputs):
self.handle = handle
self.outputs = outputs
self.temp_outputs = temp_outputs

def wait(self):
"""Returns the stored value when invoked."""
handle = self.handle
outputs = self.outputs
temp_outputs = self.temp_outputs
# Ensure that there is no leak
self.handle = self.outputs = self.temp_outputs = None

if handle is not None:
handle.wait()
for output, temp_output in zip(outputs, temp_outputs):
output.copy_(temp_output)

post_processor = _Waiter(handle, outputs, temp_outputs)
return post_processor if async_op else post_processor.wait()


def revert_to_homo(d: dict):
"""Utility function to convert a dictionary that stores homogenous data."""
is_homogenous = len(d) == 1 and "_N" in d
return list(d.values())[0] if is_homogenous else d

Expand Down Expand Up @@ -148,45 +188,31 @@ def _wait_preprocess_future(minibatch, cooperative: bool):
def _seeds_cooperative_exchange_1(minibatch, group=None):
rank = thd.get_rank(group)
world_size = thd.get_world_size(group)
assert world_size > 1
seeds = minibatch._seed_nodes
is_homogeneous = not isinstance(seeds, dict)
if is_homogeneous:
seeds = {"_N": seeds}
if minibatch._seeds_offsets is None:
seeds_list = list(seeds.values())
(
sorted_seeds_list,
index_list,
offsets_list,
) = torch.ops.graphbolt.rank_sort(seeds_list, rank, world_size)
result = torch.ops.graphbolt.rank_sort(seeds_list, rank, world_size)
assert minibatch.compacted_seeds is None
sorted_seeds, sorted_compacted, sorted_offsets = {}, {}, {}
num_ntypes = len(seeds.keys())
for i, (
seed_type,
typed_sorted_seeds,
typed_index,
typed_offsets,
) in enumerate(
zip(
seeds.keys(),
sorted_seeds_list,
index_list,
offsets_list,
)
):
(typed_sorted_seeds, typed_index, typed_offsets),
) in enumerate(zip(seeds.keys(), result)):
sorted_seeds[seed_type] = typed_sorted_seeds
sorted_compacted[seed_type] = typed_index
sorted_offsets[seed_type] = typed_offsets.tolist()
sorted_offsets[seed_type] = typed_offsets

minibatch._seed_nodes = sorted_seeds
minibatch.compacted_seeds = sorted_compacted
minibatch.compacted_seeds = revert_to_homo(sorted_compacted)
minibatch._seeds_offsets = sorted_offsets
else:
minibatch._seeds_offsets = {"_N": minibatch._seeds_offsets}
counts_sent = torch.empty(world_size * num_ntypes, dtype=torch.int64)
for i, offsets in enumerate(minibatch._seeds_offsets[0].values()):
for i, offsets in enumerate(minibatch._seeds_offsets.values()):
counts_sent[
torch.arange(i, world_size * num_ntypes, num_ntypes)
] = offsets.diff()
Expand All @@ -208,7 +234,6 @@ def _seeds_cooperative_exchange_2(minibatch, group=None):
seeds = minibatch._seed_nodes
minibatch._counts_future.wait()
delattr(minibatch, "_counts_future")
counts_received = minibatch._counts_received
num_ntypes = len(seeds.keys())
seeds_received = {}
counts_sent = {}
Expand All @@ -226,24 +251,31 @@ def _seeds_cooperative_exchange_2(minibatch, group=None):
group,
)
seeds_received[ntype] = typed_seeds_received
minibatch._seed_nodes = _revert_to_homo(seeds_received)
minibatch._counts_sent = _revert_to_homo(counts_sent)
minibatch._counts_received = _revert_to_homo(counts_received)
minibatch._seed_nodes = seeds_received
minibatch._counts_sent = revert_to_homo(counts_sent)
minibatch._counts_received = revert_to_homo(counts_received)
return minibatch

@staticmethod
def _seeds_cooperative_exchange_3(minibatch):
nodes = {
ntype: [typed_seeds]
for ntype, typed_seeds in minibatch._seed_nodes.items()
}
minibatch._unique_future = unique_and_compact(
minibatch._seed_nodes, 0, 1, async_op=True
nodes, 0, 1, async_op=True
)
return minibatch

@staticmethod
def _seeds_cooperative_exchange_4(minibatch):
unique_seeds, inverse_seeds, _ = minibatch._unique_future.wait()
delattr(minibatch, "_unique_future")
minibatch._seed_nodes = _revert_to_homo(unique_seeds)
minibatch._seed_inverse_ids = _revert_to_homo(inverse_seeds)
inverse_seeds = {
ntype: typed_inv[0] for ntype, typed_inv in inverse_seeds.items()
}
minibatch._seed_nodes = revert_to_homo(unique_seeds)
minibatch._seed_inverse_ids = revert_to_homo(inverse_seeds)
return minibatch

def _sample(self, minibatch):
Expand Down
Loading

0 comments on commit 189b83c

Please sign in to comment.