From d22d4bbc192dc1a7e233cf975e7a2228c77f8625 Mon Sep 17 00:00:00 2001 From: OmerD Date: Thu, 19 Dec 2024 19:52:05 +0200 Subject: [PATCH] Code review changes Signed-off-by: OmerD --- vllm/config.py | 4 +- vllm/model_executor/model_loader/loader.py | 2 +- vllm/transformers_utils/s3_utils.py | 49 +++++++++++++++++++--- 3 files changed, 47 insertions(+), 8 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 0f315f21ec0c9..c0fcd8f3ffb0a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -257,7 +257,7 @@ def __init__(self, f"'Please instead use `--hf-overrides '{hf_override!r}'`") warnings.warn(DeprecationWarning(msg), stacklevel=2) - self.pull_model_tokenizer_for_s3(model, tokenizer) + self.maybe_pull_model_tokenizer_for_s3(model, tokenizer) # The tokenizer version is consistent with the model version by default. if tokenizer_revision is None: @@ -360,7 +360,7 @@ def __init__(self, self._verify_cuda_graph() self._verify_bnb_config() - def pull_model_tokenizer_for_s3(self, model: str, tokenizer: str) -> None: + def maybe_pull_model_tokenizer_for_s3(self, model: str, tokenizer: str) -> None: if is_s3(model) or is_s3(tokenizer): try: from vllm.transformers_utils.s3_utils import S3Model diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 94e9e5c3d4f94..7c5ed9628930b 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -1301,7 +1301,7 @@ def _prepare_weights(self, model_name_or_path: str, model_name_or_path, index_file, self.load_config.download_dir, revision) - if len(hf_weights_files) == 0: + if not hf_weights_files: raise RuntimeError( f"Cannot find any safetensors model weights with " f"`{model_name_or_path}`") diff --git a/vllm/transformers_utils/s3_utils.py b/vllm/transformers_utils/s3_utils.py index 4474d3a914e15..5bb53758f3250 100644 --- a/vllm/transformers_utils/s3_utils.py +++ b/vllm/transformers_utils/s3_utils.py @@ -6,6 +6,7 @@ import signal import tempfile from pathlib import Path +from typing import Optional import boto3 @@ -26,7 +27,16 @@ def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]: def glob(s3=None, path: str = "", - allow_pattern: list[str] | None = None) -> list[str]: + allow_pattern: Optional[list[str]] = None) -> list[str]: + """ + List full file names from S3 path and filter by allow pattern. + + Args: + s3: S3 client to use. + path: The S3 path to list from. + allow_pattern: A list of patterns of which files to pull. + + """ if s3 is None: s3 = boto3.client("s3") bucket_name, _, paths = list_files(s3, @@ -38,8 +48,18 @@ def glob(s3=None, def list_files( s3, path: str, - allow_pattern: list[str] | None = None, - ignore_pattern: list[str] | None = None) -> tuple[str, str, list[str]]: + allow_pattern: Optional[list[str]] = None, + ignore_pattern: Optional[list[str]] = None) -> tuple[str, str, list[str]]: + """ + List files from S3 path and filter by pattern. + + Args: + s3: S3 client to use. + path: The S3 path to list from. + allow_pattern: A list of patterns of which files to pull. + ignore_pattern: A list of patterns of which files not to pull. + + """ parts = path.removeprefix('s3://').split('/') prefix = '/'.join(parts[1:]) bucket_name = parts[0] @@ -58,6 +78,16 @@ def list_files( class S3Model: + """ + A class representing a S3 model mirrored into a temporary directory. + + Attributes: + s3: S3 client. + dir: The temporary created directory. + + Methods: + pull_files(): Pull model from S3 to the temporary directory. + """ def __init__(self) -> None: self.s3 = boto3.client('s3') @@ -84,8 +114,17 @@ def new_handler(signum, frame): def pull_files(self, s3_model_path: str = "", - allow_pattern: list[str] | None = None, - ignore_pattern: list[str] | None = None) -> None: + allow_pattern: Optional[list[str]] = None, + ignore_pattern: Optional[list[str]] = None) -> None: + """ + Pull files from S3 storage into the temporary directory. + + Args: + s3_model_path: The S3 path of the model. + allow_pattern: A list of patterns of which files to pull. + ignore_pattern: A list of patterns of which files not to pull. + + """ bucket_name, base_dir, files = list_files(self.s3, s3_model_path, allow_pattern, ignore_pattern)