Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Do Not Merge] - LoRA V1 Reference PR #11613

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 40 additions & 16 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import argparse
import dataclasses
import json
import pickle
import random
import time
from functools import cache
Expand All @@ -21,10 +22,14 @@
from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
from vllm.utils import FlexibleArgumentParser, merge_async_iterators

SAMPLING_TEMPERATURE = 0.0
SAMPLING_TOP_P = 1.0


@dataclasses.dataclass
class SampleRequest:
Expand Down Expand Up @@ -165,7 +170,7 @@ def run_vllm(
requests: List[SampleRequest],
n: int,
engine_args: EngineArgs,
) -> float:
) -> Tuple[float, Optional[List[RequestOutput]]]:
from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args))

Expand All @@ -179,8 +184,8 @@ def run_vllm(
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
temperature=SAMPLING_TEMPERATURE,
top_p=SAMPLING_TOP_P,
ignore_eos=True,
max_tokens=request.expected_output_len,
))
Expand All @@ -190,12 +195,13 @@ def run_vllm(

use_beam_search = False

outputs = None
if not use_beam_search:
start = time.perf_counter()
llm.generate(prompts,
sampling_params,
lora_request=lora_requests,
use_tqdm=True)
outputs = llm.generate(prompts,
sampling_params,
lora_request=lora_requests,
use_tqdm=True)
end = time.perf_counter()
else:
assert lora_requests is None, "BeamSearch API does not support LoRA"
Expand All @@ -213,15 +219,15 @@ def run_vllm(
ignore_eos=True,
))
end = time.perf_counter()
return end - start
return end - start, outputs


async def run_vllm_async(
requests: List[SampleRequest],
n: int,
engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False,
) -> float:
) -> Tuple[float, Optional[List[RequestOutput]]]:
from vllm import SamplingParams

async with build_async_engine_client_from_engine_args(
Expand All @@ -238,8 +244,8 @@ async def run_vllm_async(
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
temperature=SAMPLING_TEMPERATURE,
top_p=SAMPLING_TOP_P,
ignore_eos=True,
max_tokens=request.expected_output_len,
))
Expand All @@ -255,10 +261,17 @@ async def run_vllm_async(
request_id=f"test{i}")
generators.append(generator)
all_gens = merge_async_iterators(*generators)
outputs_dict = {}
async for i, res in all_gens:
pass
outputs_dict[i] = res
end = time.perf_counter()
return end - start

num_prompts = len(prompts)
outputs = []
for i in range(num_prompts):
outputs.append(outputs_dict[i])

return end - start, outputs


def run_hf(
Expand Down Expand Up @@ -391,16 +404,22 @@ def main(args: argparse.Namespace):
for request in requests)
if args.backend == "vllm":
if args.async_engine:
elapsed_time = uvloop.run(
elapsed_time, outputs = uvloop.run(
run_vllm_async(
requests,
args.n,
AsyncEngineArgs.from_cli_args(args),
args.disable_frontend_multiprocessing,
))
else:
elapsed_time = run_vllm(requests, args.n,
EngineArgs.from_cli_args(args))
elapsed_time, outputs = run_vllm(requests, args.n,
EngineArgs.from_cli_args(args))

if args.pickle_outputs:
print("Pickling request outputs : ")
with open("outputs.pkl", "wb+") as f:
pickle.dump(outputs, f)

elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
Expand Down Expand Up @@ -490,6 +509,11 @@ def main(args: argparse.Namespace):
help="Path to the lora adapters to use. This can be an absolute path, "
"a relative path, or a Hugging Face model identifier.")

parser.add_argument("--pickle-outputs",
action="store_true",
default=False,
help="Pickle outputs got from benchmark")

parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
if args.tokenizer is None:
Expand Down
33 changes: 33 additions & 0 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import tempfile
from collections import OrderedDict
from contextlib import contextmanager
from typing import Dict, List, TypedDict
from unittest.mock import MagicMock, patch

Expand Down Expand Up @@ -76,6 +77,21 @@ def dist_init():
cleanup_dist_env_and_memory(shutdown_ray=True)


@contextmanager
def _dist_init():
temp_file = tempfile.mkstemp()[1]
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method=f"file://{temp_file}",
local_rank=0,
backend="nccl",
)
initialize_model_parallel(1, 1)
yield
cleanup_dist_env_and_memory(shutdown_ray=True)


@pytest.fixture
def dist_init_torch_only():
if torch.distributed.is_initialized():
Expand Down Expand Up @@ -274,3 +290,20 @@ def get_model_patched(**kwargs):
def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings):
yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker.
model_runner.model)


@pytest.fixture(params=[True])
def run_with_both_engines_lora(request):
# Automatically runs tests twice, once with V1 and once without
use_v1 = request.param
# Tests decorated with `@skip_v1` are only run without v1
skip_v1 = request.node.get_closest_marker("skip_v1")

if use_v1:
if skip_v1:
pytest.skip("Skipping test on vllm V1")
with patch('vllm.envs.VLLM_USE_V1', True):
yield
else:
with patch('vllm.envs.VLLM_USE_V1', False):
yield
Loading
Loading