Skip to content

Commit

Permalink
Applied code review suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
stefannica committed Dec 2, 2024
1 parent 75c0953 commit cbe5a80
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 54 deletions.
34 changes: 12 additions & 22 deletions src/zenml/integrations/aws/flavors/aws_image_builder_flavor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
AWS_RESOURCE_TYPE,
)
from zenml.models import ServiceConnectorRequirements
from zenml.utils.secret_utils import SecretField

if TYPE_CHECKING:
from zenml.integrations.aws.image_builders import AWSImageBuilder
Expand All @@ -34,33 +33,24 @@ class AWSImageBuilderConfig(BaseImageBuilderConfig):
Attributes:
code_build_project: The name of the AWS CodeBuild project to use to
build the image.
aws_access_key_id: The AWS access key ID to use to authenticate to AWS.
If not provided, the value from the default AWS config will be used.
aws_secret_access_key: The AWS secret access key to use to authenticate
to AWS. If not provided, the value from the default AWS config will
be used.
aws_auth_role_arn: The ARN of an intermediate IAM role to assume when
authenticating to AWS.
region: The AWS region where the processing job will be run. If not
provided, the value from the default AWS config will be used.
implicit_auth: Whether to use implicit authentication to authenticate
the AWS Code Build build to the container registry. If set to False,
the container registry credentials must be explicitly configured for
the container registry stack component or the container registry
stack component must be linked to a service connector.
NOTE: When implicit_auth is set to False, the container registry
credentials will be passed to the AWS Code Build build as
environment variables. This is not recommended for production use
unless your service connector is configured to generate short-lived
credentials.
implicit_container_registry_auth: Whether to use implicit authentication
to authenticate the AWS Code Build build to the container registry
when pushing container images. If set to False, the container
registry credentials must be explicitly configured for the container
registry stack component or the container registry stack component
must be linked to a service connector.
NOTE: When implicit_container_registry_auth is set to False, the
container registry credentials will be passed to the AWS Code Build
build as environment variables. This is not recommended for
production use unless your service connector is configured to
generate short-lived credentials.
"""

code_build_project: str
aws_access_key_id: Optional[str] = SecretField(default=None)
aws_secret_access_key: Optional[str] = SecretField(default=None)
aws_auth_role_arn: Optional[str] = None
region: Optional[str] = None
implicit_auth: bool = True
implicit_container_registry_auth: bool = True


class AWSImageBuilderFlavor(BaseImageBuilderFlavor):
Expand Down
20 changes: 2 additions & 18 deletions src/zenml/integrations/aws/image_builders/aws_image_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,27 +112,11 @@ def code_build_client(self) -> Any:
f"Expected to receive a `boto3.Session` object from the "
f"linked connector, but got type `{type(boto_session)}`."
)
# Option 2: Explicit or implicit configuration
# Option 2: Implicit configuration
else:
boto_session = boto3.Session(
aws_access_key_id=self.config.aws_access_key_id,
aws_secret_access_key=self.config.aws_secret_access_key,
region_name=self.config.region,
)
# If a role ARN is provided for authentication, assume the role
if self.config.aws_auth_role_arn:
sts = boto_session.client("sts")
response = sts.assume_role(
RoleArn=self.config.aws_auth_role_arn,
RoleSessionName="zenml-code-build-session",
)
credentials = response["Credentials"]
boto_session = boto3.Session(
aws_access_key_id=credentials["AccessKeyId"],
aws_secret_access_key=credentials["SecretAccessKey"],
aws_session_token=credentials["SessionToken"],
region_name=self.config.region,
)

self._code_build_client = boto_session.client("codebuild")
return self._code_build_client
Expand Down Expand Up @@ -186,7 +170,7 @@ def build(
# is disabled
environment_variables_override = []
pre_build_commands = []
if not self.config.implicit_auth:
if not self.config.implicit_container_registry_auth:
credentials = container_registry.credentials
if credentials:
environment_variables_override = [
Expand Down
12 changes: 3 additions & 9 deletions src/zenml/service_connectors/service_connector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,9 @@ def _raise_specific_cloud_exception_if_needed(
orchestrators: List[ResourcesInfo],
container_registries: List[ResourcesInfo],
) -> None:
AWS_DOCS = (
"https://docs.zenml.io/how-to/infrastructure-deployment/auth-management/aws-service-connector"
)
GCP_DOCS = (
"https://docs.zenml.io/how-to/infrastructure-deployment/auth-management/gcp-service-connector"
)
AZURE_DOCS = (
"https://docs.zenml.io/how-to/infrastructure-deployment/auth-management/azure-service-connector"
)
AWS_DOCS = "https://docs.zenml.io/how-to/infrastructure-deployment/auth-management/aws-service-connector"
GCP_DOCS = "https://docs.zenml.io/how-to/infrastructure-deployment/auth-management/gcp-service-connector"
AZURE_DOCS = "https://docs.zenml.io/how-to/infrastructure-deployment/auth-management/azure-service-connector"

if not artifact_stores:
error_msg = (
Expand Down
11 changes: 6 additions & 5 deletions src/zenml/utils/archivable.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import zipfile
from abc import ABC, abstractmethod
from pathlib import Path
from typing import IO, Any, Dict, Optional
from typing import IO, Any, Dict

from zenml.io import fileio
from zenml.utils.enum_utils import StrEnum
Expand Down Expand Up @@ -103,7 +103,7 @@ def write_archive(
"""
files = self.get_files()
extra_files = self.get_extra_files()
intermediate_fileobj: Optional[Any] = None
close_fileobj: bool = False
fileobj: Any = output_file

if archive_type == ArchiveType.ZIP:
Expand All @@ -118,7 +118,8 @@ def write_archive(
# file to be different each time. We use this hash to avoid
# duplicate uploads, which is why we pass empty values for filename
# and mtime here.
fileobj = intermediate_fileobj = GzipFile(
close_fileobj = True
fileobj = GzipFile(
filename="", mode="wb", fileobj=output_file, mtime=0.0
)
fileobj = tarfile.open(mode="w", fileobj=fileobj)
Expand Down Expand Up @@ -154,8 +155,8 @@ def write_archive(
info.size = len(contents_encoded)
af.addfile(info, io.BytesIO(contents_encoded))
finally:
if intermediate_fileobj:
intermediate_fileobj.close()
if close_fileobj:
fileobj.close()

output_file.seek(0)

Expand Down

0 comments on commit cbe5a80

Please sign in to comment.