Skip to content

Commit

Permalink
[Bench] Support multi round conversation (#2898)
Browse files Browse the repository at this point in the history
This PR adds the support of multi round conversation when benchmarked
with fixed concurrent request mode. When enabled, the chat history will
be logged and appended during benchmark.
  • Loading branch information
cyx-6 authored Sep 13, 2024
1 parent 2b7f128 commit 9918c4b
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 3 deletions.
8 changes: 8 additions & 0 deletions python/mlc_llm/bench/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,5 +317,13 @@ def _main():
help="Whether to enable cuda profile on server. "
"The --mlc-model-lib path should be provided when enabling this option.",
)
parser.add_argument(
"--multi-round",
default=False,
action="store_true",
help="Whether to chat like mulit round conversion with history log each request. "
"Only enabled when benchmarked with fixed concurrent request mode."
"The --num-concurrent-requests should be provided when enabling this option.",
)

main(parser.parse_args())
30 changes: 27 additions & 3 deletions python/mlc_llm/bench/request_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from mlc_llm.bench.api_endpoint import APIEndPoint
from mlc_llm.bench.request_record import RequestRecord
from mlc_llm.protocol.openai_api_protocol import ChatCompletionMessage
from mlc_llm.support import logging

logging.enable_logging()
Expand Down Expand Up @@ -249,19 +250,21 @@ def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:
class FixedConcurrentRequestExecutor(Executor): # pylint: disable=too-few-public-methods
"""The benchmark executor of fixing the number of concurrent requests."""

def __init__(
def __init__( # pylint: disable=too-many-arguments
self,
f_create_api_endpoint: Callable[[], APIEndPoint],
num_processes: Optional[int],
disable_tqdm: bool,
num_concurrent_requests: int,
multi_round: bool,
) -> None:
if num_processes is None:
# We assign each process at most 32 concurrent requests to send
# so that the asyncio pressure will not be too much.
num_processes = min((num_concurrent_requests + 31) // 32, 10)
super().__init__(f_create_api_endpoint, num_processes, disable_tqdm)
self.num_concurrent_requests = num_concurrent_requests
self.multi_round = multi_round

def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:
partitions: List[List[RequestRecord]] = [
Expand All @@ -281,6 +284,7 @@ def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:
partition,
self.num_concurrent_requests // self.num_processes
+ int(i < self.num_concurrent_requests % self.num_processes),
self.multi_round,
)
for i, partition in enumerate(partitions)
]
Expand All @@ -297,21 +301,26 @@ def _process_task(
f_create_api_endpoint: Callable[[], APIEndPoint],
request_records: List[RequestRecord],
num_concurrent_requests: int,
multi_round: bool,
) -> List[RequestRecord]:
if len(request_records) == 0:
return []
chat_history: List[List[ChatCompletionMessage]] = [
[] for _ in range(num_concurrent_requests)
]

async def process_task_impl(
f_create_api_endpoint: Callable[[], APIEndPoint],
request_records: List[RequestRecord],
num_concurrent_requests: int,
multi_round: bool,
) -> List[RequestRecord]:
api_endpoint = f_create_api_endpoint()
updated_request_records: List[RequestRecord] = [None for _ in request_records]
async with api_endpoint:
num_sent_request = 0

async def _task() -> None:
async def _task(i: int) -> None:
nonlocal num_sent_request
while True:
if num_sent_request == len(request_records):
Expand All @@ -320,9 +329,22 @@ async def _task() -> None:
num_sent_request += 1
request = request_records[idx]

if multi_round:
request.chat_cmpl.messages = (
chat_history[i] + request.chat_cmpl.messages
)

updated_request_records[idx] = await api_endpoint(request)

tasks = [asyncio.create_task(_task()) for _ in range(num_concurrent_requests)]
if multi_round:
chat_history[i] = updated_request_records[idx].chat_cmpl.messages + [
ChatCompletionMessage(
content=updated_request_records[idx].output_str,
role="assistant",
)
]

tasks = [asyncio.create_task(_task(i)) for i in range(num_concurrent_requests)]
await asyncio.gather(*tasks)

return updated_request_records
Expand All @@ -332,6 +354,7 @@ async def _task() -> None:
f_create_api_endpoint,
request_records,
num_concurrent_requests,
multi_round,
)
)

Expand Down Expand Up @@ -491,6 +514,7 @@ def create_pipelines(
args.num_process_workers,
args.disable_tqdm,
num_concurrent_requests,
args.multi_round,
),
cuda_profile_url=cuda_profile_url,
),
Expand Down

0 comments on commit 9918c4b

Please sign in to comment.