Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/vllm-project/vllm into cpu-…
Browse files Browse the repository at this point in the history
…offloading2

Signed-off-by: Dahai Tang <[email protected]>
  • Loading branch information
Dahai Tang committed Dec 18, 2024
2 parents 1ccb4e9 + f04e407 commit 05c1b3f
Show file tree
Hide file tree
Showing 56 changed files with 1,442 additions and 323 deletions.
6 changes: 5 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,12 @@ steps:
mirror_hardwares: [amd]
source_file_dependencies:
- vllm/model_executor/layers
- vllm/model_executor/guided_decoding
- tests/test_logits_processor
command: pytest -v -s test_logits_processor.py
- tests/model_executor/test_guided_processors
commands:
- pytest -v -s test_logits_processor.py
- pytest -v -s model_executor/test_guided_processors.py

- label: Speculative decoding tests # 30min
source_file_dependencies:
Expand Down
262 changes: 262 additions & 0 deletions benchmarks/kernels/benchmark_rmsnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
import itertools
from typing import Optional, Tuple, Union

import torch
import triton
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
from torch import nn

from vllm import _custom_ops as vllm_ops


class HuggingFaceRMSNorm(nn.Module):

def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)

variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight
if residual is None:
return x
else:
return x, residual


def rmsnorm_naive(
x: torch.Tensor,
weight: torch.Tensor,
residual: Optional[torch.Tensor] = None,
eps: float = 1e-6,
):
naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps)
naive_norm.weight = nn.Parameter(weight)
naive_norm = naive_norm.to(x.device)

orig_shape = x.shape
x = x.view(-1, x.shape[-1])
if residual is not None:
residual = residual.view(-1, residual.shape[-1])

output = naive_norm(x, residual)

if isinstance(output, tuple):
output = (output[0].view(orig_shape), output[1].view(orig_shape))
else:
output = output.view(orig_shape)
return output


def rmsnorm_flashinfer(
x: torch.Tensor,
weight: torch.Tensor,
residual: Optional[torch.Tensor] = None,
eps: float = 1e-6,
):
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
if residual is not None:
residual = residual.view(-1, residual.shape[-1])

if residual is not None:
fused_add_rmsnorm(x, residual, weight, eps)
output = (x, residual)
else:
output = rmsnorm(x, weight, eps)

if isinstance(output, tuple):
output = (output[0].view(orig_shape), output[1].view(orig_shape))
else:
output = output.view(orig_shape)
return output


def rmsnorm_vllm(
x: torch.Tensor,
weight: torch.Tensor,
residual: Optional[torch.Tensor] = None,
eps: float = 1e-6,
):
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
if residual is not None:
residual = residual.view(-1, residual.shape[-1])

if residual is not None:
vllm_ops.fused_add_rms_norm(x, residual, weight, eps)
output = (x, residual)
else:
out = torch.empty_like(x)
vllm_ops.rms_norm(out, x, weight, eps)
output = out

if isinstance(output, tuple):
output = (output[0].view(orig_shape), output[1].view(orig_shape))
else:
output = output.view(orig_shape)
return output


def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
dtype = torch.bfloat16
x = torch.randn(batch_size,
seq_len,
hidden_size,
dtype=dtype,
device="cuda")
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
residual = torch.randn_like(x) if use_residual else None

output_naive = rmsnorm_naive(
x.clone(), weight,
residual.clone() if residual is not None else None)
output_flashinfer = rmsnorm_flashinfer(
x.clone(), weight,
residual.clone() if residual is not None else None)
output_vllm = rmsnorm_vllm(
x.clone(), weight,
residual.clone() if residual is not None else None)

if use_residual:
output_naive = output_naive[0]
output_flashinfer = output_flashinfer[0]
output_vllm = output_vllm[0]

print(f"Naive output={output_naive}")
print(f"FlashInfer output={output_flashinfer}")
print(f"VLLM output={output_vllm}")

if torch.allclose(output_naive, output_flashinfer, atol=1e-2,
rtol=1e-2) and torch.allclose(
output_naive, output_vllm, atol=1e-2, rtol=1e-2):
print("✅ All implementations match")
else:
print("❌ Implementations differ")


batch_size_range = [2**i for i in range(0, 7, 2)]
seq_length_range = [2**i for i in range(6, 11, 1)]
head_num_range = [32, 48]
configs = list(
itertools.product(head_num_range, batch_size_range, seq_length_range))


def get_benchmark(use_residual):

@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["head_num", "batch_size", "seq_len"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["huggingface", "flashinfer", "vllm"],
line_names=["HuggingFace", "FlashInfer", "vLLM"],
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
ylabel="us",
plot_name=
f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual",
args={},
))
def benchmark(head_num, batch_size, seq_len, provider):
dtype = torch.bfloat16
hidden_size = head_num * 128 # assuming head_dim = 128

x = torch.randn(batch_size,
seq_len,
hidden_size,
dtype=dtype,
device="cuda")
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
residual = torch.randn_like(x) if use_residual else None

quantiles = [0.5, 0.2, 0.8]

if provider == "huggingface":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: rmsnorm_naive(
x.clone(),
weight,
residual.clone() if residual is not None else None,
),
quantiles=quantiles,
)
elif provider == "flashinfer":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: rmsnorm_flashinfer(
x.clone(),
weight,
residual.clone() if residual is not None else None,
),
quantiles=quantiles,
)
else:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: rmsnorm_vllm(
x.clone(),
weight,
residual.clone() if residual is not None else None,
),
quantiles=quantiles,
)

