Skip to content

Commit

Permalink
S3 full support
Browse files Browse the repository at this point in the history
  • Loading branch information
omer-dayan committed Nov 13, 2024
1 parent 9d3cfa4 commit f8d1673
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 9 deletions.
12 changes: 12 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
print_warning_once)
from vllm.transformers_utils.s3_utils import is_s3, S3Model

if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
Expand Down Expand Up @@ -194,6 +195,17 @@ def __init__(
msg = ("`--rope-theta` will be removed in a future release. "
f"'Please instead use `--hf-overrides '{hf_override!r}'`")
warnings.warn(DeprecationWarning(msg), stacklevel=2)

if is_s3(model):
self.s3_model = S3Model()
self.s3_model.pull_files(model, allow_pattern=["*config.json"])
self.model_weights = self.model
self.model = self.s3_model.dir

if is_s3(tokenizer):
self.s3_tokenizer = S3Model()
self.s3_tokenizer.pull_files(model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
self.tokenizer = self.s3_tokenizer.dir

# The tokenizer version is consistent with the model version by default.
if tokenizer_revision is None:
Expand Down
21 changes: 13 additions & 8 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
runai_safetensors_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.transformers_utils.s3_utils import is_s3, glob as s3_glob
from vllm.utils import is_pin_memory_available


Expand Down Expand Up @@ -1161,28 +1162,31 @@ def _prepare_weights(self, model_name_or_path: str,
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
is_s3_path = is_s3(model_name_or_path)
is_local = os.path.isdir(model_name_or_path)
safetensors_pattern = "*.safetensors"
index_file = SAFE_WEIGHTS_INDEX_NAME

hf_folder = (model_name_or_path
if is_local else download_weights_from_hf(
hf_folder = (model_name_or_path if
(is_local or is_s3_path) else download_weights_from_hf(
model_name_or_path,
self.load_config.download_dir,
[safetensors_pattern],
revision,
ignore_patterns=self.load_config.ignore_patterns,
))

hf_weights_files = glob.glob(
os.path.join(hf_folder, safetensors_pattern))
if is_s3_path:
hf_weights_files = s3_glob(path=hf_folder,
allow_pattern=[safetensors_pattern])
else:
hf_weights_files = glob.glob(
os.path.join(hf_folder, safetensors_pattern))

if not is_local:
if not is_local and not is_s3_path:
download_safetensors_index_file_from_hf(
model_name_or_path, index_file, self.load_config.download_dir,
revision)
hf_weights_files = filter_duplicate_safetensors_files(
hf_weights_files, hf_folder, index_file)

if len(hf_weights_files) == 0:
raise RuntimeError(
Expand Down Expand Up @@ -1210,8 +1214,9 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
with target_device:
model = _initialize_model(vllm_config=vllm_config)

assert hasattr(model_config, "model_weights")
model.load_weights(
self._get_weights_iterator(model_config.model,
self._get_weights_iterator(model_config.model_weights,
model_config.revision))

for _, module in model.named_modules():
Expand Down
99 changes: 99 additions & 0 deletions vllm/transformers_utils/s3_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from __future__ import annotations
import fnmatch
import os
import shutil
import tempfile
from pathlib import Path
from typing import List, Optional, Tuple
import signal

import boto3


def is_s3(model_or_path: str) -> bool:
return model_or_path.lower().startswith('s3://')


def _filter_allow(paths: List[str], patterns: List[str]) -> List[str]:
return [
path for path in paths if any(
fnmatch.fnmatch(path, pattern) for pattern in patterns)
]


def _filter_ignore(paths: List[str], patterns: List[str]) -> List[str]:
return [
path for path in paths
if not any(fnmatch.fnmatch(path, pattern) for pattern in patterns)
]


def glob(s3 = None,
path: str = "",
allow_pattern: Optional[List[str]] = None) -> List[str]:
if s3 is None:
s3 = boto3.client("s3")
bucket_name, _, paths = list_files(
s3, path=path, allow_pattern=allow_pattern)
return [f"s3://{bucket_name}/{path}" for path in paths]

def list_files(
s3,
path: str,
allow_pattern: Optional[List[str]] = None,
ignore_pattern: Optional[List[str]] = None
) -> Tuple[str, str, List[str]]:
parts = path.removeprefix('s3://').split('/')
prefix = '/'.join(parts[1:])
bucket_name = parts[0]

objects = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
paths = [obj['Key'] for obj in objects.get('Contents', [])]

paths = _filter_ignore(paths, ["*/"])
if allow_pattern is not None:
paths = _filter_allow(paths, allow_pattern)

if ignore_pattern is not None:
paths = _filter_ignore(paths, ignore_pattern)

return bucket_name, prefix, paths


class S3Model:
def __init__(self) -> None:
self.s3 = boto3.client('s3')
for sig in (signal.SIGINT, signal.SIGTERM):
existing_handler = signal.getsignal(sig)
signal.signal(sig, self.close_by_signal(existing_handler))
self.dir = tempfile.mkdtemp(dir='/dev/shm')

def __del__(self):
self.close()

def close(self) -> None:
if os.path.exists(self.dir):
shutil.rmtree(self.dir)

def close_by_signal(self, existing_handler=None):
def new_handler(signum, frame):
self.close()
if existing_handler:
existing_handler(signum, frame)
return new_handler

def pull_files(self,
s3_model_path: str = "",
allow_pattern: Optional[List[str]] = None,
ignore_pattern: Optional[List[str]] = None) -> str:
bucket_name, base_dir, files = list_files(self.s3, s3_model_path,
allow_pattern,
ignore_pattern)
if len(files) == 0:
return self.dir

for file in files:
destination_file = self.dir + file.removeprefix(base_dir)
local_dir = Path(destination_file).parent
os.makedirs(local_dir, exist_ok=True)
self.s3.download_file(bucket_name, file, destination_file)
2 changes: 1 addition & 1 deletion vllm/transformers_utils/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _pad(
return orig_pad(*args, **kwargs)

tokenizer._pad = MethodType(_pad, tokenizer)


def get_tokenizer(
tokenizer_name: Union[str, Path],
Expand Down

0 comments on commit f8d1673

Please sign in to comment.