Skip to content

Commit

Permalink
Refactor maybe_create_object_store_from_uri (#3679)
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea authored Oct 24, 2024
1 parent 5aaa8c9 commit af5dea4
Showing 1 changed file with 35 additions and 23 deletions.
58 changes: 35 additions & 23 deletions composer/utils/file_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import tempfile
import uuid
import warnings
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from urllib.parse import urlparse

import requests
Expand Down Expand Up @@ -393,13 +393,31 @@ def parse_uri(uri: str) -> tuple[str, str, str]:
return backend, bucket_name, path.lstrip('/')


# Dictionary mapping backend names to ObjectStore factory functions
BACKEND_TO_OBJECT_STORE_FACTORY: dict[str, Callable[[str, str], ObjectStore]] = {
's3':
lambda bucket, path: S3ObjectStore(bucket=bucket),
'gs':
lambda bucket, path: GCSObjectStore(bucket=bucket),
'oci':
lambda bucket, path: OCIObjectStore(bucket=bucket),
'azure':
lambda bucket, path: LibcloudObjectStore(
provider='AZURE_BLOBS',
container=bucket,
key_environ='AZURE_ACCOUNT_NAME',
secret_environ='AZURE_ACCOUNT_ACCESS_KEY',
),
}


def maybe_create_object_store_from_uri(uri: str) -> Optional[ObjectStore]:
"""Automatically creates an :class:`composer.utils.ObjectStore` from supported URI formats.
Currently supported backends are ``s3://``, ``oci://``, and local paths (in which case ``None`` will be returned)
Args:
uri (str): The path to (maybe) create an :class:`composer.utils.ObjectStore` from
uri (str): The path to (maybe) create an :class:`composer.utils.ObjectStore` from.
Raises:
NotImplementedError: Raises when the URI format is not supported.
Expand All @@ -408,25 +426,15 @@ def maybe_create_object_store_from_uri(uri: str) -> Optional[ObjectStore]:
Optional[ObjectStore]: Returns an :class:`composer.utils.ObjectStore` if the URI is of a supported format, otherwise None
"""
backend, bucket_name, path = parse_uri(uri)

# If backend is empty, assume local path and return None
if backend == '':
return None
if backend == 's3':
return S3ObjectStore(bucket=bucket_name)

# Handle special cases like WandB, MLFlow, etc.
elif backend == 'wandb':
raise NotImplementedError(
f'There is no implementation for WandB load_object_store via URI. Please use '
'WandBLogger',
)
elif backend == 'gs':
return GCSObjectStore(bucket=bucket_name)
elif backend == 'oci':
return OCIObjectStore(bucket=bucket_name)
elif backend == 'azure':
return LibcloudObjectStore(
provider='AZURE_BLOBS',
container=bucket_name,
key_environ='AZURE_ACCOUNT_NAME',
secret_environ='AZURE_ACCOUNT_ACCESS_KEY',
f'There is no implementation for WandB load_object_store via URI. Please use WandBLogger',
)
elif backend == 'dbfs':
if path.startswith(MLFLOW_DBFS_PATH_PREFIX):
Expand All @@ -445,17 +453,21 @@ def maybe_create_object_store_from_uri(uri: str) -> Optional[ObjectStore]:
# Create the object store for all other ranks
if dist.get_global_rank() != 0:
store = MLFlowObjectStore(path)

return store
else:
# validate if the path conforms to the requirements for UC volume paths
UCObjectStore.validate_path(path)
return UCObjectStore(path=path)
else:
raise NotImplementedError(
f'There is no implementation for the cloud backend {backend} via URI. Please use '
'one of the supported object stores',
)

# Check if backend is registered
elif backend in BACKEND_TO_OBJECT_STORE_FACTORY:
return BACKEND_TO_OBJECT_STORE_FACTORY[backend](bucket_name, path)

# If backend is unknown, raise NotImplementedError
raise NotImplementedError(
f'There is no implementation for the cloud backend {backend} via URI. Please use '
'one of the supported object stores',
)


def maybe_create_remote_uploader_downloader_from_uri(
Expand Down

0 comments on commit af5dea4

Please sign in to comment.