Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix uneven batches in distributed dataloading #237

Merged
merged 68 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
f749192
update
tchaton Jul 16, 2024
ed18cfe
update
tchaton Jul 16, 2024
c77821b
update
tchaton Jul 16, 2024
2732202
update
tchaton Jul 16, 2024
bef1698
Merge branch 'main' into fix_uneven_number_of_batches
tchaton Jul 16, 2024
a8dd576
update
tchaton Jul 16, 2024
c4f3f4e
Merge branch 'fix_uneven_number_of_batches' of https://github.com/Lig…
tchaton Jul 16, 2024
9f69690
fix with thomas
awaelchli Jul 16, 2024
34a9d74
stop length
awaelchli Jul 16, 2024
a80e430
remove redundant drop_last code
awaelchli Jul 16, 2024
f38e8ff
debug resume
awaelchli Jul 16, 2024
6b7578a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 16, 2024
0623680
update resuming logic
awaelchli Jul 16, 2024
265c4e9
update
awaelchli Jul 16, 2024
c9ecec7
length and resume fixes
awaelchli Jul 16, 2024
b0096c5
assert length in test
awaelchli Jul 16, 2024
24653b2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 16, 2024
1734629
update
awaelchli Jul 16, 2024
c3edbb4
rename variables
awaelchli Jul 16, 2024
99ca280
clean up dataset.py
awaelchli Jul 16, 2024
4130f50
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 16, 2024
0053486
clean up shuffle
awaelchli Jul 16, 2024
3e94cd6
Fix set_drop_last and test
awaelchli Jul 17, 2024
a3b9457
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 17, 2024
f52a501
fix epoch reshuffling test
awaelchli Jul 17, 2024
e621185
update combined test
awaelchli Jul 17, 2024
0641666
update replay test
awaelchli Jul 17, 2024
9b4b807
set default
awaelchli Jul 17, 2024
4aa15da
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 17, 2024
f8fe74c
clean up
awaelchli Jul 17, 2024
0eec395
update test
awaelchli Jul 17, 2024
4d9befd
disable profiler test for now
awaelchli Jul 17, 2024
fa4237b
update
awaelchli Jul 17, 2024
92d52d3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 17, 2024
ec9db93
Merge branch 'main' into fix_uneven_number_of_batches2
awaelchli Jul 17, 2024
ace58fb
fix type
awaelchli Jul 17, 2024
53c2cf4
fix with thomas
awaelchli Jul 19, 2024
322a697
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 19, 2024
28117a7
num_workers_or_1
awaelchli Jul 19, 2024
641ecae
update
tchaton Jul 19, 2024
73f376f
update
tchaton Jul 19, 2024
d7d6dfa
extend test and delete duplicated test
awaelchli Jul 19, 2024
af74a3b
simplify `num_workers or 1` logic
awaelchli Jul 19, 2024
5446fd6
mypy
awaelchli Jul 19, 2024
44572cf
todo rename
awaelchli Jul 19, 2024
533ba42
mypy
awaelchli Jul 19, 2024
dba782e
Fix typeerror
awaelchli Jul 19, 2024
5812b39
Merge branch 'main' into fix_uneven_number_of_batches2
awaelchli Jul 19, 2024
1ea878d
debug
awaelchli Jul 19, 2024
59d870f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 19, 2024
6590311
mypy
awaelchli Jul 19, 2024
a6cf041
debug
awaelchli Jul 19, 2024
dbdeb8a
debug
awaelchli Jul 19, 2024
e0720d7
debug
awaelchli Jul 19, 2024
c20a0ec
debug
awaelchli Jul 19, 2024
5755998
debug
awaelchli Jul 19, 2024
a9c688f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 19, 2024
18afb5e
debug
awaelchli Jul 19, 2024
209e0ec
Merge branch 'main' into fix_uneven_number_of_batches2
awaelchli Jul 19, 2024
6b04c22
debug
awaelchli Jul 19, 2024
46daa79
debug
awaelchli Jul 19, 2024
aafab96
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 19, 2024
d27fb34
debug
awaelchli Jul 19, 2024
06bf414
debug
awaelchli Jul 19, 2024
bc64b77
debug
awaelchli Jul 19, 2024
ac17f3e
Update src/litdata/utilities/shuffle.py
awaelchli Jul 19, 2024
66017e8
comments and test
awaelchli Jul 19, 2024
e2e9ff8
internals -> intervals
awaelchli Jul 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/litdata/streaming/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,17 @@ def set_shuffle(self, shuffle: bool) -> None:
for dataset in self._datasets:
dataset.set_shuffle(shuffle)

