diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index 8a4591df..046361ba 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -243,6 +243,17 @@ def download_s3_file_via_s5cmd(remote_filepath: str, local_filepath: str) -> Non "": "file", } +_DEFAULT_STORAGE_OPTIONS = { + "s3":{"config_kwargs": {"retries":{"max_attempts": 1000, "mode": "adaptive"}}}, +} + +def get_complete_storage_options(cloud_provider: str, storage_options: Optional[Dict] = {}) -> Dict: + if storage_options is None: + storage_options = {} + if cloud_provider in _DEFAULT_STORAGE_OPTIONS: + return {**_DEFAULT_STORAGE_OPTIONS[cloud_provider], **storage_options} + return storage_options + class FsspecDownloader(Downloader): def __init__( @@ -255,6 +266,7 @@ def __init__( ): remote_dir = remote_dir.replace("local:", "") self.is_local = False + storage_options = get_complete_storage_options(cloud_provider, storage_options) super().__init__(cloud_provider, remote_dir, cache_dir, chunks, storage_options) self.cloud_provider = cloud_provider self.use_s5cmd = cloud_provider == "s3" and os.system("s5cmd > /dev/null 2>&1") == 0 @@ -281,7 +293,7 @@ def does_file_exist( ) -> bool: if cloud_provider is None: cloud_provider = get_cloud_provider(remote_filepath) - + storage_options = get_complete_storage_options(cloud_provider, storage_options) fs = fsspec.filesystem(cloud_provider, **storage_options) return fs.exists(remote_filepath) @@ -295,7 +307,7 @@ def list_directory( """Returns a list of filenames in a remote directory.""" if cloud_provider is None: cloud_provider = get_cloud_provider(remote_directory) - + storage_options = get_complete_storage_options(cloud_provider, storage_options) fs = fsspec.filesystem(cloud_provider, **storage_options) return fs.ls(remote_directory, detail=detail) # just return the filenames @@ -309,6 +321,7 @@ def download_file_or_directory(remote_filepath: str, local_filepath: str, storag return try: with FileLock(local_filepath + ".lock", timeout=3): + storage_options = get_complete_storage_options(fs_cloud_provider, storage_options) fs = fsspec.filesystem(fs_cloud_provider, **storage_options) fs.get(remote_filepath, local_filepath, recursive=True) except Timeout: @@ -321,6 +334,7 @@ def upload_file_or_directory(local_filepath: str, remote_filepath: str, storage_ try: with FileLock(local_filepath + ".lock", timeout=3): fs_cloud_provider = get_cloud_provider(remote_filepath) + storage_options = get_complete_storage_options(fs_cloud_provider, storage_options) fs = fsspec.filesystem(fs_cloud_provider, **storage_options) fs.put(local_filepath, remote_filepath, recursive=True) except Timeout: @@ -333,6 +347,7 @@ def copy_file_or_directory( ) -> None: """Copy a file from src to target on the remote cloud storage.""" fs_cloud_provider = get_cloud_provider(remote_filepath_src) + storage_options = get_complete_storage_options(fs_cloud_provider, storage_options) fs = fsspec.filesystem(fs_cloud_provider, **storage_options) fs.copy(remote_filepath_src, remote_filepath_tg, recursive=True) @@ -340,6 +355,7 @@ def copy_file_or_directory( def remove_file_or_directory(remote_filepath: str, storage_options: Optional[Dict] = {}) -> None: """Remove a file from the remote cloud storage.""" fs_cloud_provider = get_cloud_provider(remote_filepath) + storage_options = get_complete_storage_options(fs_cloud_provider, storage_options) fs = fsspec.filesystem(fs_cloud_provider, **storage_options) fs.rm(remote_filepath, recursive=True)