diff --git a/CMakeLists.txt b/CMakeLists.txt
index aa15b632cdd3b..801429096eaab 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -32,8 +32,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11
# versions are derived from Dockerfile.rocm
#
set(TORCH_SUPPORTED_VERSION_CUDA "2.3.0")
-set(TORCH_SUPPORTED_VERSION_ROCM_5X "2.0.1")
-set(TORCH_SUPPORTED_VERSION_ROCM_6X "2.1.1")
+set(TORCH_SUPPORTED_VERSION_ROCM "2.4.0")
#
# Try to find python package with an executable that exactly matches
@@ -98,18 +97,11 @@ elseif(HIP_FOUND)
# .hip extension automatically, HIP must be enabled explicitly.
enable_language(HIP)
- # ROCm 5.x
- if (ROCM_VERSION_DEV_MAJOR EQUAL 5 AND
- NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM_5X})
- message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM_5X} "
- "expected for ROCMm 5.x build, saw ${Torch_VERSION} instead.")
- endif()
-
- # ROCm 6.x
- if (ROCM_VERSION_DEV_MAJOR EQUAL 6 AND
- NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM_6X})
- message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM_6X} "
- "expected for ROCMm 6.x build, saw ${Torch_VERSION} instead.")
+ # ROCm 5.X and 6.X
+ if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND
+ NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM})
+ message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM} "
+ "expected for ROCm build, saw ${Torch_VERSION} instead.")
endif()
else()
message(FATAL_ERROR "Can't find CUDA or HIP installation.")
diff --git a/Dockerfile.rocm b/Dockerfile.rocm
index 6bda696859c8b..652f04adf8959 100644
--- a/Dockerfile.rocm
+++ b/Dockerfile.rocm
@@ -1,34 +1,35 @@
-# default base image
-ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
-
-FROM $BASE_IMAGE
-
-ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
-
-RUN echo "Base image is $BASE_IMAGE"
-
-ARG ROCm_5_7_BASE="rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" \
- ROCm_6_0_BASE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
-
+# Default ROCm 6.1 base image
+ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"
+
+# Tested and supported base rocm/pytorch images
+ARG ROCm_5_7_BASE="rocm/pytorch:rocm5.7_ubuntu20.04_py3.9_pytorch_2.0.1" \
+ ROCm_6_0_BASE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" \
+ ROCM_6_1_BASE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"
+
+# Default ROCm ARCHes to build vLLM for.
+ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100"
+
+# Whether to build CK-based flash-attention
+# If 0, will not build flash attention
+# This is useful for gfx target where flash-attention is not supported
+# (i.e. those that do not appear in `FA_GFX_ARCHS`)
+# Triton FA is used by default on ROCm now so this is unnecessary.
+ARG BUILD_FA="1"
ARG FA_GFX_ARCHS="gfx90a;gfx942"
-RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
-
ARG FA_BRANCH="ae7928c"
-RUN echo "FA_BRANCH is $FA_BRANCH"
-# whether to build flash-attention
-# if 0, will not build flash attention
-# this is useful for gfx target where flash-attention is not supported
-# In that case, we need to use the python reference attention implementation in vllm
-ARG BUILD_FA="1"
-
-# whether to build triton on rocm
+# Whether to build triton on rocm
ARG BUILD_TRITON="1"
+ARG TRITON_BRANCH="0ef1848"
-# Install some basic utilities
-RUN apt-get update && apt-get install python3 python3-pip -y
+### Base image build stage
+FROM $BASE_IMAGE AS base
+
+# Import arg(s) defined before this build stage
+ARG PYTORCH_ROCM_ARCH
# Install some basic utilities
+RUN apt-get update && apt-get install python3 python3-pip -y
RUN apt-get update && apt-get install -y \
curl \
ca-certificates \
@@ -39,79 +40,159 @@ RUN apt-get update && apt-get install -y \
build-essential \
wget \
unzip \
- nvidia-cuda-toolkit \
tmux \
ccache \
&& rm -rf /var/lib/apt/lists/*
-### Mount Point ###
-# When launching the container, mount the code directory to /app
+# When launching the container, mount the code directory to /vllm-workspace
ARG APP_MOUNT=/vllm-workspace
-VOLUME [ ${APP_MOUNT} ]
WORKDIR ${APP_MOUNT}
-RUN python3 -m pip install --upgrade pip
-RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
+RUN pip install --upgrade pip
+# Remove sccache so it doesn't interfere with ccache
+# TODO: implement sccache support across components
+RUN apt-get purge -y sccache; pip uninstall -y sccache; rm -f "$(which sccache)"
+# Install torch == 2.4.0 on ROCm
+RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
+ *"rocm-5.7"*) \
+ pip uninstall -y torch \
+ && pip install --no-cache-dir --pre torch==2.4.0.dev20240612 \
+ --index-url https://download.pytorch.org/whl/nightly/rocm5.7;; \
+ *"rocm-6.0"*) \
+ pip uninstall -y torch \
+ && pip install --no-cache-dir --pre torch==2.4.0.dev20240612 \
+ --index-url https://download.pytorch.org/whl/nightly/rocm6.0;; \
+ *"rocm-6.1"*) \
+ pip uninstall -y torch \
+ && pip install --no-cache-dir --pre torch==2.4.0.dev20240612 \
+ --index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \
+ *) ;; esac
ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
-# Install ROCm flash-attention
-RUN if [ "$BUILD_FA" = "1" ]; then \
- mkdir libs \
+ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
+ENV CCACHE_DIR=/root/.cache/ccache
+
+
+### AMD-SMI build stage
+FROM base AS build_amdsmi
+# Build amdsmi wheel always
+RUN cd /opt/rocm/share/amd_smi \
+ && pip wheel . --wheel-dir=/install
+
+
+### Flash-Attention wheel build stage
+FROM base AS build_fa
+ARG BUILD_FA
+ARG FA_GFX_ARCHS
+ARG FA_BRANCH
+# Build ROCm flash-attention wheel if `BUILD_FA = 1`
+RUN --mount=type=cache,target=${CCACHE_DIR} \
+ if [ "$BUILD_FA" = "1" ]; then \
+ mkdir -p libs \
&& cd libs \
&& git clone https://github.com/ROCm/flash-attention.git \
&& cd flash-attention \
- && git checkout ${FA_BRANCH} \
+ && git checkout "${FA_BRANCH}" \
&& git submodule update --init \
- && export GPU_ARCHS=${FA_GFX_ARCHS} \
- && if [ "$BASE_IMAGE" = "$ROCm_5_7_BASE" ]; then \
- patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \
- && python3 setup.py install \
- && cd ..; \
+ && case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
+ *"rocm-5.7"*) \
+ export VLLM_TORCH_PATH="$(python3 -c 'import torch; print(torch.__path__[0])')" \
+ && patch "${VLLM_TORCH_PATH}"/utils/hipify/hipify_python.py hipify_patch.patch;; \
+ *) ;; esac \
+ && GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \
+ # Create an empty directory otherwise as later build stages expect one
+ else mkdir -p /install; \
fi
-# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
-# Manually removed it so that later steps of numpy upgrade can continue
-RUN if [ "$BASE_IMAGE" = "$ROCm_6_0_BASE" ]; then \
- rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
-# build triton
-RUN if [ "$BUILD_TRITON" = "1" ]; then \
+### Triton wheel build stage
+FROM base AS build_triton
+ARG BUILD_TRITON
+ARG TRITON_BRANCH
+# Build triton wheel if `BUILD_TRITON = 1`
+RUN --mount=type=cache,target=${CCACHE_DIR} \
+ if [ "$BUILD_TRITON" = "1" ]; then \
mkdir -p libs \
&& cd libs \
- && pip uninstall -y triton \
- && git clone https://github.com/ROCm/triton.git \
- && cd triton/python \
- && pip3 install . \
- && cd ../..; \
+ && git clone https://github.com/OpenAI/triton.git \
+ && cd triton \
+ && git checkout "${TRITON_BRANCH}" \
+ && cd python \
+ && python3 setup.py bdist_wheel --dist-dir=/install; \
+ # Create an empty directory otherwise as later build stages expect one
+ else mkdir -p /install; \
fi
-WORKDIR /vllm-workspace
+
+### Final vLLM build stage
+FROM base AS final
+# Import the vLLM development directory from the build context
COPY . .
-#RUN python3 -m pip install pynvml # to be removed eventually
-RUN python3 -m pip install --upgrade pip numba
+# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
+# Manually remove it so that later steps of numpy upgrade can continue
+RUN case "$(which python3)" in \
+ *"/opt/conda/envs/py_3.9"*) \
+ rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/;; \
+ *) ;; esac
+
+# Package upgrades for useful functionality or to avoid dependency issues
+RUN --mount=type=cache,target=/root/.cache/pip \
+ pip install --upgrade numba scipy huggingface-hub[cli]
-# make sure punica kernels are built (for LoRA)
+# Make sure punica kernels are built (for LoRA)
ENV VLLM_INSTALL_PUNICA_KERNELS=1
# Workaround for ray >= 2.10.0
ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
+# Silences the HF Tokenizers warning
+ENV TOKENIZERS_PARALLELISM=false
-ENV VLLM_NCCL_SO_PATH=/opt/rocm/lib/librccl.so
-
-ENV CCACHE_DIR=/root/.cache/ccache
-RUN --mount=type=cache,target=/root/.cache/ccache \
+RUN --mount=type=cache,target=${CCACHE_DIR} \
--mount=type=cache,target=/root/.cache/pip \
pip install -U -r requirements-rocm.txt \
- && if [ "$BASE_IMAGE" = "$ROCm_6_0_BASE" ]; then \
- patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch; fi \
- && python3 setup.py install \
- && export VLLM_PYTHON_VERSION=$(python -c "import sys; print(str(sys.version_info.major) + str(sys.version_info.minor))") \
- && cp build/lib.linux-x86_64-cpython-${VLLM_PYTHON_VERSION}/vllm/*.so vllm/ \
- && cd ..
+ && case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
+ *"rocm-6.0"*) \
+ patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h rocm_patch/rocm_bf16.patch;; \
+ *"rocm-6.1"*) \
+ # Bring in upgrades to HIP graph earlier than ROCm 6.2 for vLLM
+ wget -N https://github.com/ROCm/vllm/raw/fa78403/rocm_patch/libamdhip64.so.6 -P rocm_patch \
+ && cp rocm_patch/libamdhip64.so.6 /opt/rocm/lib/libamdhip64.so.6 \
+ # Prevent interference if torch bundles its own HIP runtime
+ && rm -f "$(python3 -c 'import torch; print(torch.__path__[0])')"/lib/libamdhip64.so* || true;; \
+ *) ;; esac \
+ && python3 setup.py clean --all \
+ && python3 setup.py develop
+
+# Copy amdsmi wheel into final image
+RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install \
+ mkdir -p libs \
+ && cp /install/*.whl libs \
+ # Preemptively uninstall to avoid same-version no-installs
+ && pip uninstall -y amdsmi;
+# Copy triton wheel(s) into final image if they were built
+RUN --mount=type=bind,from=build_triton,src=/install,target=/install \
+ mkdir -p libs \
+ && if ls /install/*.whl; then \
+ cp /install/*.whl libs \
+ # Preemptively uninstall to avoid same-version no-installs
+ && pip uninstall -y triton; fi
+
+# Copy flash-attn wheel(s) into final image if they were built
+RUN --mount=type=bind,from=build_fa,src=/install,target=/install \
+ mkdir -p libs \
+ && if ls /install/*.whl; then \
+ cp /install/*.whl libs \
+ # Preemptively uninstall to avoid same-version no-installs
+ && pip uninstall -y flash-attn; fi
+
+# Install wheels that were built to the final image
+RUN --mount=type=cache,target=/root/.cache/pip \
+ if ls libs/*.whl; then \
+ pip install libs/*.whl; fi
CMD ["/bin/bash"]
diff --git a/cmake/utils.cmake b/cmake/utils.cmake
index 071e16336dfa2..4869cad541135 100644
--- a/cmake/utils.cmake
+++ b/cmake/utils.cmake
@@ -147,19 +147,23 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES)
if (${GPU_LANG} STREQUAL "HIP")
#
# `GPU_ARCHES` controls the `--offload-arch` flags.
- # `CMAKE_HIP_ARCHITECTURES` is set up by torch and can be controlled
- # via the `PYTORCH_ROCM_ARCH` env variable.
#
-
+ # If PYTORCH_ROCM_ARCH env variable exists, then we take it as a list,
+ # if not, then we use CMAKE_HIP_ARCHITECTURES which was generated by calling
+ # "rocm_agent_enumerator" in "enable_language(HIP)"
+ # (in file Modules/CMakeDetermineHIPCompiler.cmake)
+ #
+ if(DEFINED ENV{PYTORCH_ROCM_ARCH})
+ set(HIP_ARCHITECTURES $ENV{PYTORCH_ROCM_ARCH})
+ else()
+ set(HIP_ARCHITECTURES ${CMAKE_HIP_ARCHITECTURES})
+ endif()
#
# Find the intersection of the supported + detected architectures to
# set the module architecture flags.
#
-
- set(VLLM_ROCM_SUPPORTED_ARCHS "gfx908;gfx90a;gfx942;gfx1100")
-
set(${GPU_ARCHES})
- foreach (_ARCH ${VLLM_ROCM_SUPPORTED_ARCHS})
+ foreach (_ARCH ${HIP_ARCHITECTURES})
if (_ARCH IN_LIST _GPU_SUPPORTED_ARCHES_LIST)
list(APPEND ${GPU_ARCHES} ${_ARCH})
endif()
@@ -167,7 +171,7 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES)
if(NOT ${GPU_ARCHES})
message(FATAL_ERROR
- "None of the detected ROCm architectures: ${CMAKE_HIP_ARCHITECTURES} is"
+ "None of the detected ROCm architectures: ${HIP_ARCHITECTURES} is"
" supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.")
endif()
diff --git a/docs/source/getting_started/amd-installation.rst b/docs/source/getting_started/amd-installation.rst
index 61fcd45a26347..cc41d47296f8d 100644
--- a/docs/source/getting_started/amd-installation.rst
+++ b/docs/source/getting_started/amd-installation.rst
@@ -88,7 +88,7 @@ Option 2: Build from source
- `Pytorch `_
- `hipBLAS `_
-For installing PyTorch, you can start from a fresh docker image, e.g, `rocm6.0.2_ubuntu22.04_py3.10_pytorch_2.1.2`, `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1`, `rocm/pytorch-nightly`.
+For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging`, `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1`, `rocm/pytorch-nightly`.
Alternatively, you can install pytorch using pytorch wheels. You can check Pytorch installation guild in Pytorch `Getting Started `_
@@ -126,12 +126,12 @@ Install ROCm's flash attention (v2.0.4) following the instructions from `ROCm/fl
$ cd vllm
$ pip install -U -r requirements-rocm.txt
- $ python setup.py install # This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation
+ $ python setup.py develop # This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation
.. tip::
- You may need to turn on the ``--enforce-eager`` flag if you experience process hang when running the `benchmark_thoughput.py` script to test your installation.
- Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers.
- - To use CK flash-attention, please use this flag ``export VLLM_USE_FLASH_ATTN_TRITON=0`` to turn off triton flash attention.
+ - To use CK flash-attention, please use this flag ``export VLLM_USE_TRITON_FLASH_ATTN=0`` to turn off triton flash attention.
- The ROCm version of pytorch, ideally, should match the ROCm driver version.
diff --git a/tests/async_engine/test_openapi_server_ray.py b/tests/async_engine/test_openapi_server_ray.py
index cc05d79e56874..332937b874e93 100644
--- a/tests/async_engine/test_openapi_server_ray.py
+++ b/tests/async_engine/test_openapi_server_ray.py
@@ -4,7 +4,7 @@
# and debugging.
import ray
-from ..utils import VLLM_PATH, RemoteOpenAIServer
+from ..utils import RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME = "facebook/opt-125m"
@@ -12,7 +12,7 @@
@pytest.fixture(scope="module")
def ray_ctx():
- ray.init(runtime_env={"working_dir": VLLM_PATH})
+ ray.init()
yield
ray.shutdown()
diff --git a/tests/distributed/test_utils.py b/tests/distributed/test_utils.py
index 49d11daca9aec..9ff11b0d27b11 100644
--- a/tests/distributed/test_utils.py
+++ b/tests/distributed/test_utils.py
@@ -1,8 +1,8 @@
-import os
-
import ray
-from vllm.utils import cuda_device_count_stateless
+import vllm.envs as envs
+from vllm.utils import (cuda_device_count_stateless, is_hip,
+ update_environment_variables)
@ray.remote
@@ -12,16 +12,21 @@ def get_count(self):
return cuda_device_count_stateless()
def set_cuda_visible_devices(self, cuda_visible_devices: str):
- os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
+ update_environment_variables(
+ {"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
def get_cuda_visible_devices(self):
- return os.environ["CUDA_VISIBLE_DEVICES"]
+ return envs.CUDA_VISIBLE_DEVICES
def test_cuda_device_count_stateless():
"""Test that cuda_device_count_stateless changes return value if
CUDA_VISIBLE_DEVICES is changed."""
-
+ if is_hip():
+ # Set HIP_VISIBLE_DEVICES == CUDA_VISIBLE_DEVICES. Conversion
+ # is handled by `update_environment_variables`
+ update_environment_variables(
+ {"CUDA_VISIBLE_DEVICES": envs.CUDA_VISIBLE_DEVICES})
actor = _CUDADeviceCountStatelessTestActor.options( # type: ignore
num_gpus=2).remote()
assert sorted(ray.get(
diff --git a/tests/entrypoints/test_openai_embedding.py b/tests/entrypoints/test_openai_embedding.py
index 2496d2ac3e97d..45f701733df0c 100644
--- a/tests/entrypoints/test_openai_embedding.py
+++ b/tests/entrypoints/test_openai_embedding.py
@@ -2,7 +2,7 @@
import pytest
import ray
-from ..utils import VLLM_PATH, RemoteOpenAIServer
+from ..utils import RemoteOpenAIServer
EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
@@ -11,7 +11,7 @@
@pytest.fixture(scope="module")
def ray_ctx():
- ray.init(runtime_env={"working_dir": VLLM_PATH})
+ ray.init()
yield
ray.shutdown()
diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py
index c22a675ff1230..5196d81815502 100644
--- a/tests/entrypoints/test_openai_server.py
+++ b/tests/entrypoints/test_openai_server.py
@@ -16,7 +16,7 @@
from vllm.transformers_utils.tokenizer import get_tokenizer
-from ..utils import VLLM_PATH, RemoteOpenAIServer
+from ..utils import RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
@@ -81,7 +81,7 @@ def zephyr_lora_files():
@pytest.fixture(scope="module")
def ray_ctx():
- ray.init(runtime_env={"working_dir": VLLM_PATH})
+ ray.init()
yield
ray.shutdown()
diff --git a/tests/entrypoints/test_openai_vision.py b/tests/entrypoints/test_openai_vision.py
index 03dc5d1161f0e..0e8d88b76ffec 100644
--- a/tests/entrypoints/test_openai_vision.py
+++ b/tests/entrypoints/test_openai_vision.py
@@ -8,7 +8,7 @@
from vllm.multimodal.utils import ImageFetchAiohttp, encode_image_base64
-from ..utils import VLLM_PATH, RemoteOpenAIServer
+from ..utils import RemoteOpenAIServer
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
LLAVA_CHAT_TEMPLATE = (Path(__file__).parent.parent.parent /
@@ -27,7 +27,7 @@
@pytest.fixture(scope="module")
def ray_ctx():
- ray.init(runtime_env={"working_dir": VLLM_PATH})
+ ray.init()
yield
ray.shutdown()
diff --git a/tests/utils.py b/tests/utils.py
index 174efca4af532..2a5f82b91c42c 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -15,9 +15,30 @@
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.utils import get_open_port, is_hip
-if (not is_hip()):
+if is_hip():
+ from amdsmi import (amdsmi_get_gpu_vram_usage,
+ amdsmi_get_processor_handles, amdsmi_init,
+ amdsmi_shut_down)
+
+ @contextmanager
+ def _nvml():
+ try:
+ amdsmi_init()
+ yield
+ finally:
+ amdsmi_shut_down()
+else:
from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo,
- nvmlInit)
+ nvmlInit, nvmlShutdown)
+
+ @contextmanager
+ def _nvml():
+ try:
+ nvmlInit()
+ yield
+ finally:
+ nvmlShutdown()
+
# Path to root of repository so that utilities can be imported by ray workers
VLLM_PATH = os.path.abspath(os.path.join(__file__, os.pardir, os.pardir))
@@ -160,20 +181,25 @@ def error_on_warning():
yield
+@_nvml()
def wait_for_gpu_memory_to_clear(devices: List[int],
threshold_bytes: int,
timeout_s: float = 120) -> None:
# Use nvml instead of pytorch to reduce measurement error from torch cuda
# context.
- nvmlInit()
start_time = time.time()
while True:
output: Dict[int, str] = {}
output_raw: Dict[int, float] = {}
for device in devices:
- dev_handle = nvmlDeviceGetHandleByIndex(device)
- mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
- gb_used = mem_info.used / 2**30
+ if is_hip():
+ dev_handle = amdsmi_get_processor_handles()[device]
+ mem_info = amdsmi_get_gpu_vram_usage(dev_handle)
+ gb_used = mem_info["vram_used"] / 2**10
+ else:
+ dev_handle = nvmlDeviceGetHandleByIndex(device)
+ mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
+ gb_used = mem_info.used / 2**30
output_raw[device] = gb_used
output[device] = f'{gb_used:.02f}'
diff --git a/vllm/config.py b/vllm/config.py
index 0217a2b569928..0c4d770e46847 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -7,13 +7,15 @@
import torch
from transformers import PretrainedConfig, PreTrainedTokenizerBase
+import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.models import ModelRegistry
from vllm.tracing import is_otel_installed
from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
- is_hip, is_neuron, is_tpu, is_xpu)
+ is_hip, is_neuron, is_tpu, is_xpu,
+ update_environment_variables)
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
@@ -634,6 +636,12 @@ def __init__(
self.distributed_executor_backend = backend
logger.info("Defaulting to use %s for distributed inference",
backend)
+ # If CUDA_VISIBLE_DEVICES is set on ROCm prior to vLLM init,
+ # propagate changes to HIP_VISIBLE_DEVICES (conversion handled by
+ # the update_environment_variables function)
+ if is_hip() and envs.CUDA_VISIBLE_DEVICES:
+ update_environment_variables(
+ {"CUDA_VISIBLE_DEVICES": envs.CUDA_VISIBLE_DEVICES})
self._verify_args()
diff --git a/vllm/distributed/device_communicators/custom_all_reduce_utils.py b/vllm/distributed/device_communicators/custom_all_reduce_utils.py
index d3e41fa710676..6f1aaed9881a2 100644
--- a/vllm/distributed/device_communicators/custom_all_reduce_utils.py
+++ b/vllm/distributed/device_communicators/custom_all_reduce_utils.py
@@ -13,7 +13,8 @@
import vllm.envs as envs
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
from vllm.logger import init_logger
-from vllm.utils import cuda_device_count_stateless
+from vllm.utils import (cuda_device_count_stateless,
+ update_environment_variables)
logger = init_logger(__name__)
@@ -24,7 +25,8 @@ def producer(batch_src: Sequence[int],
result_queue,
cuda_visible_devices: Optional[str] = None):
if cuda_visible_devices is not None:
- os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
+ update_environment_variables(
+ {"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
lib = CudaRTLibrary()
for i in batch_src:
@@ -56,7 +58,8 @@ def consumer(batch_tgt: Sequence[int],
result_queue,
cuda_visible_devices: Optional[str] = None):
if cuda_visible_devices is not None:
- os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
+ update_environment_variables(
+ {"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
lib = CudaRTLibrary()
for j in batch_tgt:
@@ -123,7 +126,7 @@ def can_actually_p2p(
processes for testing all pairs of GPUs in batch. The trick is to reset
the device after each test (which is not available in PyTorch).
""" # noqa
- cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None)
+ cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
# pass the CUDA_VISIBLE_DEVICES to the child process
# to make sure they see the same set of GPUs
diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py
index e63e5a3a027fa..a5b1d27f27596 100644
--- a/vllm/executor/multiproc_gpu_executor.py
+++ b/vllm/executor/multiproc_gpu_executor.py
@@ -11,7 +11,8 @@
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (cuda_device_count_stateless,
get_distributed_init_method, get_open_port,
- get_vllm_instance_id, make_async)
+ get_vllm_instance_id, make_async,
+ update_environment_variables)
logger = init_logger(__name__)
@@ -25,8 +26,9 @@ def _init_executor(self) -> None:
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
if "CUDA_VISIBLE_DEVICES" not in os.environ:
- os.environ["CUDA_VISIBLE_DEVICES"] = (",".join(
- map(str, range(world_size))))
+ update_environment_variables({
+ "CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
+ })
# Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id()
diff --git a/vllm/utils.py b/vllm/utils.py
index f0c7df5cf8c22..92abdb3fb9b14 100644
--- a/vllm/utils.py
+++ b/vllm/utils.py
@@ -376,6 +376,10 @@ def get_open_port() -> int:
def update_environment_variables(envs: Dict[str, str]):
+ if is_hip() and "CUDA_VISIBLE_DEVICES" in envs:
+ # Propagate changes to CUDA_VISIBLE_DEVICES to
+ # ROCm's HIP_VISIBLE_DEVICES as well
+ envs["HIP_VISIBLE_DEVICES"] = envs["CUDA_VISIBLE_DEVICES"]
for k, v in envs.items():
if k in os.environ and os.environ[k] != v:
logger.warning(
@@ -779,9 +783,14 @@ def _cuda_device_count_stateless(
if not torch.cuda._is_compiled():
return 0
- # bypass _device_count_nvml() if rocm (not supported)
- nvml_count = -1 if torch.version.hip else torch.cuda._device_count_nvml()
- r = torch._C._cuda_getDeviceCount() if nvml_count < 0 else nvml_count
+ if is_hip():
+ # ROCm uses amdsmi instead of nvml for stateless device count
+ # This requires a sufficiently modern version of Torch 2.4.0
+ raw_count = torch.cuda._device_count_amdsmi() if (hasattr(
+ torch.cuda, "_device_count_amdsmi")) else -1
+ else:
+ raw_count = torch.cuda._device_count_nvml()
+ r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
return r
@@ -795,7 +804,6 @@ def cuda_device_count_stateless() -> int:
# This can be removed and simply replaced with torch.cuda.get_device_count
# after https://github.com/pytorch/pytorch/pull/122815 is released.
-
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py
index dc09718de4a32..99482aa93bc59 100644
--- a/vllm/worker/worker_base.py
+++ b/vllm/worker/worker_base.py
@@ -6,7 +6,7 @@
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput
-from vllm.utils import (enable_trace_function_call_for_thread,
+from vllm.utils import (enable_trace_function_call_for_thread, is_hip,
update_environment_variables)
logger = init_logger(__name__)
@@ -125,6 +125,14 @@ def update_environment_variables(envs: Dict[str, str]) -> None:
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
# suppress the warning in `update_environment_variables`
del os.environ[key]
+ if is_hip():
+ hip_env_var = "HIP_VISIBLE_DEVICES"
+ if hip_env_var in os.environ:
+ logger.warning(
+ "Ignoring pre-set environment variable `%s=%s` as "
+ "%s has also been set, which takes precedence.",
+ hip_env_var, os.environ[hip_env_var], key)
+ os.environ.pop(hip_env_var, None)
update_environment_variables(envs)
def init_worker(self, *args, **kwargs):