Skip to content

Commit

Permalink
Make asset requirements for remote storage optional (#167)
Browse files Browse the repository at this point in the history
* assets: make s3 optional, add error when it is not installed

* assets: make gcs optional, add error when it is not installed

* assets: make azure optional, add error when it is not installed

* lint: lint remote.py

* update docs with optional dependencies

* remove tf pinned version on optional requirements

* update requirements files
move optional dependencies to requirements-optional

* update cli and asset drivers for optional imports

Co-authored-by: Victor Benichoux <[email protected]>
Co-authored-by: Cyril Le Mat <[email protected]>
  • Loading branch information
3 people authored Sep 1, 2022
1 parent ef900c2 commit c0a6e5c
Show file tree
Hide file tree
Showing 15 changed files with 203 additions and 237 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,7 @@ Install with `pip`:
pip install modelkit
```

Optional dependencies are available for remote storage providers ([see documentation](https://cornerstone-ondemand.github.io/modelkit/assets/storage_provider/#using-different-providers))

## Community
Join our [community](https://discord.gg/ayj5wdAArV) on Discord to get support and leave feedback
13 changes: 8 additions & 5 deletions docs/assets/storage_provider.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@ Developers may additionally need to be able to push new assets and or update exi

## Using different providers

The flavor of the remote store that is used depends on the `MODELKIT_STORAGE_PROVIDER` environment variables
The flavor of the remote store that is used depends on optional dependencies used during pip install and on the `MODELKIT_STORAGE_PROVIDER` environment variable.

The default `pip install modelkit` will only allow you to target a local directory.


### Using AWS S3 storage

Use `MODELKIT_STORAGE_PROVIDER=s3` to connect to S3 storage.
Use `pip install modelkit[assets-s3]` and setup this environment variable `MODELKIT_STORAGE_PROVIDER=s3` to connect to S3 storage.

We use [boto3](https://boto3.amazonaws.com/v1/documentation/api/latest/index.html) under the hood.

Expand All @@ -53,7 +56,7 @@ Use `AWS_KMS_KEY_ID` environment variable to set your key and be able to upload

### GCS storage

Use `MODELKIT_STORAGE_PROVIDER=gcs` to connect to GCS storage.
Use `pip install modelkit[assets-gcs]` and setup this environment variable `MODELKIT_STORAGE_PROVIDER=gcs` to connect to GCS storage.

We use [google-cloud-storage](https://googleapis.dev/python/storage/latest/index.html).

Expand All @@ -67,7 +70,7 @@ If `GOOGLE_APPLICATION_CREDENTIALS` is provided, it should point to a local JSON

### Using Azure blob storage

Use `MODELKIT_STORAGE_PROVIDER=az` to connect to Azure blob storage.
Use `pip install modelkit[assets-az]` and setup this environment variable `MODELKIT_STORAGE_PROVIDER=az` to connect to Azure blob storage.

We use [azure-storage-blobl](https://docs.microsoft.com/en-us/azure/storage/blobs/storage-quickstart-blobs-python) under the hood.

Expand All @@ -80,7 +83,7 @@ The client is created by passing the authentication information to `BlobServiceC

### `local` mode

Use `MODELKIT_STORAGE_PROVIDER=local` to treat a local folder as a remote source.
Setup this environment variable `MODELKIT_STORAGE_PROVIDER=local` to treat a local folder as a remote source.

Assets will be downloaded from this folder to the configured asset dir.

Expand Down
32 changes: 29 additions & 3 deletions modelkit/assets/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,21 @@
from rich.table import Table
from rich.tree import Tree

from modelkit.assets.drivers.gcs import GCSStorageDriver
from modelkit.assets.drivers.s3 import S3StorageDriver
try:
from modelkit.assets.drivers.gcs import GCSStorageDriver

has_gcs = True
except ModuleNotFoundError:
has_gcs = False
try:
from modelkit.assets.drivers.s3 import S3StorageDriver

has_s3 = True
except ModuleNotFoundError:
has_s3 = False
from modelkit.assets.errors import ObjectDoesNotExistError
from modelkit.assets.manager import AssetsManager
from modelkit.assets.remote import StorageProvider
from modelkit.assets.remote import DriverNotInstalledError, StorageProvider
from modelkit.assets.settings import AssetSpec


Expand Down Expand Up @@ -121,8 +131,16 @@ def new_(asset_path, asset_spec, storage_prefix, dry_run):
if not os.path.exists(asset_path):
parsed_path = parse_remote_url(asset_path)
if parsed_path["storage_prefix"] == "gs":
if not has_gcs:
raise DriverNotInstalledError(
"GCS driver not installed, install modelkit[assets-gcs]"
)
driver = GCSStorageDriver(bucket=parsed_path["bucket_name"])
elif parsed_path["storage_prefix"] == "s3":
if not has_s3:
raise DriverNotInstalledError(
"S3 driver not installed, install modelkit[assets-s3]"
)
driver = S3StorageDriver(bucket=parsed_path["bucket_name"])
else:
raise ValueError(
Expand Down Expand Up @@ -212,8 +230,16 @@ def update_(asset_path, asset_spec, storage_prefix, bump_major, dry_run):
if not os.path.exists(asset_path):
parsed_path = parse_remote_url(asset_path)
if parsed_path["storage_prefix"] == "gs":
if not has_gcs:
raise DriverNotInstalledError(
"GCS driver not installed, install modelkit[assets-gcs]"
)
driver = GCSStorageDriver(bucket=parsed_path["bucket_name"])
elif parsed_path["storage_prefix"] == "s3":
if not has_s3:
raise DriverNotInstalledError(
"S3 driver not installed, install modelkit[assets-s3]"
)
driver = S3StorageDriver(bucket=parsed_path["bucket_name"])
else:
raise ValueError(
Expand Down
14 changes: 8 additions & 6 deletions modelkit/assets/drivers/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@

from modelkit.assets import errors
from modelkit.assets.drivers.abc import StorageDriver
from modelkit.assets.drivers.retry import RETRY_POLICY
from modelkit.assets.drivers.retry import retry_policy

logger = get_logger(__name__)

AZURE_RETRY_POLICY = retry_policy()


class AzureStorageDriver(StorageDriver):
bucket: str
Expand All @@ -34,13 +36,13 @@ def __init__(
os.environ["AZURE_STORAGE_CONNECTION_STRING"]
)

@retry(**RETRY_POLICY)
@retry(**AZURE_RETRY_POLICY)
def iterate_objects(self, prefix=None):
container = self.client.get_container_client(self.bucket)
for blob in container.list_blobs(prefix=prefix):
yield blob["name"]

@retry(**RETRY_POLICY)
@retry(**AZURE_RETRY_POLICY)
def upload_object(self, file_path, object_name):
blob_client = self.client.get_blob_client(
container=self.bucket, blob=object_name
Expand All @@ -50,7 +52,7 @@ def upload_object(self, file_path, object_name):
with open(file_path, "rb") as f:
blob_client.upload_blob(f)

@retry(**RETRY_POLICY)
@retry(**AZURE_RETRY_POLICY)
def download_object(self, object_name, destination_path):
blob_client = self.client.get_blob_client(
container=self.bucket, blob=object_name
Expand All @@ -67,14 +69,14 @@ def download_object(self, object_name, destination_path):
with open(destination_path, "wb") as f:
f.write(blob_client.download_blob().readall())

@retry(**RETRY_POLICY)
@retry(**AZURE_RETRY_POLICY)
def delete_object(self, object_name):
blob_client = self.client.get_blob_client(
container=self.bucket, blob=object_name
)
blob_client.delete_blob()

@retry(**RETRY_POLICY)
@retry(**AZURE_RETRY_POLICY)
def exists(self, object_name):
blob_client = self.client.get_blob_client(
container=self.bucket, blob=object_name
Expand Down
16 changes: 9 additions & 7 deletions modelkit/assets/drivers/gcs.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import os
from typing import Optional

from google.api_core.exceptions import NotFound
from google.api_core.exceptions import GoogleAPIError, NotFound
from google.cloud import storage
from google.cloud.storage import Client
from structlog import get_logger
from tenacity import retry

from modelkit.assets import errors
from modelkit.assets.drivers.abc import StorageDriver
from modelkit.assets.drivers.retry import RETRY_POLICY
from modelkit.assets.drivers.retry import retry_policy

logger = get_logger(__name__)

GCS_RETRY_POLICY = retry_policy(GoogleAPIError)


class GCSStorageDriver(StorageDriver):
bucket: str
Expand All @@ -35,13 +37,13 @@ def __init__(
else:
self.client = Client()

@retry(**RETRY_POLICY)
@retry(**GCS_RETRY_POLICY)
def iterate_objects(self, prefix=None):
bucket = self.client.bucket(self.bucket)
for blob in bucket.list_blobs(prefix=prefix):
yield blob.name

@retry(**RETRY_POLICY)
@retry(**GCS_RETRY_POLICY)
def upload_object(self, file_path, object_name):
bucket = self.client.bucket(self.bucket)
blob = bucket.blob(object_name)
Expand All @@ -50,7 +52,7 @@ def upload_object(self, file_path, object_name):
with open(file_path, "rb") as f:
blob.upload_from_file(f)

@retry(**RETRY_POLICY)
@retry(**GCS_RETRY_POLICY)
def download_object(self, object_name, destination_path):
bucket = self.client.bucket(self.bucket)
blob = bucket.blob(object_name)
Expand All @@ -66,13 +68,13 @@ def download_object(self, object_name, destination_path):
driver=self, bucket=self.bucket, object_name=object_name
)

@retry(**RETRY_POLICY)
@retry(**GCS_RETRY_POLICY)
def delete_object(self, object_name):
bucket = self.client.bucket(self.bucket)
blob = bucket.blob(object_name)
blob.delete()

@retry(**RETRY_POLICY)
@retry(**GCS_RETRY_POLICY)
def exists(self, object_name):
bucket = self.client.bucket(self.bucket)
blob = bucket.blob(object_name)
Expand Down
37 changes: 20 additions & 17 deletions modelkit/assets/drivers/retry.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,10 @@
import botocore
import google
import requests
from structlog import get_logger
from tenacity import retry_if_exception, stop_after_attempt, wait_random_exponential

logger = get_logger(__name__)


def retriable_error(exception):
return (
isinstance(exception, botocore.exceptions.ClientError)
or isinstance(exception, google.api_core.exceptions.GoogleAPIError)
or isinstance(exception, requests.exceptions.ChunkedEncodingError)
)


def log_after_retry(retry_state):
logger.info(
"Retrying",
Expand All @@ -24,10 +14,23 @@ def log_after_retry(retry_state):
)


RETRY_POLICY = {
"wait": wait_random_exponential(multiplier=1, min=4, max=10),
"stop": stop_after_attempt(5),
"retry": retry_if_exception(retriable_error),
"after": log_after_retry,
"reraise": True,
}
def retry_policy(type_error=None):
if not type_error:

def is_retry_eligible(error):
return isinstance(error, requests.exceptions.ChunkedEncodingError)

else:

def is_retry_eligible(error):
return isinstance(error, type_error) or isinstance(
error, requests.exceptions.ChunkedEncodingError
)

return {
"wait": wait_random_exponential(multiplier=1, min=4, max=10),
"stop": stop_after_attempt(5),
"retry": retry_if_exception(is_retry_eligible),
"after": log_after_retry,
"reraise": True,
}
14 changes: 8 additions & 6 deletions modelkit/assets/drivers/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@

from modelkit.assets import errors
from modelkit.assets.drivers.abc import StorageDriver
from modelkit.assets.drivers.retry import RETRY_POLICY
from modelkit.assets.drivers.retry import retry_policy

logger = get_logger(__name__)

S3_RETRY_POLICY = retry_policy(botocore.exceptions.ClientError)


class S3StorageDriver(StorageDriver):
bucket: str
Expand Down Expand Up @@ -44,15 +46,15 @@ def __init__(
region_name=aws_default_region or os.environ.get("AWS_DEFAULT_REGION"),
)

@retry(**RETRY_POLICY)
@retry(**S3_RETRY_POLICY)
def iterate_objects(self, prefix=None):
paginator = self.client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=self.bucket, Prefix=prefix or "")
for page in pages:
for obj in page.get("Contents", []):
yield obj["Key"]

@retry(**RETRY_POLICY)
@retry(**S3_RETRY_POLICY)
def upload_object(self, file_path, object_name):
if self.aws_kms_key_id:
self.client.upload_file( # pragma: no cover
Expand All @@ -67,7 +69,7 @@ def upload_object(self, file_path, object_name):
else:
self.client.upload_file(file_path, self.bucket, object_name)

@retry(**RETRY_POLICY)
@retry(**S3_RETRY_POLICY)
def download_object(self, object_name, destination_path):
try:
with open(destination_path, "wb") as f:
Expand All @@ -81,11 +83,11 @@ def download_object(self, object_name, destination_path):
driver=self, bucket=self.bucket, object_name=object_name
)

@retry(**RETRY_POLICY)
@retry(**S3_RETRY_POLICY)
def delete_object(self, object_name):
self.client.delete_object(Bucket=self.bucket, Key=object_name)

@retry(**RETRY_POLICY)
@retry(**S3_RETRY_POLICY)
def exists(self, object_name):
try:
self.client.head_object(Bucket=self.bucket, Key=object_name)
Expand Down
Loading

0 comments on commit c0a6e5c

Please sign in to comment.