Skip to content

Commit

Permalink
Merge pull request #98 from clustree/s3_driver-sse-kms
Browse files Browse the repository at this point in the history
s3_driver: manage s3 with sse kms encryption
  • Loading branch information
victorbenichoux authored Sep 24, 2021
2 parents 76f56fa + 3ebfc76 commit 56c06c7
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
2 changes: 2 additions & 0 deletions docs/assets/storage_provider.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ The authentication information here is passed to the `boto3.client` object:

Typically, if you use AWS: having `AWS_DEFAULT_PROFILE`, `AWS_DEFAULT_REGION` and valid credentials in `~/.aws` is enough.

S3 storage driver is compatible with KMS encrypted s3 volumes.
Use `AWS_KMS_KEY_ID` environment variable to set your key and be able to upload files to such volume.

### GCS storage

Expand Down
15 changes: 14 additions & 1 deletion modelkit/assets/drivers/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@ def __init__(
aws_secret_access_key: Optional[str] = None,
aws_default_region: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_kms_key_id: Optional[str] = None,
s3_endpoint: Optional[str] = None,
):

self.bucket = bucket or os.environ.get("MODELKIT_STORAGE_BUCKET") or ""
if not self.bucket:
raise ValueError("Bucket needs to be set for S3 storage driver")
self.endpoint_url = s3_endpoint or os.environ.get("S3_ENDPOINT")
self.aws_kms_key_id = aws_kms_key_id or os.environ.get("AWS_KMS_KEY_ID")
self.client = boto3.client(
"s3",
endpoint_url=self.endpoint_url,
Expand All @@ -52,7 +54,18 @@ def iterate_objects(self, prefix=None):

@retry(**RETRY_POLICY)
def upload_object(self, file_path, object_name):
self.client.upload_file(file_path, self.bucket, object_name)
if self.aws_kms_key_id:
self.client.upload_file( # pragma: no cover
file_path,
self.bucket,
object_name,
ExtraArgs={
"ServerSideEncryption": "aws:kms",
"SSEKMSKeyId": self.aws_kms_key_id,
},
)
else:
self.client.upload_file(file_path, self.bucket, object_name)

@retry(**RETRY_POLICY)
def download_object(self, object_name, destination_path):
Expand Down

0 comments on commit 56c06c7

Please sign in to comment.