Skip to content

Commit

Permalink
Merge branch 'main' into chat-template-content-format
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 committed Nov 5, 2024
2 parents 0410d9f + bbc3619 commit 4856117
Show file tree
Hide file tree
Showing 171 changed files with 4,358 additions and 2,690 deletions.
11 changes: 6 additions & 5 deletions .buildkite/run-amd-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,12 @@ fi
PARALLEL_JOB_COUNT=8
# check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs.
if [[ $commands == *"--shard-id="* ]]; then
# assign job count as the number of shards used
commands=${commands//"--num-shards= "/"--num-shards=${PARALLEL_JOB_COUNT} "}
for GPU in $(seq 0 $(($PARALLEL_JOB_COUNT-1))); do
#replace shard arguments
commands=${commands//"--shard-id= "/"--shard-id=${GPU} "}
commands=${commands//"--num-shards= "/"--num-shards=${PARALLEL_JOB_COUNT} "}
echo "Shard ${GPU} commands:$commands"
# assign shard-id for each shard
commands_gpu=${commands//"--shard-id= "/"--shard-id=${GPU} "}
echo "Shard ${GPU} commands:$commands_gpu"
docker run \
--device /dev/kfd --device /dev/dri \
--network host \
Expand All @@ -123,7 +124,7 @@ if [[ $commands == *"--shard-id="* ]]; then
-e HF_HOME=${HF_MOUNT} \
--name ${container_name}_${GPU} \
${image_name} \
/bin/bash -c "${commands}" \
/bin/bash -c "${commands_gpu}" \
|& while read -r line; do echo ">>Shard $GPU: $line"; done &
PIDS+=($!)
done
Expand Down
2 changes: 1 addition & 1 deletion .buildkite/run-tpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ remove_docker_container
# For HF_TOKEN.
source /etc/environment
# Run a simple end-to-end example.
docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py"
docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && python3 -m pip install lm_eval[api]==0.4.4 && pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py"
9 changes: 9 additions & 0 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@ updates:
reviewers: ["khluu", "simon-mo"]
allow:
- dependency-type: "all"
ignore:
- dependency-name: "torch"
- dependency-name: "torchvision"
- dependency-name: "xformers"
- dependency-name: "lm-format-enforcer"
- dependency-name: "gguf"
- dependency-name: "compressed-tensors"
- dependency-name: "ray[adag]"
- dependency-name: "lm-eval"
groups:
patch-update:
applies-to: version-updates
Expand Down
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11
# requirements.txt files and should be kept consistent. The ROCm torch
# versions are derived from Dockerfile.rocm
#
set(TORCH_SUPPORTED_VERSION_CUDA "2.5.0")
set(TORCH_SUPPORTED_VERSION_ROCM "2.5.0")
set(TORCH_SUPPORTED_VERSION_CUDA "2.5.1")
set(TORCH_SUPPORTED_VERSION_ROCM "2.5.1")

#
# Try to find python package with an executable that exactly matches
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ FROM vllm-base AS vllm-openai

# install additional dependencies for openai api server
RUN --mount=type=cache,target=/root/.cache/pip \
pip install accelerate hf_transfer 'modelscope!=1.15.0' bitsandbytes>=0.44.0 timm==0.9.10
pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.44.0' timm==0.9.10

ENV VLLM_USAGE_SOURCE production-docker-image

Expand Down
2 changes: 1 addition & 1 deletion Dockerfile.neuron
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ RUN --mount=type=bind,source=.git,target=.git \
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi

RUN python3 -m pip install -U \
cmake>=3.26 ninja packaging setuptools-scm>=8 wheel jinja2 \
'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
-r requirements-neuron.txt

ENV VLLM_TARGET_DEVICE neuron
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile.ppc64le
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ RUN --mount=type=bind,source=.git,target=.git \
# These packages will be in rocketce eventually
RUN --mount=type=cache,target=/root/.cache/pip \
pip install -v --prefer-binary --extra-index-url https://repo.fury.io/mgiessing \
cmake>=3.26 ninja packaging setuptools-scm>=8 wheel jinja2 \
'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
torch==2.3.1 \
-r requirements-cpu.txt \
xformers uvloop==0.20.0
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip uninstall -y torch torchvision \
&& python3 -m pip install --pre \
torch==2.6.0.dev20240918 \
setuptools-scm>=8 \
'setuptools-scm>=8' \
torchvision==0.20.0.dev20240918 \
--extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2;; \
*) ;; esac
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile.tpu
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ENV VLLM_TARGET_DEVICE="tpu"
RUN --mount=type=cache,target=/root/.cache/pip \
--mount=type=bind,source=.git,target=.git \
python3 -m pip install \
cmake>=3.26 ninja packaging setuptools-scm>=8 wheel jinja2 \
'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
-r requirements-tpu.txt
RUN python3 setup.py develop

Expand Down
4 changes: 2 additions & 2 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,9 +406,9 @@ def calculate_metrics(
median_itl_ms=np.median(itls or 0) * 1000,
percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000)
for p in selected_percentiles],
mean_e2el_ms=np.median(e2els or 0) * 1000,
mean_e2el_ms=np.mean(e2els or 0) * 1000,
std_e2el_ms=np.std(e2els or 0) * 1000,
median_e2el_ms=np.mean(e2els or 0) * 1000,
median_e2el_ms=np.median(e2els or 0) * 1000,
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
for p in selected_percentiles],
)
Expand Down
81 changes: 55 additions & 26 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import random
import time
from typing import List, Optional, Tuple
from typing import List, Optional