def set_batch_size(self, batch_size: int) -> None:
"""Set the current batch size to the datasets."""
self.batch_size = batch_size
for dataset in self._datasets:
dataset.set_batch_size(batch_size)

def set_num_workers(self, num_workers: int) -> None:
"""Set the current number of workers to the datasets."""
for dataset in self._datasets:
dataset.set_num_workers(num_workers)

def set_drop_last(self, drop_last: bool) -> None:
"""Set the current drop_last to the datasets."""
for dataset in self._datasets:
Expand Down
3 changes: 3 additions & 0 deletions src/litdata/streaming/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,9 @@ def __init__(
if drop_last is not None:
dataset.set_drop_last(drop_last)

dataset.set_batch_size(batch_size)
dataset.set_num_workers(num_workers)

shuffle = None

if profile_batches and not _VIZ_TRACKER_AVAILABLE:
Expand Down
95 changes: 41 additions & 54 deletions src/litdata/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from litdata.streaming.shuffle import FullShuffle, NoShuffle, Shuffle
from litdata.utilities.dataset_utilities import _should_replace_path, _try_create_cache_dir, subsample_streaming_dataset
from litdata.utilities.env import _DistributedEnv, _is_in_dataloader_worker, _WorkerEnv
from litdata.utilities.shuffle import _find_chunks_per_ranks_on_which_to_skip_deletion

logger = Logger(__name__)

Expand Down Expand Up @@ -174,7 +173,13 @@ def _create_shuffler(self, cache: Cache) -> Shuffle:
return FullShuffle(cache, seed, drop_last) if self.shuffle else NoShuffle(cache, seed, drop_last)

def __len__(self) -> int:
return self.get_len(1, 1)
return self.get_len(self.num_workers if self.num_workers else 1, self.batch_size if self.batch_size else 1)

def set_batch_size(self, batch_size: int) -> None:
self.batch_size = batch_size

def set_num_workers(self, num_workers: int) -> None:
self.num_workers = num_workers

def get_len(self, num_workers: int, batch_size: int) -> int:
self.num_workers = num_workers
Expand All @@ -200,35 +205,34 @@ def __iter__(self) -> "StreamingDataset":
state: Dict[str, Any] = self._state_dict
self.current_epoch = state["current_epoch"]

chunks_per_replica, intervals_per_replica = self.shuffler.get_chunks_and_intervals_per_ranks(
workers_chunks, workers_intervals = self.shuffler.get_chunks_and_intervals_per_workers(
self.distributed_env, self.worker_env.world_size, self.batch_size or 1, self.current_epoch
)
chunks_replica = chunks_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size]
intervals_replica = intervals_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size]

worker_rank = self.distributed_env.global_rank * self.worker_env.world_size + self.worker_env.rank
self.worker_chunks = workers_chunks[worker_rank]
self.worker_intervals = workers_intervals[worker_rank]

# The max number of samples to return from `__next__` (in worker)
self.stop_length = sum(interval[2] - interval[1] for interval in self.worker_intervals)

# Handle restart
if self._state_dict:
self._resume(chunks_replica, intervals_replica)
self._resume(workers_chunks, workers_intervals)
else:
# TODO: Reimplement this logic
# Find the chunks shared across multiple ranks.
# For each shared chunk, find the rank to use the chunk last and prevent deletion
# for the other ranks.
chunks_indexes_skip_deletion = _find_chunks_per_ranks_on_which_to_skip_deletion(
self.worker_env.world_size, chunks_per_replica, intervals_per_replica
)
if self.distributed_env.global_rank in chunks_indexes_skip_deletion:
self.cache._reader.config.skip_chunk_indexes_deletion = chunks_indexes_skip_deletion[
self.distributed_env.global_rank
]

workers_chunks, workers_intervals = _associate_chunks_to_workers(
self.worker_env,
chunks_per_replica[self.distributed_env.global_rank],
intervals_per_replica[self.distributed_env.global_rank],
)

