diff --git a/docs/source/index.rst b/docs/source/index.rst index 8457d4476a1c4..75e6db8f70382 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -88,6 +88,7 @@ Documentation serving/usage_stats serving/integrations serving/tensorizer + serving/runai_model_streamer serving/compatibility_matrix serving/faq diff --git a/docs/source/serving/runai_model_streamer.rst b/docs/source/serving/runai_model_streamer.rst new file mode 100644 index 0000000000000..8cee969260770 --- /dev/null +++ b/docs/source/serving/runai_model_streamer.rst @@ -0,0 +1,32 @@ +.. _runai_model_streamer: + +Loading Models with Run:ai Model Streamer +========================================= +Run:ai Model Streamer is a library to read tensors in concurrency, while streaming it to GPU memory. +Further reading can be found in `Run:ai Model Streamer Documentation `_. + +vLLM supports loading weights in Safetensors format using the Run:ai Model Streamer. + +To run it as an OpenAI-compatible server, add the `--load-format runai_streamer` flag: + +.. code-block:: console + + $ vllm serve /home/meta-llama/Llama-3.2-3B-Instruct --load-format runai_streamer + +Tunable parameters +------------------ + +You can control the level of concurrency by using the `concurrency` parameter in `--model-loader-extra-config`: + + .. code-block:: console + + $ vllm serve /home/meta-llama/Llama-3.2-3B-Instruct --load-format runai_streamer --model-loader-extra-config '{"concurrency":16}' + +You can control the amount of CPU memory used to stream tensors by using the `memory_limit` parameter in `--model-loader-extra-config`: + + .. code-block:: console + + $ vllm serve /home/meta-llama/Llama-3.2-3B-Instruct --load-format runai_streamer --model-loader-extra-config '{"memory_limit":5368709120}' + +.. note:: + For further instructions about tunable parameters and additional parameters configurable through environment variables, read the `Environment Variables Documentation `_. diff --git a/requirements-common.txt b/requirements-common.txt index ef5ed8b645158..6bc15ead74415 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -32,3 +32,4 @@ six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that need setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12 einops # Required for Qwen2-VL. compressed-tensors == 0.7.1 # required for compressed-tensors +runai-model-streamer diff --git a/vllm/config.py b/vllm/config.py index f9b230e1bc688..8879ad12c50d0 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -868,6 +868,7 @@ class LoadFormat(str, enum.Enum): GGUF = "gguf" BITSANDBYTES = "bitsandbytes" MISTRAL = "mistral" + RUNAI_STREAMER = "runai_streamer" @dataclass diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 95d55e86e08e8..acfcb6456f878 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -307,6 +307,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: '* "tensorizer" will load the weights using tensorizer from ' 'CoreWeave. See the Tensorize vLLM Model script in the Examples ' 'section for more information.\n' + '* "runai_streamer" will load the Safetensors weights using Run:ai' + 'Model Streamer \n' '* "bitsandbytes" will load the weights using bitsandbytes ' 'quantization.\n') parser.add_argument( diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 8d3024534734b..e415942b59812 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -38,7 +38,7 @@ filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, get_gguf_extra_tensor_names, gguf_quant_weights_iterator, initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator, - safetensors_weights_iterator) + runai_safetensors_weights_iterator, safetensors_weights_iterator) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available @@ -1138,6 +1138,90 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module: return model +class RunaiModelStreamerLoader(BaseModelLoader): + """Model loader that can load different safetensors .""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + extra_config = load_config.model_loader_extra_config + + if ("concurrency" in extra_config + and isinstance(extra_config.get("concurrency"), int)): + os.environ["RUNAI_STREAMER_CONCURRENCY"] = str( + extra_config.get("concurrency")) + + if ("memory_limit" in extra_config + and isinstance(extra_config.get("memory_limit"), int)): + os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str( + extra_config.get("memory_limit")) + + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str]) -> List[str]: + """Prepare weights for the model. + + If the model is not local, it will be downloaded.""" + is_local = os.path.isdir(model_name_or_path) + safetensors_pattern = "*.safetensors" + index_file = SAFE_WEIGHTS_INDEX_NAME + + hf_folder = (model_name_or_path + if is_local else download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + [safetensors_pattern], + revision, + ignore_patterns=self.load_config.ignore_patterns, + )) + + hf_weights_files = glob.glob( + os.path.join(hf_folder, safetensors_pattern)) + + if not is_local: + download_safetensors_index_file_from_hf( + model_name_or_path, index_file, self.load_config.download_dir, + revision) + hf_weights_files = filter_duplicate_safetensors_files( + hf_weights_files, hf_folder, index_file) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any safetensors model weights with " + f"`{model_name_or_path}`") + + return hf_weights_files + + def _get_weights_iterator( + self, model_or_path: str, + revision: str) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Get an iterator for the model weights based on the load format.""" + hf_weights_files = self._prepare_weights(model_or_path, revision) + return runai_safetensors_weights_iterator(hf_weights_files) + + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model, model_config.revision) + + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + device_config = vllm_config.device_config + model_config = vllm_config.model_config + + target_device = torch.device(device_config.device) + with set_default_torch_dtype(model_config.dtype): + with target_device: + model = _initialize_model(vllm_config=vllm_config) + + model.load_weights( + self._get_weights_iterator(model_config.model, + model_config.revision)) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + return model.eval() + + def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: """Get a model loader based on the load format.""" @@ -1159,4 +1243,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: if load_config.load_format == LoadFormat.GGUF: return GGUFModelLoader(load_config) + if load_config.load_format == LoadFormat.RUNAI_STREAMER: + return RunaiModelStreamerLoader(load_config) + return DefaultModelLoader(load_config) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 9488d54edf365..e600dd3b7b4ce 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -410,6 +410,25 @@ def safetensors_weights_iterator( yield name, param +def runai_safetensors_weights_iterator( + hf_weights_files: List[str] +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model safetensor files.""" + from runai_model_streamer import SafetensorsStreamer + + enable_tqdm = not torch.distributed.is_initialized( + ) or torch.distributed.get_rank() == 0 + with SafetensorsStreamer() as streamer: + for st_file in tqdm( + hf_weights_files, + desc="Loading safetensors using Runai Model Streamer", + disable=not enable_tqdm, + bar_format=_BAR_FORMAT, + ): + streamer.stream_file(st_file) + yield from streamer.get_tensors() + + def pt_weights_iterator( hf_weights_files: List[str] ) -> Generator[Tuple[str, torch.Tensor], None, None]: