Skip to content

Commit

Permalink
[GraphBolt] Rewrite DistributeItemSampler logic (dmlc#6565)
Browse files Browse the repository at this point in the history
  • Loading branch information
RamonZhou authored Nov 19, 2023
1 parent 81c7781 commit 46d7b1d
Show file tree
Hide file tree
Showing 4 changed files with 260 additions and 126 deletions.
167 changes: 47 additions & 120 deletions python/dgl/graphbolt/item_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ..heterograph import DGLGraph
from .itemset import ItemSet, ItemSetDict
from .minibatch import MiniBatch
from .utils import calculate_range

__all__ = ["ItemSampler", "DistributedItemSampler", "minibatcher_default"]

Expand Down Expand Up @@ -125,9 +126,8 @@ def __init__(
)
self._distributed = distributed
self._drop_uneven_inputs = drop_uneven_inputs
if distributed:
self._num_replicas = world_size
self._rank = rank
self._num_replicas = world_size
self._rank = rank

def _collate_batch(self, buffer, indices, offsets=None):
"""Collate a batch from the buffer. For internal use only."""
Expand Down Expand Up @@ -184,101 +184,33 @@ def __iter__(self):
num_workers = 1
worker_id = 0
buffer = None
if not self._distributed:
num_items = len(self._item_set)
start_offset = 0
else:
total_count = len(self._item_set)
big_batch_size = self._num_replicas * self._batch_size
big_batch_count, big_batch_remain = divmod(
total_count, big_batch_size
)
last_batch_count, batch_remain = divmod(
big_batch_remain, self._batch_size
)
if self._rank < last_batch_count:
last_batch = self._batch_size
elif self._rank == last_batch_count:
last_batch = batch_remain
else:
last_batch = 0
num_items = big_batch_count * self._batch_size + last_batch
start_offset = (
big_batch_count * self._batch_size * self._rank
+ min(self._rank * self._batch_size, big_batch_remain)
)
if not self._drop_uneven_inputs or (
not self._drop_last and last_batch_count == self._num_replicas
):
# No need to drop uneven batches.
num_evened_items = num_items
if num_workers > 1:
total_batch_count = (
num_items + self._batch_size - 1
) // self._batch_size
split_batch_count = total_batch_count // num_workers + (
worker_id < total_batch_count % num_workers
)
split_num_items = split_batch_count * self._batch_size
num_items = (
min(num_items, split_num_items * (worker_id + 1))
- split_num_items * worker_id
)
num_evened_items = num_items
start_offset = (
big_batch_count * self._batch_size * self._rank
+ min(self._rank * self._batch_size, big_batch_remain)
+ self._batch_size
* (
total_batch_count // num_workers * worker_id
+ min(worker_id, total_batch_count % num_workers)
)
)
else:
# Needs to drop uneven batches. As many items as `last_batch`
# size will be dropped. It would be better not to let those
# dropped items come from the same worker.
num_evened_items = big_batch_count * self._batch_size
if num_workers > 1:
total_batch_count = big_batch_count
split_batch_count = total_batch_count // num_workers + (
worker_id < total_batch_count % num_workers
)
split_num_items = split_batch_count * self._batch_size
split_item_remain = last_batch // num_workers + (
worker_id < last_batch % num_workers
)
num_items = split_num_items + split_item_remain
num_evened_items = split_num_items
start_offset = (
big_batch_count * self._batch_size * self._rank
+ min(self._rank * self._batch_size, big_batch_remain)
+ self._batch_size
* (
total_batch_count // num_workers * worker_id
+ min(worker_id, total_batch_count % num_workers)
)
+ last_batch // num_workers * worker_id
+ min(worker_id, last_batch % num_workers)
)
total = len(self._item_set)
start_offset, assigned_count, output_count = calculate_range(
self._distributed,
total,
self._num_replicas,
self._rank,
num_workers,
worker_id,
self._batch_size,
self._drop_last,
self._drop_uneven_inputs,
)
start = 0
while start < num_items:
end = min(start + self._buffer_size, num_items)
while start < assigned_count:
end = min(start + self._buffer_size, assigned_count)
buffer = self._item_set[start_offset + start : start_offset + end]
indices = torch.arange(end - start)
if self._shuffle:
np.random.shuffle(indices.numpy())
offsets = self._calculate_offsets(buffer)
for i in range(0, len(indices), self._batch_size):
if self._drop_last and i + self._batch_size > len(indices):
if output_count <= 0:
break
if (
self._distributed
and self._drop_uneven_inputs
and i >= num_evened_items
):
break
batch_indices = indices[i : i + self._batch_size]
batch_indices = indices[
i : i + min(self._batch_size, output_count)
]
output_count -= self._batch_size
yield self._collate_batch(buffer, batch_indices, offsets)
buffer = None
start = end
Expand Down Expand Up @@ -592,10 +524,6 @@ class DistributedItemSampler(ItemSampler):
counterparts. The original item set is split such that each replica
(process) receives an exclusive subset.
Note: DistributedItemSampler may not work as expected when it is the last
datapipe before the data is fetched. Please wrap a SingleProcessDataLoader
or another datapipe on it.
Note: The items will be first split onto each replica, then get shuffled
(if needed) and batched. Therefore, each replica will always get a same set
of items.
Expand Down Expand Up @@ -638,15 +566,14 @@ class DistributedItemSampler(ItemSampler):
Examples
--------
TODO[Kaicheng]: Modify examples here.
0. Preparation: DistributedItemSampler needs multi-processing environment to
work. You need to spawn subprocesses and initialize processing group before
executing following examples. Due to randomness, the output is not always
the same as listed below.
>>> import torch
>>> from dgl import graphbolt as gb
>>> item_set = gb.ItemSet(torch.arange(0, 14))
>>> item_set = gb.ItemSet(torch.arange(15))
>>> num_replicas = 4
>>> batch_size = 2
>>> mp.spawn(...)
Expand All @@ -659,10 +586,10 @@ class DistributedItemSampler(ItemSampler):
>>> )
>>> data_loader = gb.SingleProcessDataLoader(item_sampler)
>>> print(f"Replica#{proc_id}: {list(data_loader)})
Replica#0: [tensor([0, 4]), tensor([ 8, 12])]
Replica#1: [tensor([1, 5]), tensor([ 9, 13])]
Replica#2: [tensor([2, 6]), tensor([10])]
Replica#3: [tensor([3, 7]), tensor([11])]
Replica#0: [tensor([0, 1]), tensor([2, 3])]
Replica#1: [tensor([4, 5]), tensor([6, 7])]
Replica#2: [tensor([8, 9]), tensor([10, 11])]
Replica#3: [tensor([12, 13]), tensor([14])]
2. shuffle = False, drop_last = True, drop_uneven_inputs = False.
Expand All @@ -672,10 +599,10 @@ class DistributedItemSampler(ItemSampler):
>>> )
>>> data_loader = gb.SingleProcessDataLoader(item_sampler)
>>> print(f"Replica#{proc_id}: {list(data_loader)})
Replica#0: [tensor([0, 4]), tensor([ 8, 12])]
Replica#1: [tensor([1, 5]), tensor([ 9, 13])]
Replica#2: [tensor([2, 6])]
Replica#3: [tensor([3, 7])]
Replica#0: [tensor([0, 1]), tensor([2, 3])]
Replica#1: [tensor([4, 5]), tensor([6, 7])]
Replica#2: [tensor([8, 9]), tensor([10, 11])]
Replica#3: [tensor([12, 13])]
3. shuffle = False, drop_last = False, drop_uneven_inputs = True.
Expand All @@ -685,10 +612,10 @@ class DistributedItemSampler(ItemSampler):
>>> )
>>> data_loader = gb.SingleProcessDataLoader(item_sampler)
>>> print(f"Replica#{proc_id}: {list(data_loader)})
Replica#0: [tensor([0, 4]), tensor([ 8, 12])]
Replica#1: [tensor([1, 5]), tensor([ 9, 13])]
Replica#2: [tensor([2, 6]), tensor([10])]
Replica#3: [tensor([3, 7]), tensor([11])]
Replica#0: [tensor([0, 1]), tensor([2, 3])]
Replica#1: [tensor([4, 5]), tensor([6, 7])]
Replica#2: [tensor([8, 9]), tensor([10, 11])]
Replica#3: [tensor([12, 13]), tensor([14])]
4. shuffle = False, drop_last = True, drop_uneven_inputs = True.
Expand All @@ -698,10 +625,10 @@ class DistributedItemSampler(ItemSampler):
>>> )
>>> data_loader = gb.SingleProcessDataLoader(item_sampler)
>>> print(f"Replica#{proc_id}: {list(data_loader)})
Replica#0: [tensor([0, 4])]
Replica#1: [tensor([1, 5])]
Replica#2: [tensor([2, 6])]
Replica#3: [tensor([3, 7])]
Replica#0: [tensor([0, 1])]
Replica#1: [tensor([4, 5])]
Replica#2: [tensor([8, 9])]
Replica#3: [tensor([12, 13])]
5. shuffle = True, drop_last = True, drop_uneven_inputs = False.
Expand All @@ -712,10 +639,10 @@ class DistributedItemSampler(ItemSampler):
>>> data_loader = gb.SingleProcessDataLoader(item_sampler)
>>> print(f"Replica#{proc_id}: {list(data_loader)})
(One possible output:)
Replica#0: [tensor([0, 8]), tensor([ 4, 12])]
Replica#1: [tensor([ 5, 13]), tensor([9, 1])]
Replica#2: [tensor([ 2, 10])]
Replica#3: [tensor([11, 7])]
Replica#0: [tensor([3, 2]), tensor([0, 1])]
Replica#1: [tensor([6, 5]), tensor([7, 4])]
Replica#2: [tensor([8, 10])]
Replica#3: [tensor([14, 12])]
6. shuffle = True, drop_last = True, drop_uneven_inputs = True.
Expand All @@ -726,10 +653,10 @@ class DistributedItemSampler(ItemSampler):
>>> data_loader = gb.SingleProcessDataLoader(item_sampler)
>>> print(f"Replica#{proc_id}: {list(data_loader)})
(One possible output:)
Replica#0: [tensor([8, 0])]
Replica#1: [tensor([ 1, 13])]
Replica#2: [tensor([10, 6])]
Replica#3: [tensor([ 3, 11])]
Replica#0: [tensor([1, 3])]
Replica#1: [tensor([7, 5])]
Replica#2: [tensor([11, 9])]
Replica#3: [tensor([13, 14])]
"""

def __init__(
Expand Down
1 change: 1 addition & 0 deletions python/dgl/graphbolt/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .internal import *
from .sample_utils import *
from .datapipe_utils import *
from .item_sampler_utils import *
112 changes: 112 additions & 0 deletions python/dgl/graphbolt/utils/item_sampler_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""Utility functions for DistributedItemSampler."""


def count_split(total, num_workers, worker_id, batch_size=1):
"""Calculate the number of assigned items after splitting them by batch
size evenly. It will return the number for this worker and also a sum of
previous workers.
"""
quotient, remainder = divmod(total, num_workers * batch_size)
if batch_size == 1:
assigned = quotient + (worker_id < remainder)
else:
batch_count, last_batch = divmod(remainder, batch_size)
assigned = quotient * batch_size + (
batch_size
if worker_id < batch_count
else (last_batch if worker_id == batch_count else 0)
)
prefix_sum = quotient * worker_id * batch_size + min(
worker_id * batch_size, remainder
)
return (assigned, prefix_sum)


def calculate_range(
distributed,
total,
num_replicas,
rank,
num_workers,
worker_id,
batch_size,
drop_last,
drop_uneven_inputs,
):
"""Calculates the range of items to be assigned to the current worker.
This function evenly distributes `total` items among multiple workers,
batching them using `batch_size`. Each replica has `num_workers` workers.
The batches generated by workers within the same replica are combined into
the replica`s output. The `drop_last` parameter determines whether
incomplete batches should be dropped. If `drop_last` is True, incomplete
batches are discarded. The `drop_uneven_inputs` parameter determines if the
number of batches assigned to each replica should be the same. If
`drop_uneven_inputs` is True, excessive batches for some replicas will be
dropped.
Args:
distributed (bool): Whether it's in distributed mode.
total (int): The total number of items.
num_replicas (int): The total number of replicas.
rank (int): The rank of the current replica.
num_workers (int): The number of workers per replica.
worker_id (int): The ID of the current worker.
batch_size (int): The desired batch size.
drop_last (bool): Whether to drop incomplete batches.
drop_uneven_inputs (bool): Whether to drop excessive batches for some
replicas.
Returns:
tuple: A tuple containing three numbers:
- start_offset (int): The starting offset of the range assigned to
the current worker.
- assigned_count (int): The length of the range assigned to the
current worker.
- output_count (int): The number of items that the current worker
will produce after dropping.
"""
# Check if it's distributed mode.
if not distributed:
if not drop_last:
return (0, total, total)
else:
return (0, total, total // batch_size * batch_size)
# First, equally distribute items into all replicas.
assigned_count, start_offset = count_split(
total, num_replicas, rank, batch_size
)
# Calculate the number of outputs when drop_uneven_inputs is True.
# `assigned_count` is the number of items distributed to the current
# process. `output_count` is the number of items should be output
# by this process after dropping.
if not drop_uneven_inputs:
if not drop_last:
output_count = assigned_count
else:
output_count = assigned_count // batch_size * batch_size
else:
if not drop_last:
min_item_count, _ = count_split(
total, num_replicas, num_replicas - 1, batch_size
)
min_batch_count = (min_item_count + batch_size - 1) // batch_size
output_count = min(min_batch_count * batch_size, assigned_count)
else:
output_count = total // (batch_size * num_replicas) * batch_size
# If there are multiple workers, equally distribute the batches to
# all workers.
if num_workers > 1:
# Equally distribute the dropped number too.
dropped_items, prev_dropped_items = count_split(
assigned_count - output_count, num_workers, worker_id
)
output_count, prev_output_count = count_split(
output_count,
num_workers,
worker_id,
batch_size,
)
assigned_count = output_count + dropped_items
start_offset += prev_output_count + prev_dropped_items
return (start_offset, assigned_count, output_count)
Loading

0 comments on commit 46d7b1d

Please sign in to comment.