-
-
Notifications
You must be signed in to change notification settings - Fork 5.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Rafael Vasquez <[email protected]>
- Loading branch information
Showing
13 changed files
with
457 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
(runai-model-streamer)= | ||
|
||
# Loading Models with Run:ai Model Streamer | ||
|
||
Run:ai Model Streamer is a library to read tensors in concurrency, while streaming it to GPU memory. | ||
Further reading can be found in [Run:ai Model Streamer Documentation](https://github.com/run-ai/runai-model-streamer/blob/master/docs/README.md). | ||
|
||
vLLM supports loading weights in Safetensors format using the Run:ai Model Streamer. | ||
You first need to install vLLM RunAI optional dependency: | ||
|
||
```console | ||
$ pip3 install vllm[runai] | ||
``` | ||
|
||
To run it as an OpenAI-compatible server, add the `--load-format runai_streamer` flag: | ||
|
||
```console | ||
$ vllm serve /home/meta-llama/Llama-3.2-3B-Instruct --load-format runai_streamer | ||
``` | ||
|
||
To run model from AWS S3 object store run: | ||
|
||
```console | ||
$ vllm serve s3://core-llm/Llama-3-8b --load-format runai_streamer | ||
``` | ||
|
||
To run model from a S3 compatible object store run: | ||
|
||
```console | ||
$ RUNAI_STREAMER_S3_USE_VIRTUAL_ADDRESSING=0 AWS_EC2_METADATA_DISABLED=true AWS_ENDPOINT_URL=https://storage.googleapis.com vllm serve s3://core-llm/Llama-3-8b --load-format runai_streamer | ||
``` | ||
|
||
## Tunable parameters | ||
|
||
You can tune parameters using `--model-loader-extra-config`: | ||
|
||
You can tune `concurrency` that controls the level of concurrency and number of OS threads reading tensors from the file to the CPU buffer. | ||
For reading from S3, it will be the number of client instances the host is opening to the S3 server. | ||
|
||
> ```console | ||
> $ vllm serve /home/meta-llama/Llama-3.2-3B-Instruct --load-format runai_streamer --model-loader-extra-config '{"concurrency":16}' | ||
> ``` | ||
You can controls the size of the CPU Memory buffer to which tensors are read from the file, and limit this size. | ||
You can read further about CPU buffer memory limiting [here](https://github.com/run-ai/runai-model-streamer/blob/master/docs/src/env-vars.md#runai_streamer_memory_limit). | ||
> ```console | ||
> $ vllm serve /home/meta-llama/Llama-3.2-3B-Instruct --load-format runai_streamer --model-loader-extra-config '{"memory_limit":5368709120}' | ||
> ``` | ||
```{note} | ||
For further instructions about tunable parameters and additional parameters configurable through environment variables, read the [Environment Variables Documentation](https://github.com/run-ai/runai-model-streamer/blob/master/docs/src/env-vars.md). | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
31 changes: 31 additions & 0 deletions
31
tests/runai_model_streamer/test_runai_model_streamer_loader.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from vllm import SamplingParams | ||
from vllm.config import LoadConfig, LoadFormat | ||
from vllm.model_executor.model_loader.loader import (RunaiModelStreamerLoader, | ||
get_model_loader) | ||
|
||
test_model = "openai-community/gpt2" | ||
|
||
prompts = [ | ||
"Hello, my name is", | ||
"The president of the United States is", | ||
"The capital of France is", | ||
"The future of AI is", | ||
] | ||
# Create a sampling params object. | ||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0) | ||
|
||
|
||
def get_runai_model_loader(): | ||
load_config = LoadConfig(load_format=LoadFormat.RUNAI_STREAMER) | ||
return get_model_loader(load_config) | ||
|
||
|
||
def test_get_model_loader_with_runai_flag(): | ||
model_loader = get_runai_model_loader() | ||
assert isinstance(model_loader, RunaiModelStreamerLoader) | ||
|
||
|
||
def test_runai_model_loader_download_files(vllm_runner): | ||
with vllm_runner(test_model, load_format=LoadFormat.RUNAI_STREAMER) as llm: | ||
deserialized_outputs = llm.generate(prompts, sampling_params) | ||
assert deserialized_outputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import glob | ||
import tempfile | ||
|
||
import huggingface_hub.constants | ||
import torch | ||
|
||
from vllm.model_executor.model_loader.weight_utils import ( | ||
download_weights_from_hf, runai_safetensors_weights_iterator, | ||
safetensors_weights_iterator) | ||
|
||
|
||
def test_runai_model_loader(): | ||
with tempfile.TemporaryDirectory() as tmpdir: | ||
huggingface_hub.constants.HF_HUB_OFFLINE = False | ||
download_weights_from_hf("openai-community/gpt2", | ||
allow_patterns=["*.safetensors"], | ||
cache_dir=tmpdir) | ||
safetensors = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True) | ||
assert len(safetensors) > 0 | ||
|
||
runai_model_streamer_tensors = {} | ||
hf_safetensors_tensors = {} | ||
|
||
for name, tensor in runai_safetensors_weights_iterator(safetensors): | ||
runai_model_streamer_tensors[name] = tensor | ||
|
||
for name, tensor in safetensors_weights_iterator(safetensors): | ||
hf_safetensors_tensors[name] = tensor | ||
|
||
assert len(runai_model_streamer_tensors) == len(hf_safetensors_tensors) | ||
|
||
for name, runai_tensor in runai_model_streamer_tensors.items(): | ||
assert runai_tensor.dtype == hf_safetensors_tensors[name].dtype | ||
assert runai_tensor.shape == hf_safetensors_tensors[name].shape | ||
assert torch.all(runai_tensor.eq(hf_safetensors_tensors[name])) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_runai_model_loader() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.