Skip to content

Commit

Permalink
add default storage_options
Browse files Browse the repository at this point in the history
  • Loading branch information
deependujha committed Sep 6, 2024
1 parent 5ef4004 commit 67205ea
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions src/litdata/streaming/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,17 @@ def download_s3_file_via_s5cmd(remote_filepath: str, local_filepath: str) -> Non
"": "file",
}

_DEFAULT_STORAGE_OPTIONS = {
"s3":{"config_kwargs": {"retries":{"max_attempts": 1000, "mode": "adaptive"}}},
}

def get_complete_storage_options(cloud_provider: str, storage_options: Optional[Dict] = {}) -> Dict:
if storage_options is None:
storage_options = {}
if cloud_provider in _DEFAULT_STORAGE_OPTIONS:
return {**_DEFAULT_STORAGE_OPTIONS[cloud_provider], **storage_options}
return storage_options


class FsspecDownloader(Downloader):
def __init__(
Expand All @@ -255,6 +266,7 @@ def __init__(
):
remote_dir = remote_dir.replace("local:", "")
self.is_local = False
storage_options = get_complete_storage_options(cloud_provider, storage_options)
super().__init__(cloud_provider, remote_dir, cache_dir, chunks, storage_options)
self.cloud_provider = cloud_provider
self.use_s5cmd = cloud_provider == "s3" and os.system("s5cmd > /dev/null 2>&1") == 0
Expand All @@ -281,7 +293,7 @@ def does_file_exist(
) -> bool:
if cloud_provider is None:
cloud_provider = get_cloud_provider(remote_filepath)

storage_options = get_complete_storage_options(cloud_provider, storage_options)
fs = fsspec.filesystem(cloud_provider, **storage_options)
return fs.exists(remote_filepath)

Expand All @@ -295,7 +307,7 @@ def list_directory(
"""Returns a list of filenames in a remote directory."""
if cloud_provider is None:
cloud_provider = get_cloud_provider(remote_directory)

storage_options = get_complete_storage_options(cloud_provider, storage_options)
fs = fsspec.filesystem(cloud_provider, **storage_options)
return fs.ls(remote_directory, detail=detail) # just return the filenames

Expand All @@ -309,6 +321,7 @@ def download_file_or_directory(remote_filepath: str, local_filepath: str, storag
return
try:
with FileLock(local_filepath + ".lock", timeout=3):
storage_options = get_complete_storage_options(fs_cloud_provider, storage_options)
fs = fsspec.filesystem(fs_cloud_provider, **storage_options)
fs.get(remote_filepath, local_filepath, recursive=True)
except Timeout:
Expand All @@ -321,6 +334,7 @@ def upload_file_or_directory(local_filepath: str, remote_filepath: str, storage_
try:
with FileLock(local_filepath + ".lock", timeout=3):
fs_cloud_provider = get_cloud_provider(remote_filepath)
storage_options = get_complete_storage_options(fs_cloud_provider, storage_options)
fs = fsspec.filesystem(fs_cloud_provider, **storage_options)
fs.put(local_filepath, remote_filepath, recursive=True)
except Timeout:
Expand All @@ -333,13 +347,15 @@ def copy_file_or_directory(
) -> None:
"""Copy a file from src to target on the remote cloud storage."""
fs_cloud_provider = get_cloud_provider(remote_filepath_src)
storage_options = get_complete_storage_options(fs_cloud_provider, storage_options)
fs = fsspec.filesystem(fs_cloud_provider, **storage_options)
fs.copy(remote_filepath_src, remote_filepath_tg, recursive=True)


def remove_file_or_directory(remote_filepath: str, storage_options: Optional[Dict] = {}) -> None:
"""Remove a file from the remote cloud storage."""
fs_cloud_provider = get_cloud_provider(remote_filepath)
storage_options = get_complete_storage_options(fs_cloud_provider, storage_options)
fs = fsspec.filesystem(fs_cloud_provider, **storage_options)
fs.rm(remote_filepath, recursive=True)

Expand Down

0 comments on commit 67205ea

Please sign in to comment.