Skip to content

Commit

Permalink
Merge branch 'main' into kk/grok-1_fix-scale-factor
Browse files Browse the repository at this point in the history
  • Loading branch information
gshtras authored Sep 13, 2024
2 parents cb082eb + 164ce38 commit ea93135
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 27 deletions.
64 changes: 54 additions & 10 deletions Dockerfile.rocm
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# default base image
ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"
ARG BASE_IMAGE="rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0"

ARG COMMON_WORKDIR=/app

# The following ARGs should be "0" or "1". If "1", the respective component will be built and installed on top of the base image
ARG BUILD_HIPBLASLT="1"
ARG BUILD_HIPBLASLT="0"
ARG BUILD_RCCL="1"
ARG BUILD_FA="1"
ARG BUILD_TRITON="1"
ARG BUILD_PYTORCH="1"
# This ARG should also be "0" or "1". If "1", the vLLM development directory is obtained via git clone.
# If "0", it is copied in from the local working directory.
ARG REMOTE_VLLM="0"
Expand Down Expand Up @@ -39,11 +40,12 @@ WORKDIR ${COMMON_WORKDIR}
# -----------------------
# hipBLASLt build stages
FROM base AS build_hipblaslt
ARG HIPBLASLT_BRANCH="6f65c6e"
RUN git clone https://github.com/ROCm/hipBLASLt \
ARG HIPBLASLT_BRANCH="e6da924"
RUN apt-get purge -y hipblaslt \
&& git clone https://github.com/ROCm/hipBLASLt.git \
&& cd hipBLASLt \
&& git checkout ${HIPBLASLT_BRANCH} \
&& SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} \
&& SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} --legacy_hipblas_direct \
&& cd build/release \
&& make package
FROM scratch AS export_hipblaslt_1
Expand All @@ -55,7 +57,7 @@ FROM export_hipblaslt_${BUILD_HIPBLASLT} AS export_hipblaslt
# -----------------------
# RCCL build stages
FROM base AS build_rccl
ARG RCCL_BRANCH="73221b4"
ARG RCCL_BRANCH="rocm-6.2.0"
RUN git clone https://github.com/ROCm/rccl \
&& cd rccl \
&& git checkout ${RCCL_BRANCH} \
Expand All @@ -69,7 +71,7 @@ FROM export_rccl_${BUILD_RCCL} AS export_rccl
# -----------------------
# flash attn build stages
FROM base AS build_flash_attn
ARG FA_BRANCH="ae7928c"
ARG FA_BRANCH="3cea2fb"
ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
RUN git clone ${FA_REPO} \
&& cd flash-attention \
Expand All @@ -85,9 +87,9 @@ FROM export_flash_attn_${BUILD_FA} AS export_flash_attn
# -----------------------
# Triton build stages
FROM base AS build_triton
ARG TRITON_BRANCH="6ddb79b"
ARG TRITON_REPO="https://github.com/OpenAI/triton.git"
RUN git clone ${TRITON_REPO} \
ARG TRITON_BRANCH="e192dba"
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
RUN python3 -m pip install ninja cmake wheel pybind11 && git clone ${TRITON_REPO} \
&& cd triton \
&& git checkout ${TRITON_BRANCH} \
&& cd python \
Expand All @@ -105,6 +107,36 @@ RUN cd /opt/rocm/share/amd_smi \
FROM scratch AS export_amdsmi
COPY --from=build_amdsmi /opt/rocm/share/amd_smi/dist/*.whl /

FROM base as build_pytorch
# A commit to fix the output scaling factor issue in _scaled_mm
# Not yet in 2.5.0-rc1
ARG PYTORCH_BRANCH="cedc116"
ARG PYTORCH_VISION_BRANCH="v0.19.1"
ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
#RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
#if ls /install/*.deb; then \
# apt-get purge -y hipblaslt \
# && dpkg -i /install/*.deb \
# && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
# && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
#fi
RUN git clone ${PYTORCH_REPO} pytorch \
&& cd pytorch && git checkout ${PYTORCH_BRANCH} && git submodule update --init --recursive \
&& python tools/amd_build/build_amd.py \
&& CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \
&& pip install dist/*.whl \
&& cd .. \
&& git clone ${PYTORCH_VISION_REPO} vision \
&& cd vision && git checkout ${PYTORCH_VISION_BRANCH} \
&& python3 setup.py bdist_wheel --dist-dir=dist
FROM scratch as export_pytorch_1
ARG COMMON_WORKDIR
COPY --from=build_pytorch ${COMMON_WORKDIR}/pytorch/dist/*.whl /
COPY --from=build_pytorch ${COMMON_WORKDIR}/vision/dist/*.whl /
FROM scratch as export_pytorch_0
from export_pytorch_${BUILD_PYTORCH} as export_pytorch

# -----------------------
# vLLM (and gradlib) fetch stages
FROM base AS fetch_vllm_0
Expand All @@ -129,6 +161,11 @@ if ls /install/*.deb; then \
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
fi
# Install pytorch
RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \
if ls /install/*.whl; then \
pip install /install/*.whl; \
fi
# Build vLLM
RUN cd vllm \
&& python3 setup.py clean --all && python3 setup.py bdist_wheel --dist-dir=dist
Expand Down Expand Up @@ -197,6 +234,13 @@ RUN --mount=type=bind,from=export_amdsmi,src=/,target=/install \
pip uninstall -y amdsmi \
&& pip install /install/*.whl;

RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \
if ls /install/*.whl; then \
# Preemptively uninstall to prevent pip same-version no-installs
pip uninstall -y torch torchvision \
&& pip install /install/*.whl; \
fi

RUN python3 -m pip install --upgrade numba scipy huggingface-hub[cli]

# Install vLLM (and gradlib)
Expand Down
3 changes: 3 additions & 0 deletions gradlib/gradlib/GemmTuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
atol = 1

CACHE_INVALIDATE_BUFFERS = int(os.getenv("CACHE_INVALIDATE_BUFFERS", "37"))
ONE = torch.ones(1, dtype=torch.float32, device='cuda')


class Gemm:
Expand Down Expand Up @@ -68,6 +69,8 @@ def check_gemm_ref(self, libtype, solidx):
if self.indtype == torch.float8_e4m3fnuz:
ref, _ = torch._scaled_mm(self.inp,
self.weights.t(),
scale_a=ONE,
scale_b=ONE,
out_dtype=self.outdtype)
else:
ref = F.linear(self.inp, self.weights)
Expand Down
14 changes: 10 additions & 4 deletions gradlib/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,14 @@
extra_compile_args={
'cxx': [
'-O3',
'-DLEGACY_HIPBLAS_DIRECT=ON',
],
'nvcc': [
'-O3', '-U__CUDA_NO_HALF_OPERATORS__',
'-O3',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
"-ftemplate-depth=1024"
"-ftemplate-depth=1024",
'-DLEGACY_HIPBLAS_DIRECT=ON',
] + extra_args
}))
ext_modules.append(
Expand All @@ -142,11 +145,14 @@
extra_compile_args={
'cxx': [
'-O3',
'-DLEGACY_HIPBLAS_DIRECT=ON',
],
'nvcc': [
'-O3', '-U__CUDA_NO_HALF_OPERATORS__',
'-O3',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
"-ftemplate-depth=1024"
"-ftemplate-depth=1024",
'-DLEGACY_HIPBLAS_DIRECT=ON',
] + extra_args
}))

Expand Down
25 changes: 12 additions & 13 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# providing scaling factor for result. This value is created
# as global value to avoid multiple tensor allocations, and
# can be removed once pytorch fixes the bug.
TORCH_SCALED_MM_SCALE_RESULT = torch.ones(1).cuda() if is_hip() else None
TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() if is_hip() else None


def cutlass_fp8_supported() -> bool:
Expand Down Expand Up @@ -132,20 +132,17 @@ def apply_fp8_linear(
per_tensor_weights = (weight_scale.numel() == 1)
per_tensor_activations = (x_scale.numel() == 1)

global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY.device != weight.device:
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
if per_tensor_weights and per_tensor_activations:
# Fused GEMM_DQ
global TORCH_SCALED_MM_SCALE_RESULT
if TORCH_SCALED_MM_SCALE_RESULT.device != weight.device:
TORCH_SCALED_MM_SCALE_RESULT = TORCH_SCALED_MM_SCALE_RESULT.to(
weight.device)
output = torch._scaled_mm(
qinput,
weight,
out_dtype=out_dtype,
scale_a=x_scale,
scale_b=weight_scale,
scale_result=TORCH_SCALED_MM_SCALE_RESULT,
bias=bias)
output = torch._scaled_mm(qinput,
weight,
out_dtype=out_dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
Expand Down Expand Up @@ -173,6 +170,8 @@ def apply_fp8_linear(
# Output in fp32 to allow subsequent ops to happen in-place
output, _ = torch._scaled_mm(qinput,
weight,
scale_a=TORCH_DEVICE_IDENTITY,
scale_b=TORCH_DEVICE_IDENTITY,
out_dtype=torch.float32)
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, input.shape[0])
Expand Down

0 comments on commit ea93135

Please sign in to comment.