Skip to content

Commit

Permalink
Multithreading function for merge_datasets (#413)
Browse files Browse the repository at this point in the history
* added multithreading function to merge_datasets

* Update functions.py

---------

Co-authored-by: thomas chaton <[email protected]>
  • Loading branch information
yhl48 and tchaton authored Nov 10, 2024
1 parent ac0c89b commit cedc6a6
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/litdata/processing/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,12 +521,13 @@ class CopyInfo:
new_filename: str


def merge_datasets(input_dirs: List[str], output_dir: str) -> None:
def merge_datasets(input_dirs: List[str], output_dir: str, max_workers: Optional[int] = os.cpu_count()) -> None:
"""Enables to merge multiple existing optimized datasets into a single optimized dataset.
Args:
input_dirs: A list of directories pointing to the existing optimized datasets.
output_dir: The directory where the merged dataset would be stored.
max_workers: Number of workers for multithreading
"""
if len(input_dirs) == 0:
Expand All @@ -537,6 +538,7 @@ def merge_datasets(input_dirs: List[str], output_dir: str) -> None:

resolved_input_dirs = [_resolve_dir(input_dir) for input_dir in input_dirs]
resolved_output_dir = _resolve_dir(output_dir)
max_workers = max_workers or 1

if any(input_dir == resolved_output_dir for input_dir in resolved_input_dirs):
raise ValueError("The provided output_dir was found within the input_dirs. This isn't supported.")
Expand Down Expand Up @@ -580,8 +582,11 @@ def merge_datasets(input_dirs: List[str], output_dir: str) -> None:

_tqdm = _get_tqdm_iterator_if_available()

for copy_info in _tqdm(copy_infos):
_apply_copy(copy_info, resolved_output_dir)
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures: List[concurrent.futures.Future] = []
for copy_info in _tqdm(copy_infos):
future = executor.submit(_apply_copy, copy_info, resolved_output_dir)
futures.append(future)

_save_index(index_json, resolved_output_dir)

Expand Down

0 comments on commit cedc6a6

Please sign in to comment.