Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support downloading from azure blob storage #262

Merged
merged 6 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ pandas
lightning
lightning-cloud == 0.5.70 # Must be pinned to ensure compatibility
zstd
numpy < 2.0
1 change: 1 addition & 0 deletions src/litdata/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
_ZSTD_AVAILABLE = RequirementCache("zstd")
_CRYPTOGRAPHY_AVAILABLE = RequirementCache("cryptography")
_GOOGLE_STORAGE_AVAILABLE = RequirementCache("google.cloud.storage")
_AZURE_STORAGE_AVAILABLE = RequirementCache("azure.storage.blob")
_TQDM_AVAILABLE = RequirementCache("tqdm")
_LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk")

Expand Down
47 changes: 45 additions & 2 deletions src/litdata/streaming/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from filelock import FileLock, Timeout

from litdata.constants import _GOOGLE_STORAGE_AVAILABLE, _INDEX_FILENAME
from litdata.constants import _AZURE_STORAGE_AVAILABLE, _GOOGLE_STORAGE_AVAILABLE, _INDEX_FILENAME
from litdata.streaming.client import S3Client


Expand Down Expand Up @@ -129,6 +129,43 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
pass


class AzureDownloader(Downloader):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {}
):
if not _AZURE_STORAGE_AVAILABLE:
raise ModuleNotFoundError(str(_AZURE_STORAGE_AVAILABLE))

super().__init__(remote_dir, cache_dir, chunks, storage_options)

def download_file(self, remote_filepath: str, local_filepath: str) -> None:
from azure.storage.blob import BlobServiceClient

obj = parse.urlparse(remote_filepath)

if obj.scheme != "azure":
raise ValueError(
f"Expected obj.scheme to be `azure`, instead, got {obj.scheme} for remote={remote_filepath}"
)

if os.path.exists(local_filepath):
return

try:
with FileLock(local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0):
directories = obj.path.lstrip("/").split("/")
container_name = directories[0]
blob_name = os.path.join(*directories[1:])
service_client = BlobServiceClient.from_connection_string(self._storage_options["connection_string"])
blob_client = service_client.get_blob_client(container=container_name, blob=blob_name)
with open(local_filepath, "wb") as download_file:
download_file.write(blob_client.download_blob().readall())

except Timeout:
# another process is responsible to download that file, continue
pass


class LocalDownloader(Downloader):
def download_file(self, remote_filepath: str, local_filepath: str) -> None:
if not os.path.exists(remote_filepath):
Expand All @@ -144,7 +181,13 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
super().download_file(remote_filepath, local_filepath)


_DOWNLOADERS = {"s3://": S3Downloader, "gs://": GCPDownloader, "local:": LocalDownloaderWithCache, "": LocalDownloader}
_DOWNLOADERS = {
"s3://": S3Downloader,
"gs://": GCPDownloader,
"azure://": AzureDownloader,
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"local:": LocalDownloaderWithCache,
"": LocalDownloader,
}


def get_downloader_cls(
Expand Down
Loading