import torch
import uvloop
Expand All @@ -15,16 +15,35 @@
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)
from vllm.inputs import TextPrompt
from vllm.multimodal import MultiModalDataDict
from vllm.sampling_params import BeamSearchParams
from vllm.utils import FlexibleArgumentParser, merge_async_iterators


@dataclasses.dataclass
class SampleRequest:
"""A class representing a single inference request for benchmarking.
Attributes:
prompt: The input text prompt for the model.
multi_modal_data: Optional dictionary containing multi-modal data (e.g.
images).
prompt_len: The length of the prompt in tokens.
expected_output_len: The expected length of the output in tokens.
"""
prompt: str
prompt_len: int
expected_output_len: int
multi_modal_data: Optional[MultiModalDataDict] = None


def sample_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int],
) -> List[Tuple[str, int, int]]:
) -> List[SampleRequest]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")

Expand All @@ -41,7 +60,7 @@ def sample_requests(
random.shuffle(dataset)

# Filter out sequences that are too long or too short
filtered_dataset: List[Tuple[str, int, int]] = []
filtered_dataset: List[SampleRequest] = []
for i in range(len(dataset)):
if len(filtered_dataset) == num_requests:
break
Expand All @@ -60,31 +79,34 @@ def sample_requests(
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences.
continue
filtered_dataset.append((prompt, prompt_len, output_len))
filtered_dataset.append(
SampleRequest(prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len))

return filtered_dataset


def run_vllm(
requests: List[Tuple[str, int, int]],
requests: List[SampleRequest],
n: int,
engine_args: EngineArgs,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args))

# Add the requests to the engine.
prompts: List[str] = []
prompts: List[TextPrompt] = []
sampling_params: List[SamplingParams] = []
for prompt, _, output_len in requests:
prompts.append(prompt)
for request in requests:
prompts.append(TextPrompt(prompt=request.prompt))
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=output_len,
max_tokens=request.expected_output_len,
))

use_beam_search = False
Expand All @@ -94,11 +116,11 @@ def run_vllm(
llm.generate(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter()
else:
prompts = [prompt for prompt, _, _ in requests]
prompts = [request.prompt for request in requests]
# output_len should be the same for all requests.
output_len = requests[0][2]
for prompt, input_len, _output_len in requests:
assert _output_len == output_len
for request in requests:
assert request.expected_output_len == output_len
start = time.perf_counter()
llm.beam_search(
prompts,
Expand All @@ -112,7 +134,7 @@ def run_vllm(


async def run_vllm_async(
requests: List[Tuple[str, int, int]],
requests: List[SampleRequest],
n: int,
engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False,
Expand All @@ -123,17 +145,17 @@ async def run_vllm_async(
engine_args, disable_frontend_multiprocessing) as llm:

# Add the requests to the engine.
prompts: List[str] = []
prompts: List[TextPrompt] = []
sampling_params: List[SamplingParams] = []
for prompt, _, output_len in requests:
prompts.append(prompt)
for request in requests:
prompts.append(TextPrompt(prompt=request.prompt))
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=output_len,
max_tokens=request.expected_output_len,
))

generators = []
Expand All @@ -149,7 +171,7 @@ async def run_vllm_async(


def run_hf(
requests: List[Tuple[str, int, int]],
requests: List[SampleRequest],
model: str,
tokenizer: PreTrainedTokenizerBase,
n: int,
Expand Down Expand Up @@ -207,14 +229,14 @@ def run_hf(


def run_mii(
requests: List[Tuple[str, int, int]],
requests: List[SampleRequest],
model: str,
tensor_parallel_size: int,
output_len: int,
) -> float:
from mii import client, serve
llm = serve(model, tensor_parallel=tensor_parallel_size)
prompts = [prompt for prompt, _, _ in requests]
prompts = [request.prompt for request in requests]

start = time.perf_counter()
llm.generate(prompts, max_new_tokens=output_len)
Expand Down Expand Up @@ -243,8 +265,12 @@ def main(args: argparse.Namespace):
else:
raise ValueError(
f"Failed to synthesize a prompt with {args.input_len} tokens.")
requests = [(prompt, args.input_len, args.output_len)
for _ in range(args.num_prompts)]
requests = [
SampleRequest(prompt=prompt,
prompt_len=args.input_len,
expected_output_len=args.output_len)
for _ in range(args.num_prompts)
]
else:
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
args.output_len)
Expand All @@ -270,9 +296,10 @@ def main(args: argparse.Namespace):
args.output_len)
else:
raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum(prompt_len + output_len
for _, prompt_len, output_len in requests)
total_output_tokens = sum(output_len for _, _, output_len in requests)
total_num_tokens = sum(request.prompt_len + request.expected_output_len
for request in requests)
total_output_tokens = sum(request.expected_output_len
for request in requests)
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
Expand All @@ -299,7 +326,9 @@ def main(args: argparse.Namespace):
parser.add_argument("--dataset",
type=str,
default=None,
help="Path to the dataset.")
help="Path to the dataset. The dataset is expected to "
"be a json in form of List[Dict[..., conversations: "
"List[Dict[..., value: <prompt_or_response>]]]]")
parser.add_argument("--input-len",
type=int,
default=None,
Expand Down
Loading

0 comments on commit 4856117

Please sign in to comment.