Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Properly assign the chunks to the right worker #449

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 33 additions & 12 deletions src/litdata/utilities/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,39 @@ def _associate_chunks_and_intervals_to_workers(
batch_size: int = 1,
) -> Tuple[List[List[int]], List[Any]]:
num_items = sum([(interval[2] - interval[1]) for interval in chunk_intervals])
world_size = distributed_env.world_size * num_workers
num_items_per_workers: List[int] = [
num_items // world_size + num_items % world_size
if rank == world_size - 1 and not drop_last
else num_items // world_size
for rank in range(world_size)
]
if drop_last:
num_items_per_workers = [batch_size * int(item // batch_size) for item in num_items_per_workers]

chunks_per_workers: List[List[int]] = [[] for _ in range(world_size)]
intervals_per_workers: List[List[List[int]]] = [[] for _ in range(world_size)]
max_batches = num_items // batch_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is a remainder here that’s a “partial” batch.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is why we are casting things to int, so it is an exact number.

global_num_workers = distributed_env.world_size * num_workers

num_items_per_workers: Any = []

for rank in range(distributed_env.world_size):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why only dist worksize here and not global num workers?
same question for below.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because, we want to ensure we fill up all the workers for each process rank in the same way.

tmp_arr = [0 for _ in range(num_workers)]

num_batches_per_rank = int(max_batches // distributed_env.world_size)
base_batches = num_batches_per_rank // num_workers
rem_batches = num_batches_per_rank % num_workers
tmp_arr = [base_batches + (1 if i < rem_batches else 0) for i in range(num_workers)]

if rank == distributed_env.world_size - 1:
# Find how batches were associated
num_assigned_items = batch_size * (sum(num_items_per_workers) + sum(tmp_arr))

# Multiply with the batch_size to get the number of items
if batch_size > 1:
tmp_arr = [x * batch_size for x in tmp_arr]
num_items_per_workers = [x * batch_size for x in num_items_per_workers]

# If there are items left to assign, let's give it the last worker
left_items = num_items - num_assigned_items
if not drop_last and left_items > 0:
tmp_arr[rem_batches % num_workers] += left_items

num_items_per_workers.extend(tmp_arr)
else:
num_items_per_workers.extend(tmp_arr)

chunks_per_workers: List[List[int]] = [[] for _ in range(global_num_workers)]
intervals_per_workers: List[List[List[int]]] = [[] for _ in range(global_num_workers)]

# 4. Assign the chunk & intervals to each rank
for chunk_index, chunk_interval in zip(indexes, chunk_intervals):
Expand Down
101 changes: 101 additions & 0 deletions tests/utilities/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,107 @@ def test_associate_chunks_and_intervals_to_workers():
[[0, 14, 27, 27], [0, 0, 50, 50], [0, 0, 1, 1]],
]

chunk_intervals = [
Interval(0, 0, 6, 6),
Interval(0, 0, 6, 6),
Interval(0, 0, 6, 6),
Interval(0, 0, 6, 6),
]

workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers(
_DistributedEnv(1, 0, 1), range(0, 4), chunk_intervals, False, 8, 6
)

assert workers_intervals == [[[0, 0, 6, 6]], [[0, 0, 6, 6]], [[0, 0, 6, 6]], [[0, 0, 6, 6]], [], [], [], []]
assert workers_chunks == [[0], [1], [2], [3], [], [], [], []]

workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers(
_DistributedEnv(2, 0, 1), range(0, 4), chunk_intervals, False, 8, 6
)

assert workers_chunks == [[0], [1], [], [], [], [], [], [], [2], [3], [], [], [], [], [], []]
assert workers_intervals == [
[[0, 0, 6, 6]],
[[0, 0, 6, 6]],
[],
[],
[],
[],
[],
[],
[[0, 0, 6, 6]],
[[0, 0, 6, 6]],
[],
[],
[],
[],
[],
[],
]

workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers(
_DistributedEnv(1, 0, 1), range(0, 4), chunk_intervals, False, 2, 8
)
assert workers_chunks == [[0, 1, 2], [2, 3]]
assert workers_intervals == [[[0, 0, 6, 6], [0, 0, 6, 6], [0, 0, 4, 6]], [[0, 4, 6, 6], [0, 0, 6, 6]]]

chunk_intervals = [
Interval(0, 0, 6, 6),
Interval(0, 0, 7, 7),
Interval(0, 0, 6, 6),
Interval(0, 0, 7, 8),
]

workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers(
_DistributedEnv(2, 0, 1), range(0, 4), chunk_intervals, False, 8, 6
)

assert sum([y[2] - y[1] for x in workers_intervals for y in x]) == 26
assert workers_chunks == [[0], [1], [], [], [], [], [], [], [1, 2], [2, 3], [3], [], [], [], [], []]
assert workers_intervals == [
[[0, 0, 6, 6]],
[[0, 0, 6, 7]],
[],
[],
[],
[],
[],
[],
[[0, 6, 7, 7], [0, 0, 5, 6]],
[[0, 5, 6, 6], [0, 0, 5, 8]],
[[0, 5, 7, 8]],
[],
[],
[],
[],
[],
]

workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers(
_DistributedEnv(2, 0, 1), range(0, 4), chunk_intervals, True, 8, 6
)

assert sum([y[2] - y[1] for x in workers_intervals for y in x]) == 24
assert workers_chunks == [[0], [1], [], [], [], [], [], [], [1, 2], [2, 3], [], [], [], [], [], []]
assert workers_intervals == [
[[0, 0, 6, 6]],
[[0, 0, 6, 7]],
[],
[],
[],
[],
[],
[],
[[0, 6, 7, 7], [0, 0, 5, 6]],
[[0, 5, 6, 6], [0, 0, 5, 8]],
[],
[],
[],
[],
[],
[],
]


def test_get_shared_chunks():
assert _get_shared_chunks([]) == {}
Expand Down
Loading