Skip to content

Commit

Permalink
Code review changes
Browse files Browse the repository at this point in the history
Signed-off-by: OmerD <[email protected]>
  • Loading branch information
omer-dayan committed Dec 19, 2024
1 parent fc19e86 commit 2c2b9f2
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 15 deletions.
13 changes: 11 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def __init__(self,
f"'Please instead use `--hf-overrides '{hf_override!r}'`")
warnings.warn(DeprecationWarning(msg), stacklevel=2)

self.pull_model_tokenizer_for_s3(model, tokenizer)
self.maybe_pull_model_tokenizer_for_s3(model, tokenizer)

# The tokenizer version is consistent with the model version by default.
if tokenizer_revision is None:
Expand Down Expand Up @@ -360,7 +360,16 @@ def __init__(self,
self._verify_cuda_graph()
self._verify_bnb_config()

def pull_model_tokenizer_for_s3(self, model: str, tokenizer: str) -> None:
def maybe_pull_model_tokenizer_for_s3(self, model: str,
tokenizer: str) -> None:
"""
Pull the model config or tokenizer to a temporary directory in case of S3.

Check failure on line 366 in vllm/config.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/config.py:366:81: E501 Line too long (82 > 80)
Args:
model: The model name or path.
tokenizer: The tokenizer name or path.
"""
if is_s3(model) or is_s3(tokenizer):
try:
from vllm.transformers_utils.s3_utils import S3Model
Expand Down
9 changes: 7 additions & 2 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,7 +1236,10 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:


class RunaiModelStreamerLoader(BaseModelLoader):
"""Model loader that can load different safetensors ."""
"""
Model loader that can load safetensors
files from local FS or S3 bucket.
"""

def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
Expand Down Expand Up @@ -1301,7 +1304,7 @@ def _prepare_weights(self, model_name_or_path: str,
model_name_or_path, index_file, self.load_config.download_dir,
revision)

if len(hf_weights_files) == 0:
if not hf_weights_files:
raise RuntimeError(
f"Cannot find any safetensors model weights with "
f"`{model_name_or_path}`")
Expand All @@ -1316,9 +1319,11 @@ def _get_weights_iterator(
return runai_safetensors_weights_iterator(hf_weights_files)

def download_model(self, model_config: ModelConfig) -> None:
"""Download model if necessery"""
self._prepare_weights(model_config.model, model_config.revision)

def load_model(self, vllm_config: VllmConfig) -> nn.Module:
"""Perform streaming of the model to destination"""
device_config = vllm_config.device_config
model_config = vllm_config.model_config

Expand Down
67 changes: 56 additions & 11 deletions vllm/transformers_utils/s3_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from __future__ import annotations

import fnmatch
import os
import shutil
import signal
import tempfile
from pathlib import Path
from typing import Optional

import boto3

Expand All @@ -26,7 +25,18 @@ def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]:

def glob(s3=None,
path: str = "",
allow_pattern: list[str] | None = None) -> list[str]:
allow_pattern: Optional[list[str]] = None) -> list[str]:
"""
List full file names from S3 path and filter by allow pattern.
Args:
s3: S3 client to use.
path: The S3 path to list from.
allow_pattern: A list of patterns of which files to pull.
Returns:
list[str]: List of full S3 paths allowed by the pattern
"""
if s3 is None:
s3 = boto3.client("s3")
bucket_name, _, paths = list_files(s3,
Expand All @@ -38,8 +48,24 @@ def glob(s3=None,
def list_files(
s3,
path: str,
allow_pattern: list[str] | None = None,
ignore_pattern: list[str] | None = None) -> tuple[str, str, list[str]]:
allow_pattern: Optional[list[str]] = None,
ignore_pattern: Optional[list[str]] = None
) -> tuple[str, str, list[str]]:
"""
List files from S3 path and filter by pattern.
Args:
s3: S3 client to use.
path: The S3 path to list from.
allow_pattern: A list of patterns of which files to pull.
ignore_pattern: A list of patterns of which files not to pull.
Returns:
tuple[str, str, list[str]]: A tuple where:
- The first element is the bucket name
- The second element is string represent the bucket and the prefix as a dir like string

Check failure on line 66 in vllm/transformers_utils/s3_utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/transformers_utils/s3_utils.py:66:81: E501 Line too long (99 > 80)
- The third element is a list of files allowed or disallowed by pattern

Check failure on line 67 in vllm/transformers_utils/s3_utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/transformers_utils/s3_utils.py:67:81: E501 Line too long (83 > 80)
"""
parts = path.removeprefix('s3://').split('/')
prefix = '/'.join(parts[1:])
bucket_name = parts[0]
Expand All @@ -58,22 +84,32 @@ def list_files(


class S3Model:
"""
A class representing a S3 model mirrored into a temporary directory.
Attributes:
s3: S3 client.
dir: The temporary created directory.
Methods:
pull_files(): Pull model from S3 to the temporary directory.
"""

def __init__(self) -> None:
self.s3 = boto3.client('s3')
for sig in (signal.SIGINT, signal.SIGTERM):
existing_handler = signal.getsignal(sig)
signal.signal(sig, self.close_by_signal(existing_handler))
signal.signal(sig, self._close_by_signal(existing_handler))
self.dir = tempfile.mkdtemp()

def __del__(self):
self.close()
self._close()

def close(self) -> None:
def _close(self) -> None:
if os.path.exists(self.dir):
shutil.rmtree(self.dir)

def close_by_signal(self, existing_handler=None):
def _close_by_signal(self, existing_handler=None):

def new_handler(signum, frame):
self.close()

Check failure on line 115 in vllm/transformers_utils/s3_utils.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

"S3Model" has no attribute "close"; maybe "_close"? [attr-defined]

Check failure on line 115 in vllm/transformers_utils/s3_utils.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

"S3Model" has no attribute "close"; maybe "_close"? [attr-defined]

Check failure on line 115 in vllm/transformers_utils/s3_utils.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

"S3Model" has no attribute "close"; maybe "_close"? [attr-defined]

Check failure on line 115 in vllm/transformers_utils/s3_utils.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

"S3Model" has no attribute "close"; maybe "_close"? [attr-defined]
Expand All @@ -84,8 +120,17 @@ def new_handler(signum, frame):

def pull_files(self,
s3_model_path: str = "",
allow_pattern: list[str] | None = None,
ignore_pattern: list[str] | None = None) -> None:
allow_pattern: Optional[list[str]] = None,
ignore_pattern: Optional[list[str]] = None) -> None:
"""
Pull files from S3 storage into the temporary directory.
Args:
s3_model_path: The S3 path of the model.
allow_pattern: A list of patterns of which files to pull.
ignore_pattern: A list of patterns of which files not to pull.
"""
bucket_name, base_dir, files = list_files(self.s3, s3_model_path,
allow_pattern,
ignore_pattern)
Expand Down

0 comments on commit 2c2b9f2

Please sign in to comment.