Skip to content

Commit

Permalink
Making flash attention build happen after we built and installed the …
Browse files Browse the repository at this point in the history
…required torch version (#270)
  • Loading branch information
gshtras authored Nov 8, 2024
1 parent 15d17c5 commit 1dcc5df
Showing 1 changed file with 21 additions and 17 deletions.
38 changes: 21 additions & 17 deletions Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -73,23 +73,6 @@ ARG COMMON_WORKDIR
COPY --from=build_rccl ${COMMON_WORKDIR}/rccl/build/release/*.deb /
FROM scratch AS export_rccl_0
FROM export_rccl_${BUILD_RCCL} AS export_rccl

# -----------------------
# flash attn build stages
FROM base AS build_flash_attn
ARG FA_BRANCH="3cea2fb"
ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
RUN git clone ${FA_REPO} \
&& cd flash-attention \
&& git checkout ${FA_BRANCH} \
&& git submodule update --init \
&& GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist
FROM scratch AS export_flash_attn_1
ARG COMMON_WORKDIR
COPY --from=build_flash_attn ${COMMON_WORKDIR}/flash-attention/dist/*.whl /
FROM scratch AS export_flash_attn_0
FROM export_flash_attn_${BUILD_FA} AS export_flash_attn

# -----------------------
# Triton build stages
FROM base AS build_triton
Expand Down Expand Up @@ -143,6 +126,27 @@ COPY --from=build_pytorch ${COMMON_WORKDIR}/vision/dist/*.whl /
FROM scratch as export_pytorch_0
from export_pytorch_${BUILD_PYTORCH} as export_pytorch

# -----------------------
# flash attn build stages
FROM base AS build_flash_attn
ARG FA_BRANCH="3cea2fb"
ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \
if ls /install/*.whl; then \
pip uninstall -y torch torchvision \
&& pip install /install/*.whl; \
fi
RUN git clone ${FA_REPO} \
&& cd flash-attention \
&& git checkout ${FA_BRANCH} \
&& git submodule update --init \
&& MAX_JOBS=64 GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist
FROM scratch AS export_flash_attn_1
ARG COMMON_WORKDIR
COPY --from=build_flash_attn ${COMMON_WORKDIR}/flash-attention/dist/*.whl /
FROM scratch AS export_flash_attn_0
FROM export_flash_attn_${BUILD_FA} AS export_flash_attn

# -----------------------
# vLLM (and gradlib) fetch stages
FROM base AS fetch_vllm_0
Expand Down

0 comments on commit 1dcc5df

Please sign in to comment.