Skip to content

Commit

Permalink
Change RunAI as optional dependency
Browse files Browse the repository at this point in the history
Signed-off-by: OmerD <[email protected]>
  • Loading branch information
omer-dayan committed Nov 19, 2024
1 parent 7ce8dbd commit fdbce70
Show file tree
Hide file tree
Showing 10 changed files with 50 additions and 23 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ FROM vllm-base AS vllm-openai

# install additional dependencies for openai api server
RUN --mount=type=cache,target=/root/.cache/pip \
pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.44.0' timm==0.9.10
pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.44.0' timm==0.9.10 boto3 runai-model-streamer runai-model-streamer[s3]

ENV VLLM_USAGE_SOURCE production-docker-image

Expand Down
1 change: 0 additions & 1 deletion docs/requirements-docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ pydantic >= 2.8
torch
py-cpuinfo
transformers
boto3
mistral_common >= 1.3.4
aiohttp
starlette
Expand Down
5 changes: 5 additions & 0 deletions docs/source/serving/runai_model_streamer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ Run:ai Model Streamer is a library to read tensors in concurrency, while streami
Further reading can be found in `Run:ai Model Streamer Documentation <https://github.com/run-ai/runai-model-streamer/blob/master/docs/README.md>`_.

vLLM supports loading weights in Safetensors format using the Run:ai Model Streamer.
You first need to install vLLM RunAI optional dependency:

.. code-block:: console
$ pip3 install vllm[runai]
To run it as an OpenAI-compatible server, add the `--load-format runai_streamer` flag:

Expand Down
2 changes: 0 additions & 2 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,3 @@ six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that need
setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
einops # Required for Qwen2-VL.
compressed-tensors == 0.8.0 # required for compressed-tensors
runai-model-streamer
boto3
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,7 @@ def _read_requirements(filename: str) -> List[str]:
ext_modules=ext_modules,
extras_require={
"tensorizer": ["tensorizer>=2.9.0"],
"runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"],
"audio": ["librosa", "soundfile"], # Required for audio processing
"video": ["decord"] # Required for video processing
},
Expand Down
34 changes: 22 additions & 12 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
ConfigFormat, get_config, get_hf_image_processor_config,
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
from vllm.transformers_utils.s3_utils import S3Model, is_s3
from vllm.transformers_utils.utils import is_s3
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
identity, print_warning_once)

Expand Down Expand Up @@ -194,17 +194,27 @@ def __init__(
f"'Please instead use `--hf-overrides '{hf_override!r}'`")
warnings.warn(DeprecationWarning(msg), stacklevel=2)

if is_s3(model):
self.s3_model = S3Model()
self.s3_model.pull_files(model, allow_pattern=["*config.json"])
self.model_weights = self.model
self.model = self.s3_model.dir

if is_s3(tokenizer):
self.s3_tokenizer = S3Model()
self.s3_tokenizer.pull_files(
model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
self.tokenizer = self.s3_tokenizer.dir
if is_s3(model) or is_s3(tokenizer):
try:
from vllm.transformers_utils.s3_utils import S3Model
except ImportError as err:
raise ImportError(
"Please install Run:ai optional dependency "
"to use the S3 capabilities. "
"You can install it with: pip install vllm[runai]"
) from err

if is_s3(model):
self.s3_model = S3Model()
self.s3_model.pull_files(model, allow_pattern=["*config.json"])
self.model_weights = self.model
self.model = self.s3_model.dir

if is_s3(tokenizer):
self.s3_tokenizer = S3Model()
self.s3_tokenizer.pull_files(
model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
self.tokenizer = self.s3_tokenizer.dir

# The tokenizer version is consistent with the model version by default.
if tokenizer_revision is None:
Expand Down
13 changes: 11 additions & 2 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.plugins import set_current_vllm_config
from vllm.transformers_utils.s3_utils import glob as s3_glob
from vllm.transformers_utils.s3_utils import is_s3
from vllm.transformers_utils.utils import is_s3
from vllm.utils import is_pin_memory_available


Expand Down Expand Up @@ -1209,6 +1208,16 @@ def _prepare_weights(self, model_name_or_path: str,
If the model is not local, it will be downloaded."""
is_s3_path = is_s3(model_name_or_path)
if is_s3_path:
try:
from vllm.transformers_utils.s3_utils import glob as s3_glob
except ImportError as err:
raise ImportError(
"Please install Run:ai optional dependency "
"to use the S3 capabilities. "
"You can install it with: pip install vllm[runai]"
) from err

is_local = os.path.isdir(model_name_or_path)
safetensors_pattern = "*.safetensors"
index_file = SAFE_WEIGHTS_INDEX_NAME
Expand Down
7 changes: 6 additions & 1 deletion vllm/model_executor/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,12 @@ def runai_safetensors_weights_iterator(
hf_weights_files: List[str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
from runai_model_streamer import SafetensorsStreamer
try:
from runai_model_streamer import SafetensorsStreamer
except ImportError as err:
raise ImportError(
"Please install Run:ai optional dependency."
"You can install it with: pip install vllm[runai]") from err

enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
Expand Down
4 changes: 0 additions & 4 deletions vllm/transformers_utils/s3_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@
import boto3


def is_s3(model_or_path: str) -> bool:
return model_or_path.lower().startswith('s3://')


def _filter_allow(paths: list[str], patterns: list[str]) -> list[str]:
return [
path for path in paths if any(
Expand Down
4 changes: 4 additions & 0 deletions vllm/transformers_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
from typing import Union


def is_s3(model_or_path: str) -> bool:
return model_or_path.lower().startswith('s3://')


def check_gguf_file(model: Union[str, PathLike]) -> bool:
"""Check if the file is a GGUF model."""
model = Path(model)
Expand Down

0 comments on commit fdbce70

Please sign in to comment.