From 2c2b9f2c96142f21b4f995d7cc9dad67b8445af8 Mon Sep 17 00:00:00 2001 From: OmerD Date: Thu, 19 Dec 2024 19:52:05 +0200 Subject: [PATCH] Code review changes Signed-off-by: OmerD --- vllm/config.py | 13 ++++- vllm/model_executor/model_loader/loader.py | 9 ++- vllm/transformers_utils/s3_utils.py | 67 ++++++++++++++++++---- 3 files changed, 74 insertions(+), 15 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 0f315f21ec0c9..ce064e06250c0 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -257,7 +257,7 @@ def __init__(self, f"'Please instead use `--hf-overrides '{hf_override!r}'`") warnings.warn(DeprecationWarning(msg), stacklevel=2) - self.pull_model_tokenizer_for_s3(model, tokenizer) + self.maybe_pull_model_tokenizer_for_s3(model, tokenizer) # The tokenizer version is consistent with the model version by default. if tokenizer_revision is None: @@ -360,7 +360,16 @@ def __init__(self, self._verify_cuda_graph() self._verify_bnb_config() - def pull_model_tokenizer_for_s3(self, model: str, tokenizer: str) -> None: + def maybe_pull_model_tokenizer_for_s3(self, model: str, + tokenizer: str) -> None: + """ + Pull the model config or tokenizer to a temporary directory in case of S3. + + Args: + model: The model name or path. + tokenizer: The tokenizer name or path. + + """ if is_s3(model) or is_s3(tokenizer): try: from vllm.transformers_utils.s3_utils import S3Model diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 94e9e5c3d4f94..4b647e0590d7f 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -1236,7 +1236,10 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module: class RunaiModelStreamerLoader(BaseModelLoader): - """Model loader that can load different safetensors .""" + """ + Model loader that can load safetensors + files from local FS or S3 bucket. + """ def __init__(self, load_config: LoadConfig): super().__init__(load_config) @@ -1301,7 +1304,7 @@ def _prepare_weights(self, model_name_or_path: str, model_name_or_path, index_file, self.load_config.download_dir, revision) - if len(hf_weights_files) == 0: + if not hf_weights_files: raise RuntimeError( f"Cannot find any safetensors model weights with " f"`{model_name_or_path}`") @@ -1316,9 +1319,11 @@ def _get_weights_iterator( return runai_safetensors_weights_iterator(hf_weights_files) def download_model(self, model_config: ModelConfig) -> None: + """Download model if necessery""" self._prepare_weights(model_config.model, model_config.revision) def load_model(self, vllm_config: VllmConfig) -> nn.Module: + """Perform streaming of the model to destination""" device_config = vllm_config.device_config model_config = vllm_config.model_config diff --git a/vllm/transformers_utils/s3_utils.py b/vllm/transformers_utils/s3_utils.py index 4474d3a914e15..3ec97ff7994c7 100644 --- a/vllm/transformers_utils/s3_utils.py +++ b/vllm/transformers_utils/s3_utils.py @@ -1,11 +1,10 @@ -from __future__ import annotations - import fnmatch import os import shutil import signal import tempfile from pathlib import Path +from typing import Optional import boto3 @@ -26,7 +25,18 @@ def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]: def glob(s3=None, path: str = "", - allow_pattern: list[str] | None = None) -> list[str]: + allow_pattern: Optional[list[str]] = None) -> list[str]: + """ + List full file names from S3 path and filter by allow pattern. + + Args: + s3: S3 client to use. + path: The S3 path to list from. + allow_pattern: A list of patterns of which files to pull. + + Returns: + list[str]: List of full S3 paths allowed by the pattern + """ if s3 is None: s3 = boto3.client("s3") bucket_name, _, paths = list_files(s3, @@ -38,8 +48,24 @@ def glob(s3=None, def list_files( s3, path: str, - allow_pattern: list[str] | None = None, - ignore_pattern: list[str] | None = None) -> tuple[str, str, list[str]]: + allow_pattern: Optional[list[str]] = None, + ignore_pattern: Optional[list[str]] = None +) -> tuple[str, str, list[str]]: + """ + List files from S3 path and filter by pattern. + + Args: + s3: S3 client to use. + path: The S3 path to list from. + allow_pattern: A list of patterns of which files to pull. + ignore_pattern: A list of patterns of which files not to pull. + + Returns: + tuple[str, str, list[str]]: A tuple where: + - The first element is the bucket name + - The second element is string represent the bucket and the prefix as a dir like string + - The third element is a list of files allowed or disallowed by pattern + """ parts = path.removeprefix('s3://').split('/') prefix = '/'.join(parts[1:]) bucket_name = parts[0] @@ -58,22 +84,32 @@ def list_files( class S3Model: + """ + A class representing a S3 model mirrored into a temporary directory. + + Attributes: + s3: S3 client. + dir: The temporary created directory. + + Methods: + pull_files(): Pull model from S3 to the temporary directory. + """ def __init__(self) -> None: self.s3 = boto3.client('s3') for sig in (signal.SIGINT, signal.SIGTERM): existing_handler = signal.getsignal(sig) - signal.signal(sig, self.close_by_signal(existing_handler)) + signal.signal(sig, self._close_by_signal(existing_handler)) self.dir = tempfile.mkdtemp() def __del__(self): - self.close() + self._close() - def close(self) -> None: + def _close(self) -> None: if os.path.exists(self.dir): shutil.rmtree(self.dir) - def close_by_signal(self, existing_handler=None): + def _close_by_signal(self, existing_handler=None): def new_handler(signum, frame): self.close() @@ -84,8 +120,17 @@ def new_handler(signum, frame): def pull_files(self, s3_model_path: str = "", - allow_pattern: list[str] | None = None, - ignore_pattern: list[str] | None = None) -> None: + allow_pattern: Optional[list[str]] = None, + ignore_pattern: Optional[list[str]] = None) -> None: + """ + Pull files from S3 storage into the temporary directory. + + Args: + s3_model_path: The S3 path of the model. + allow_pattern: A list of patterns of which files to pull. + ignore_pattern: A list of patterns of which files not to pull. + + """ bucket_name, base_dir, files = list_files(self.s3, s3_model_path, allow_pattern, ignore_pattern)