From 164ce38c110c3505ad0cd5eb1ee68ca81c76d8db Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Fri, 13 Sep 2024 12:56:52 -0400 Subject: [PATCH] 6.2 dockerfile (#176) * Trying to modernize the dockerfile, pinning rccl; triton; pytorch; hipblaslt to the latest required versions * Dockerfile fixes. Using the scaling factors in scaled_mm where they are required by torch 2.5 or acceptable by others * Building torchvision too when building torch * gradlib as a not-cmake project doesn't inherit `target_compile_definitions(hipblaslt PUBLIC LEGACY_HIPBLAS_DIRECT )` * Using a specific torch commit with scaled_mm fix until it is in mainline. Fixed scaled_mm in gradlib for no reason at all * No point in pinning hipblaslt to rocm6.2 release, if we want to build it, we'll want 0.10 * Removed torch requirement --- Dockerfile.rocm | 64 ++++++++++++++++--- gradlib/gradlib/GemmTuner.py | 3 + gradlib/setup.py | 14 ++-- .../layers/quantization/utils/w8a8_utils.py | 25 ++++---- 4 files changed, 79 insertions(+), 27 deletions(-) diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 499f55896c35c..c8fe899356a67 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -1,13 +1,14 @@ # default base image -ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging" +ARG BASE_IMAGE="rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0" ARG COMMON_WORKDIR=/app # The following ARGs should be "0" or "1". If "1", the respective component will be built and installed on top of the base image -ARG BUILD_HIPBLASLT="1" +ARG BUILD_HIPBLASLT="0" ARG BUILD_RCCL="1" ARG BUILD_FA="1" ARG BUILD_TRITON="1" +ARG BUILD_PYTORCH="1" # This ARG should also be "0" or "1". If "1", the vLLM development directory is obtained via git clone. # If "0", it is copied in from the local working directory. ARG REMOTE_VLLM="0" @@ -39,11 +40,12 @@ WORKDIR ${COMMON_WORKDIR} # ----------------------- # hipBLASLt build stages FROM base AS build_hipblaslt -ARG HIPBLASLT_BRANCH="6f65c6e" -RUN git clone https://github.com/ROCm/hipBLASLt \ +ARG HIPBLASLT_BRANCH="e6da924" +RUN apt-get purge -y hipblaslt \ + && git clone https://github.com/ROCm/hipBLASLt.git \ && cd hipBLASLt \ && git checkout ${HIPBLASLT_BRANCH} \ - && SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} \ + && SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} --legacy_hipblas_direct \ && cd build/release \ && make package FROM scratch AS export_hipblaslt_1 @@ -55,7 +57,7 @@ FROM export_hipblaslt_${BUILD_HIPBLASLT} AS export_hipblaslt # ----------------------- # RCCL build stages FROM base AS build_rccl -ARG RCCL_BRANCH="73221b4" +ARG RCCL_BRANCH="rocm-6.2.0" RUN git clone https://github.com/ROCm/rccl \ && cd rccl \ && git checkout ${RCCL_BRANCH} \ @@ -69,7 +71,7 @@ FROM export_rccl_${BUILD_RCCL} AS export_rccl # ----------------------- # flash attn build stages FROM base AS build_flash_attn -ARG FA_BRANCH="ae7928c" +ARG FA_BRANCH="3cea2fb" ARG FA_REPO="https://github.com/ROCm/flash-attention.git" RUN git clone ${FA_REPO} \ && cd flash-attention \ @@ -85,9 +87,9 @@ FROM export_flash_attn_${BUILD_FA} AS export_flash_attn # ----------------------- # Triton build stages FROM base AS build_triton -ARG TRITON_BRANCH="6ddb79b" -ARG TRITON_REPO="https://github.com/OpenAI/triton.git" -RUN git clone ${TRITON_REPO} \ +ARG TRITON_BRANCH="e192dba" +ARG TRITON_REPO="https://github.com/triton-lang/triton.git" +RUN python3 -m pip install ninja cmake wheel pybind11 && git clone ${TRITON_REPO} \ && cd triton \ && git checkout ${TRITON_BRANCH} \ && cd python \ @@ -105,6 +107,36 @@ RUN cd /opt/rocm/share/amd_smi \ FROM scratch AS export_amdsmi COPY --from=build_amdsmi /opt/rocm/share/amd_smi/dist/*.whl / +FROM base as build_pytorch +# A commit to fix the output scaling factor issue in _scaled_mm +# Not yet in 2.5.0-rc1 +ARG PYTORCH_BRANCH="cedc116" +ARG PYTORCH_VISION_BRANCH="v0.19.1" +ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git" +ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" +#RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \ +#if ls /install/*.deb; then \ +# apt-get purge -y hipblaslt \ +# && dpkg -i /install/*.deb \ +# && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \ +# && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \ +#fi +RUN git clone ${PYTORCH_REPO} pytorch \ + && cd pytorch && git checkout ${PYTORCH_BRANCH} && git submodule update --init --recursive \ + && python tools/amd_build/build_amd.py \ + && CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \ + && pip install dist/*.whl \ + && cd .. \ + && git clone ${PYTORCH_VISION_REPO} vision \ + && cd vision && git checkout ${PYTORCH_VISION_BRANCH} \ + && python3 setup.py bdist_wheel --dist-dir=dist +FROM scratch as export_pytorch_1 +ARG COMMON_WORKDIR +COPY --from=build_pytorch ${COMMON_WORKDIR}/pytorch/dist/*.whl / +COPY --from=build_pytorch ${COMMON_WORKDIR}/vision/dist/*.whl / +FROM scratch as export_pytorch_0 +from export_pytorch_${BUILD_PYTORCH} as export_pytorch + # ----------------------- # vLLM (and gradlib) fetch stages FROM base AS fetch_vllm_0 @@ -129,6 +161,11 @@ if ls /install/*.deb; then \ && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \ && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \ fi +# Install pytorch +RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \ +if ls /install/*.whl; then \ + pip install /install/*.whl; \ +fi # Build vLLM RUN cd vllm \ && python3 setup.py clean --all && python3 setup.py bdist_wheel --dist-dir=dist @@ -197,6 +234,13 @@ RUN --mount=type=bind,from=export_amdsmi,src=/,target=/install \ pip uninstall -y amdsmi \ && pip install /install/*.whl; +RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \ + if ls /install/*.whl; then \ + # Preemptively uninstall to prevent pip same-version no-installs + pip uninstall -y torch torchvision \ + && pip install /install/*.whl; \ + fi + RUN python3 -m pip install --upgrade numba scipy huggingface-hub[cli] # Install vLLM (and gradlib) diff --git a/gradlib/gradlib/GemmTuner.py b/gradlib/gradlib/GemmTuner.py index 8e10934f7f7ef..a586d772de0d0 100644 --- a/gradlib/gradlib/GemmTuner.py +++ b/gradlib/gradlib/GemmTuner.py @@ -15,6 +15,7 @@ atol = 1 CACHE_INVALIDATE_BUFFERS = int(os.getenv("CACHE_INVALIDATE_BUFFERS", "37")) +ONE = torch.ones(1, dtype=torch.float32, device='cuda') class Gemm: @@ -68,6 +69,8 @@ def check_gemm_ref(self, libtype, solidx): if self.indtype == torch.float8_e4m3fnuz: ref, _ = torch._scaled_mm(self.inp, self.weights.t(), + scale_a=ONE, + scale_b=ONE, out_dtype=self.outdtype) else: ref = F.linear(self.inp, self.weights) diff --git a/gradlib/setup.py b/gradlib/setup.py index 0400741f61c85..e90eacfe2a7c2 100644 --- a/gradlib/setup.py +++ b/gradlib/setup.py @@ -125,11 +125,14 @@ extra_compile_args={ 'cxx': [ '-O3', + '-DLEGACY_HIPBLAS_DIRECT=ON', ], 'nvcc': [ - '-O3', '-U__CUDA_NO_HALF_OPERATORS__', + '-O3', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', - "-ftemplate-depth=1024" + "-ftemplate-depth=1024", + '-DLEGACY_HIPBLAS_DIRECT=ON', ] + extra_args })) ext_modules.append( @@ -142,11 +145,14 @@ extra_compile_args={ 'cxx': [ '-O3', + '-DLEGACY_HIPBLAS_DIRECT=ON', ], 'nvcc': [ - '-O3', '-U__CUDA_NO_HALF_OPERATORS__', + '-O3', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', - "-ftemplate-depth=1024" + "-ftemplate-depth=1024", + '-DLEGACY_HIPBLAS_DIRECT=ON', ] + extra_args })) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index a73a08856da40..20c96fbcaed90 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -10,7 +10,7 @@ # providing scaling factor for result. This value is created # as global value to avoid multiple tensor allocations, and # can be removed once pytorch fixes the bug. -TORCH_SCALED_MM_SCALE_RESULT = torch.ones(1).cuda() if is_hip() else None +TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() if is_hip() else None def cutlass_fp8_supported() -> bool: @@ -132,20 +132,17 @@ def apply_fp8_linear( per_tensor_weights = (weight_scale.numel() == 1) per_tensor_activations = (x_scale.numel() == 1) + global TORCH_DEVICE_IDENTITY + if TORCH_DEVICE_IDENTITY.device != weight.device: + TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device) if per_tensor_weights and per_tensor_activations: # Fused GEMM_DQ - global TORCH_SCALED_MM_SCALE_RESULT - if TORCH_SCALED_MM_SCALE_RESULT.device != weight.device: - TORCH_SCALED_MM_SCALE_RESULT = TORCH_SCALED_MM_SCALE_RESULT.to( - weight.device) - output = torch._scaled_mm( - qinput, - weight, - out_dtype=out_dtype, - scale_a=x_scale, - scale_b=weight_scale, - scale_result=TORCH_SCALED_MM_SCALE_RESULT, - bias=bias) + output = torch._scaled_mm(qinput, + weight, + out_dtype=out_dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias) # A fix for discrepancy in scaled_mm which returns tuple # for torch < 2.5 and a single value in torch >= 2.5 if type(output) is tuple and len(output) == 2: @@ -173,6 +170,8 @@ def apply_fp8_linear( # Output in fp32 to allow subsequent ops to happen in-place output, _ = torch._scaled_mm(qinput, weight, + scale_a=TORCH_DEVICE_IDENTITY, + scale_b=TORCH_DEVICE_IDENTITY, out_dtype=torch.float32) # Unpad (undo num_token_padding) output = torch.narrow(output, 0, 0, input.shape[0])