self.worker_chunks = workers_chunks[self.worker_env.rank]
self.worker_intervals = workers_intervals[self.worker_env.rank]
# worker_start = self.distributed_env.global_rank * self.num_workers
# worker_end = worker_start + self.num_workers
# chunks_indexes_skip_deletion = _find_chunks_per_ranks_on_which_to_skip_deletion(
# self.worker_env.world_size, workers_chunks[worker_start: worker_end], workers_intervals[worker_start: worker_end]
# )
# if self.distributed_env.global_rank in chunks_indexes_skip_deletion:
# self.cache._reader.config.skip_chunk_indexes_deletion = chunks_indexes_skip_deletion[
# self.distributed_env.global_rank
# ]

self.num_chunks = len(self.worker_chunks)
self.current_indexes = []
Expand All @@ -241,7 +245,7 @@ def __iter__(self) -> "StreamingDataset":

return self

def _resume(self, chunks_replica: List[int], intervals_replica: List[Any]) -> None:
def _resume(self, workers_chunks: List[int], workers_intervals: List[Any]) -> None:
assert self._state_dict
assert self.worker_env
assert self.shuffler
Expand All @@ -254,17 +258,22 @@ def _resume(self, chunks_replica: List[int], intervals_replica: List[Any]) -> No
# TODO: Implement elastic sampling where the number of workers, ranks can change.
num_samples_yielded = self._state_dict["num_samples_yielded"]

worker_start = self.distributed_env.global_rank * num_workers
worker_end = worker_start + num_workers

# replay sampling from each worker / chunks using the batch size
workers_chunks, workers_intervals = _associate_chunks_to_workers(
self.worker_env, chunks_replica, intervals_replica
)
indexes = _replay_sampling(num_samples_yielded, batch_size, num_workers)
chunks_index, indexes = _replay_chunks_sampling(workers_intervals, indexes)
chunks_index, indexes = _replay_chunks_sampling(
workers_intervals={i: workers_intervals[i] for i in range(worker_start, worker_end)},
indexes=indexes,
)

# select the chunks and intervals associated to this worker
worker_rank = self.worker_env.rank
worker_rank = self.distributed_env.global_rank * self.worker_env.world_size + self.worker_env.rank
worker_local_rank = self.worker_env.rank

self.num_chunks = len(workers_intervals[worker_rank])
self.chunk_index = chunks_index[worker_rank]
self.chunk_index = chunks_index[worker_local_rank]
self.worker_chunks = workers_chunks[worker_rank]
self.worker_intervals = workers_intervals[worker_rank]

Expand All @@ -276,10 +285,10 @@ def _resume(self, chunks_replica: List[int], intervals_replica: List[Any]) -> No
current_indexes = self.shuffler(current_indexes, self.num_chunks, self.current_epoch, self.chunk_index)

# skip any indexes already consumed
current_indexes = current_indexes[indexes[worker_rank] :]
current_indexes = current_indexes[indexes[worker_local_rank] :]
self.current_indexes = current_indexes

self.global_index = num_samples_yielded
self.global_index = indexes[worker_local_rank]

# bump the chunk_index
self.chunk_index += 1
Expand All @@ -300,7 +309,7 @@ def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any:

def __next__(self) -> Any:
# Prevent to create more batch on a given process
if self.global_index >= len(self):
if self.global_index >= self.stop_length:
self.current_epoch += 1
raise StopIteration

Expand Down Expand Up @@ -465,28 +474,6 @@ def is_integer(value: str) -> bool:
return False


def _associate_chunks_to_workers(
worker_env: _WorkerEnv, chunks_replica: List[int], intervals_replica: List[Any]
) -> Any:
workers_chunks = {}
workers_intervals = {}

for worker_idx in range(worker_env.world_size):
worker_chunks = []
worker_intervals = []
for i, (chunk_index, chunk_interval) in enumerate(zip(chunks_replica, intervals_replica)):
if i % worker_env.world_size != worker_idx:
continue

worker_chunks.append(chunk_index)
worker_intervals.append(chunk_interval)

workers_chunks[worker_idx] = worker_chunks
workers_intervals[worker_idx] = worker_intervals

return workers_chunks, workers_intervals


def _replay_sampling(num_samples_yielded: int, batch_size: int, num_workers: int) -> Dict[int, int]:
"""This function replays the sampling from the dataloader."""
divisible_num_batches_yielded = num_samples_yielded // (num_workers * batch_size)
Expand Down
46 changes: 22 additions & 24 deletions src/litdata/streaming/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@

