From 5a77dbe4c18ebe4bf30be5224ad562c28e736b55 Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Fri, 2 Aug 2024 12:31:56 +0200 Subject: [PATCH] AWS Image Builder implementation --- .../image_builders/base_image_builder.py | 7 +- src/zenml/image_builders/build_context.py | 23 +- src/zenml/integrations/aws/__init__.py | 3 + .../integrations/aws/flavors/__init__.py | 6 + .../aws/flavors/aws_image_builder_flavor.py | 141 ++++++++ .../aws/image_builders/__init__.py | 20 ++ .../aws/image_builders/aws_image_builder.py | 313 ++++++++++++++++++ .../image_builders/kaniko_image_builder.py | 3 +- src/zenml/new/pipelines/code_archive.py | 157 --------- src/zenml/utils/archivable.py | 101 ++++-- src/zenml/utils/code_utils.py | 12 +- 11 files changed, 570 insertions(+), 216 deletions(-) create mode 100644 src/zenml/integrations/aws/flavors/aws_image_builder_flavor.py create mode 100644 src/zenml/integrations/aws/image_builders/__init__.py create mode 100644 src/zenml/integrations/aws/image_builders/aws_image_builder.py delete mode 100644 src/zenml/new/pipelines/code_archive.py diff --git a/src/zenml/image_builders/base_image_builder.py b/src/zenml/image_builders/base_image_builder.py index b99bb277ec8..4ad38cd8652 100644 --- a/src/zenml/image_builders/base_image_builder.py +++ b/src/zenml/image_builders/base_image_builder.py @@ -25,6 +25,7 @@ from zenml.logger import get_logger from zenml.stack import Flavor, StackComponent from zenml.stack.stack_component import StackComponentConfig +from zenml.utils.archivable import ArchiveType if TYPE_CHECKING: from zenml.container_registries import BaseContainerRegistry @@ -100,6 +101,7 @@ def build( def _upload_build_context( build_context: "BuildContext", parent_path_directory_name: str, + archive_type: ArchiveType = ArchiveType.TAR_GZ, ) -> str: """Uploads a Docker image build context to a remote location. @@ -109,6 +111,7 @@ def _upload_build_context( the build context to. It will be appended to the artifact store path to create the parent path where the build context will be uploaded to. + archive_type: The type of archive to create. Returns: The path to the uploaded build context. @@ -119,7 +122,7 @@ def _upload_build_context( hash_ = hashlib.sha1() # nosec with tempfile.NamedTemporaryFile(mode="w+b", delete=False) as f: - build_context.write_archive(f, use_gzip=True) + build_context.write_archive(f, archive_type) while True: data = f.read(64 * 1024) @@ -127,7 +130,7 @@ def _upload_build_context( break hash_.update(data) - filename = f"{hash_.hexdigest()}.tar.gz" + filename = f"{hash_.hexdigest()}.{archive_type.value}" filepath = f"{parent_path}/{filename}" if not fileio.exists(filepath): logger.info("Uploading build context to `%s`.", filepath) diff --git a/src/zenml/image_builders/build_context.py b/src/zenml/image_builders/build_context.py index e8284cfb446..610348ef1a1 100644 --- a/src/zenml/image_builders/build_context.py +++ b/src/zenml/image_builders/build_context.py @@ -20,7 +20,7 @@ from zenml.io import fileio from zenml.logger import get_logger from zenml.utils import io_utils, string_utils -from zenml.utils.archivable import Archivable +from zenml.utils.archivable import Archivable, ArchiveType logger = get_logger(__name__) @@ -69,28 +69,19 @@ def dockerignore_file(self) -> Optional[str]: return None def write_archive( - self, output_file: IO[bytes], use_gzip: bool = True + self, + output_file: IO[bytes], + archive_type: ArchiveType = ArchiveType.TAR_GZ, ) -> None: """Writes an archive of the build context to the given file. Args: output_file: The file to write the archive to. - use_gzip: Whether to use `gzip` to compress the file. + archive_type: The type of archive to create. """ - from docker.utils import build as docker_build_utils - - files = self.get_files() - extra_files = self.get_extra_files() - - context_archive = docker_build_utils.create_archive( - fileobj=output_file, - root=self._root, - files=sorted(files.keys()), - gzip=use_gzip, - extra_files=list(extra_files.items()), - ) + super().write_archive(output_file, archive_type) - build_context_size = os.path.getsize(context_archive.name) + build_context_size = os.path.getsize(output_file.name) if ( self._root and build_context_size > 50 * 1024 * 1024 diff --git a/src/zenml/integrations/aws/__init__.py b/src/zenml/integrations/aws/__init__.py index 206297f542c..4dd3db2e75c 100644 --- a/src/zenml/integrations/aws/__init__.py +++ b/src/zenml/integrations/aws/__init__.py @@ -33,6 +33,7 @@ AWS_CONNECTOR_TYPE = "aws" AWS_RESOURCE_TYPE = "aws-generic" S3_RESOURCE_TYPE = "s3-bucket" +AWS_IMAGE_BUILDER_FLAVOR = "aws" class AWSIntegration(Integration): """Definition of AWS integration for ZenML.""" @@ -59,12 +60,14 @@ def flavors(cls) -> List[Type[Flavor]]: """ from zenml.integrations.aws.flavors import ( AWSContainerRegistryFlavor, + AWSImageBuilderFlavor, SagemakerOrchestratorFlavor, SagemakerStepOperatorFlavor, ) return [ AWSContainerRegistryFlavor, + AWSImageBuilderFlavor, SagemakerStepOperatorFlavor, SagemakerOrchestratorFlavor, ] diff --git a/src/zenml/integrations/aws/flavors/__init__.py b/src/zenml/integrations/aws/flavors/__init__.py index 0e674dd9b5d..a2cdc428add 100644 --- a/src/zenml/integrations/aws/flavors/__init__.py +++ b/src/zenml/integrations/aws/flavors/__init__.py @@ -17,6 +17,10 @@ AWSContainerRegistryConfig, AWSContainerRegistryFlavor, ) +from zenml.integrations.aws.flavors.aws_image_builder_flavor import ( + AWSImageBuilderConfig, + AWSImageBuilderFlavor, +) from zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor import ( SagemakerOrchestratorConfig, SagemakerOrchestratorFlavor, @@ -29,6 +33,8 @@ __all__ = [ "AWSContainerRegistryFlavor", "AWSContainerRegistryConfig", + "AWSImageBuilderConfig", + "AWSImageBuilderFlavor", "SagemakerStepOperatorFlavor", "SagemakerStepOperatorConfig", "SagemakerOrchestratorFlavor", diff --git a/src/zenml/integrations/aws/flavors/aws_image_builder_flavor.py b/src/zenml/integrations/aws/flavors/aws_image_builder_flavor.py new file mode 100644 index 00000000000..f2d7e413955 --- /dev/null +++ b/src/zenml/integrations/aws/flavors/aws_image_builder_flavor.py @@ -0,0 +1,141 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""AWS Code Build image builder flavor.""" + +from typing import TYPE_CHECKING, Optional, Type + +from zenml.image_builders import BaseImageBuilderConfig, BaseImageBuilderFlavor +from zenml.integrations.aws import ( + AWS_CONNECTOR_TYPE, + AWS_IMAGE_BUILDER_FLAVOR, + AWS_RESOURCE_TYPE, +) +from zenml.models import ServiceConnectorRequirements +from zenml.utils.secret_utils import SecretField + +if TYPE_CHECKING: + from zenml.integrations.aws.image_builders import AWSImageBuilder + + +class AWSImageBuilderConfig(BaseImageBuilderConfig): + """AWS Code Build image builder configuration. + + Attributes: + code_build_project: The name of the AWS CodeBuild project to use to + build the image. + aws_access_key_id: The AWS access key ID to use to authenticate to AWS. + If not provided, the value from the default AWS config will be used. + aws_secret_access_key: The AWS secret access key to use to authenticate + to AWS. If not provided, the value from the default AWS config will + be used. + aws_auth_role_arn: The ARN of an intermediate IAM role to assume when + authenticating to AWS. + region: The AWS region where the processing job will be run. If not + provided, the value from the default AWS config will be used. + implicit_auth: Whether to use implicit authentication to authenticate + the AWS Code Build build to the container registry. If set to False, + the container registry credentials must be explicitly configured for + the container registry stack component or the container registry + stack component must be linked to a service connector. + NOTE: When implicit_auth is set to False, the container registry + credentials will be passed to the AWS Code Build build as + environment variables. This is not recommended for production use + unless your service connector is configured to generate short-lived + credentials. + """ + + code_build_project: str + aws_access_key_id: Optional[str] = SecretField(default=None) + aws_secret_access_key: Optional[str] = SecretField(default=None) + aws_auth_role_arn: Optional[str] = None + region: Optional[str] = None + implicit_auth: bool = True + + +class AWSImageBuilderFlavor(BaseImageBuilderFlavor): + """AWS Code Build image builder flavor.""" + + @property + def name(self) -> str: + """The flavor name. + + Returns: + The name of the flavor. + """ + return AWS_IMAGE_BUILDER_FLAVOR + + @property + def service_connector_requirements( + self, + ) -> Optional[ServiceConnectorRequirements]: + """Service connector resource requirements for service connectors. + + Specifies resource requirements that are used to filter the available + service connector types that are compatible with this flavor. + + Returns: + Requirements for compatible service connectors, if a service + connector is required for this flavor. + """ + return ServiceConnectorRequirements( + connector_type=AWS_CONNECTOR_TYPE, + resource_type=AWS_RESOURCE_TYPE, + ) + + @property + def docs_url(self) -> Optional[str]: + """A url to point at docs explaining this flavor. + + Returns: + A flavor docs url. + """ + return self.generate_default_docs_url() + + @property + def sdk_docs_url(self) -> Optional[str]: + """A url to point at SDK docs explaining this flavor. + + Returns: + A flavor SDK docs url. + """ + return self.generate_default_sdk_docs_url() + + @property + def logo_url(self) -> str: + """A url to represent the flavor in the dashboard. + + Returns: + The flavor logo. + """ + return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/image_builder/aws.png" + + @property + def config_class(self) -> Type[BaseImageBuilderConfig]: + """The config class. + + Returns: + The config class. + """ + return AWSImageBuilderConfig + + @property + def implementation_class(self) -> Type["AWSImageBuilder"]: + """Implementation class. + + Returns: + The implementation class. + """ + from zenml.integrations.aws.image_builders import AWSImageBuilder + + return AWSImageBuilder diff --git a/src/zenml/integrations/aws/image_builders/__init__.py b/src/zenml/integrations/aws/image_builders/__init__.py new file mode 100644 index 00000000000..667ae28e50a --- /dev/null +++ b/src/zenml/integrations/aws/image_builders/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Initialization for the AWS image builder.""" + +from zenml.integrations.aws.image_builders.aws_image_builder import ( + AWSImageBuilder, +) + +__all__ = ["AWSImageBuilder"] diff --git a/src/zenml/integrations/aws/image_builders/aws_image_builder.py b/src/zenml/integrations/aws/image_builders/aws_image_builder.py new file mode 100644 index 00000000000..c9d4287482e --- /dev/null +++ b/src/zenml/integrations/aws/image_builders/aws_image_builder.py @@ -0,0 +1,313 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""AWS Code Build image builder implementation.""" + +import time +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, cast +from urllib.parse import urlparse +from uuid import uuid4 + +import boto3 + +from zenml.enums import StackComponentType +from zenml.image_builders import BaseImageBuilder +from zenml.integrations.aws import ( + AWS_CONTAINER_REGISTRY_FLAVOR, +) +from zenml.integrations.aws.flavors import AWSImageBuilderConfig +from zenml.logger import get_logger +from zenml.stack import StackValidator +from zenml.utils.archivable import ArchiveType + +if TYPE_CHECKING: + from zenml.container_registries import BaseContainerRegistry + from zenml.image_builders import BuildContext + from zenml.stack import Stack + +logger = get_logger(__name__) + + +class AWSImageBuilder(BaseImageBuilder): + """AWS Code Build image builder implementation.""" + + _code_build_client: Optional[Any] = None + + @property + def config(self) -> AWSImageBuilderConfig: + """The stack component configuration. + + Returns: + The configuration. + """ + return cast(AWSImageBuilderConfig, self._config) + + @property + def is_building_locally(self) -> bool: + """Whether the image builder builds the images on the client machine. + + Returns: + True if the image builder builds locally, False otherwise. + """ + return False + + @property + def validator(self) -> Optional["StackValidator"]: + """Validates the stack for the AWS Code Build Image Builder. + + The AWS Code Build Image Builder requires a container registry to + push the image to and an S3 Artifact Store to upload the build context, + so AWS Code Build can access it. + + Returns: + Stack validator. + """ + + def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]: + if stack.artifact_store.flavor != "s3": + return False, ( + "The AWS Image Builder requires an S3 Artifact Store to " + "upload the build context, so AWS Code Build can access it." + "Please update your stack to include an S3 Artifact Store " + "and try again." + ) + + return True, "" + + return StackValidator( + required_components={StackComponentType.CONTAINER_REGISTRY}, + custom_validation_function=_validate_remote_components, + ) + + @property + def code_build_client(self) -> Any: + """The authenticated AWS Code Build client to use for interacting with AWS services. + + Returns: + The authenticated AWS Code Build client. + """ + if ( + self._code_build_client is not None + and self.connector_has_expired() + ): + self._code_build_client = None + if self._code_build_client is not None: + return self._code_build_client + + # Option 1: Service connector + if connector := self.get_connector(): + boto_session = connector.connect() + if not isinstance(boto_session, boto3.Session): + raise RuntimeError( + f"Expected to receive a `boto3.Session` object from the " + f"linked connector, but got type `{type(boto_session)}`." + ) + # Option 2: Explicit or implicit configuration + else: + boto_session = boto3.Session( + aws_access_key_id=self.config.aws_access_key_id, + aws_secret_access_key=self.config.aws_secret_access_key, + region_name=self.config.region, + ) + # If a role ARN is provided for authentication, assume the role + if self.config.aws_auth_role_arn: + sts = boto_session.client("sts") + response = sts.assume_role( + RoleArn=self.config.aws_auth_role_arn, + RoleSessionName="zenml-code-build-session", + ) + credentials = response["Credentials"] + boto_session = boto3.Session( + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + region_name=self.config.region, + ) + + self._code_build_client = boto_session.client("codebuild") + return self._code_build_client + + def build( + self, + image_name: str, + build_context: "BuildContext", + docker_build_options: Dict[str, Any], + container_registry: Optional["BaseContainerRegistry"] = None, + ) -> str: + """Builds and pushes a Docker image. + + Args: + image_name: Name of the image to build and push. + build_context: The build context to use for the image. + docker_build_options: Docker build options. + container_registry: Optional container registry to push to. + + Returns: + The Docker image name with digest. + + Raises: + RuntimeError: If no container registry is passed. + RuntimeError: If the Cloud Build build fails. + """ + if not container_registry: + raise RuntimeError( + "The AWS Image Builder requires a container registry to push " + "the image to. Please provide one and try again." + ) + + logger.info("Using AWS Code Build to build image `%s`", image_name) + cloud_build_context = self._upload_build_context( + build_context=build_context, + parent_path_directory_name=f"code-build-contexts/{str(self.id)}", + archive_type=ArchiveType.ZIP, + ) + + url_parts = urlparse(cloud_build_context) + bucket = url_parts.netloc + object_path = url_parts.path.lstrip("/") + logger.info( + "Build context located in bucket `%s` and object path `%s`", + bucket, + object_path, + ) + + # Pass authentication credentials as environment variables, if + # the container registry has credentials and if implicit authentication + # is disabled + environment_variables_override = [] + pre_build_commands = [] + if not self.config.implicit_auth: + credentials = container_registry.credentials + if credentials: + environment_variables_override = [ + { + "name": "CONTAINER_REGISTRY_USERNAME", + "value": credentials[0], + "type": "PLAINTEXT", + }, + { + "name": "CONTAINER_REGISTRY_PASSWORD", + "value": credentials[1], + "type": "PLAINTEXT", + }, + ] + pre_build_commands = [ + "echo Logging in to container registry", + 'echo "$CONTAINER_REGISTRY_PASSWORD" | docker login --username "$CONTAINER_REGISTRY_USERNAME" --password-stdin ' + f"{container_registry.config.uri}", + ] + elif container_registry.flavor == AWS_CONTAINER_REGISTRY_FLAVOR: + pre_build_commands = [ + "echo Logging in to EKS", + f"aws ecr get-login-password --region {self.code_build_client._client_config.region_name} | docker login --username AWS --password-stdin {container_registry.config.uri}", + ] + + # Convert the docker_build_options dictionary to a list of strings + docker_build_args = "" + for key, value in docker_build_options.items(): + option = f"--{key}" + if isinstance(value, list): + for val in value: + docker_build_args += f"{option} {val} " + elif value is not None and not isinstance(value, bool): + docker_build_args += f"{option} {value} " + elif value is not False: + docker_build_args += f"{option} " + + pre_build_commands_str = "\n".join( + [f" - {command}" for command in pre_build_commands] + ) + + # Generate and use a unique tag for the Docker image. This is easier + # than trying to parse the image digest from the Code Build logs. + build_id = str(uuid4()) + # Replace the tag in the image name with the unique build ID + repo_name = image_name.split(":")[0] + alt_image_name = f"{repo_name}:{build_id}" + + buildspec = f""" +version: 0.2 +phases: + pre_build: + commands: +{pre_build_commands_str} + build: + commands: + - echo Build started on `date` + - echo Building the Docker image... + - docker build -t {image_name} . {docker_build_args} + - echo Build completed on `date` + post_build: + commands: + - echo Pushing the Docker image... + - docker push {image_name} + - docker tag {image_name} {alt_image_name} + - docker push {alt_image_name} + - echo Pushed the Docker image +artifacts: + files: + - '**/*' +""" + + # Override the build project with the parameters needed to run a + # docker-in-docker build, as covered here: https://docs.aws.amazon.com/codebuild/latest/userguide/sample-docker-section.html + response = self.code_build_client.start_build( + projectName=self.config.code_build_project, + environmentTypeOverride="LINUX_CONTAINER", + imageOverride="bentolor/docker-dind-awscli", # "docker:dind", + computeTypeOverride="BUILD_GENERAL1_SMALL", + privilegedModeOverride=False, + sourceTypeOverride="S3", + sourceLocationOverride=f"{bucket}/{object_path}", + buildspecOverride=buildspec, + environmentVariablesOverride=environment_variables_override, + # no artifacts + artifactsOverride={"type": "NO_ARTIFACTS"}, + ) + + logs_url = response["build"]["logs"]["deepLink"] + + logger.info( + f"Running Code Build to build the Docker image. Cloud Build logs: `{logs_url}`", + ) + + # Wait for the build to complete + code_build_id = response["build"]["id"] + while True: + build_status = self.code_build_client.batch_get_builds( + ids=[code_build_id] + ) + build = build_status["builds"][0] + status = build["buildStatus"] + if status in [ + "SUCCEEDED", + "FAILED", + "FAULT", + "TIMED_OUT", + "STOPPED", + ]: + break + time.sleep(10) + + if status != "SUCCEEDED": + raise RuntimeError( + f"The Code Build run to build the Docker image has failed. More " + f"information can be found in the Cloud Build logs: {logs_url}." + ) + + logger.info( + f"The Docker image has been built successfully. More information can " + f"be found in the Cloud Build logs: `{logs_url}`." + ) + + return alt_image_name diff --git a/src/zenml/integrations/kaniko/image_builders/kaniko_image_builder.py b/src/zenml/integrations/kaniko/image_builders/kaniko_image_builder.py index 1a4aaac3ad8..04acf22ae8e 100644 --- a/src/zenml/integrations/kaniko/image_builders/kaniko_image_builder.py +++ b/src/zenml/integrations/kaniko/image_builders/kaniko_image_builder.py @@ -25,6 +25,7 @@ from zenml.integrations.kaniko.flavors import KanikoImageBuilderConfig from zenml.logger import get_logger from zenml.stack import StackValidator +from zenml.utils.archivable import ArchiveType if TYPE_CHECKING: from zenml.container_registries import BaseContainerRegistry @@ -295,7 +296,7 @@ def _write_build_context( logger.debug("Writing build context to process stdin.") assert process.stdin with process.stdin as _, tempfile.TemporaryFile(mode="w+b") as f: - build_context.write_archive(f, use_gzip=True) + build_context.write_archive(f, archive_type=ArchiveType.TAR_GZ) while True: data = f.read(1024) if not data: diff --git a/src/zenml/new/pipelines/code_archive.py b/src/zenml/new/pipelines/code_archive.py deleted file mode 100644 index 9eba95cf06b..00000000000 --- a/src/zenml/new/pipelines/code_archive.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright (c) ZenML GmbH 2024. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at: -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing -# permissions and limitations under the License. -"""Code archive.""" - -import os -from pathlib import Path -from typing import IO, TYPE_CHECKING, Dict, Optional - -from zenml.logger import get_logger -from zenml.utils import string_utils -from zenml.utils.archivable import Archivable - -if TYPE_CHECKING: - from git.repo.base import Repo - - -logger = get_logger(__name__) - - -class CodeArchive(Archivable): - """Code archive class. - - This class is used to archive user code before uploading it to the artifact - store. If the user code is stored in a Git repository, only files not - excluded by gitignores will be included in the archive. - """ - - def __init__(self, root: str) -> None: - """Initialize the object. - - Args: - root: Root directory of the archive. - """ - super().__init__() - self._root = root - - @property - def git_repo(self) -> Optional["Repo"]: - """Git repository active at the code archive root. - - Returns: - The git repository if available. - """ - try: - # These imports fail when git is not installed on the machine - from git.exc import InvalidGitRepositoryError - from git.repo.base import Repo - except ImportError: - return None - - try: - git_repo = Repo(path=self._root, search_parent_directories=True) - except InvalidGitRepositoryError: - return None - - return git_repo - - def _get_all_files(self) -> Dict[str, str]: - """Get all files inside the archive root. - - Returns: - All files inside the archive root. - """ - all_files = {} - for root, _, files in os.walk(self._root): - for file in files: - file_path = os.path.join(root, file) - path_in_archive = os.path.relpath(file_path, self._root) - all_files[path_in_archive] = file_path - - return all_files - - def get_files(self) -> Dict[str, str]: - """Gets all regular files that should be included in the archive. - - Raises: - RuntimeError: If the code archive would not include any files. - - Returns: - A dict {path_in_archive: path_on_filesystem} for all regular files - in the archive. - """ - all_files = {} - - if repo := self.git_repo: - try: - result = repo.git.ls_files( - "--cached", - "--others", - "--modified", - "--exclude-standard", - self._root, - ) - except Exception as e: - logger.warning( - "Failed to get non-ignored files from git: %s", str(e) - ) - all_files = self._get_all_files() - else: - for file in result.split(): - file_path = os.path.join(repo.working_dir, file) - path_in_archive = os.path.relpath(file_path, self._root) - - if os.path.exists(file_path): - all_files[path_in_archive] = file_path - else: - all_files = self._get_all_files() - - if not all_files: - raise RuntimeError( - "The code archive to be uploaded does not contain any files. " - "This is probably because all files in your source root " - f"`{self._root}` are ignored by a .gitignore file." - ) - - # Explicitly remove .zen directories as we write an updated version - # to disk everytime ZenML is called. This updates the mtime of the - # file, which invalidates the code upload caching. The values in - # the .zen directory are not needed anyway as we set them as - # environment variables. - all_files = { - path_in_archive: file_path - for path_in_archive, file_path in sorted(all_files.items()) - if ".zen" not in Path(path_in_archive).parts[:-1] - } - - return all_files - - def write_archive( - self, output_file: IO[bytes], use_gzip: bool = True - ) -> None: - """Writes an archive of the build context to the given file. - - Args: - output_file: The file to write the archive to. - use_gzip: Whether to use `gzip` to compress the file. - """ - super().write_archive(output_file=output_file, use_gzip=use_gzip) - archive_size = os.path.getsize(output_file.name) - if archive_size > 20 * 1024 * 1024: - logger.warning( - "Code archive size: `%s`. If you believe this is " - "unreasonably large, make sure to version your code in git and " - "ignore unnecessary files using a `.gitignore` file.", - string_utils.get_human_readable_filesize(archive_size), - ) diff --git a/src/zenml/utils/archivable.py b/src/zenml/utils/archivable.py index c2d7b83c422..488b55f778d 100644 --- a/src/zenml/utils/archivable.py +++ b/src/zenml/utils/archivable.py @@ -15,11 +15,21 @@ import io import tarfile +import zipfile from abc import ABC, abstractmethod from pathlib import Path -from typing import IO, Any, Dict +from typing import IO, Any, Dict, Optional from zenml.io import fileio +from zenml.utils.enum_utils import StrEnum + + +class ArchiveType(StrEnum): + """Archive types supported by the ZenML build context.""" + + TAR = "tar" + TAR_GZ = "tar.gz" + ZIP = "zip" class Archivable(ABC): @@ -81,52 +91,71 @@ def add_directory(self, source: str, destination: str) -> None: self._extra_files[file_destination.as_posix()] = f.read() def write_archive( - self, output_file: IO[bytes], use_gzip: bool = True + self, + output_file: IO[bytes], + archive_type: ArchiveType = ArchiveType.TAR_GZ, ) -> None: """Writes an archive of the build context to the given file. Args: output_file: The file to write the archive to. - use_gzip: Whether to use `gzip` to compress the file. + archive_type: The type of archive to create. """ files = self.get_files() extra_files = self.get_extra_files() + intermediate_fileobj: Optional[Any] = None + fileobj: Any = output_file - if use_gzip: - from gzip import GzipFile - - # We don't use the builtin gzip functionality of the `tarfile` - # library as that one includes the tar filename and creation - # timestamp in the archive which causes the hash of the resulting - # file to be different each time. We use this hash to avoid - # duplicate uploads, which is why we pass empty values for filename - # and mtime here. - fileobj: Any = GzipFile( - filename="", mode="wb", fileobj=output_file, mtime=0.0 - ) + if archive_type == ArchiveType.ZIP: + fileobj = zipfile.ZipFile(output_file, "w", zipfile.ZIP_DEFLATED) else: - fileobj = output_file - - with tarfile.open(mode="w", fileobj=fileobj) as tf: - for archive_path, file_path in files.items(): - if archive_path in extra_files: - continue - - if info := tf.gettarinfo(file_path, arcname=archive_path): - if info.isfile(): - with open(file_path, "rb") as f: - tf.addfile(info, f) + if archive_type == ArchiveType.TAR_GZ: + from gzip import GzipFile + + # We don't use the builtin gzip functionality of the `tarfile` + # library as that one includes the tar filename and creation + # timestamp in the archive which causes the hash of the resulting + # file to be different each time. We use this hash to avoid + # duplicate uploads, which is why we pass empty values for filename + # and mtime here. + fileobj = intermediate_fileobj = GzipFile( + filename="", mode="wb", fileobj=output_file, mtime=0.0 + ) + fileobj = tarfile.open(mode="w", fileobj=fileobj) + + try: + with fileobj as af: + for archive_path, file_path in files.items(): + if archive_path in extra_files: + continue + if archive_type == ArchiveType.ZIP: + assert isinstance(af, zipfile.ZipFile) + af.write(file_path, arcname=archive_path) else: - tf.addfile(info, None) - - for archive_path, contents in extra_files.items(): - info = tarfile.TarInfo(archive_path) - contents_encoded = contents.encode("utf-8") - info.size = len(contents_encoded) - tf.addfile(info, io.BytesIO(contents_encoded)) - - if use_gzip: - fileobj.close() + assert isinstance(af, tarfile.TarFile) + if info := af.gettarinfo( + file_path, arcname=archive_path + ): + if info.isfile(): + with open(file_path, "rb") as f: + af.addfile(info, f) + else: + af.addfile(info, None) + + for archive_path, contents in extra_files.items(): + contents_encoded = contents.encode("utf-8") + + if archive_type == ArchiveType.ZIP: + assert isinstance(af, zipfile.ZipFile) + af.writestr(archive_path, contents_encoded) + else: + assert isinstance(af, tarfile.TarFile) + info = tarfile.TarInfo(archive_path) + info.size = len(contents_encoded) + af.addfile(info, io.BytesIO(contents_encoded)) + finally: + if intermediate_fileobj: + intermediate_fileobj.close() output_file.seek(0) diff --git a/src/zenml/utils/code_utils.py b/src/zenml/utils/code_utils.py index d38888aa399..d5d66664a3e 100644 --- a/src/zenml/utils/code_utils.py +++ b/src/zenml/utils/code_utils.py @@ -25,7 +25,7 @@ from zenml.io import fileio from zenml.logger import get_logger from zenml.utils import source_utils, string_utils -from zenml.utils.archivable import Archivable +from zenml.utils.archivable import Archivable, ArchiveType if TYPE_CHECKING: from git.repo.base import Repo @@ -152,15 +152,19 @@ def get_files(self) -> Dict[str, str]: return all_files def write_archive( - self, output_file: IO[bytes], use_gzip: bool = True + self, + output_file: IO[bytes], + archive_type: ArchiveType = ArchiveType.TAR_GZ, ) -> None: """Writes an archive of the build context to the given file. Args: output_file: The file to write the archive to. - use_gzip: Whether to use `gzip` to compress the file. + archive_type: The type of archive to create. """ - super().write_archive(output_file=output_file, use_gzip=use_gzip) + super().write_archive( + output_file=output_file, archive_type=archive_type + ) archive_size = os.path.getsize(output_file.name) if archive_size > 20 * 1024 * 1024: logger.warning(