return 1000 * ms, 1000 * max_ms, 1000 * min_ms

return benchmark


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
parser.add_argument(
"--batch-size",
type=int,
default=4,
help="Batch size",
)
parser.add_argument(
"--seq-len",
type=int,
default=128,
help="Sequence length",
)
parser.add_argument(
"--hidden-size",
type=int,
default=4096,
help="Hidden size (2nd dimension) of the sequence",
)
parser.add_argument("--use-residual",
action="store_true",
help="Whether to use residual connection")
parser.add_argument(
"--save-path",
type=str,
default="./configs/rmsnorm/",
help="Path to save rmsnorm benchmark results",
)

args = parser.parse_args()

# Run correctness test
calculate_diff(batch_size=args.batch_size,
seq_len=args.seq_len,
hidden_size=args.hidden_size,
use_residual=args.use_residual)

# Get the benchmark function with proper use_residual setting
benchmark = get_benchmark(args.use_residual)
# Run performance benchmark
benchmark.run(print_data=True, save_path=args.save_path)
10 changes: 5 additions & 5 deletions docs/source/serving/openai_compatible_server.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,6 @@ We currently support the following OpenAI APIs:
- *Note: `suffix` parameter is not supported.*
- [Chat Completions API](#chat-api) (`/v1/chat/completions`)
- Only applicable to [text generation models](../models/generative_models.rst) (`--task generate`) with a [chat template](#chat-template).
- [Vision](https://platform.openai.com/docs/guides/vision)-related parameters are supported; see [Multimodal Inputs](../usage/multimodal_inputs.rst).
- *Note: `image_url.detail` parameter is not supported.*
- We also support `audio_url` content type for audio files.
- Refer to [vllm.entrypoints.chat_utils](https://github.com/vllm-project/vllm/tree/main/vllm/entrypoints/chat_utils.py) for the exact schema.
- *TODO: Support `input_audio` content type as defined [here](https://github.com/openai/openai-python/blob/v1.52.2/src/openai/types/chat/chat_completion_content_part_input_audio_param.py).*
- *Note: `parallel_tool_calls` and `user` parameters are ignored.*
- [Embeddings API](#embeddings-api) (`/v1/embeddings`)
- Only applicable to [embedding models](../models/pooling_models.rst) (`--task embed`).
Expand Down Expand Up @@ -209,6 +204,11 @@ The following extra parameters are supported:

Refer to [OpenAI's API reference](https://platform.openai.com/docs/api-reference/chat) for more details.

We support both [Vision](https://platform.openai.com/docs/guides/vision)- and
[Audio](https://platform.openai.com/docs/guides/audio?audio-generation-quickstart-example=audio-in)-related parameters;
see our [Multimodal Inputs](../usage/multimodal_inputs.rst) guide for more information.
- *Note: `image_url.detail` parameter is not supported.*

#### Extra parameters

The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported.
Expand Down
90 changes: 89 additions & 1 deletion docs/source/usage/multimodal_inputs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,95 @@ You can use `these tests <https://github.com/vllm-project/vllm/blob/main/tests/e
Audio
^^^^^

Instead of :code:`image_url`, you can pass an audio file via :code:`audio_url`.
Audio input is supported according to `OpenAI Audio API <https://platform.openai.com/docs/guides/audio?audio-generation-quickstart-example=audio-in>`_.
Here is a simple example using Ultravox-v0.3.

First, launch the OpenAI-compatible server:

.. code-block:: bash
vllm serve fixie-ai/ultravox-v0_3
Then, you can use the OpenAI client as follows:

.. code-block:: python
import base64
import requests
from openai import OpenAI
from vllm.assets.audio import AudioAsset
def encode_base64_content_from_url(content_url: str) -> str:
"""Encode a content retrieved from a remote url to base64 format."""
with requests.get(content_url) as response:
response.raise_for_status()
result = base64.b64encode(response.content).decode('utf-8')
return result
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
# Any format supported by librosa is supported
audio_url = AudioAsset("winning_call").url
audio_base64 = encode_base64_content_from_url(audio_url)
chat_completion_from_base64 = client.chat.completions.create(
messages=[{
"role": "user",
"content": [
{
"type": "text",
"text": "What's in this audio?"
},
{
"type": "input_audio",
"input_audio": {
"data": audio_base64,
"format": "wav"
},
},
],
}],
model=model,
max_completion_tokens=64,
)
result = chat_completion_from_base64.choices[0].message.content
print("Chat completion output from input audio:", result)
Alternatively, you can pass :code:`audio_url`, which is the audio counterpart of :code:`image_url` for image input:

.. code-block:: python
chat_completion_from_url = client.chat.completions.create(
messages=[{
"role": "user",
"content": [
{
"type": "text",
"text": "What's in this audio?"
},
{
"type": "audio_url",
"audio_url": {
"url": audio_url
},
},
],
}],
model=model,
max_completion_tokens=64,
)
result = chat_completion_from_url.choices[0].message.content
print("Chat completion output from audio url:", result)
A full code example can be found in `examples/openai_chat_completion_client_for_multimodal.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_chat_completion_client_for_multimodal.py>`_.

Expand Down
Loading

0 comments on commit 05c1b3f

Please sign in to comment.