From 70d8e8f5207458a1dd1c85da188e0a68ce24e97a Mon Sep 17 00:00:00 2001 From: xiaohan Date: Sun, 4 Feb 2024 21:12:01 -0800 Subject: [PATCH 01/14] parallel merge index --- streaming/base/util.py | 81 +++++++++++++++-------- tests/test_util.py | 142 +++++++++++++++++++++++++++++++++++++---- 2 files changed, 183 insertions(+), 40 deletions(-) diff --git a/streaming/base/util.py b/streaming/base/util.py index 3be5b729a..6c3437299 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -14,6 +14,7 @@ import urllib.parse from collections import OrderedDict from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory +from multiprocessing import Pool from pathlib import Path from time import sleep, time from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union, cast, overload @@ -253,6 +254,46 @@ def merge_index(*args: Any, **kwargs: Any): raise ValueError(f'Invalid arguments to merge_index: {args}, {kwargs}') +def _download_url(url_info): + """Download a file given URL information.""" + from streaming.base.storage.download import download_file + src, dest, download_timeout = url_info + try: + download_file(src, dest, download_timeout) + except Exception as ex: + return f'Failed to download index.json: {src} to {dest}: {str(ex)}', ex + return dest, None + +def _merge_partition_indices(partition_indices): + """Function to be executed by each process to merge a subset of partition indices.""" + shards = [] + for partition_index in partition_indices: + p = Path(partition_index) + with open(partition_index, 'r') as f: + obj = json.load(f) + for shard in obj['shards']: + for key in ('raw_data', 'zip_data', 'raw_meta', 'zip_meta'): + if shard.get(key): + basename = shard[key]['basename'] + shard[key]['basename'] = os.path.join(os.path.basename(p.parent), basename) + shards.extend(obj['shards']) + return shards + +def _parallel_merge_partitions(partitions, n_processes=4): + """Divide the list of partitions among multiple processes and merge them in parallel.""" + with Pool(processes=n_processes) as pool: + # Split the list of partitions into N chunks where N is the number of processes + chunk_size = len(partitions) // n_processes + (len(partitions) % n_processes > 0) + partition_chunks = [partitions[i:i + chunk_size] for i in range(0, len(partitions), chunk_size)] + + # Process each chunk in parallel + results = pool.map(_merge_partition_indices, partition_chunks) + + # Combine the results from all processes + final_shards = [shard for result in results for shard in result] + return final_shards + + def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]]], out: Union[str, Tuple[str, str]], keep_local: bool = True, @@ -273,7 +314,6 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]] keep_local (bool): Keep local copy of the merged index file. Defaults to ``True`` download_timeout (int): The allowed time for downloading each json file. Defaults to 60. """ - from streaming.base.storage.download import download_file from streaming.base.storage.upload import CloudUploader if not index_file_urls or not out: @@ -297,10 +337,10 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]] # Prepare a temp folder to download index.json from remote if necessary. Removed in the end. with tempfile.TemporaryDirectory() as temp_root: - logging.warning(f'A temporary folder {temp_root} is created to store index files') + logging.info(f'A temporary folder {temp_root} is created to store index files') # Copy files to a temporary directory. Download if necessary - partitions = [] + download_tasks = [] for url in urls: if isinstance(url, tuple): src = url[0] if os.path.exists(url[0]) else url[1] @@ -314,30 +354,18 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]] f'Check data availability! local index {url[0]} is not accessible.' + f'remote index {url[1]} does not have a valid url format') dest = os.path.join(temp_root, path.lstrip('/')) + download_tasks.append((src, dest, download_timeout)) - try: - download_file(src, dest, download_timeout) - except Exception as ex: - raise RuntimeError(f'Failed to download index.json: {src} to {dest}') from ex - - if not os.path.exists(dest): - raise FileNotFoundError(f'Index file {dest} does not exist or not accessible.') - - partitions.append(dest) - - # merge shards from all index files - shards = [] - for partition_index in partitions: - p = Path(partition_index) - obj = json.load(open(partition_index)) - for i in range(len(obj['shards'])): - shard = obj['shards'][i] - for key in ('raw_data', 'zip_data', 'raw_meta', 'zip_meta'): - if shard.get(key): - basename = shard[key]['basename'] - obj['shards'][i][key]['basename'] = os.path.join( - os.path.basename(p.parent), basename) - shards += obj['shards'] + with Pool(processes=os.cpu_count()) as pool: + results = pool.map(_download_url, download_tasks) + + partitions = [] + for partition_index, error in results: + if error: + raise RuntimeError(partition_index) + partitions.append(partition_index) + + shards = _parallel_merge_partitions(partitions) # Save merged index locally obj = { @@ -358,7 +386,6 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]] if not keep_local: shutil.rmtree(cu.local, ignore_errors=True) - def _not_merged_index(index_file_path: str, out: str): """Check if index_file_path is the merged index at folder out. diff --git a/tests/test_util.py b/tests/test_util.py index 5aa8cabd7..ab024fb05 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -7,7 +7,7 @@ import time import urllib.parse from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Sequence import pytest @@ -194,9 +194,9 @@ def test_format_remote_index_files(scheme: str): assert obj.scheme == scheme -@pytest.mark.parametrize('index_file_urls_pattern', [1, 2, 3]) -@pytest.mark.parametrize('keep_local', [True, False]) -@pytest.mark.parametrize('scheme', ['gs://', 's3://', 'oci://']) +@pytest.mark.parametrize('index_file_urls_pattern', [1]) # , 2, 3]) +@pytest.mark.parametrize('keep_local', [True]) # , False]) +@pytest.mark.parametrize('scheme', ['gs://']) # , 's3://', 'oci://']) def test_merge_index_from_list_local(local_remote_dir: Tuple[str, str], keep_local: bool, index_file_urls_pattern: int, scheme: str): """Validate the final merge index json for following patterns of index_file_urls: @@ -212,6 +212,8 @@ def test_merge_index_from_list_local(local_remote_dir: Tuple[str, str], keep_loc from pyspark.sql.types import DecimalType, IntegerType, StringType, StructField, StructType from streaming.base.converters import dataframeToMDS + import random + import string def not_merged_index(index_file_path: str, out: str): """Check if index_file_path is the merged index at folder out.""" @@ -223,15 +225,18 @@ def not_merged_index(index_file_path: str, out: str): mds_out = out = local spark = SparkSession.builder.getOrCreate() # pyright: ignore - schema = StructType([ - StructField('id', IntegerType(), nullable=False), - StructField('name', StringType(), nullable=False), - StructField('amount', DecimalType(10, 2), nullable=False) - ]) - data = [(1, 'Alice', Decimal('123.45')), (2, 'Bob', Decimal('67.89')), - (3, 'Charlie', Decimal('987.65'))] - df = spark.createDataFrame(data=data, schema=schema).repartition(3) - mds_kwargs = {'out': mds_out, 'columns': {'id': 'int', 'name': 'str'}, 'keep_local': True} + + def random_string(length=1000): + """Generate a random string of fixed length.""" + letters = string.ascii_letters + string.digits + string.punctuation + ' ' + return ''.join(random.choice(letters) for i in range(length)) + + # Generate a DataFrame with 10000 rows of random text + num_rows = 100 + data = [(i, random_string(),random_string()) for i in range(num_rows)] + df = spark.createDataFrame(data, ["id", "name", "amount"]) + + mds_kwargs = {'out': mds_out, 'columns': {'id': 'int64', 'name': 'str'}, 'keep_local': True} dataframeToMDS(df, merge_index=False, mds_kwargs=mds_kwargs) local_cu = CloudUploader.get(local, exist_ok=True, keep_local=True) @@ -241,6 +246,16 @@ def not_merged_index(index_file_path: str, out: str): if index_file_urls_pattern == 1: merge_index(local_index_files, out, keep_local=keep_local) + d1 = json.load(open(os.path.join(out, 'index.json'))) + + _merge_index_from_list_serial(local_index_files, out, keep_local=keep_local) + d2 = json.load(open(os.path.join(out, 'index.json'))) + + print('d1 = ', d1) + print('d2 = ', d2) + + assert len(d1['shards']) == len(d2['shards']), 'parallel and serial results different' + assert d1['shards'] == d2['shards'], 'parallel and serial results different' if index_file_urls_pattern == 2: with tempfile.TemporaryDirectory() as a_temporary_folder: @@ -323,3 +338,104 @@ def flaky_function(): return "Third time's a charm" assert flaky_function() == "Third time's a charm" + + +def _merge_index_from_list_serial(index_file_urls: Sequence[Union[str, Tuple[str, str]]], + out: Union[str, Tuple[str, str]], + keep_local: bool = True, + download_timeout: int = 60, + merge = False) -> None: + import urllib.parse + from streaming.base.storage.download import download_file + from streaming.base.storage.upload import CloudUploader + from streaming.base.util import _not_merged_index, _format_remote_index_files + from streaming.base.format.index import get_index_basename + from collections import OrderedDict + import logging + + if not index_file_urls or not out: + logger.warning('Either index_file_urls or out are None. ' + + 'Need to specify both `index_file_urls` and `out`. ' + 'No index merged') + return + + # This is the index json file name, e.g., it is index.json as of 0.6.0 + index_basename = get_index_basename() + + print('i am here 1.1') + cu = CloudUploader.get(out, keep_local=True, exist_ok=True) + print('i am here 1.2') + + # Remove duplicates, and strip '/' from right if any + index_file_urls = list(OrderedDict.fromkeys(index_file_urls)) + urls = [] + for url in index_file_urls: + if isinstance(url, str): + urls.append(url.rstrip('/').strip()) + else: + urls.append((url[0].rstrip('/').strip(), url[1].rstrip('/').strip())) + + # Prepare a temp folder to download index.json from remote if necessary. Removed in the end. + with tempfile.TemporaryDirectory() as temp_root: + logging.warning(f'A temporary folder {temp_root} is created to store index files') + + # Copy files to a temporary directory. Download if necessary + partitions = [] + for url in urls: + if isinstance(url, tuple): + src = url[0] if os.path.exists(url[0]) else url[1] + else: + src = url + + obj = urllib.parse.urlparse(src) + scheme, bucket, path = obj.scheme, obj.netloc, obj.path + if scheme == '' and bucket == '' and path == '': + raise FileNotFoundError( + f'Check data availability! local index {url[0]} is not accessible.' + + f'remote index {url[1]} does not have a valid url format') + dest = os.path.join(temp_root, path.lstrip('/')) + + try: + download_file(src, dest, download_timeout) + except Exception as ex: + raise RuntimeError(f'Failed to download index.json: {src} to {dest}') from ex + + if not os.path.exists(dest): + raise FileNotFoundError(f'Index file {dest} does not exist or not accessible.') + + partitions.append(dest) + + if not merge: + return + + # merge shards from all index files + shards = [] + for partition_index in partitions: + p = Path(partition_index) + obj = json.load(open(partition_index)) + for i in range(len(obj['shards'])): + shard = obj['shards'][i] + for key in ('raw_data', 'zip_data', 'raw_meta', 'zip_meta'): + if shard.get(key): + basename = shard[key]['basename'] + obj['shards'][i][key]['basename'] = os.path.join( + os.path.basename(p.parent), basename) + shards += obj['shards'] + + # Save merged index locally + obj = { + 'version': 2, + 'shards': shards, + } + merged_index_path = os.path.join(temp_root, index_basename) + with open(merged_index_path, 'w') as outfile: + json.dump(obj, outfile) + + # Move merged index from temp path to local part in out + # Upload merged index to remote if out has remote part + shutil.move(merged_index_path, cu.local) + if cu.remote is not None: + cu.upload_file(index_basename) + + # Clean up + if not keep_local: + shutil.rmtree(cu.local, ignore_errors=True) From feee52d287ce8d949cd8e5fc6a1a7832524645cc Mon Sep 17 00:00:00 2001 From: xiaohan Date: Mon, 5 Feb 2024 09:54:29 -0800 Subject: [PATCH 02/14] fix lints --- streaming/base/util.py | 9 +++++++-- tests/test_util.py | 40 ++++++++++++++++------------------------ 2 files changed, 23 insertions(+), 26 deletions(-) diff --git a/streaming/base/util.py b/streaming/base/util.py index 6c3437299..1804b373c 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -13,8 +13,8 @@ import tempfile import urllib.parse from collections import OrderedDict -from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory from multiprocessing import Pool +from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory from pathlib import Path from time import sleep, time from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union, cast, overload @@ -264,6 +264,7 @@ def _download_url(url_info): return f'Failed to download index.json: {src} to {dest}: {str(ex)}', ex return dest, None + def _merge_partition_indices(partition_indices): """Function to be executed by each process to merge a subset of partition indices.""" shards = [] @@ -279,12 +280,15 @@ def _merge_partition_indices(partition_indices): shards.extend(obj['shards']) return shards + def _parallel_merge_partitions(partitions, n_processes=4): """Divide the list of partitions among multiple processes and merge them in parallel.""" with Pool(processes=n_processes) as pool: # Split the list of partitions into N chunks where N is the number of processes chunk_size = len(partitions) // n_processes + (len(partitions) % n_processes > 0) - partition_chunks = [partitions[i:i + chunk_size] for i in range(0, len(partitions), chunk_size)] + partition_chunks = [ + partitions[i:i + chunk_size] for i in range(0, len(partitions), chunk_size) + ] # Process each chunk in parallel results = pool.map(_merge_partition_indices, partition_chunks) @@ -386,6 +390,7 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]] if not keep_local: shutil.rmtree(cu.local, ignore_errors=True) + def _not_merged_index(index_file_path: str, out: str): """Check if index_file_path is the merged index at folder out. diff --git a/tests/test_util.py b/tests/test_util.py index ab024fb05..b05da8612 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -7,7 +7,7 @@ import time import urllib.parse from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory -from typing import List, Optional, Tuple, Union, Sequence +from typing import List, Optional, Sequence, Tuple, Union import pytest @@ -194,9 +194,9 @@ def test_format_remote_index_files(scheme: str): assert obj.scheme == scheme -@pytest.mark.parametrize('index_file_urls_pattern', [1]) # , 2, 3]) -@pytest.mark.parametrize('keep_local', [True]) # , False]) -@pytest.mark.parametrize('scheme', ['gs://']) # , 's3://', 'oci://']) +@pytest.mark.parametrize('index_file_urls_pattern', [1]) # , 2, 3]) +@pytest.mark.parametrize('keep_local', [True]) # , False]) +@pytest.mark.parametrize('scheme', ['gs://']) # , 's3://', 'oci://']) def test_merge_index_from_list_local(local_remote_dir: Tuple[str, str], keep_local: bool, index_file_urls_pattern: int, scheme: str): """Validate the final merge index json for following patterns of index_file_urls: @@ -206,14 +206,12 @@ def test_merge_index_from_list_local(local_remote_dir: Tuple[str, str], keep_loc 4. All URLs are tuple (local, remote). At least one url is not accessible locally -> download all 5. All URLs are str (remote) -> download all """ - from decimal import Decimal + import random + import string from pyspark.sql import SparkSession - from pyspark.sql.types import DecimalType, IntegerType, StringType, StructField, StructType from streaming.base.converters import dataframeToMDS - import random - import string def not_merged_index(index_file_path: str, out: str): """Check if index_file_path is the merged index at folder out.""" @@ -229,12 +227,12 @@ def not_merged_index(index_file_path: str, out: str): def random_string(length=1000): """Generate a random string of fixed length.""" letters = string.ascii_letters + string.digits + string.punctuation + ' ' - return ''.join(random.choice(letters) for i in range(length)) + return ''.join(random.choice(letters) for _ in range(length)) # Generate a DataFrame with 10000 rows of random text num_rows = 100 - data = [(i, random_string(),random_string()) for i in range(num_rows)] - df = spark.createDataFrame(data, ["id", "name", "amount"]) + data = [(i, random_string(), random_string()) for i in range(num_rows)] + df = spark.createDataFrame(data, ['id', 'name', 'amount']) mds_kwargs = {'out': mds_out, 'columns': {'id': 'int64', 'name': 'str'}, 'keep_local': True} dataframeToMDS(df, merge_index=False, mds_kwargs=mds_kwargs) @@ -343,27 +341,24 @@ def flaky_function(): def _merge_index_from_list_serial(index_file_urls: Sequence[Union[str, Tuple[str, str]]], out: Union[str, Tuple[str, str]], keep_local: bool = True, - download_timeout: int = 60, - merge = False) -> None: + download_timeout: int = 60) -> None: + import logging + import shutil import urllib.parse + from collections import OrderedDict + from pathlib import Path + + from streaming.base.format.index import get_index_basename from streaming.base.storage.download import download_file from streaming.base.storage.upload import CloudUploader - from streaming.base.util import _not_merged_index, _format_remote_index_files - from streaming.base.format.index import get_index_basename - from collections import OrderedDict - import logging if not index_file_urls or not out: - logger.warning('Either index_file_urls or out are None. ' + - 'Need to specify both `index_file_urls` and `out`. ' + 'No index merged') return # This is the index json file name, e.g., it is index.json as of 0.6.0 index_basename = get_index_basename() - print('i am here 1.1') cu = CloudUploader.get(out, keep_local=True, exist_ok=True) - print('i am here 1.2') # Remove duplicates, and strip '/' from right if any index_file_urls = list(OrderedDict.fromkeys(index_file_urls)) @@ -404,9 +399,6 @@ def _merge_index_from_list_serial(index_file_urls: Sequence[Union[str, Tuple[str partitions.append(dest) - if not merge: - return - # merge shards from all index files shards = [] for partition_index in partitions: From a0605a2935481d5cca7229c38b8945a225afb128 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Mon, 5 Feb 2024 22:30:19 -0800 Subject: [PATCH 03/14] test --- streaming/base/util.py | 28 +++++++++----- tests/test_util.py | 86 +++++++++++++++++++++++++++++++++++------- 2 files changed, 91 insertions(+), 23 deletions(-) diff --git a/streaming/base/util.py b/streaming/base/util.py index 1804b373c..110f1a230 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -17,7 +17,8 @@ from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory from pathlib import Path from time import sleep, time -from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union, cast, overload +from typing import (Any, Callable, List, Optional, Sequence, Tuple, Type, TypeVar, Union, cast, + overload) import torch.distributed as dist @@ -254,7 +255,7 @@ def merge_index(*args: Any, **kwargs: Any): raise ValueError(f'Invalid arguments to merge_index: {args}, {kwargs}') -def _download_url(url_info): +def _download_url(url_info: Tuple[str, str, int]): """Download a file given URL information.""" from streaming.base.storage.download import download_file src, dest, download_timeout = url_info @@ -265,7 +266,7 @@ def _download_url(url_info): return dest, None -def _merge_partition_indices(partition_indices): +def _merge_partition_indices(partition_indices: List[str]): """Function to be executed by each process to merge a subset of partition indices.""" shards = [] for partition_index in partition_indices: @@ -281,7 +282,7 @@ def _merge_partition_indices(partition_indices): return shards -def _parallel_merge_partitions(partitions, n_processes=4): +def _parallel_merge_partitions(partitions: List[str], n_processes: Optional[int] = 1): """Divide the list of partitions among multiple processes and merge them in parallel.""" with Pool(processes=n_processes) as pool: # Split the list of partitions into N chunks where N is the number of processes @@ -301,7 +302,8 @@ def _parallel_merge_partitions(partitions, n_processes=4): def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]]], out: Union[str, Tuple[str, str]], keep_local: bool = True, - download_timeout: int = 60) -> None: + download_timeout: int = 60, + n_processes: int = 8) -> None: """Merge index.json from a list of index files of MDS directories to create joined index. Args: @@ -317,6 +319,7 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]] out (Union[str, Tuple[str, str]]): path to put the merged index file keep_local (bool): Keep local copy of the merged index file. Defaults to ``True`` download_timeout (int): The allowed time for downloading each json file. Defaults to 60. + n_processes (int): The number of cores to run the function in parallel """ from streaming.base.storage.upload import CloudUploader @@ -339,6 +342,8 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]] else: urls.append((url[0].rstrip('/').strip(), url[1].rstrip('/').strip())) + n_processes = n_processes if (n_processes is not None and 1 <= n_processes <= os.cpu_count()) else 1 + # Prepare a temp folder to download index.json from remote if necessary. Removed in the end. with tempfile.TemporaryDirectory() as temp_root: logging.info(f'A temporary folder {temp_root} is created to store index files') @@ -360,7 +365,7 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]] dest = os.path.join(temp_root, path.lstrip('/')) download_tasks.append((src, dest, download_timeout)) - with Pool(processes=os.cpu_count()) as pool: + with Pool(processes=n_processes) as pool: results = pool.map(_download_url, download_tasks) partitions = [] @@ -369,7 +374,7 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]] raise RuntimeError(partition_index) partitions.append(partition_index) - shards = _parallel_merge_partitions(partitions) + shards = _parallel_merge_partitions(partitions, n_processes) # Save merged index locally obj = { @@ -466,18 +471,21 @@ def _merge_index_from_root(out: Union[str, Tuple[str, str]], _merge_index_from_list(list(zip(local_index_files, remote_index_files)), out, keep_local=keep_local, - download_timeout=download_timeout) + download_timeout=download_timeout, + n_processes = os.cpu_count()) else: _merge_index_from_list(remote_index_files, out, keep_local=keep_local, - download_timeout=download_timeout) + download_timeout=download_timeout, + n_processes = os.cpu_count()) return _merge_index_from_list(local_index_files, out, keep_local=keep_local, - download_timeout=download_timeout) + download_timeout=download_timeout, + n_processes = os.cpu_count()) @overload diff --git a/tests/test_util.py b/tests/test_util.py index b05da8612..fc9059a53 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -10,6 +10,7 @@ from typing import List, Optional, Sequence, Tuple, Union import pytest +from unittest.mock import MagicMock from streaming.base.constant import RESUME from streaming.base.shared.prefix import _get_path @@ -193,12 +194,69 @@ def test_format_remote_index_files(scheme: str): obj = urllib.parse.urlparse(file) assert obj.scheme == scheme +@pytest.mark.parametrize('cpu_count', [0, 1,4,2000]) +def test_merge_index_from_list_local_cpucount(local_remote_dir: Tuple[str, str], cpu_count: int): + """Validate the multiprocessing setting""" + import random + import string + from pyspark.sql import SparkSession + from streaming.base.converters import dataframeToMDS + + def not_merged_index(index_file_path: str, out: str): + """Check if index_file_path is the merged index at folder out.""" + prefix = str(urllib.parse.urlparse(out).path) + return os.path.dirname(index_file_path).strip('/') != prefix.strip('/') + + keep_local = True + + local, _ = local_remote_dir + + mds_out = out = local + + os.cpu_count = MagicMock() + os.cpu_count.return_value = cpu_count + + spark = SparkSession.builder.getOrCreate() # pyright: ignore + + def random_string(length: int = 1000): + """Generate a random string of fixed length.""" + letters = string.ascii_letters + string.digits + string.punctuation + ' ' + return ''.join(random.choice(letters) for _ in range(length)) -@pytest.mark.parametrize('index_file_urls_pattern', [1]) # , 2, 3]) -@pytest.mark.parametrize('keep_local', [True]) # , False]) -@pytest.mark.parametrize('scheme', ['gs://']) # , 's3://', 'oci://']) + # Generate a DataFrame with 10000 rows of random text + num_rows = 100 + data = [(i, random_string(), random_string()) for i in range(num_rows)] + df = spark.createDataFrame(data, ['id', 'name', 'amount']) + + mds_kwargs = {'out': mds_out, 'columns': {'id': 'int64', 'name': 'str'}, 'keep_local': True} + dataframeToMDS(df, merge_index=False, mds_kwargs=mds_kwargs) + + local_cu = CloudUploader.get(local, exist_ok=True, keep_local=True) + local_index_files = [ + o for o in local_cu.list_objects() if o.endswith('.json') and not_merged_index(o, local) + ] + + merge_index(local_index_files, out, keep_local=keep_local) + + d1 = json.load(open(os.path.join(out, 'index.json'))) + + _merge_index_from_list_serial(local_index_files, out, keep_local=keep_local) + d2 = json.load(open(os.path.join(out, 'index.json'))) + + print('d1 = ', d1) + print('d2 = ', d2) + + assert len(d1['shards']) == len(d2['shards']), 'parallel and serial results different' + assert d1['shards'] == d2['shards'], 'parallel and serial results different' + + + +@pytest.mark.parametrize('cpu_count', [1,4,2000]) +@pytest.mark.parametrize('index_file_urls_pattern', [1, 2, 3]) +@pytest.mark.parametrize('keep_local', [True, False]) +@pytest.mark.parametrize('scheme', ['gs://', 's3://', 'oci://']) def test_merge_index_from_list_local(local_remote_dir: Tuple[str, str], keep_local: bool, - index_file_urls_pattern: int, scheme: str): + index_file_urls_pattern: int, scheme: str, cpu_count: int): """Validate the final merge index json for following patterns of index_file_urls: 1. All URLs are str (local). All URLs are accessible locally -> no download 2. All URLs are str (local). At least one url is unaccessible locally -> Error @@ -224,7 +282,7 @@ def not_merged_index(index_file_path: str, out: str): spark = SparkSession.builder.getOrCreate() # pyright: ignore - def random_string(length=1000): + def random_string(length: int = 1000): """Generate a random string of fixed length.""" letters = string.ascii_letters + string.digits + string.punctuation + ' ' return ''.join(random.choice(letters) for _ in range(length)) @@ -244,16 +302,18 @@ def random_string(length=1000): if index_file_urls_pattern == 1: merge_index(local_index_files, out, keep_local=keep_local) - d1 = json.load(open(os.path.join(out, 'index.json'))) - _merge_index_from_list_serial(local_index_files, out, keep_local=keep_local) - d2 = json.load(open(os.path.join(out, 'index.json'))) + if keep_local: + d1 = json.load(open(os.path.join(out, 'index.json'))) + + _merge_index_from_list_serial(local_index_files, out, keep_local=keep_local) + d2 = json.load(open(os.path.join(out, 'index.json'))) - print('d1 = ', d1) - print('d2 = ', d2) + print('d1 = ', d1) + print('d2 = ', d2) - assert len(d1['shards']) == len(d2['shards']), 'parallel and serial results different' - assert d1['shards'] == d2['shards'], 'parallel and serial results different' + assert len(d1['shards']) == len(d2['shards']), 'parallel and serial results different' + assert d1['shards'] == d2['shards'], 'parallel and serial results different' if index_file_urls_pattern == 2: with tempfile.TemporaryDirectory() as a_temporary_folder: @@ -424,7 +484,7 @@ def _merge_index_from_list_serial(index_file_urls: Sequence[Union[str, Tuple[str # Move merged index from temp path to local part in out # Upload merged index to remote if out has remote part - shutil.move(merged_index_path, cu.local) + shutil.move(merged_index_path, os.path.join(cu.local, index_basename)) if cu.remote is not None: cu.upload_file(index_basename) From e7edd52fbd7bf3796342231383caf52cd40469a5 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Mon, 12 Feb 2024 11:09:04 -0800 Subject: [PATCH 04/14] update --- .coveragerc | 8 -------- streaming/base/util.py | 25 +++++++++++++++---------- tests/test_util.py | 4 +++- 3 files changed, 18 insertions(+), 19 deletions(-) delete mode 100644 .coveragerc diff --git a/.coveragerc b/.coveragerc deleted file mode 100644 index 53166bbe7..000000000 --- a/.coveragerc +++ /dev/null @@ -1,8 +0,0 @@ -[run] -branch = True -omit = streaming/text/convert/enwiki/mds/*,streaming/text/convert/enwiki/tfrecord/* - -[report] -show_missing = True -precision = 2 -exclude_lines = raise NotImplementedError.* diff --git a/streaming/base/util.py b/streaming/base/util.py index 110f1a230..433fda847 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -9,6 +9,7 @@ import logging import os import random +import numpy as np import shutil import tempfile import urllib.parse @@ -258,12 +259,12 @@ def merge_index(*args: Any, **kwargs: Any): def _download_url(url_info: Tuple[str, str, int]): """Download a file given URL information.""" from streaming.base.storage.download import download_file - src, dest, download_timeout = url_info + src, dst, download_timeout = url_info try: - download_file(src, dest, download_timeout) + download_file(src, dst, download_timeout) except Exception as ex: - return f'Failed to download index.json: {src} to {dest}: {str(ex)}', ex - return dest, None + return f'Failed to download index.json: {src} to {dst}: {str(ex)}', ex + return dst, None def _merge_partition_indices(partition_indices: List[str]): @@ -286,13 +287,15 @@ def _parallel_merge_partitions(partitions: List[str], n_processes: Optional[int] """Divide the list of partitions among multiple processes and merge them in parallel.""" with Pool(processes=n_processes) as pool: # Split the list of partitions into N chunks where N is the number of processes - chunk_size = len(partitions) // n_processes + (len(partitions) % n_processes > 0) + chunk_size = int(np.ceil(len(partitions)/n_processes)) partition_chunks = [ partitions[i:i + chunk_size] for i in range(0, len(partitions), chunk_size) ] # Process each chunk in parallel - results = pool.map(_merge_partition_indices, partition_chunks) + results = pool.imap_unordered(_merge_partition_indices, partition_chunks) + pool.close() + pool.join() # Combine the results from all processes final_shards = [shard for result in results for shard in result] @@ -346,7 +349,7 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]] # Prepare a temp folder to download index.json from remote if necessary. Removed in the end. with tempfile.TemporaryDirectory() as temp_root: - logging.info(f'A temporary folder {temp_root} is created to store index files') + logging.info(f'Created temporary folder {temp_root} to store index files') # Copy files to a temporary directory. Download if necessary download_tasks = [] @@ -362,11 +365,13 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]] raise FileNotFoundError( f'Check data availability! local index {url[0]} is not accessible.' + f'remote index {url[1]} does not have a valid url format') - dest = os.path.join(temp_root, path.lstrip('/')) - download_tasks.append((src, dest, download_timeout)) + dst = os.path.join(temp_root, path.lstrip('/')) + download_tasks.append((src, dst, download_timeout)) with Pool(processes=n_processes) as pool: - results = pool.map(_download_url, download_tasks) + results = pool.imap_unordered(_download_url, download_tasks) + pool.close() + pool.join() partitions = [] for partition_index, error in results: diff --git a/tests/test_util.py b/tests/test_util.py index fc9059a53..e13348a68 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -8,6 +8,7 @@ import urllib.parse from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory from typing import List, Optional, Sequence, Tuple, Union +import numpy as np import pytest from unittest.mock import MagicMock @@ -221,7 +222,8 @@ def not_merged_index(index_file_path: str, out: str): def random_string(length: int = 1000): """Generate a random string of fixed length.""" letters = string.ascii_letters + string.digits + string.punctuation + ' ' - return ''.join(random.choice(letters) for _ in range(length)) + return ''.join(map(chr, np.random.choice(0x10FFFF - 1, length))) + #return ''.join(random.choice(letters) for _ in range(length)) # Generate a DataFrame with 10000 rows of random text num_rows = 100 From a58332906365ff7cb82287756b363895a52cce81 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Mon, 12 Feb 2024 21:28:41 -0800 Subject: [PATCH 05/14] update tests --- tests/test_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_util.py b/tests/test_util.py index e13348a68..95f2e1cdb 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -195,7 +195,7 @@ def test_format_remote_index_files(scheme: str): obj = urllib.parse.urlparse(file) assert obj.scheme == scheme -@pytest.mark.parametrize('cpu_count', [0, 1,4,2000]) +@pytest.mark.parametrize('cpu_count', [0, 1,4]) def test_merge_index_from_list_local_cpucount(local_remote_dir: Tuple[str, str], cpu_count: int): """Validate the multiprocessing setting""" import random From d6206f05e3be588696b2910c0c9005e9d7280a58 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Mon, 12 Feb 2024 22:51:04 -0800 Subject: [PATCH 06/14] updates --- streaming/base/util.py | 67 ++++++++++++++++++++++-------------------- tests/test_util.py | 15 ++++------ 2 files changed, 41 insertions(+), 41 deletions(-) diff --git a/streaming/base/util.py b/streaming/base/util.py index 433fda847..d92f2c803 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -9,7 +9,6 @@ import logging import os import random -import numpy as np import shutil import tempfile import urllib.parse @@ -18,9 +17,10 @@ from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory from pathlib import Path from time import sleep, time -from typing import (Any, Callable, List, Optional, Sequence, Tuple, Type, TypeVar, Union, cast, - overload) +from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union, cast, overload +import numpy as np +import psutil import torch.distributed as dist from streaming.base.constant import SHM_TO_CLEAN @@ -220,7 +220,7 @@ def get_import_exception_message(package_name: str, extra_deps: str) -> str: def merge_index(*args: Any, **kwargs: Any): - r"""Merge index.json from partitions to form a global index.json. + r"""Merge index.json from shards to form a global index.json. This can be called as @@ -228,18 +228,18 @@ def merge_index(*args: Any, **kwargs: Any): merge_index(out, keep_local, download_timeout) - The first signature takes in a list of index files URLs of MDS partitions. - The second takes the root of a MDS dataset and parse the partition folders from there. + The first signature takes in a list of index files URLs of MDS shards. + The second takes the root of a MDS dataset and parse the shards folders from there. Args: - index_file_urls (List[Union[str, Tuple[str,str]]]): index.json from all the partitions. + index_file_urls (List[Union[str, Tuple[str,str]]]): index.json from all the shards. Each element can take the form of a single path string or a tuple string. 1. If ``index_file_urls`` is a List of local URLs, merge locally without download. 2. If ``index_file_urls`` is a List of tuple (local, remote) URLs, check if local index.json are missing, download before merging. 3. If ``index_file_urls`` is a List of remote URLs, download all and merge. - out (Union[str, Tuple[str,str]]): folder that contain MDS partitions and to put the merged index file + out (Union[str, Tuple[str,str]]): folder that contain MDS shards and to put the merged index file 1. A local directory, merge index happens locally. 2. A remote directory, download all the sub-directories index.json, merge locally and upload. @@ -267,12 +267,12 @@ def _download_url(url_info: Tuple[str, str, int]): return dst, None -def _merge_partition_indices(partition_indices: List[str]): - """Function to be executed by each process to merge a subset of partition indices.""" +def _merge_shard_indices(shard_indices: List[str]): + """Function to be executed by each process to merge a subset of shard indices.""" shards = [] - for partition_index in partition_indices: - p = Path(partition_index) - with open(partition_index, 'r') as f: + for shard_index in shard_indices: + p = Path(shard_index) + with open(shard_index, 'r') as f: obj = json.load(f) for shard in obj['shards']: for key in ('raw_data', 'zip_data', 'raw_meta', 'zip_meta'): @@ -283,17 +283,17 @@ def _merge_partition_indices(partition_indices: List[str]): return shards -def _parallel_merge_partitions(partitions: List[str], n_processes: Optional[int] = 1): - """Divide the list of partitions among multiple processes and merge them in parallel.""" +def _parallel_merge_shards(shards: List[str], n_processes: int = 1): + """Divide the list of shards among multiple processes and merge them in parallel.""" with Pool(processes=n_processes) as pool: - # Split the list of partitions into N chunks where N is the number of processes - chunk_size = int(np.ceil(len(partitions)/n_processes)) - partition_chunks = [ - partitions[i:i + chunk_size] for i in range(0, len(partitions), chunk_size) + # Split the list of shards into N chunks where N is the number of processes + chunk_size = int(np.ceil(len(shards) / n_processes)) + shard_chunks = [ + shards[i:i + chunk_size] for i in range(0, len(shards), chunk_size) ] # Process each chunk in parallel - results = pool.imap_unordered(_merge_partition_indices, partition_chunks) + results = pool.map(_merge_shard_indices, shard_chunks) pool.close() pool.join() @@ -310,7 +310,7 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]] """Merge index.json from a list of index files of MDS directories to create joined index. Args: - index_file_urls (Union[str, Tuple[str,str]]): index.json from all the partitions + index_file_urls (Union[str, Tuple[str,str]]): index.json from all the shards each element can take the form of a single path string or a tuple string. The pattern of index_file_urls and corresponding reaction is one of: @@ -345,7 +345,8 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]] else: urls.append((url[0].rstrip('/').strip(), url[1].rstrip('/').strip())) - n_processes = n_processes if (n_processes is not None and 1 <= n_processes <= os.cpu_count()) else 1 + cpu_count = max(psutil.cpu_count() - 2, 1) + n_processes = n_processes if (n_processes is not None and 1 <= n_processes <= cpu_count) else 1 # Prepare a temp folder to download index.json from remote if necessary. Removed in the end. with tempfile.TemporaryDirectory() as temp_root: @@ -369,17 +370,17 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]] download_tasks.append((src, dst, download_timeout)) with Pool(processes=n_processes) as pool: - results = pool.imap_unordered(_download_url, download_tasks) + results = pool.map(_download_url, download_tasks) pool.close() pool.join() - partitions = [] - for partition_index, error in results: + shards = [] + for shard_index, error in results: if error: - raise RuntimeError(partition_index) - partitions.append(partition_index) + raise RuntimeError(shard_index) + shards.append(shard_index) - shards = _parallel_merge_partitions(partitions, n_processes) + shards = _parallel_merge_shards(shards, n_processes) # Save merged index locally obj = { @@ -445,7 +446,7 @@ def _merge_index_from_root(out: Union[str, Tuple[str, str]], """Merge index.json given the root of MDS dataset. Write merged index to the root folder. Args: - out (Union[str, Tuple[str,str]]): folder that contain MDS partitions. + out (Union[str, Tuple[str,str]]): folder that contain MDS shards. :A local directory, merge index happens locally :A remote directory, download all the sub-directories index.json in a temporary sub-directories, merge locally, and then upload it to out location @@ -470,6 +471,8 @@ def _merge_index_from_root(out: Union[str, Tuple[str, str]], if file.endswith('.json') and _not_merged_index(file, cu.local): local_index_files.append(file) + cpu_count = max(psutil.cpu_count() - 2, 1) + if cu.remote: remote_index_files = _format_remote_index_files(cu.remote, cu.list_objects()) if len(local_index_files) == len(remote_index_files): @@ -477,20 +480,20 @@ def _merge_index_from_root(out: Union[str, Tuple[str, str]], out, keep_local=keep_local, download_timeout=download_timeout, - n_processes = os.cpu_count()) + n_processes=cpu_count) else: _merge_index_from_list(remote_index_files, out, keep_local=keep_local, download_timeout=download_timeout, - n_processes = os.cpu_count()) + n_processes=cpu_count) return _merge_index_from_list(local_index_files, out, keep_local=keep_local, download_timeout=download_timeout, - n_processes = os.cpu_count()) + n_processes=cpu_count) @overload diff --git a/tests/test_util.py b/tests/test_util.py index 95f2e1cdb..dee2614be 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -8,10 +8,10 @@ import urllib.parse from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory from typing import List, Optional, Sequence, Tuple, Union -import numpy as np +from unittest.mock import MagicMock +import numpy as np import pytest -from unittest.mock import MagicMock from streaming.base.constant import RESUME from streaming.base.shared.prefix import _get_path @@ -195,12 +195,12 @@ def test_format_remote_index_files(scheme: str): obj = urllib.parse.urlparse(file) assert obj.scheme == scheme -@pytest.mark.parametrize('cpu_count', [0, 1,4]) + +@pytest.mark.parametrize('cpu_count', [0, 1, 4]) def test_merge_index_from_list_local_cpucount(local_remote_dir: Tuple[str, str], cpu_count: int): """Validate the multiprocessing setting""" - import random - import string from pyspark.sql import SparkSession + from streaming.base.converters import dataframeToMDS def not_merged_index(index_file_path: str, out: str): @@ -221,9 +221,7 @@ def not_merged_index(index_file_path: str, out: str): def random_string(length: int = 1000): """Generate a random string of fixed length.""" - letters = string.ascii_letters + string.digits + string.punctuation + ' ' return ''.join(map(chr, np.random.choice(0x10FFFF - 1, length))) - #return ''.join(random.choice(letters) for _ in range(length)) # Generate a DataFrame with 10000 rows of random text num_rows = 100 @@ -252,8 +250,7 @@ def random_string(length: int = 1000): assert d1['shards'] == d2['shards'], 'parallel and serial results different' - -@pytest.mark.parametrize('cpu_count', [1,4,2000]) +@pytest.mark.parametrize('cpu_count', [1, 4]) @pytest.mark.parametrize('index_file_urls_pattern', [1, 2, 3]) @pytest.mark.parametrize('keep_local', [True, False]) @pytest.mark.parametrize('scheme', ['gs://', 's3://', 'oci://']) From 21c591a8fd253cee73572e24a285f74389434fe8 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Tue, 13 Feb 2024 21:18:34 -0800 Subject: [PATCH 07/14] Fix lints --- streaming/base/util.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/streaming/base/util.py b/streaming/base/util.py index d92f2c803..ca31cd3d2 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -288,9 +288,7 @@ def _parallel_merge_shards(shards: List[str], n_processes: int = 1): with Pool(processes=n_processes) as pool: # Split the list of shards into N chunks where N is the number of processes chunk_size = int(np.ceil(len(shards) / n_processes)) - shard_chunks = [ - shards[i:i + chunk_size] for i in range(0, len(shards), chunk_size) - ] + shard_chunks = [shards[i:i + chunk_size] for i in range(0, len(shards), chunk_size)] # Process each chunk in parallel results = pool.map(_merge_shard_indices, shard_chunks) From 168d3dd6d9eec76281165d46f4c1abe760741c44 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Tue, 13 Feb 2024 21:20:16 -0800 Subject: [PATCH 08/14] add psutil --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 618f6c2b2..a66adb182 100644 --- a/setup.py +++ b/setup.py @@ -58,6 +58,7 @@ 'azure-storage-blob>=12.0.0,<13', 'azure-storage-file-datalake>=12.11.0,<13', 'azure-identity>=1.13.0', + 'psutil>=5.8.0,<6', ] extra_deps = {} From f82d47e069c719eb42ef7e4f8abcd2509772daf4 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Thu, 15 Feb 2024 06:53:41 -0800 Subject: [PATCH 09/14] update --- streaming/base/util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/streaming/base/util.py b/streaming/base/util.py index ca31cd3d2..c5ff862a0 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -291,7 +291,7 @@ def _parallel_merge_shards(shards: List[str], n_processes: int = 1): shard_chunks = [shards[i:i + chunk_size] for i in range(0, len(shards), chunk_size)] # Process each chunk in parallel - results = pool.map(_merge_shard_indices, shard_chunks) + results = pool.imap_unordered(_merge_shard_indices, shard_chunks) pool.close() pool.join() @@ -368,7 +368,7 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]] download_tasks.append((src, dst, download_timeout)) with Pool(processes=n_processes) as pool: - results = pool.map(_download_url, download_tasks) + results = pool.imap_unordered(_download_url, download_tasks) pool.close() pool.join() From 6add8eafbbbaf7f5d2db13af9c8c4da09ab5f8c7 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Thu, 15 Feb 2024 06:55:40 -0800 Subject: [PATCH 10/14] update --- streaming/base/util.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/streaming/base/util.py b/streaming/base/util.py index c5ff862a0..53baa94a4 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -346,6 +346,8 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]] cpu_count = max(psutil.cpu_count() - 2, 1) n_processes = n_processes if (n_processes is not None and 1 <= n_processes <= cpu_count) else 1 + logger.warning(f'Got n_processes = {n_processes}. download and merge index in parallel') + # Prepare a temp folder to download index.json from remote if necessary. Removed in the end. with tempfile.TemporaryDirectory() as temp_root: logging.info(f'Created temporary folder {temp_root} to store index files') From 18a2f97f4669fe91c86f2b9cd0a5cd39238b972c Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Sat, 24 Feb 2024 15:53:19 -0800 Subject: [PATCH 11/14] Add warning --- streaming/base/util.py | 49 ++++++++++++++++++++++-------------------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/streaming/base/util.py b/streaming/base/util.py index 53baa94a4..2211271e6 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -220,7 +220,7 @@ def get_import_exception_message(package_name: str, extra_deps: str) -> str: def merge_index(*args: Any, **kwargs: Any): - r"""Merge index.json from shards to form a global index.json. + r"""Merge index.json from streams to form a global index.json. This can be called as @@ -228,18 +228,18 @@ def merge_index(*args: Any, **kwargs: Any): merge_index(out, keep_local, download_timeout) - The first signature takes in a list of index files URLs of MDS shards. - The second takes the root of a MDS dataset and parse the shards folders from there. + The first signature takes in a list of index files URLs of MDS streams. + The second takes the root of a MDS dataset and parse the streams folders from there. Args: - index_file_urls (List[Union[str, Tuple[str,str]]]): index.json from all the shards. + index_file_urls (List[Union[str, Tuple[str,str]]]): index.json from all the streams. Each element can take the form of a single path string or a tuple string. 1. If ``index_file_urls`` is a List of local URLs, merge locally without download. 2. If ``index_file_urls`` is a List of tuple (local, remote) URLs, check if local index.json are missing, download before merging. 3. If ``index_file_urls`` is a List of remote URLs, download all and merge. - out (Union[str, Tuple[str,str]]): folder that contain MDS shards and to put the merged index file + out (Union[str, Tuple[str,str]]): folder that contain MDS streams and to put the merged index file 1. A local directory, merge index happens locally. 2. A remote directory, download all the sub-directories index.json, merge locally and upload. @@ -267,12 +267,12 @@ def _download_url(url_info: Tuple[str, str, int]): return dst, None -def _merge_shard_indices(shard_indices: List[str]): - """Function to be executed by each process to merge a subset of shard indices.""" +def _merge_stream_indices(stream_indices: List[str]): + """Function to be executed by each process to merge a subset of stream indices.""" shards = [] - for shard_index in shard_indices: - p = Path(shard_index) - with open(shard_index, 'r') as f: + for stream_index in stream_indices: + p = Path(stream_index) + with open(stream_index, 'r') as f: obj = json.load(f) for shard in obj['shards']: for key in ('raw_data', 'zip_data', 'raw_meta', 'zip_meta'): @@ -283,15 +283,15 @@ def _merge_shard_indices(shard_indices: List[str]): return shards -def _parallel_merge_shards(shards: List[str], n_processes: int = 1): - """Divide the list of shards among multiple processes and merge them in parallel.""" +def _parallel_merge_streams(streams: List[str], n_processes: int = 1): + """Divide the list of streams among multiple processes and merge their shards in parallel.""" with Pool(processes=n_processes) as pool: - # Split the list of shards into N chunks where N is the number of processes - chunk_size = int(np.ceil(len(shards) / n_processes)) - shard_chunks = [shards[i:i + chunk_size] for i in range(0, len(shards), chunk_size)] + # Split the list of streams into N chunks where N is the number of processes + chunk_size = int(np.ceil(len(streams) / n_processes)) + stream_chunks = [streams[i:i + chunk_size] for i in range(0, len(streams), chunk_size)] # Process each chunk in parallel - results = pool.imap_unordered(_merge_shard_indices, shard_chunks) + results = pool.map(_merge_stream_indices, stream_chunks) pool.close() pool.join() @@ -308,7 +308,7 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]] """Merge index.json from a list of index files of MDS directories to create joined index. Args: - index_file_urls (Union[str, Tuple[str,str]]): index.json from all the shards + index_file_urls (Union[str, Tuple[str,str]]): index.json from all the streams each element can take the form of a single path string or a tuple string. The pattern of index_file_urls and corresponding reaction is one of: @@ -370,17 +370,17 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]] download_tasks.append((src, dst, download_timeout)) with Pool(processes=n_processes) as pool: - results = pool.imap_unordered(_download_url, download_tasks) + results = pool.map(_download_url, download_tasks) pool.close() pool.join() - shards = [] - for shard_index, error in results: + streams = [] + for stream_index, error in results: if error: - raise RuntimeError(shard_index) - shards.append(shard_index) + raise RuntimeError(stream_index) + streams.append(stream_index) - shards = _parallel_merge_shards(shards, n_processes) + shards = _parallel_merge_streams(streams, n_processes) # Save merged index locally obj = { @@ -467,6 +467,9 @@ def _merge_index_from_root(out: Union[str, Tuple[str, str]], local_index_files = [] cl = CloudUploader.get(cu.local, exist_ok=True, keep_local=True) + + logger.warning(f"We will be listing objects from {out}, which may take a long time if the number of stream folders is large. Consider provide the list of path/to/index.json directly.") + for file in cl.list_objects(): if file.endswith('.json') and _not_merged_index(file, cu.local): local_index_files.append(file) From eb5a16f1760eec6fbb62fb8248b3b9f695a9a463 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Sat, 24 Feb 2024 16:41:39 -0800 Subject: [PATCH 12/14] Fix lints --- streaming/base/util.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/streaming/base/util.py b/streaming/base/util.py index 2211271e6..94f06e116 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -468,7 +468,9 @@ def _merge_index_from_root(out: Union[str, Tuple[str, str]], local_index_files = [] cl = CloudUploader.get(cu.local, exist_ok=True, keep_local=True) - logger.warning(f"We will be listing objects from {out}, which may take a long time if the number of stream folders is large. Consider provide the list of path/to/index.json directly.") + logger.warning( + f'We will be listing objects from {out}, which may take a long time if the number of stream folders is large. Consider provide the list of path/to/index.json directly.' + ) for file in cl.list_objects(): if file.endswith('.json') and _not_merged_index(file, cu.local): From 1a0c458fdc1f7cd4c7775ef454bcb2205b903789 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Tue, 27 Feb 2024 09:07:45 -0800 Subject: [PATCH 13/14] Fix comments --- streaming/base/util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/streaming/base/util.py b/streaming/base/util.py index 94f06e116..5e7a13baf 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -344,9 +344,9 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]] urls.append((url[0].rstrip('/').strip(), url[1].rstrip('/').strip())) cpu_count = max(psutil.cpu_count() - 2, 1) - n_processes = n_processes if (n_processes is not None and 1 <= n_processes <= cpu_count) else 1 + n_processes = n_processes if (1 <= n_processes <= cpu_count) else 1 - logger.warning(f'Got n_processes = {n_processes}. download and merge index in parallel') + logger.warning(f'Using n_processes = {n_processes} to download and merge index in parallel') # Prepare a temp folder to download index.json from remote if necessary. Removed in the end. with tempfile.TemporaryDirectory() as temp_root: From 6004e81c71aee8244be20bd472e3214b202a3a50 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Wed, 20 Mar 2024 10:07:54 -0700 Subject: [PATCH 14/14] Change default --- streaming/base/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/base/util.py b/streaming/base/util.py index 5e7a13baf..d6f29fc11 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -304,7 +304,7 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]] out: Union[str, Tuple[str, str]], keep_local: bool = True, download_timeout: int = 60, - n_processes: int = 8) -> None: + n_processes: int = 1) -> None: """Merge index.json from a list of index files of MDS directories to create joined index. Args: