diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index e3e35844405ac..43d0c45201de6 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -12,6 +12,7 @@ pydantic >= 2.8 torch py-cpuinfo transformers +boto3 mistral_common >= 1.3.4 aiohttp starlette diff --git a/requirements-common.txt b/requirements-common.txt index 3f5fa571ed4b3..8bfa113e74fd6 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -33,3 +33,4 @@ setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we n einops # Required for Qwen2-VL. compressed-tensors == 0.8.0 # required for compressed-tensors runai-model-streamer +boto3 diff --git a/vllm/config.py b/vllm/config.py index 18e41ceec5942..f61665eb2214f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -19,6 +19,7 @@ ConfigFormat, get_config, get_hf_image_processor_config, get_hf_text_config, get_pooling_config, get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope) +from vllm.transformers_utils.s3_utils import S3Model, is_s3 from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, print_warning_once) @@ -195,6 +196,18 @@ def __init__( f"'Please instead use `--hf-overrides '{hf_override!r}'`") warnings.warn(DeprecationWarning(msg), stacklevel=2) + if is_s3(model): + self.s3_model = S3Model() + self.s3_model.pull_files(model, allow_pattern=["*config.json"]) + self.model_weights = self.model + self.model = self.s3_model.dir + + if is_s3(tokenizer): + self.s3_tokenizer = S3Model() + self.s3_tokenizer.pull_files( + model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) + self.tokenizer = self.s3_tokenizer.dir + # The tokenizer version is consistent with the model version by default. if tokenizer_revision is None: self.tokenizer_revision = revision diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index ed5a26f88c6dd..b0cb863585821 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -42,6 +42,8 @@ runai_safetensors_weights_iterator, safetensors_weights_iterator) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform +from vllm.transformers_utils.s3_utils import glob as s3_glob +from vllm.transformers_utils.s3_utils import is_s3 from vllm.utils import is_pin_memory_available @@ -1174,18 +1176,24 @@ def __init__(self, load_config: LoadConfig): and isinstance(extra_config.get("memory_limit"), int)): os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str( extra_config.get("memory_limit")) + + runai_streamer_s3_endpoint = os.getenv('RUNAI_STREAMER_S3_ENDPOINT') + aws_endpoint_url = os.getenv('AWS_ENDPOINT_URL') + if runai_streamer_s3_endpoint is None and aws_endpoint_url is not None: + os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url def _prepare_weights(self, model_name_or_path: str, revision: Optional[str]) -> List[str]: """Prepare weights for the model. If the model is not local, it will be downloaded.""" + is_s3_path = is_s3(model_name_or_path) is_local = os.path.isdir(model_name_or_path) safetensors_pattern = "*.safetensors" index_file = SAFE_WEIGHTS_INDEX_NAME - hf_folder = (model_name_or_path - if is_local else download_weights_from_hf( + hf_folder = (model_name_or_path if + (is_local or is_s3_path) else download_weights_from_hf( model_name_or_path, self.load_config.download_dir, [safetensors_pattern], @@ -1193,15 +1201,17 @@ def _prepare_weights(self, model_name_or_path: str, ignore_patterns=self.load_config.ignore_patterns, )) - hf_weights_files = glob.glob( - os.path.join(hf_folder, safetensors_pattern)) + if is_s3_path: + hf_weights_files = s3_glob(path=hf_folder, + allow_pattern=[safetensors_pattern]) + else: + hf_weights_files = glob.glob( + os.path.join(hf_folder, safetensors_pattern)) - if not is_local: + if not is_local and not is_s3_path: download_safetensors_index_file_from_hf( model_name_or_path, index_file, self.load_config.download_dir, revision) - hf_weights_files = filter_duplicate_safetensors_files( - hf_weights_files, hf_folder, index_file) if len(hf_weights_files) == 0: raise RuntimeError( @@ -1229,8 +1239,9 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module: with target_device: model = _initialize_model(vllm_config=vllm_config) + assert hasattr(model_config, "model_weights") model.load_weights( - self._get_weights_iterator(model_config.model, + self._get_weights_iterator(model_config.model_weights, model_config.revision)) for _, module in model.named_modules(): diff --git a/vllm/transformers_utils/s3_utils.py b/vllm/transformers_utils/s3_utils.py new file mode 100644 index 0000000000000..390a47a24b691 --- /dev/null +++ b/vllm/transformers_utils/s3_utils.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import fnmatch +import os +import shutil +import signal +import tempfile +from pathlib import Path + +import boto3 + + +def is_s3(model_or_path: str) -> bool: + return model_or_path.lower().startswith('s3://') + + +def _filter_allow(paths: list[str], patterns: list[str]) -> list[str]: + return [ + path for path in paths if any( + fnmatch.fnmatch(path, pattern) for pattern in patterns) + ] + + +def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]: + return [ + path for path in paths + if not any(fnmatch.fnmatch(path, pattern) for pattern in patterns) + ] + + +def glob(s3=None, + path: str = "", + allow_pattern: list[str] | None = None) -> list[str]: + if s3 is None: + s3 = boto3.client("s3") + bucket_name, _, paths = list_files(s3, + path=path, + allow_pattern=allow_pattern) + return [f"s3://{bucket_name}/{path}" for path in paths] + + +def list_files( + s3, + path: str, + allow_pattern: list[str] | None = None, + ignore_pattern: list[str] | None = None) -> tuple[str, str, list[str]]: + parts = path.removeprefix('s3://').split('/') + prefix = '/'.join(parts[1:]) + bucket_name = parts[0] + + objects = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix) + paths = [obj['Key'] for obj in objects.get('Contents', [])] + + paths = _filter_ignore(paths, ["*/"]) + if allow_pattern is not None: + paths = _filter_allow(paths, allow_pattern) + + if ignore_pattern is not None: + paths = _filter_ignore(paths, ignore_pattern) + + return bucket_name, prefix, paths + + +class S3Model: + + def __init__(self) -> None: + self.s3 = boto3.client('s3') + for sig in (signal.SIGINT, signal.SIGTERM): + existing_handler = signal.getsignal(sig) + signal.signal(sig, self.close_by_signal(existing_handler)) + self.dir = tempfile.mkdtemp(dir='/dev/shm') + + def __del__(self): + self.close() + + def close(self) -> None: + if os.path.exists(self.dir): + shutil.rmtree(self.dir) + + def close_by_signal(self, existing_handler=None): + + def new_handler(signum, frame): + self.close() + if existing_handler: + existing_handler(signum, frame) + + return new_handler + + def pull_files(self, + s3_model_path: str = "", + allow_pattern: list[str] | None = None, + ignore_pattern: list[str] | None = None) -> None: + bucket_name, base_dir, files = list_files(self.s3, s3_model_path, + allow_pattern, + ignore_pattern) + if len(files) == 0: + return + + for file in files: + destination_file = self.dir + file.removeprefix(base_dir) + local_dir = Path(destination_file).parent + os.makedirs(local_dir, exist_ok=True) + self.s3.download_file(bucket_name, file, destination_file)