Skip to content

Commit

Permalink
Add RunAI Model Streamer as optional loader.
Browse files Browse the repository at this point in the history
Add it to the docs as well

Signed-off-by: OmerD <[email protected]>
  • Loading branch information
omer-dayan committed Nov 14, 2024
1 parent 294bf46 commit bac1fd7
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ Documentation
serving/usage_stats
serving/integrations
serving/tensorizer
serving/runai_model_streamer
serving/compatibility_matrix
serving/faq

Expand Down
32 changes: 32 additions & 0 deletions docs/source/serving/runai_model_streamer.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
.. _runai_model_streamer:

Loading Models with Run:ai Model Streamer
=========================================
Run:ai Model Streamer is a library to read tensors in concurrency, while streaming it to GPU memory.
Further reading can be found in `Run:ai Model Streamer Documentation <https://github.com/run-ai/runai-model-streamer/blob/master/docs/README.md>`_.

vLLM supports loading weights in Safetensors format using the Run:ai Model Streamer.

To run it as an OpenAI-compatible server, add the `--load-format runai_streamer` flag:

.. code-block:: console
$ vllm serve /home/meta-llama/Llama-3.2-3B-Instruct --load-format runai_streamer
Tunable parameters
------------------

You can control the level of concurrency by using the `concurrency` parameter in `--model-loader-extra-config`:

.. code-block:: console
$ vllm serve /home/meta-llama/Llama-3.2-3B-Instruct --load-format runai_streamer --model-loader-extra-config '{"concurrency":16}'
You can control the amount of CPU memory used to stream tensors by using the `memory_limit` parameter in `--model-loader-extra-config`:

.. code-block:: console
$ vllm serve /home/meta-llama/Llama-3.2-3B-Instruct --load-format runai_streamer --model-loader-extra-config '{"memory_limit":5368709120}'
.. note::
For further instructions about tunable parameters and additional parameters configurable through environment variables, read the `Environment Variables Documentation <https://github.com/run-ai/runai-model-streamer/blob/master/docs/src/env-vars.md>`_.
3 changes: 2 additions & 1 deletion requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,5 @@ pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
einops # Required for Qwen2-VL.
compressed-tensors == 0.8.0 # required for compressed-tensors
compressed-tensors == 0.8.0 # required for compressed-tensors
runai-model-streamer
1 change: 1 addition & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,7 @@ class LoadFormat(str, enum.Enum):
GGUF = "gguf"
BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral"
RUNAI_STREAMER = "runai_streamer"


@dataclass
Expand Down
2 changes: 2 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'* "tensorizer" will load the weights using tensorizer from '
'CoreWeave. See the Tensorize vLLM Model script in the Examples '
'section for more information.\n'
'* "runai_streamer" will load the Safetensors weights using Run:ai'
'Model Streamer \n'
'* "bitsandbytes" will load the weights using bitsandbytes '
'quantization.\n')
parser.add_argument(
Expand Down
89 changes: 88 additions & 1 deletion vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
get_gguf_extra_tensor_names, gguf_quant_weights_iterator,
initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
safetensors_weights_iterator)
runai_safetensors_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available
Expand Down Expand Up @@ -1157,6 +1157,90 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
return model


class RunaiModelStreamerLoader(BaseModelLoader):
"""Model loader that can load different safetensors ."""

def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
extra_config = load_config.model_loader_extra_config

if ("concurrency" in extra_config
and isinstance(extra_config.get("concurrency"), int)):
os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
extra_config.get("concurrency"))

if ("memory_limit" in extra_config
and isinstance(extra_config.get("memory_limit"), int)):
os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
extra_config.get("memory_limit"))

def _prepare_weights(self, model_name_or_path: str,
revision: Optional[str]) -> List[str]:
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
is_local = os.path.isdir(model_name_or_path)
safetensors_pattern = "*.safetensors"
index_file = SAFE_WEIGHTS_INDEX_NAME

hf_folder = (model_name_or_path
if is_local else download_weights_from_hf(
model_name_or_path,
self.load_config.download_dir,
[safetensors_pattern],
revision,
ignore_patterns=self.load_config.ignore_patterns,
))

hf_weights_files = glob.glob(
os.path.join(hf_folder, safetensors_pattern))

if not is_local:
download_safetensors_index_file_from_hf(
model_name_or_path, index_file, self.load_config.download_dir,
revision)
hf_weights_files = filter_duplicate_safetensors_files(
hf_weights_files, hf_folder, index_file)

if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any safetensors model weights with "
f"`{model_name_or_path}`")

return hf_weights_files

def _get_weights_iterator(
self, model_or_path: str,
revision: str) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
hf_weights_files = self._prepare_weights(model_or_path, revision)
return runai_safetensors_weights_iterator(hf_weights_files)

def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)

def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config

target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = _initialize_model(vllm_config=vllm_config)

model.load_weights(
self._get_weights_iterator(model_config.model,
model_config.revision))

for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
return model.eval()


def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
"""Get a model loader based on the load format."""

Expand All @@ -1178,4 +1262,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
if load_config.load_format == LoadFormat.GGUF:
return GGUFModelLoader(load_config)

if load_config.load_format == LoadFormat.RUNAI_STREAMER:
return RunaiModelStreamerLoader(load_config)

return DefaultModelLoader(load_config)
19 changes: 19 additions & 0 deletions vllm/model_executor/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,25 @@ def safetensors_weights_iterator(
yield name, param


def runai_safetensors_weights_iterator(
hf_weights_files: List[str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
from runai_model_streamer import SafetensorsStreamer

enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
with SafetensorsStreamer() as streamer:
for st_file in tqdm(
hf_weights_files,
desc="Loading safetensors using Runai Model Streamer",
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
):
streamer.stream_file(st_file)
yield from streamer.get_tensors()


def pt_weights_iterator(
hf_weights_files: List[str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
Expand Down

0 comments on commit bac1fd7

Please sign in to comment.