from litdata.streaming import Cache
from litdata.utilities.env import _DistributedEnv
from litdata.utilities.shuffle import _associate_chunks_and_internals_to_ranks, _intra_node_chunk_shuffle
from litdata.utilities.shuffle import (
_associate_chunks_and_internals_to_workers,
_intra_node_chunk_shuffle,
)


class Shuffle(ABC):
Expand All @@ -32,23 +35,19 @@ def __init__(self, cache: Cache, seed: int, drop_last: bool):

@lru_cache(maxsize=10)
def get_len(self, distributed_env: _DistributedEnv, num_workers: int, batch_size: int, current_epoch: int) -> int:
_, intervals_per_ranks = self.get_chunks_and_intervals_per_ranks(
_, workers_intervals = self.get_chunks_and_intervals_per_workers(
distributed_env, num_workers, batch_size, current_epoch
)

if self.drop_last:
items_per_process = [
sum((interval[2] - interval[1]) for interval in intervals) for intervals in intervals_per_ranks
]
# Validate each processes gets the exact number of elements
if len(items_per_process) > 1:
assert all(items_per_process[0] == items_to_process for items_to_process in items_per_process[:1])
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
return items_per_process[0]

return sum((interval[2] - interval[1]) for interval in intervals_per_ranks[distributed_env.global_rank])
worker_start = distributed_env.global_rank * num_workers
worker_end = worker_start + num_workers
return sum(
(interval[2] - interval[1])
for intervals in workers_intervals[worker_start:worker_end]
for interval in intervals
)

@abstractmethod
def get_chunks_and_intervals_per_ranks(
def get_chunks_and_intervals_per_workers(
self, distributed_env: _DistributedEnv, num_workers: int, batch_size: int, current_epoch: int
) -> Any:
pass
Expand All @@ -63,19 +62,18 @@ class NoShuffle(Shuffle):
is True."""

@lru_cache(maxsize=10)
def get_chunks_and_intervals_per_ranks(
def get_chunks_and_intervals_per_workers(
self, distributed_env: _DistributedEnv, num_workers: int, batch_size: int, current_epoch: int
) -> Any:
# 1. Get the intervals
chunk_intervals = self.cache.get_chunk_intervals()
indexes = range(len(chunk_intervals))

# 2. Compute the items budget of each rank
chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_ranks(
workers_chunks, workers_intervals = _associate_chunks_and_internals_to_workers(
distributed_env, indexes, chunk_intervals, self.drop_last, num_workers, batch_size
)

return chunks_per_ranks, intervals_per_ranks
return workers_chunks, workers_intervals

def __call__(self, array: np.ndarray, num_chunks: int, current_epoch: int, chunk_index: int) -> List[int]:
return array.tolist()
Expand All @@ -100,7 +98,7 @@ class FullShuffle(Shuffle):
"""

@lru_cache(maxsize=10)
def get_chunks_and_intervals_per_ranks(
def get_chunks_and_intervals_per_workers(
self, distributed_env: _DistributedEnv, num_workers: int, batch_size: int, current_epoch: int
) -> Any:
# 1. Get the intervals
Expand All @@ -120,24 +118,24 @@ def get_chunks_and_intervals_per_ranks(
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes].tolist()

# 3. Compute the items budget of each rank
chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_ranks(
workers_chunks, workers_intervals = _associate_chunks_and_internals_to_workers(
distributed_env, shuffled_indexes, shuffled_chunk_intervals, self.drop_last, num_workers, batch_size
)

# For the first epoch, no need of further shuffling
if current_epoch == 1 or distributed_env.num_nodes == 1:
return chunks_per_ranks, intervals_per_ranks
return workers_chunks, workers_intervals

# Perform shuffle within the nodes to avoid cache miss.
# Note: It is possible for the overlapping chunks to change due to the changing order.
shuffled_indexes = _intra_node_chunk_shuffle(distributed_env, chunks_per_ranks, self.seed, current_epoch)
shuffled_indexes = _intra_node_chunk_shuffle(distributed_env, workers_chunks, self.seed, current_epoch)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes].tolist()

chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_ranks(
workers_chunks, workers_intervals = _associate_chunks_and_internals_to_workers(
distributed_env, shuffled_indexes, shuffled_chunk_intervals, self.drop_last, num_workers, batch_size
)

return chunks_per_ranks, intervals_per_ranks
return workers_chunks, workers_intervals

def __call__(self, array: np.ndarray, num_chunks: int, current_epoch: int, chunk_index: int) -> List[int]:
return np.random.RandomState([self.seed, num_chunks * current_epoch, chunk_index]).permutation(array).tolist()
40 changes: 21 additions & 19 deletions src/litdata/utilities/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from litdata.utilities.env import _DistributedEnv


# TODO: Logic needs to be updated? chunks_per_ranks -> workers_chunks
def _intra_node_chunk_shuffle(
distributed_env: _DistributedEnv,
chunks_per_ranks: List[List[int]],
Expand All @@ -42,7 +43,7 @@ def _intra_node_chunk_shuffle(
return [index for chunks in chunk_indexes_per_nodes for index in chunks]


def _associate_chunks_and_internals_to_ranks(
def _associate_chunks_and_internals_to_workers(
distributed_env: _DistributedEnv,
indexes: Any,
chunk_intervals: List[Interval],
Expand All @@ -51,28 +52,29 @@ def _associate_chunks_and_internals_to_ranks(
batch_size: int = 1,
) -> Tuple[List[List[int]], List[Any]]:
num_items = sum([(interval[2] - interval[1]) for interval in chunk_intervals])
num_items_per_ranks: List[int] = [
num_items // distributed_env.world_size + num_items % distributed_env.world_size
if rank == distributed_env.world_size - 1 and not drop_last
else num_items // distributed_env.world_size
for rank in range(distributed_env.world_size)
world_size = distributed_env.world_size * num_workers
num_items_per_workers: List[int] = [
num_items // world_size + num_items % world_size
if rank == world_size - 1 and not drop_last
else num_items // world_size
for rank in range(world_size)
]
if drop_last:
ratio = num_workers * batch_size
num_items_per_ranks = [ratio * int(item // ratio) for item in num_items_per_ranks]
ratio = batch_size
num_items_per_workers = [ratio * int(item // ratio) for item in num_items_per_workers]

chunks_per_ranks: List[List[int]] = [[] for _ in range(distributed_env.world_size)]
intervals_per_ranks: List[List[List[int]]] = [[] for _ in range(distributed_env.world_size)]
chunks_per_workers: List[List[int]] = [[] for _ in range(world_size)]
intervals_per_workers: List[List[List[int]]] = [[] for _ in range(world_size)]

# 4. Assign the chunk & intervals to each rank
for chunk_index, chunk_interval in zip(indexes, chunk_intervals):
rank = 0

while True:
if rank == len(num_items_per_ranks):
if rank == len(num_items_per_workers):
break

items_left_to_assign = num_items_per_ranks[rank]
items_left_to_assign = num_items_per_workers[rank]

if items_left_to_assign == 0:
rank += 1
Expand All @@ -84,23 +86,23 @@ def _associate_chunks_and_internals_to_ranks(
break

if items_in_chunk > items_left_to_assign:
chunks_per_ranks[rank].append(chunk_index)
chunks_per_workers[rank].append(chunk_index)

chunk_start, chunk_roi_start, chunk_roi_end, chunk_end = chunk_interval

intervals_per_ranks[rank].append(
intervals_per_workers[rank].append(
[chunk_start, chunk_roi_start, chunk_roi_start + items_left_to_assign, chunk_end]
)
chunk_interval = Interval(chunk_start, chunk_roi_start + items_left_to_assign, chunk_roi_end, chunk_end)
num_items_per_ranks[rank] = 0
num_items_per_workers[rank] = 0
rank += 1
else:
chunks_per_ranks[rank].append(chunk_index)
intervals_per_ranks[rank].append(list(chunk_interval))
num_items_per_ranks[rank] -= items_in_chunk
chunks_per_workers[rank].append(chunk_index)
intervals_per_workers[rank].append(list(chunk_interval))
num_items_per_workers[rank] -= items_in_chunk
break

return chunks_per_ranks, intervals_per_ranks
return chunks_per_workers, intervals_per_workers


def _find_chunks_per_ranks_on_which_to_skip_deletion(
Expand Down
Loading
Loading