From 88c51cc2a35b8a04f66c5b4f766efa37ded5331a Mon Sep 17 00:00:00 2001 From: Param Bole Date: Thu, 25 Jul 2024 17:32:56 +0000 Subject: [PATCH] Refactoring Maxdiffusion-JAX-Stable-Stack Build Process --- .github/workflows/UploadDockerImages.yml | 31 ++++++++++++++++++ docker_maxdiffusion_image_upload.sh | 32 ++++++++++++++----- ...xdiffusion_jax_stable_stack_tpu.Dockerfile | 6 ++-- ... => requirements_with_jax_stable_stack.txt | 2 +- 4 files changed, 59 insertions(+), 12 deletions(-) create mode 100644 .github/workflows/UploadDockerImages.yml rename maxdiffusion_jax_ss_tpu.Dockerfile => maxdiffusion_jax_stable_stack_tpu.Dockerfile (81%) rename requirements_with_jax_ss.txt => requirements_with_jax_stable_stack.txt (90%) diff --git a/.github/workflows/UploadDockerImages.yml b/.github/workflows/UploadDockerImages.yml new file mode 100644 index 00000000..2de44791 --- /dev/null +++ b/.github/workflows/UploadDockerImages.yml @@ -0,0 +1,31 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python + +name: Build Images + +on: + push: + branches: [ "pb-jax-ss" ] + +jobs: + build-image: + runs-on: ["self-hosted", "e2", "cpu"] + steps: + - uses: actions/checkout@v3 + - name: build jax stable stack image + run: | + bash docker_maxdiffusion_image_upload.sh PROJECT_ID=tpu-prod-env-multipod BASEIMAGE=us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/tpu:jax0.4.30-rev1 CLOUD_IMAGE_NAME=maxdiffusion-jax-stable-stack IMAGE_TAG=jax0.4.30-rev1 MAXDIFFUSION_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt DELETE_LOCAL_IMAGE=true \ No newline at end of file diff --git a/docker_maxdiffusion_image_upload.sh b/docker_maxdiffusion_image_upload.sh index c555ae3d..d6e783e1 100644 --- a/docker_maxdiffusion_image_upload.sh +++ b/docker_maxdiffusion_image_upload.sh @@ -21,10 +21,12 @@ # (minutes). However, if you are simply changing local code and not updating dependencies, uploading just takes a few seconds. # Example command: -# bash docker_maxdiffusion_image_upload.sh PROJECT_ID=tpu-prod-env-multipod BASEIMAGE=gcr.io/tpu-prod-env-multipod/jax-ss/tpu:jax0.4.28-v1.0.0 CLOUD_IMAGE_NAME=maxdiffusion-jax-ss-0.4.28-v1.0.0 IMAGE_TAG=latest MAXDIFFUSION_REQUIREMENTS_FILE=requirements_with_jax_ss.txt +# bash docker_maxdiffusion_image_upload.sh PROJECT_ID=tpu-prod-env-multipod BASEIMAGE=gcr.io/tpu-prod-env-multipod/jax-ss/tpu:jax0.4.28-v1.0.0 CLOUD_IMAGE_NAME=maxdiffusion-jax-stable-stack IMAGE_TAG=latest MAXDIFFUSION_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt set -e +export LOCAL_IMAGE_NAME=maxdiffusion_base_image + # Set environment variables for ARGUMENT in "$@"; do IFS='=' read -r KEY VALUE <<< "$ARGUMENT" @@ -57,18 +59,32 @@ if [[ ! -v MAXDIFFUSION_REQUIREMENTS_FILE ]]; then exit 1 fi +# Default: Don't delete local image +DELETE_LOCAL_IMAGE="${DELETE_LOCAL_IMAGE:-false}" + +gcloud auth configure-docker us-docker.pkg.dev --quiet + COMMIT_HASH=$(git rev-parse --short HEAD) -echo "Building JAX SS MaxDiffusion at commit hash ${COMMIT_HASH} . . ." +echo "Building JAX Stable Stack MaxDiffusion at commit hash ${COMMIT_HASH} . . ." -docker build \ - --build-arg JAX_SS_BASEIMAGE=${BASEIMAGE} \ +IMAGE_DATE=$(date +%Y-%m-%d) + +FULL_IMAGE_TAG=${IMAGE_TAG}-${IMAGE_DATE} + +docker build --no-cache \ + --build-arg JAX_STABLE_STACK_BASEIMAGE=${BASEIMAGE} \ --build-arg COMMIT_HASH=${COMMIT_HASH} \ --build-arg MAXDIFFUSION_REQUIREMENTS_FILE=${MAXDIFFUSION_REQUIREMENTS_FILE} \ --network=host \ - -t gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}/tpu:${IMAGE_TAG} \ - -f ./maxdiffusion_jax_ss_tpu.Dockerfile . + -t us-docker.pkg.dev/${PROJECT_ID}/${CLOUD_IMAGE_NAME}/tpu:${FULL_IMAGE_TAG} \ + -f maxdiffusion_jax_stable_stack_tpu.Dockerfile . + +docker push us-docker.pkg.dev/${PROJECT_ID}/${CLOUD_IMAGE_NAME}/tpu:${FULL_IMAGE_TAG} -docker push gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}/tpu:${IMAGE_TAG} +echo "All done, check out your artifacts at: us-docker.pkg.dev/${PROJECT_ID}/${CLOUD_IMAGE_NAME}/tpu:${FULL_IMAGE_TAG}" -echo "All done, check out your artifacts at: gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}/tpu:${IMAGE_TAG}" \ No newline at end of file +if [ "$DELETE_LOCAL_IMAGE" == "true" ]; then + docker rmi us-docker.pkg.dev/${PROJECT_ID}/${CLOUD_IMAGE_NAME}/tpu:${FULL_IMAGE_TAG} + echo "Local image deleted." +fi \ No newline at end of file diff --git a/maxdiffusion_jax_ss_tpu.Dockerfile b/maxdiffusion_jax_stable_stack_tpu.Dockerfile similarity index 81% rename from maxdiffusion_jax_ss_tpu.Dockerfile rename to maxdiffusion_jax_stable_stack_tpu.Dockerfile index 67a32ce5..6d7c35fe 100644 --- a/maxdiffusion_jax_ss_tpu.Dockerfile +++ b/maxdiffusion_jax_stable_stack_tpu.Dockerfile @@ -1,7 +1,7 @@ -ARG JAX_SS_BASEIMAGE +ARG JAX_STABLE_STACK_BASEIMAGE # JAX Stable Stack Base Image -From $JAX_SS_BASEIMAGE +FROM $JAX_STABLE_STACK_BASEIMAGE ARG COMMIT_HASH @@ -27,5 +27,5 @@ RUN if [ ! -z "${MAXDIFFUSION_REQUIREMENTS_FILE}" ]; then \ # Install MaxDiffusion RUN pip install . -# Run the script available in JAX-SS base image to generate the manifest file +# Run the script available in JAX-Stable-Stack base image to generate the manifest file RUN bash /generate_manifest.sh PREFIX=maxdiffusion COMMIT_HASH=$COMMIT_HASH \ No newline at end of file diff --git a/requirements_with_jax_ss.txt b/requirements_with_jax_stable_stack.txt similarity index 90% rename from requirements_with_jax_ss.txt rename to requirements_with_jax_stable_stack.txt index 5d349970..f4ebde55 100644 --- a/requirements_with_jax_ss.txt +++ b/requirements_with_jax_stable_stack.txt @@ -1,5 +1,5 @@ # Requirements for Building the MaxDifussion Docker Image -# These requirements are additional to the dependencies present in the JAX SS base image. +# These requirements are additional to the dependencies present in the JAX Stable Stack base image. absl-py transformers>=4.25.1 datasets