diff --git a/.github/workflows/UploadDockerImages.yml b/.github/workflows/UploadDockerImages.yml new file mode 100644 index 00000000..2de44791 --- /dev/null +++ b/.github/workflows/UploadDockerImages.yml @@ -0,0 +1,31 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python + +name: Build Images + +on: + push: + branches: [ "pb-jax-ss" ] + +jobs: + build-image: + runs-on: ["self-hosted", "e2", "cpu"] + steps: + - uses: actions/checkout@v3 + - name: build jax stable stack image + run: | + bash docker_maxdiffusion_image_upload.sh PROJECT_ID=tpu-prod-env-multipod BASEIMAGE=us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/tpu:jax0.4.30-rev1 CLOUD_IMAGE_NAME=maxdiffusion-jax-stable-stack IMAGE_TAG=jax0.4.30-rev1 MAXDIFFUSION_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt DELETE_LOCAL_IMAGE=true \ No newline at end of file diff --git a/docker_maxdiffusion_image_upload.sh b/docker_maxdiffusion_image_upload.sh index c555ae3d..1dabb5ee 100644 --- a/docker_maxdiffusion_image_upload.sh +++ b/docker_maxdiffusion_image_upload.sh @@ -21,10 +21,12 @@ # (minutes). However, if you are simply changing local code and not updating dependencies, uploading just takes a few seconds. # Example command: -# bash docker_maxdiffusion_image_upload.sh PROJECT_ID=tpu-prod-env-multipod BASEIMAGE=gcr.io/tpu-prod-env-multipod/jax-ss/tpu:jax0.4.28-v1.0.0 CLOUD_IMAGE_NAME=maxdiffusion-jax-ss-0.4.28-v1.0.0 IMAGE_TAG=latest MAXDIFFUSION_REQUIREMENTS_FILE=requirements_with_jax_ss.txt +# bash docker_maxdiffusion_image_upload.sh PROJECT_ID=tpu-prod-env-multipod BASEIMAGE=gcr.io/tpu-prod-env-multipod/jax-ss/tpu:jax0.4.28-v1.0.0 CLOUD_IMAGE_NAME=maxdiffusion-jax-stable-stack IMAGE_TAG=latest MAXDIFFUSION_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt set -e +export LOCAL_IMAGE_NAME=maxdiffusion_base_image + # Set environment variables for ARGUMENT in "$@"; do IFS='=' read -r KEY VALUE <<< "$ARGUMENT" @@ -57,18 +59,32 @@ if [[ ! -v MAXDIFFUSION_REQUIREMENTS_FILE ]]; then exit 1 fi +# Default: Don't delete local image +DELETE_LOCAL_IMAGE="${DELETE_LOCAL_IMAGE:-false}" + +gcloud auth configure-docker us-docker.pkg.dev --quiet + COMMIT_HASH=$(git rev-parse --short HEAD) -echo "Building JAX SS MaxDiffusion at commit hash ${COMMIT_HASH} . . ." +echo "Building JAX Stable Stack MaxDiffusion at commit hash ${COMMIT_HASH} . . ." -docker build \ - --build-arg JAX_SS_BASEIMAGE=${BASEIMAGE} \ +IMAGE_DATE=$(date +%Y-%m-%d) + +FULL_IMAGE_TAG=${IMAGE_TAG}-${IMAGE_DATE} + +docker build --no-cache \ + --build-arg JAX_STABLE_STACK_BASEIMAGE=${BASEIMAGE} \ --build-arg COMMIT_HASH=${COMMIT_HASH} \ --build-arg MAXDIFFUSION_REQUIREMENTS_FILE=${MAXDIFFUSION_REQUIREMENTS_FILE} \ --network=host \ - -t gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}/tpu:${IMAGE_TAG} \ - -f ./maxdiffusion_jax_ss_tpu.Dockerfile . + -t gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}/tpu:${FULL_IMAGE_TAG} \ + -f maxdiffusion_jax_stable_stack_tpu.Dockerfile . + +docker push gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}/tpu:${FULL_IMAGE_TAG} -docker push gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}/tpu:${IMAGE_TAG} +echo "All done, check out your artifacts at: gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}/tpu:${FULL_IMAGE_TAG}" -echo "All done, check out your artifacts at: gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}/tpu:${IMAGE_TAG}" \ No newline at end of file +if [ "$DELETE_LOCAL_IMAGE" == "true" ]; then + docker rmi gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}/tpu:${FULL_IMAGE_TAG} + echo "Local image deleted." +fi \ No newline at end of file diff --git a/maxdiffusion_jax_ss_tpu.Dockerfile b/maxdiffusion_jax_ss_tpu.Dockerfile deleted file mode 100644 index 67a32ce5..00000000 --- a/maxdiffusion_jax_ss_tpu.Dockerfile +++ /dev/null @@ -1,31 +0,0 @@ -ARG JAX_SS_BASEIMAGE - -# JAX Stable Stack Base Image -From $JAX_SS_BASEIMAGE - -ARG COMMIT_HASH - -ENV COMMIT_HASH=$COMMIT_HASH - -RUN mkdir -p /deps - -# Set the working directory in the container -WORKDIR /deps - -# Copy all files from local workspace into docker container -COPY . . -RUN ls . - -ARG MAXDIFFUSION_REQUIREMENTS_FILE - -# Install Maxdiffusion requirements -RUN if [ ! -z "${MAXDIFFUSION_REQUIREMENTS_FILE}" ]; then \ - echo "Using MaxDiffusion requirements: ${MAXDIFFUSION_REQUIREMENTS_FILE}" && \ - pip install -r /deps/${MAXDIFFUSION_REQUIREMENTS_FILE}; \ - fi - -# Install MaxDiffusion -RUN pip install . - -# Run the script available in JAX-SS base image to generate the manifest file -RUN bash /generate_manifest.sh PREFIX=maxdiffusion COMMIT_HASH=$COMMIT_HASH \ No newline at end of file diff --git a/maxdiffusion_jax_stable_stack_tpu.Dockerfile b/maxdiffusion_jax_stable_stack_tpu.Dockerfile new file mode 100644 index 00000000..5096ba3d --- /dev/null +++ b/maxdiffusion_jax_stable_stack_tpu.Dockerfile @@ -0,0 +1,31 @@ +ARG JAX_STABLE_STACK_BASEIMAGE + +# JAX Stable Stack Base Image +FROM $JAX_STABLE_STACK_BASEIMAGE + +ARG COMMIT_HASH + +ENV COMMIT_HASH=$COMMIT_HASH + +RUN mkdir -p /deps + +# Set the working directory in the container +WORKDIR /deps + +# Copy all files from local workspace into docker container +COPY . . +RUN ls . + +# ARG MAXDIFFUSION_REQUIREMENTS_FILE + +# # Install Maxdiffusion requirements +# RUN if [ ! -z "${MAXDIFFUSION_REQUIREMENTS_FILE}" ]; then \ +# echo "Using MaxDiffusion requirements: ${MAXDIFFUSION_REQUIREMENTS_FILE}" && \ +# pip install -r /deps/${MAXDIFFUSION_REQUIREMENTS_FILE}; \ +# fi + +# # Install MaxDiffusion +# RUN pip install . + +# Run the script available in JAX-Stable-Stack base image to generate the manifest file +RUN bash /generate_manifest.sh PREFIX=maxdiffusion COMMIT_HASH=$COMMIT_HASH \ No newline at end of file diff --git a/requirements_with_jax_ss.txt b/requirements_with_jax_stable_stack.txt similarity index 90% rename from requirements_with_jax_ss.txt rename to requirements_with_jax_stable_stack.txt index 5d349970..f4ebde55 100644 --- a/requirements_with_jax_ss.txt +++ b/requirements_with_jax_stable_stack.txt @@ -1,5 +1,5 @@ # Requirements for Building the MaxDifussion Docker Image -# These requirements are additional to the dependencies present in the JAX SS base image. +# These requirements are additional to the dependencies present in the JAX Stable Stack base image. absl-py transformers>=4.25.1 datasets