diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index 56480a9f..ba100794 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -255,7 +255,7 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_ output_filepath = remove_uuid_from_filename(output_filepath) # remove unique id from checkpoints os.makedirs(os.path.dirname(output_filepath), exist_ok=True) - shutil.move(local_filepath, output_filepath) + shutil.copy(local_filepath, output_filepath) else: raise ValueError(f"The provided {output_dir.path} isn't supported.") diff --git a/src/litdata/streaming/combined.py b/src/litdata/streaming/combined.py index 4741b75d..8f789949 100644 --- a/src/litdata/streaming/combined.py +++ b/src/litdata/streaming/combined.py @@ -118,6 +118,17 @@ def set_shuffle(self, shuffle: bool) -> None: for dataset in self._datasets: dataset.set_shuffle(shuffle) + def set_batch_size(self, batch_size: int) -> None: + """Set the current batch size to the datasets.""" + self.batch_size = batch_size + for dataset in self._datasets: + dataset.set_batch_size(batch_size) + + def set_num_workers(self, num_workers: int) -> None: + """Set the current number of workers to the datasets.""" + for dataset in self._datasets: + dataset.set_num_workers(num_workers) + def set_drop_last(self, drop_last: bool) -> None: """Set the current drop_last to the datasets.""" for dataset in self._datasets: diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index 05286ff1..4ad656db 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -555,7 +555,7 @@ def __init__( profile_dir: Optional[str] = None, prefetch_factor: Optional[int] = None, shuffle: Optional[bool] = None, - drop_last: Optional[bool] = False, + drop_last: Optional[bool] = None, collate_fn: Optional[Callable] = None, **kwargs: Any, ) -> None: # pyright: ignore @@ -571,6 +571,9 @@ def __init__( if drop_last is not None: dataset.set_drop_last(drop_last) + dataset.set_batch_size(batch_size) + dataset.set_num_workers(num_workers) + shuffle = None if profile_batches and not _VIZ_TRACKER_AVAILABLE: diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 579ccf6b..69526a37 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -32,7 +32,10 @@ from litdata.utilities.dataset_utilities import _should_replace_path, _try_create_cache_dir, subsample_streaming_dataset from litdata.utilities.encryption import Encryption from litdata.utilities.env import _DistributedEnv, _is_in_dataloader_worker, _WorkerEnv -from litdata.utilities.shuffle import _find_chunks_per_ranks_on_which_to_skip_deletion +from litdata.utilities.shuffle import ( + _find_chunks_per_workers_on_which_to_skip_deletion, + _map_node_worker_rank_to_chunk_indexes_to_not_delete, +) logger = Logger(__name__) @@ -120,8 +123,10 @@ def __init__( self.shuffler: Optional[Shuffle] = None self.serializers = serializers self._state_dict: Optional[Dict[str, Any]] = None - self.num_workers: Optional[int] = None - self.batch_size: Optional[int] = None + # Has slightly different meaning in the context of the dataset + # We consider `num_workers = 0` from `torch.utils.DataLoader` still as 1 worker (the main process) + self.num_workers: int = 1 + self.batch_size: int = 1 self._encryption = encryption def set_shuffle(self, shuffle: bool) -> None: @@ -179,7 +184,13 @@ def _create_shuffler(self, cache: Cache) -> Shuffle: return FullShuffle(cache, seed, drop_last) if self.shuffle else NoShuffle(cache, seed, drop_last) def __len__(self) -> int: - return self.get_len(1, 1) + return self.get_len(self.num_workers, self.batch_size if self.batch_size else 1) + + def set_batch_size(self, batch_size: int) -> None: + self.batch_size = batch_size + + def set_num_workers(self, num_workers: int) -> None: + self.num_workers = num_workers or 1 def get_len(self, num_workers: int, batch_size: int) -> int: self.num_workers = num_workers @@ -205,35 +216,46 @@ def __iter__(self) -> "StreamingDataset": state: Dict[str, Any] = self._state_dict self.current_epoch = state["current_epoch"] - chunks_per_replica, intervals_per_replica = self.shuffler.get_chunks_and_intervals_per_ranks( - self.distributed_env, self.worker_env.world_size, self.batch_size or 1, self.current_epoch + workers_chunks, workers_intervals = self.shuffler.get_chunks_and_intervals_per_workers( + self.distributed_env, self.worker_env.world_size, self.batch_size, self.current_epoch ) - chunks_replica = chunks_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size] - intervals_replica = intervals_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size] + + worker_rank = self.distributed_env.global_rank * self.worker_env.world_size + self.worker_env.rank + self.worker_chunks = workers_chunks[worker_rank] + self.worker_intervals = workers_intervals[worker_rank] + + # The max number of samples to return from `__next__` (in worker) + self.stop_length = sum(interval[2] - interval[1] for interval in self.worker_intervals) # Handle restart if self._state_dict: - self._resume(chunks_replica, intervals_replica) + self._resume(workers_chunks, workers_intervals) else: - # Find the chunks shared across multiple ranks. - # For each shared chunk, find the rank to use the chunk last and prevent deletion - # for the other ranks. - chunks_indexes_skip_deletion = _find_chunks_per_ranks_on_which_to_skip_deletion( - self.worker_env.world_size, chunks_per_replica, intervals_per_replica + # Find the chunks shared across all workers of the current node. + # For each shared chunk, find the rank and worker to use the chunk last and prevent + # premature deletion for the other workers. + node_size = self.distributed_env.world_size // self.distributed_env.num_nodes + first_rank_this_node = (self.distributed_env.global_rank // node_size) * node_size + num_workers_per_node = node_size * self.num_workers + worker_start = first_rank_this_node * num_workers_per_node + worker_end = worker_start + num_workers_per_node + local_rank = self.distributed_env.global_rank % node_size + + chunks_indexes_skip_deletion = _find_chunks_per_workers_on_which_to_skip_deletion( + self.num_workers, + self.batch_size, + workers_chunks[worker_start:worker_end], + workers_intervals[worker_start:worker_end], ) - if self.distributed_env.global_rank in chunks_indexes_skip_deletion: - self.cache._reader.config.skip_chunk_indexes_deletion = chunks_indexes_skip_deletion[ - self.distributed_env.global_rank - ] - - workers_chunks, workers_intervals = _associate_chunks_to_workers( - self.worker_env, - chunks_per_replica[self.distributed_env.global_rank], - intervals_per_replica[self.distributed_env.global_rank], + worker_node_rank_to_chunk_indexes = _map_node_worker_rank_to_chunk_indexes_to_not_delete( + chunks_indexes_skip_deletion ) - self.worker_chunks = workers_chunks[self.worker_env.rank] - self.worker_intervals = workers_intervals[self.worker_env.rank] + worker_rank_local_node = local_rank * self.num_workers + self.worker_env.rank + if worker_rank_local_node in worker_node_rank_to_chunk_indexes: + self.cache._reader.config.skip_chunk_indexes_deletion = worker_node_rank_to_chunk_indexes[ + worker_rank_local_node + ] self.num_chunks = len(self.worker_chunks) self.current_indexes = [] @@ -246,7 +268,7 @@ def __iter__(self) -> "StreamingDataset": return self - def _resume(self, chunks_replica: List[int], intervals_replica: List[Any]) -> None: + def _resume(self, workers_chunks: List[List[int]], workers_intervals: List[Any]) -> None: assert self._state_dict assert self.worker_env assert self.shuffler @@ -259,17 +281,22 @@ def _resume(self, chunks_replica: List[int], intervals_replica: List[Any]) -> No # TODO: Implement elastic sampling where the number of workers, ranks can change. num_samples_yielded = self._state_dict["num_samples_yielded"] + worker_start = self.distributed_env.global_rank * num_workers + worker_end = worker_start + num_workers + # replay sampling from each worker / chunks using the batch size - workers_chunks, workers_intervals = _associate_chunks_to_workers( - self.worker_env, chunks_replica, intervals_replica - ) indexes = _replay_sampling(num_samples_yielded, batch_size, num_workers) - chunks_index, indexes = _replay_chunks_sampling(workers_intervals, indexes) + chunks_index, indexes = _replay_chunks_sampling( + workers_intervals={i: workers_intervals[i] for i in range(worker_start, worker_end)}, + indexes=indexes, + ) # select the chunks and intervals associated to this worker - worker_rank = self.worker_env.rank + worker_rank = self.distributed_env.global_rank * self.worker_env.world_size + self.worker_env.rank + worker_local_rank = self.worker_env.rank + self.num_chunks = len(workers_intervals[worker_rank]) - self.chunk_index = chunks_index[worker_rank] + self.chunk_index = chunks_index[worker_local_rank] self.worker_chunks = workers_chunks[worker_rank] self.worker_intervals = workers_intervals[worker_rank] @@ -281,10 +308,10 @@ def _resume(self, chunks_replica: List[int], intervals_replica: List[Any]) -> No current_indexes = self.shuffler(current_indexes, self.num_chunks, self.current_epoch, self.chunk_index) # skip any indexes already consumed - current_indexes = current_indexes[indexes[worker_rank] :] + current_indexes = current_indexes[indexes[worker_local_rank] :] self.current_indexes = current_indexes - self.global_index = num_samples_yielded + self.global_index = indexes[worker_local_rank] # bump the chunk_index self.chunk_index += 1 @@ -305,7 +332,7 @@ def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any: def __next__(self) -> Any: # Prevent to create more batch on a given process - if self.global_index >= len(self): + if self.global_index >= self.stop_length: self.current_epoch += 1 raise StopIteration @@ -454,8 +481,8 @@ def reset(self) -> None: "random_state": None, "shuffler": None, "_state_dict": None, - "num_workers": None, - "batch_size": None, + "num_workers": 1, + "batch_size": 1, } for prop, value in default_properties.items(): @@ -470,28 +497,6 @@ def is_integer(value: str) -> bool: return False -def _associate_chunks_to_workers( - worker_env: _WorkerEnv, chunks_replica: List[int], intervals_replica: List[Any] -) -> Any: - workers_chunks = {} - workers_intervals = {} - - for worker_idx in range(worker_env.world_size): - worker_chunks = [] - worker_intervals = [] - for i, (chunk_index, chunk_interval) in enumerate(zip(chunks_replica, intervals_replica)): - if i % worker_env.world_size != worker_idx: - continue - - worker_chunks.append(chunk_index) - worker_intervals.append(chunk_interval) - - workers_chunks[worker_idx] = worker_chunks - workers_intervals[worker_idx] = worker_intervals - - return workers_chunks, workers_intervals - - def _replay_sampling(num_samples_yielded: int, batch_size: int, num_workers: int) -> Dict[int, int]: """This function replays the sampling from the dataloader.""" divisible_num_batches_yielded = num_samples_yielded // (num_workers * batch_size) diff --git a/src/litdata/streaming/shuffle.py b/src/litdata/streaming/shuffle.py index 2f28ca07..a61e2f24 100644 --- a/src/litdata/streaming/shuffle.py +++ b/src/litdata/streaming/shuffle.py @@ -19,7 +19,10 @@ from litdata.streaming import Cache from litdata.utilities.env import _DistributedEnv -from litdata.utilities.shuffle import _associate_chunks_and_internals_to_ranks, _intra_node_chunk_shuffle +from litdata.utilities.shuffle import ( + _associate_chunks_and_intervals_to_workers, + _intra_node_chunk_shuffle, +) class Shuffle(ABC): @@ -32,23 +35,19 @@ def __init__(self, cache: Cache, seed: int, drop_last: bool): @lru_cache(maxsize=10) def get_len(self, distributed_env: _DistributedEnv, num_workers: int, batch_size: int, current_epoch: int) -> int: - _, intervals_per_ranks = self.get_chunks_and_intervals_per_ranks( + _, workers_intervals = self.get_chunks_and_intervals_per_workers( distributed_env, num_workers, batch_size, current_epoch ) - - if self.drop_last: - items_per_process = [ - sum((interval[2] - interval[1]) for interval in intervals) for intervals in intervals_per_ranks - ] - # Validate each processes gets the exact number of elements - if len(items_per_process) > 1: - assert all(items_per_process[0] == items_to_process for items_to_process in items_per_process[:1]) - return items_per_process[0] - - return sum((interval[2] - interval[1]) for interval in intervals_per_ranks[distributed_env.global_rank]) + worker_start = distributed_env.global_rank * num_workers + worker_end = worker_start + num_workers + return sum( + (interval[2] - interval[1]) + for intervals in workers_intervals[worker_start:worker_end] + for interval in intervals + ) @abstractmethod - def get_chunks_and_intervals_per_ranks( + def get_chunks_and_intervals_per_workers( self, distributed_env: _DistributedEnv, num_workers: int, batch_size: int, current_epoch: int ) -> Any: pass @@ -63,7 +62,7 @@ class NoShuffle(Shuffle): is True.""" @lru_cache(maxsize=10) - def get_chunks_and_intervals_per_ranks( + def get_chunks_and_intervals_per_workers( self, distributed_env: _DistributedEnv, num_workers: int, batch_size: int, current_epoch: int ) -> Any: # 1. Get the intervals @@ -71,11 +70,10 @@ def get_chunks_and_intervals_per_ranks( indexes = range(len(chunk_intervals)) # 2. Compute the items budget of each rank - chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_ranks( + workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( distributed_env, indexes, chunk_intervals, self.drop_last, num_workers, batch_size ) - - return chunks_per_ranks, intervals_per_ranks + return workers_chunks, workers_intervals def __call__(self, array: np.ndarray, num_chunks: int, current_epoch: int, chunk_index: int) -> List[int]: return array.tolist() @@ -100,7 +98,7 @@ class FullShuffle(Shuffle): """ @lru_cache(maxsize=10) - def get_chunks_and_intervals_per_ranks( + def get_chunks_and_intervals_per_workers( self, distributed_env: _DistributedEnv, num_workers: int, batch_size: int, current_epoch: int ) -> Any: # 1. Get the intervals @@ -120,24 +118,24 @@ def get_chunks_and_intervals_per_ranks( shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes].tolist() # 3. Compute the items budget of each rank - chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_ranks( + workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( distributed_env, shuffled_indexes, shuffled_chunk_intervals, self.drop_last, num_workers, batch_size ) # For the first epoch, no need of further shuffling if current_epoch == 1 or distributed_env.num_nodes == 1: - return chunks_per_ranks, intervals_per_ranks + return workers_chunks, workers_intervals # Perform shuffle within the nodes to avoid cache miss. # Note: It is possible for the overlapping chunks to change due to the changing order. - shuffled_indexes = _intra_node_chunk_shuffle(distributed_env, chunks_per_ranks, self.seed, current_epoch) + shuffled_indexes = _intra_node_chunk_shuffle(distributed_env, workers_chunks, self.seed, current_epoch) shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes].tolist() - chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_ranks( + workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( distributed_env, shuffled_indexes, shuffled_chunk_intervals, self.drop_last, num_workers, batch_size ) - return chunks_per_ranks, intervals_per_ranks + return workers_chunks, workers_intervals def __call__(self, array: np.ndarray, num_chunks: int, current_epoch: int, chunk_index: int) -> List[int]: return np.random.RandomState([self.seed, num_chunks * current_epoch, chunk_index]).permutation(array).tolist() diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index aa8b519b..a8ab3930 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy from typing import Any, Dict, List, Tuple import numpy as np @@ -42,37 +43,37 @@ def _intra_node_chunk_shuffle( return [index for chunks in chunk_indexes_per_nodes for index in chunks] -def _associate_chunks_and_internals_to_ranks( +def _associate_chunks_and_intervals_to_workers( distributed_env: _DistributedEnv, indexes: Any, chunk_intervals: List[Interval], - drop_last: bool, + drop_last: bool = False, num_workers: int = 1, batch_size: int = 1, ) -> Tuple[List[List[int]], List[Any]]: num_items = sum([(interval[2] - interval[1]) for interval in chunk_intervals]) - num_items_per_ranks: List[int] = [ - num_items // distributed_env.world_size + num_items % distributed_env.world_size - if rank == distributed_env.world_size - 1 and not drop_last - else num_items // distributed_env.world_size - for rank in range(distributed_env.world_size) + world_size = distributed_env.world_size * num_workers + num_items_per_workers: List[int] = [ + num_items // world_size + num_items % world_size + if rank == world_size - 1 and not drop_last + else num_items // world_size + for rank in range(world_size) ] if drop_last: - ratio = num_workers * batch_size - num_items_per_ranks = [ratio * int(item // ratio) for item in num_items_per_ranks] + num_items_per_workers = [batch_size * int(item // batch_size) for item in num_items_per_workers] - chunks_per_ranks: List[List[int]] = [[] for _ in range(distributed_env.world_size)] - intervals_per_ranks: List[List[List[int]]] = [[] for _ in range(distributed_env.world_size)] + chunks_per_workers: List[List[int]] = [[] for _ in range(world_size)] + intervals_per_workers: List[List[List[int]]] = [[] for _ in range(world_size)] # 4. Assign the chunk & intervals to each rank for chunk_index, chunk_interval in zip(indexes, chunk_intervals): rank = 0 while True: - if rank == len(num_items_per_ranks): + if rank == len(num_items_per_workers): break - items_left_to_assign = num_items_per_ranks[rank] + items_left_to_assign = num_items_per_workers[rank] if items_left_to_assign == 0: rank += 1 @@ -84,83 +85,172 @@ def _associate_chunks_and_internals_to_ranks( break if items_in_chunk > items_left_to_assign: - chunks_per_ranks[rank].append(chunk_index) + chunks_per_workers[rank].append(chunk_index) chunk_start, chunk_roi_start, chunk_roi_end, chunk_end = chunk_interval - intervals_per_ranks[rank].append( + intervals_per_workers[rank].append( [chunk_start, chunk_roi_start, chunk_roi_start + items_left_to_assign, chunk_end] ) chunk_interval = Interval(chunk_start, chunk_roi_start + items_left_to_assign, chunk_roi_end, chunk_end) - num_items_per_ranks[rank] = 0 + num_items_per_workers[rank] = 0 rank += 1 else: - chunks_per_ranks[rank].append(chunk_index) - intervals_per_ranks[rank].append(list(chunk_interval)) - num_items_per_ranks[rank] -= items_in_chunk + chunks_per_workers[rank].append(chunk_index) + intervals_per_workers[rank].append(list(chunk_interval)) + num_items_per_workers[rank] -= items_in_chunk break - return chunks_per_ranks, intervals_per_ranks + return chunks_per_workers, intervals_per_workers -def _find_chunks_per_ranks_on_which_to_skip_deletion( - num_workers: int, chunks_per_ranks: List[List[int]], intervals_per_ranks: List[Any] +def _find_chunks_per_workers_on_which_to_skip_deletion( + num_workers: int, + batch_size: int, + workers_chunks: List[List[int]], + workers_intervals: List[List[Interval]], ) -> Dict[int, List[int]]: - # TODO: Add support for the real batch size - batch_size = 1 - shared_chunks = {} - for rank, chunks in enumerate(chunks_per_ranks): - for c in chunks: - if c not in shared_chunks: - shared_chunks[c] = [rank] - else: - shared_chunks[c].append(rank) - - shared_chunks = {c: ranks for c, ranks in shared_chunks.items() if len(ranks) > 1} - - disable_deletion_ranks = {} + """Returns a dictionary mapping a chunk index to a list of workers that should not delete that chunk. + + If a worker is included in this list, it should not delete the chunk after fully reading it, because another worker + will still have items left to read and therefore needs the chunk to be present. This mapping is used in the dataset + to only let the worker delete a chunk when that worker is the last to read from it. + + """ + + # Shared chunks across all workers and ranks + shared_chunks = _get_shared_chunks(workers_chunks) + + # Shared chunks grouped together by rank + shared_chunks_aggregated_by_rank = _aggregate_shared_chunks_per_rank(shared_chunks, num_workers) + + max_trackers = {} + for chunk_index, map_local_rank_to_worker_ids in shared_chunks_aggregated_by_rank.items(): + for local_rank, workers_index_sharing_chunks_for_this_rank in map_local_rank_to_worker_ids.items(): + # Get all the worker chunks and intervals for this distributed rank + workers_slice = slice(local_rank * num_workers, (local_rank + 1) * num_workers) + workers_chunks_for_this_rank = copy.deepcopy(workers_chunks[workers_slice]) + workers_interval_sizes_for_this_rank = copy.deepcopy( + [ + [interval[2] - interval[1] for interval in worker_intervals] + for worker_intervals in workers_intervals[workers_slice] + ] + ) + + num_shared_workers_for_this_rank = len(workers_index_sharing_chunks_for_this_rank) + worker_tracker_idx = 0 + num_of_samples_to_carry_to_next_chunk = None + counter = 0 - for shared_chunk, ranks in shared_chunks.items(): - counters = [] - for rank in ranks: - chunks = chunks_per_ranks[rank] - intervals = [interval[2] - interval[1] for interval in intervals_per_ranks[rank]] + while True: + interval_size_of_current_worker = workers_interval_sizes_for_this_rank[worker_tracker_idx % num_workers] + if len(interval_size_of_current_worker) == 0: + worker_tracker_idx += 1 + continue + + num_samples_left_for_this_worker_chunk = interval_size_of_current_worker[0] + + # To consume a batch, we want to subtract `batch_size` from the size we have left, + # unless we had a remainder (< batch size) from the previous iteration/chunk + remover = ( + batch_size + if num_of_samples_to_carry_to_next_chunk is None + else num_of_samples_to_carry_to_next_chunk + ) - workers_chunks: Any = [[] for _ in range(num_workers)] - workers_intervals: Any = [[] for _ in range(num_workers)] - for interval_idx, (c, i) in enumerate(zip(chunks, intervals)): - workers_chunks[interval_idx % num_workers].append(c) - workers_intervals[interval_idx % num_workers].append(i) + if num_samples_left_for_this_worker_chunk > remover: + # There are samples left to consume, so we subtract the batch size (or a remainder) + workers_interval_sizes_for_this_rank[worker_tracker_idx % num_workers][0] -= remover + counter += remover + num_of_samples_to_carry_to_next_chunk = None + else: + # There are fewer samples left in this chunk than we would like to consume for a full batch + # So we take what's left from the chunk and move to the next chunk to complete the batch + current_worker_chunk_index = workers_chunks_for_this_rank[worker_tracker_idx % num_workers].pop(0) + workers_interval_sizes_for_this_rank[worker_tracker_idx % num_workers].pop(0) + counter += remover + + if current_worker_chunk_index == chunk_index: + num_shared_workers_for_this_rank -= 1 + + # TODO: Maybe, we can prevent loading over and over for each worker + if num_shared_workers_for_this_rank == 0 and current_worker_chunk_index == chunk_index: + # We consumed entirely the chunk of the worker we were tracking + # Keep track of how many samples this worker consumed for this chunk and which worker + # has consumed the most samples for this chunk + if chunk_index not in max_trackers: + max_trackers[chunk_index] = ( + local_rank * num_workers + worker_tracker_idx % num_workers, + counter, + ) + else: + if max_trackers[chunk_index][1] < counter: + max_trackers[chunk_index] = ( + local_rank * num_workers + worker_tracker_idx % num_workers, + counter, + ) + break - counter = 0 - worker_idx = 0 # reset the worker_idx - while True: - current_chunks = workers_chunks[worker_idx] - current_intervals = workers_intervals[worker_idx] + if num_samples_left_for_this_worker_chunk != batch_size: + # If a batch was not assembled completely because we're at the end of a chunk, + # we need to complete the assembly from samples in the next chunk and carry + # over that remainder to the next loop iteration + num_of_samples_to_carry_to_next_chunk = batch_size - num_samples_left_for_this_worker_chunk - if len(current_intervals) == 0: - break + if remover != batch_size: + # We've handled the remainder, reset it. Next iteration will start a fresh batch. + num_of_samples_to_carry_to_next_chunk = None - if current_intervals[0] > batch_size: - current_intervals[0] -= batch_size - counter += batch_size - worker_idx = (worker_idx + 1) % num_workers - else: - counter += current_intervals[0] - current_intervals.pop(0) - current_chunk = current_chunks.pop(0) - worker_idx = (worker_idx + 1) % num_workers + if num_of_samples_to_carry_to_next_chunk is None: + # Only go to the next worker if we assembled a full batch. If we have a remainder, + # we need to go to the next chunk with the same worker and complete the batch. + worker_tracker_idx += 1 - if current_chunk == shared_chunk: - break + to_disable = {} + for chunk_index, worker_ids in shared_chunks.items(): + last_worker_idx = max_trackers[chunk_index][0] + to_disable[chunk_index] = [worker_idx for worker_idx in worker_ids if worker_idx != last_worker_idx] + return to_disable - counters.append(counter) - max_counter = np.argmax(counters) - disable_ranks = [rank for rank in ranks if rank != ranks[max_counter]] - for rank in disable_ranks: - if rank not in disable_deletion_ranks: - disable_deletion_ranks[rank] = [shared_chunk] +def _get_shared_chunks(workers_chunks: List[List[int]]) -> Dict[int, List[int]]: + """Returns a dictionary mapping a chunk index to a list of workers that share that same chunk.""" + shared_chunks = {} + for worker, chunks in enumerate(workers_chunks): + for chunk in chunks: + if chunk not in shared_chunks: + shared_chunks[chunk] = [worker] else: - disable_deletion_ranks[rank].append(shared_chunk) - return disable_deletion_ranks + shared_chunks[chunk].append(worker) + # Remove chunk indexes that are only read by a single worker (and thus not shared) + return {chunk: workers for chunk, workers in shared_chunks.items() if len(workers) > 1} + + +def _aggregate_shared_chunks_per_rank( + shared_chunks: Dict[int, List[int]], num_workers: int +) -> Dict[int, Dict[int, List[int]]]: + """Groups together shared chunks by rank. + + The output is a dictionary mapping a chunk index to a dictionary that maps a rank to a list of workers. + + """ + aggregated_shared_chunks_per_rank: Dict[int, Dict[int, List[int]]] = {} + for chunk_index, workers_ids in shared_chunks.items(): + aggregated_shared_chunks_per_rank[chunk_index] = {} + for worker_idx in workers_ids: + if (worker_idx // num_workers) not in aggregated_shared_chunks_per_rank[chunk_index]: + aggregated_shared_chunks_per_rank[chunk_index][worker_idx // num_workers] = [] + aggregated_shared_chunks_per_rank[chunk_index][worker_idx // num_workers].append(worker_idx) + return aggregated_shared_chunks_per_rank + + +def _map_node_worker_rank_to_chunk_indexes_to_not_delete(to_disable: Dict[int, List[int]]) -> Dict[int, List[int]]: + """Takes a dictionary mapping a chunk index to a list of workers and inverts the map such that it returns a + dictionary mapping a worker to a list of chunk indexes (that should not be deleted by that worker).""" + map_node_worker_rank_to_chunk_indexes: Dict[int, List[int]] = {} + for chunk_index, worker_ids in to_disable.items(): + for worker_idx in worker_ids: + if worker_idx not in map_node_worker_rank_to_chunk_indexes: + map_node_worker_rank_to_chunk_indexes[worker_idx] = [] + map_node_worker_rank_to_chunk_indexes[worker_idx].append(chunk_index) + return map_node_worker_rank_to_chunk_indexes diff --git a/tests/streaming/test_combined.py b/tests/streaming/test_combined.py index 5630d042..7079f564 100644 --- a/tests/streaming/test_combined.py +++ b/tests/streaming/test_combined.py @@ -93,9 +93,16 @@ def test_drop_last_and_shuffle(): dataset_mock_1.set_shuffle.assert_called() dataset_mock_2.set_shuffle.assert_called() + dataset_mock_1.set_drop_last.assert_called() dataset_mock_2.set_drop_last.assert_called() + dataset_mock_1.set_num_workers.assert_called() + dataset_mock_2.set_num_workers.assert_called() + + dataset_mock_1.set_batch_size.assert_called() + dataset_mock_2.set_batch_size.assert_called() + class TestStatefulDataset: def __init__(self, size, step): @@ -227,6 +234,12 @@ def set_shuffle(self, _): def set_drop_last(self, _): pass + def set_batch_size(self, _): + pass + + def set_num_workers(self, _): + pass + def test_combined_dataset(): dataset1 = SimpleDataset(0, 10) @@ -555,43 +568,6 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): "latest_worker_idx": 1, "num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [2, 0]}, }, - { - "dataset": { - "0": { - "num_samples_yielded": 8, - "num_workers": 3, - "batch_size": 2, - "current_epoch": 1, - "input_dir_path": ANY, - "input_dir_url": ANY, - "item_loader": None, - "drop_last": False, - "seed": 42, - "world_size": 1, - "shuffle": True, - "subsampled_files": ANY, - "region_of_interest": ANY, - }, - "1": { - "num_samples_yielded": 3, - "num_workers": 3, - "batch_size": 2, - "current_epoch": 1, - "input_dir_path": ANY, - "input_dir_url": ANY, - "item_loader": None, - "drop_last": False, - "seed": 42, - "world_size": 1, - "shuffle": True, - "subsampled_files": ANY, - "region_of_interest": ANY, - }, - }, - "current_epoch": 0, - "latest_worker_idx": 2, - "num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [2, 1]}, - }, { "dataset": { "0": { @@ -626,8 +602,8 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): }, }, "current_epoch": 0, - "latest_worker_idx": 0, - "num_samples_yielded": {0: [4, 1], 1: [3, 1], 2: [2, 1]}, + "latest_worker_idx": 2, + "num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [3, 1]}, }, { "dataset": { @@ -663,8 +639,8 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): }, }, "current_epoch": 0, - "latest_worker_idx": 1, - "num_samples_yielded": {0: [4, 1], 1: [4, 1], 2: [2, 1]}, + "latest_worker_idx": 0, + "num_samples_yielded": {0: [4, 1], 1: [3, 1], 2: [3, 1]}, }, ] @@ -854,43 +830,6 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): "latest_worker_idx": 1, "num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [2, 0]}, }, - { - "dataset": { - "0": { - "num_samples_yielded": 8, - "num_workers": 3, - "batch_size": 2, - "current_epoch": 2, - "input_dir_path": ANY, - "input_dir_url": ANY, - "item_loader": None, - "drop_last": False, - "seed": 42, - "world_size": 1, - "shuffle": True, - "subsampled_files": ANY, - "region_of_interest": ANY, - }, - "1": { - "num_samples_yielded": 3, - "num_workers": 3, - "batch_size": 2, - "current_epoch": 2, - "input_dir_path": ANY, - "input_dir_url": ANY, - "item_loader": None, - "drop_last": False, - "seed": 42, - "world_size": 1, - "shuffle": True, - "subsampled_files": ANY, - "region_of_interest": ANY, - }, - }, - "current_epoch": 1, - "latest_worker_idx": 2, - "num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [2, 1]}, - }, { "dataset": { "0": { @@ -925,8 +864,8 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): }, }, "current_epoch": 1, - "latest_worker_idx": 0, - "num_samples_yielded": {0: [4, 1], 1: [3, 1], 2: [2, 1]}, + "latest_worker_idx": 2, + "num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [3, 1]}, }, { "dataset": { @@ -962,8 +901,8 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): }, }, "current_epoch": 1, - "latest_worker_idx": 1, - "num_samples_yielded": {0: [4, 1], 1: [4, 1], 2: [2, 1]}, + "latest_worker_idx": 0, + "num_samples_yielded": {0: [4, 1], 1: [3, 1], 2: [3, 1]}, }, ] diff --git a/tests/streaming/test_dataloader.py b/tests/streaming/test_dataloader.py index 5690e145..f0a5e138 100644 --- a/tests/streaming/test_dataloader.py +++ b/tests/streaming/test_dataloader.py @@ -45,6 +45,12 @@ def set_epoch(self, current_epoch): def set_drop_last(self, drop_last): self.drop_last = drop_last + def set_batch_size(self, batch_size): + self.batch_size = batch_size + + def set_num_workers(self, num_workers): + self.num_workers = num_workers + class TestCombinedStreamingDataset(CombinedStreamingDataset): def _check_datasets(self, datasets) -> None: @@ -88,6 +94,7 @@ def test_streaming_dataloader(): } +@pytest.mark.skip(reason="Profiling patches torch which leads to undesired test interactions") @pytest.mark.skipif(not _VIZ_TRACKER_AVAILABLE, reason="viz tracker required") @pytest.mark.parametrize("profile", [2, True]) def test_dataloader_profiling(profile, tmpdir, monkeypatch): diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 92d1068b..7b64f863 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -32,7 +32,6 @@ _INDEX_FILENAME, Dir, StreamingDataset, - _associate_chunks_to_workers, _replay_chunks_sampling, _replay_sampling, ) @@ -40,6 +39,7 @@ from litdata.streaming.shuffle import FullShuffle, NoShuffle from litdata.utilities import dataset_utilities as dataset_utilities_module from litdata.utilities.env import _DistributedEnv, _WorkerEnv +from litdata.utilities.shuffle import _associate_chunks_and_intervals_to_workers from torch.utils.data import DataLoader @@ -158,7 +158,7 @@ def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir, compression assert len(process_2_2) == 50 + int(not drop_last) - _, intervals_per_ranks = dataset.shuffler.get_chunks_and_intervals_per_ranks( + _, workers_intervals = dataset.shuffler.get_chunks_and_intervals_per_workers( dataset.distributed_env, 1, 1, dataset.current_epoch ) @@ -167,7 +167,7 @@ def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir, compression found_list = [] for i in process_1_1: found = False - for interval in intervals_per_ranks[0]: + for interval in workers_intervals[0]: if interval[1] <= i <= interval[2]: found = True break @@ -178,7 +178,7 @@ def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir, compression found_list = [] for i in process_2_1: found = False - for interval in intervals_per_ranks[1]: + for interval in workers_intervals[1]: if interval[1] <= i <= interval[2]: found = True break @@ -506,19 +506,21 @@ def test_dataset_for_text_tokens_multiple_workers(tmpdir): expected = [ [0, 10], - [40, 50], + [100, 110], [20, 30], - [60, 70], - [80, 90], [120, 130], - [100, 110], + [40, 50], [140, 150], + [60, 70], [160, 170], + [80, 90], [180, 190], ] - for result, batch in zip(expected, dataloader): - assert [batch[0][0].item(), batch[1][0].item()] == result + result = [] + for batch in dataloader: + result.append(batch[:, 0].tolist()) + assert result == expected def test_dataset_for_text_tokens_distributed_num_workers(tmpdir): @@ -601,34 +603,46 @@ def test_dataset_for_text_tokens_distributed_num_workers_end_to_end(tmpdir, monk dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False) L = len(dataset) - assert len(dataset) == L + assert L == 20 for i in range(L): sequence = dataset[i] assert sequence[0].item() == i * block_size assert sequence[-1].item() == (i + 1) * block_size - 1 + monkeypatch.setenv("WORLD_SIZE", "2") + monkeypatch.setenv("GLOBAL_RANK", "0") + monkeypatch.setenv("NNODES", "1") dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False) + dataloader = StreamingDataLoader(dataset, batch_size=2, shuffle=False, num_workers=2) + assert dataset.drop_last # in distributed setting, this is forced automatically - dataset.distributed_env = _DistributedEnv(2, 0, 1) - dataloader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=2) - - assert len(dataloader) == 5 - - expected = [[0, 10], [20, 30], [40, 50], [60, 70], [80, 90]] + # L = 20, world size 2, num workers 2 + # L / (2 * 2) = 5 items per worker + # drop last -> 4 items per worker + # batch size = 2 -> 2 batches per worker -> len(dataloader) = 4 + assert len(dataloader) == 4 - for batch_idx, batch in enumerate(dataloader): - assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx] - - dataset.distributed_env = _DistributedEnv(2, 1, 1) - dataloader = DataLoader(dataset, batch_size=2, shuffle=False) + expected = [[0, 10], [40, 50], [20, 30], [60, 70]] + returned = [] + for batch in dataloader: + returned.append(batch[:, 0].tolist()) + assert returned == expected - assert len(dataloader) == 5 + monkeypatch.setenv("WORLD_SIZE", "2") + monkeypatch.setenv("GLOBAL_RANK", "1") + monkeypatch.setenv("NNODES", "1") + dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False) + dataloader = StreamingDataLoader(dataset, batch_size=2, shuffle=False, num_workers=2) + assert dataset.drop_last # in distributed setting, this is forced automatically - expected = [[100, 110], [120, 130], [140, 150], [160, 170], [180, 190]] + assert len(dataloader) == 4 - for batch_idx, batch in enumerate(dataloader): - assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx] + expected = [[80, 90], [120, 130], [100, 110], [140, 150]] + returned = [] + for batch in dataloader: + returned.append(batch[:, 0].tolist()) + assert returned == expected @pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs") @@ -673,7 +687,7 @@ def _create_cache(self, worker_env: _WorkerEnv) -> Cache: @pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs") -def test_resumable_dataset_two_workers(tmpdir): +def test_dataset_reshuffling_every_epoch(tmpdir): seed_everything(42) data_dir = os.path.join(tmpdir, "data") @@ -775,7 +789,6 @@ def test_resumable_dataset_two_workers_2_epochs(tmpdir): input_dir=Dir(cache_dir, data_dir), item_loader=TokensLoader(block_size), shuffle=True ) - dataset.current_epoch = 1 dataloader = StreamingDataLoader(dataset, num_workers=2, batch_size=2, prefetch_factor=1, persistent_workers=True) batches_epoch_1 = [] @@ -789,9 +802,7 @@ def test_resumable_dataset_two_workers_2_epochs(tmpdir): batches_epoch_2.append(batch) assert len(os.listdir(cache_dir)) == 51 - - for batch_1, batch_2 in zip(batches_epoch_1, batches_epoch_2): - assert not torch.equal(batch_1, batch_2) + assert not all(torch.equal(b1, b2) for b1, b2 in zip(batches_epoch_1, batches_epoch_2)) def _simple_preprocess(_): @@ -799,39 +810,56 @@ def _simple_preprocess(_): yield torch.randint(0, 100, size=(10,), dtype=torch.int64) -def _get_simulated_s3_dataloader(cache_dir, data_dir): +def _get_simulated_s3_dataloader(cache_dir, data_dir, shuffle=False): dataset = EmulateS3StreamingDataset( input_dir=Dir(cache_dir, data_dir), item_loader=TokensLoader(block_size=10), + shuffle=shuffle, ) - return StreamingDataLoader(dataset, batch_size=2, num_workers=1) + return StreamingDataLoader(dataset, batch_size=2, num_workers=2) -@pytest.mark.skipif(sys.platform == "win32" or sys.platform == "darwin", reason="Not tested on windows and MacOs") +@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs") @mock.patch.dict(os.environ, {}, clear=True) @pytest.mark.timeout(60) -def test_dataset_resume_on_future_chunks(tmpdir, monkeypatch): +@pytest.mark.parametrize("shuffle", [True, False]) +def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch): """This test is constructed to test resuming from a chunk past the first chunk, when subsequent chunks don't have the same size.""" s3_cache_dir = str(tmpdir / "s3cache") + optimize_data_cache_dir = str(tmpdir / "optimize_data_cache") optimize_cache_dir = str(tmpdir / "optimize_cache") data_dir = str(tmpdir / "optimized") + monkeypatch.setenv("DATA_OPTIMIZER_DATA_CACHE_FOLDER", optimize_data_cache_dir) monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", optimize_cache_dir) optimize( fn=_simple_preprocess, - inputs=list(range(5)), - output_dir=str(tmpdir / "optimized"), + inputs=list(range(8)), + output_dir=data_dir, chunk_size=190, - num_workers=4, + num_workers=1, # TODO: Want 4 here, but optimize() has deletion race condition + num_uploaders=1, ) - assert len(os.listdir(tmpdir / "optimized")) > 0 + sleep(5) # wait for copier/remover threads to complete + # assert set(os.listdir(data_dir)) == { + # "chunk-0-0.bin", + # "chunk-0-1.bin", + # "chunk-1-0.bin", + # "chunk-1-1.bin", + # "chunk-2-0.bin", + # "chunk-2-1.bin", + # "chunk-3-0.bin", + # "chunk-3-1.bin", + # "index.json", + # } os.mkdir(s3_cache_dir) - train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir) + train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir, shuffle=shuffle) batches_to_fetch = 16 batch_to_resume_from = None dataloader_state = None + for i, batch in enumerate(train_dataloader): if i == batches_to_fetch: dataloader_state = train_dataloader.state_dict() @@ -841,7 +869,7 @@ def test_dataset_resume_on_future_chunks(tmpdir, monkeypatch): shutil.rmtree(s3_cache_dir) os.mkdir(s3_cache_dir) - train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir) + train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir, shuffle=shuffle) assert dataloader_state is not None assert batch_to_resume_from is not None train_dataloader.load_state_dict(dataloader_state) @@ -972,14 +1000,15 @@ def test_replay_sampling(): def test_replay_chunks_sampling(): chunks_replica = range(10) intervals_replica = [(i, i, i + 5, i + 5) for i in range(0, 50, 5)] - workers_chunks, workers_intervals = _associate_chunks_to_workers( - _WorkerEnv(2, 0), chunks_replica, intervals_replica + workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( + _DistributedEnv(2, 0, 1), chunks_replica, intervals_replica ) - assert workers_chunks == {0: [0, 2, 4, 6, 8], 1: [1, 3, 5, 7, 9]} - assert workers_intervals == { - 0: [(0, 0, 5, 5), (10, 10, 15, 15), (20, 20, 25, 25), (30, 30, 35, 35), (40, 40, 45, 45)], - 1: [(5, 5, 10, 10), (15, 15, 20, 20), (25, 25, 30, 30), (35, 35, 40, 40), (45, 45, 50, 50)], - } + assert workers_chunks == [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] + assert workers_intervals == [ + [[0, 0, 5, 5], [5, 5, 10, 10], [10, 10, 15, 15], [15, 15, 20, 20], [20, 20, 25, 25]], + [[25, 25, 30, 30], [30, 30, 35, 35], [35, 35, 40, 40], [40, 40, 45, 45], [45, 45, 50, 50]], + ] + workers_intervals = {i: workers_intervals[i] for i in range(len(workers_intervals))} assert _replay_chunks_sampling(workers_intervals, {0: 16, 1: 11}) == ({0: 3, 1: 2}, {0: 1, 1: 1}) assert _replay_chunks_sampling(workers_intervals, {0: 14, 1: 13}) == ({0: 2, 1: 2}, {0: 4, 1: 3}) assert _replay_chunks_sampling(workers_intervals, {0: 15, 1: 12}) == ({0: 3, 1: 2}, {0: 0, 1: 2}) diff --git a/tests/utilities/test_shuffle.py b/tests/utilities/test_shuffle.py index 2645f487..15532a21 100644 --- a/tests/utilities/test_shuffle.py +++ b/tests/utilities/test_shuffle.py @@ -1,9 +1,12 @@ from litdata.streaming.item_loader import Interval from litdata.utilities.env import _DistributedEnv from litdata.utilities.shuffle import ( - _associate_chunks_and_internals_to_ranks, - _find_chunks_per_ranks_on_which_to_skip_deletion, + _aggregate_shared_chunks_per_rank, + _associate_chunks_and_intervals_to_workers, + _find_chunks_per_workers_on_which_to_skip_deletion, + _get_shared_chunks, _intra_node_chunk_shuffle, + _map_node_worker_rank_to_chunk_indexes_to_not_delete, ) @@ -21,7 +24,7 @@ def test_intra_node_chunk_shuffle(): assert shuffled_indexes == [5, 2, 0, 7, 6, 1, 3, 4, 13, 10, 8, 15, 14, 9, 11, 12] -def test_associate_chunks_and_internals_to_ranks(): +def test_associate_chunks_and_intervals_to_workers(): indexes = [0, 1, 2, 3, 4, 5, 6, 7] chunk_intervals = [ Interval(0, 0, 50, 50), @@ -34,15 +37,15 @@ def test_associate_chunks_and_internals_to_ranks(): Interval(0, 0, 50, 50), ] - chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_ranks( + workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( _DistributedEnv(4, 1, 2), indexes, chunk_intervals, drop_last=True, ) - assert chunks_per_ranks == [[0, 1], [2, 3], [4, 5], [6, 7]] - assert intervals_per_ranks == [ + assert workers_chunks == [[0, 1], [2, 3], [4, 5], [6, 7]] + assert workers_intervals == [ [[0, 0, 50, 50], [0, 0, 50, 50]], [[0, 0, 50, 50], [0, 0, 50, 50]], [[0, 0, 50, 50], [0, 0, 50, 50]], @@ -60,20 +63,20 @@ def test_associate_chunks_and_internals_to_ranks(): Interval(0, 0, 33, 33), ] - chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_ranks( + workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( _DistributedEnv(4, 1, 2), indexes, chunk_intervals, drop_last=True, ) - assert chunks_per_ranks == [[0, 1], [1, 2], [2, 3, 4, 5], [5, 6, 7]] - assert sum([interval[2] - interval[1] for interval in intervals_per_ranks[0]]) == 105 - assert sum([interval[2] - interval[1] for interval in intervals_per_ranks[1]]) == 105 - assert sum([interval[2] - interval[1] for interval in intervals_per_ranks[2]]) == 105 - assert sum([interval[2] - interval[1] for interval in intervals_per_ranks[3]]) == 105 + assert workers_chunks == [[0, 1], [1, 2], [2, 3, 4, 5], [5, 6, 7]] + assert sum([interval[2] - interval[1] for interval in workers_intervals[0]]) == 105 + assert sum([interval[2] - interval[1] for interval in workers_intervals[1]]) == 105 + assert sum([interval[2] - interval[1] for interval in workers_intervals[2]]) == 105 + assert sum([interval[2] - interval[1] for interval in workers_intervals[3]]) == 105 - assert intervals_per_ranks == [ + assert workers_intervals == [ [[0, 0, 50, 50], [0, 0, 55, 150]], [[0, 55, 150, 150], [0, 0, 10, 50]], [[0, 10, 50, 50], [0, 0, 12, 12], [0, 0, 50, 50], [0, 0, 3, 27]], @@ -91,24 +94,133 @@ def test_associate_chunks_and_internals_to_ranks(): Interval(0, 0, 1, 1), ] - chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_ranks( + workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers( _DistributedEnv(4, 1, 2), indexes, chunk_intervals, drop_last=True, ) - assert chunks_per_ranks == [[0, 1], [1], [1, 2, 3, 4, 5], [5, 6, 7]] - assert sum([interval[2] - interval[1] for interval in intervals_per_ranks[0]]) == 64 - assert sum([interval[2] - interval[1] for interval in intervals_per_ranks[1]]) == 64 - assert sum([interval[2] - interval[1] for interval in intervals_per_ranks[2]]) == 64 - assert sum([interval[2] - interval[1] for interval in intervals_per_ranks[3]]) == 64 - assert intervals_per_ranks == [ + assert workers_chunks == [[0, 1], [1], [1, 2, 3, 4, 5], [5, 6, 7]] + assert sum([interval[2] - interval[1] for interval in workers_intervals[0]]) == 64 + assert sum([interval[2] - interval[1] for interval in workers_intervals[1]]) == 64 + assert sum([interval[2] - interval[1] for interval in workers_intervals[2]]) == 64 + assert sum([interval[2] - interval[1] for interval in workers_intervals[3]]) == 64 + assert workers_intervals == [ [[0, 0, 5, 5], [0, 0, 59, 150]], [[0, 59, 123, 150]], [[0, 123, 150, 150], [0, 0, 7, 7], [0, 0, 12, 12], [0, 0, 4, 4], [0, 0, 14, 27]], [[0, 14, 27, 27], [0, 0, 50, 50], [0, 0, 1, 1]], ] - disable_deletion_ranks = _find_chunks_per_ranks_on_which_to_skip_deletion(1, chunks_per_ranks, intervals_per_ranks) - assert disable_deletion_ranks == {1: [1], 2: [1], 3: [5]} + +def test_get_shared_chunks(): + assert _get_shared_chunks([]) == {} + assert _get_shared_chunks([[0]]) == {} + assert _get_shared_chunks([[0], [1]]) == {} + assert _get_shared_chunks([[0], [0, 1]]) == {0: [0, 1]} # chunk 0 is shared by worker 0 and 1 + assert _get_shared_chunks([[0, 1], [1]]) == {1: [0, 1]} # chunk 1 is shared by worker 0 and 1 + assert _get_shared_chunks([[2], [0, 1], [2, 3]]) == {2: [0, 2]} + assert _get_shared_chunks([[2], [0, 1], [2, 3], [1, 4], [1]]) == {1: [1, 3, 4], 2: [0, 2]} + + +def test_find_chunks_per_workers_on_which_to_skip_deletion(): + # world size = 1, single worker + chunks_to_disable = _find_chunks_per_workers_on_which_to_skip_deletion( + num_workers=1, + batch_size=1, + workers_chunks=[[0]], + workers_intervals=[[(0, 0, 50, 50)]], + ) + assert chunks_to_disable == {} + + # world size = 1, multiple workers, no shared chunks + chunks_to_disable = _find_chunks_per_workers_on_which_to_skip_deletion( + num_workers=2, + batch_size=5, + workers_chunks=[[0], [1]], + workers_intervals=[[(0, 0, 50, 50)], [(0, 0, 50, 50)]], + ) + assert chunks_to_disable == {} + + # world size = 1, 2 workers sharing one chunk + chunks_to_disable = _find_chunks_per_workers_on_which_to_skip_deletion( + num_workers=2, + batch_size=5, + workers_chunks=[[0, 1], [1, 2]], + workers_intervals=[[(0, 0, 50, 50), (0, 0, 25, 50)], [(0, 25, 50, 50), (0, 0, 50, 50)]], + ) + assert chunks_to_disable == {1: [1]} + + # world size = 1, 4 workers sharing one chunk + chunks_to_disable = _find_chunks_per_workers_on_which_to_skip_deletion( + num_workers=4, + batch_size=5, + workers_chunks=[[0], [0], [0], [0]], + workers_intervals=[[(0, 0, 50, 50)], [(0, 50, 100, 50)], [(0, 100, 150, 50)], [(0, 150, 200, 50)]], + ) + assert chunks_to_disable == {0: [0, 1, 2]} + + # world size = 1, 4 workers sharing one chunk, different size + chunks_to_disable = _find_chunks_per_workers_on_which_to_skip_deletion( + num_workers=4, + batch_size=5, + workers_chunks=[[0], [0], [0], [0]], + workers_intervals=[[(0, 0, 50, 50)], [(0, 50, 95, 50)], [(0, 95, 150, 50)], [(0, 150, 200, 50)]], + ) + assert chunks_to_disable == {0: [0, 1, 3]} + + # world size 2, 2 workers per rank, varying batch size + for batch_size in range(1, 6): + chunks_to_disable = _find_chunks_per_workers_on_which_to_skip_deletion( + num_workers=2, + batch_size=batch_size, + workers_chunks=[[0], [0], [0], [0]], + workers_intervals=[ + [(0, 0, 50, 50)], + [(0, 50, 95, 50)], + [(0, 95, 145, 50)], + [(0, 145, 205, 50)], # last to access chunk 0 + ], + ) + assert chunks_to_disable == {0: [0, 1, 2]} + + # world size 2, 2 workers per rank, sharing multiple chunks + chunks_to_disable = _find_chunks_per_workers_on_which_to_skip_deletion( + num_workers=2, + batch_size=5, + workers_chunks=[[0, 1], [3, 4], [1, 2], [4, 5]], + workers_intervals=[ + [(0, 0, 50, 50), (0, 0, 50, 50)], + [(0, 0, 50, 50), (0, 0, 50, 50)], + [(0, 50, 100, 100), (0, 0, 50, 50)], + [(0, 50, 100, 100), (0, 0, 50, 50)], + ], + ) + assert chunks_to_disable == {1: [2], 4: [3]} + + +def test_aggregate_shared_chunks_per_rank(): + # world size = 1, num workers per rank = 1 + num_workers = 1 + shared_chunks = {0: [0], 1: [0], 2: [0]} # 3 chunks shared by 1 worker + expected = {0: {0: [0]}, 1: {0: [0]}, 2: {0: [0]}} + assert _aggregate_shared_chunks_per_rank(shared_chunks, num_workers) == expected + + # world size = 1, num workers per rank = 2 + num_workers = 2 + shared_chunks = {0: [0, 1], 1: [0, 1], 2: [0]} # 3 chunks shared by 2 workers + expected = {0: {0: [0, 1]}, 1: {0: [0, 1]}, 2: {0: [0]}} + assert _aggregate_shared_chunks_per_rank(shared_chunks, num_workers) == expected + + # world size = 4, num workers per rank = 2 + num_workers = 2 + shared_chunks = {0: [0, 2], 1: [1, 3], 2: [2, 3]} # 3 chunks distributed among 2 * 2 workers + expected = {0: {0: [0], 1: [2]}, 1: {0: [1], 1: [3]}, 2: {1: [2, 3]}} + assert _aggregate_shared_chunks_per_rank(shared_chunks, num_workers) == expected + + +def test_map_node_worker_rank_to_chunk_indexes_to_not_delete(): + chunks_to_workers = {10: [2, 3, 4], 20: [1, 2, 3], 30: [3, 4], 40: [4, 5, 6]} + workers_to_chunks = _map_node_worker_rank_to_chunk_indexes_to_not_delete(chunks_to_workers) + assert workers_to_chunks == {1: [20], 2: [10, 20], 3: [10, 20, 30], 4: [10, 30, 40], 5: [40], 6: [40]}