Skip to content

Commit

Permalink
S3 full support
Browse files Browse the repository at this point in the history
Signed-off-by: OmerD <[email protected]>
  • Loading branch information
omer-dayan committed Nov 14, 2024
1 parent bac1fd7 commit c1f75d6
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 8 deletions.
1 change: 1 addition & 0 deletions docs/requirements-docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pydantic >= 2.8
torch
py-cpuinfo
transformers
boto3
mistral_common >= 1.3.4
aiohttp
starlette
Expand Down
1 change: 1 addition & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we n
einops # Required for Qwen2-VL.
compressed-tensors == 0.8.0 # required for compressed-tensors
runai-model-streamer
boto3
13 changes: 13 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
ConfigFormat, get_config, get_hf_image_processor_config,
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
from vllm.transformers_utils.s3_utils import S3Model, is_s3
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
print_warning_once)

Expand Down Expand Up @@ -195,6 +196,18 @@ def __init__(
f"'Please instead use `--hf-overrides '{hf_override!r}'`")
warnings.warn(DeprecationWarning(msg), stacklevel=2)

if is_s3(model):
self.s3_model = S3Model()
self.s3_model.pull_files(model, allow_pattern=["*config.json"])
self.model_weights = self.model
self.model = self.s3_model.dir

if is_s3(tokenizer):
self.s3_tokenizer = S3Model()
self.s3_tokenizer.pull_files(
model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
self.tokenizer = self.s3_tokenizer.dir

# The tokenizer version is consistent with the model version by default.
if tokenizer_revision is None:
self.tokenizer_revision = revision
Expand Down
22 changes: 14 additions & 8 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
runai_safetensors_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.transformers_utils.s3_utils import glob as s3_glob
from vllm.transformers_utils.s3_utils import is_s3
from vllm.utils import is_pin_memory_available


Expand Down Expand Up @@ -1180,28 +1182,31 @@ def _prepare_weights(self, model_name_or_path: str,
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
is_s3_path = is_s3(model_name_or_path)
is_local = os.path.isdir(model_name_or_path)
safetensors_pattern = "*.safetensors"
index_file = SAFE_WEIGHTS_INDEX_NAME

hf_folder = (model_name_or_path
if is_local else download_weights_from_hf(
hf_folder = (model_name_or_path if
(is_local or is_s3_path) else download_weights_from_hf(
model_name_or_path,
self.load_config.download_dir,
[safetensors_pattern],
revision,
ignore_patterns=self.load_config.ignore_patterns,
))

hf_weights_files = glob.glob(
os.path.join(hf_folder, safetensors_pattern))
if is_s3_path:
hf_weights_files = s3_glob(path=hf_folder,
allow_pattern=[safetensors_pattern])
else:
hf_weights_files = glob.glob(
os.path.join(hf_folder, safetensors_pattern))

if not is_local:
if not is_local and not is_s3_path:
download_safetensors_index_file_from_hf(
model_name_or_path, index_file, self.load_config.download_dir,
revision)
hf_weights_files = filter_duplicate_safetensors_files(
hf_weights_files, hf_folder, index_file)

if len(hf_weights_files) == 0:
raise RuntimeError(
Expand Down Expand Up @@ -1229,8 +1234,9 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
with target_device:
model = _initialize_model(vllm_config=vllm_config)

assert hasattr(model_config, "model_weights")
model.load_weights(
self._get_weights_iterator(model_config.model,
self._get_weights_iterator(model_config.model_weights,
model_config.revision))

for _, module in model.named_modules():
Expand Down
103 changes: 103 additions & 0 deletions vllm/transformers_utils/s3_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from __future__ import annotations

import fnmatch
import os
import shutil
import signal
import tempfile
from pathlib import Path

import boto3


def is_s3(model_or_path: str) -> bool:
return model_or_path.lower().startswith('s3://')


def _filter_allow(paths: list[str], patterns: list[str]) -> list[str]:
return [
path for path in paths if any(
fnmatch.fnmatch(path, pattern) for pattern in patterns)
]


def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]:
return [
path for path in paths
if not any(fnmatch.fnmatch(path, pattern) for pattern in patterns)
]


def glob(s3=None,
path: str = "",
allow_pattern: list[str] | None = None) -> list[str]:
if s3 is None:
s3 = boto3.client("s3")
bucket_name, _, paths = list_files(s3,
path=path,
allow_pattern=allow_pattern)
return [f"s3://{bucket_name}/{path}" for path in paths]


def list_files(
s3,
path: str,
allow_pattern: list[str] | None = None,
ignore_pattern: list[str] | None = None) -> tuple[str, str, list[str]]:
parts = path.removeprefix('s3://').split('/')
prefix = '/'.join(parts[1:])
bucket_name = parts[0]

objects = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
paths = [obj['Key'] for obj in objects.get('Contents', [])]

paths = _filter_ignore(paths, ["*/"])
if allow_pattern is not None:
paths = _filter_allow(paths, allow_pattern)

if ignore_pattern is not None:
paths = _filter_ignore(paths, ignore_pattern)

return bucket_name, prefix, paths


class S3Model:

def __init__(self) -> None:
self.s3 = boto3.client('s3')
for sig in (signal.SIGINT, signal.SIGTERM):
existing_handler = signal.getsignal(sig)
signal.signal(sig, self.close_by_signal(existing_handler))
self.dir = tempfile.mkdtemp(dir='/dev/shm')

def __del__(self):
self.close()

def close(self) -> None:
if os.path.exists(self.dir):
shutil.rmtree(self.dir)

def close_by_signal(self, existing_handler=None):

def new_handler(signum, frame):
self.close()
if existing_handler:
existing_handler(signum, frame)

return new_handler

def pull_files(self,
s3_model_path: str = "",
allow_pattern: list[str] | None = None,
ignore_pattern: list[str] | None = None) -> None:
bucket_name, base_dir, files = list_files(self.s3, s3_model_path,
allow_pattern,
ignore_pattern)
if len(files) == 0:
return

for file in files:
destination_file = self.dir + file.removeprefix(base_dir)
local_dir = Path(destination_file).parent
os.makedirs(local_dir, exist_ok=True)
self.s3.download_file(bucket_name, file, destination_file)

0 comments on commit c1f75d6

Please sign in to comment.