Skip to content

Commit

Permalink
Change to dictionary
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Oct 23, 2024
1 parent 10db562 commit c719226
Showing 1 changed file with 18 additions and 30 deletions.
48 changes: 18 additions & 30 deletions composer/utils/file_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,34 +393,22 @@ def parse_uri(uri: str) -> tuple[str, str, str]:
return backend, bucket_name, path.lstrip('/')


# Registry for object store creation functions
object_store_registry: dict[str, Callable[[str, str], ObjectStore]] = {}


def register_object_store(backend: str, factory_func: Callable[[str, str], ObjectStore]):
"""Registers a new object store backend to the registry.
Args:
backend (str): The backend name (e.g., 's3', 'oci').
factory_func (Callable): A function that accepts bucket_name and path and returns an ObjectStore instance.
"""
object_store_registry[backend] = factory_func


# Register default object stores
register_object_store('s3', lambda bucket, path: S3ObjectStore(bucket=bucket))
register_object_store('gs', lambda bucket, path: GCSObjectStore(bucket=bucket))
register_object_store('oci', lambda bucket, path: OCIObjectStore(bucket=bucket))
register_object_store(
'azure',
lambda bucket,
path: LibcloudObjectStore(
provider='AZURE_BLOBS',
container=bucket,
key_environ='AZURE_ACCOUNT_NAME',
secret_environ='AZURE_ACCOUNT_ACCESS_KEY',
),
)
# Dictionary mapping backend names to ObjectStore factory functions
BACKEND_TO_OBJECT_STORE_FACTORY: dict[str, Callable[[str, str], ObjectStore]] = {
's3':
lambda bucket, path: S3ObjectStore(bucket=bucket),
'gs':
lambda bucket, path: GCSObjectStore(bucket=bucket),
'oci':
lambda bucket, path: OCIObjectStore(bucket=bucket),
'azure':
lambda bucket, path: LibcloudObjectStore(
provider='AZURE_BLOBS',
container=bucket,
key_environ='AZURE_ACCOUNT_NAME',
secret_environ='AZURE_ACCOUNT_ACCESS_KEY',
),
}


def maybe_create_object_store_from_uri(uri: str) -> Optional[ObjectStore]:
Expand All @@ -442,8 +430,8 @@ def maybe_create_object_store_from_uri(uri: str) -> Optional[ObjectStore]:
return None

# Check if backend is registered
if backend in object_store_registry:
return object_store_registry[backend](bucket_name, path)
if backend in BACKEND_TO_OBJECT_STORE_FACTORY:
return BACKEND_TO_OBJECT_STORE_FACTORY[backend](bucket_name, path)

# Handle special cases like WandB, MLFlow, etc.
if backend == 'wandb':
Expand Down

0 comments on commit c719226

Please sign in to comment.