diff --git a/python/mlc_llm/bench/__main__.py b/python/mlc_llm/bench/__main__.py index b0ecc46e2a..a74e187b44 100644 --- a/python/mlc_llm/bench/__main__.py +++ b/python/mlc_llm/bench/__main__.py @@ -317,5 +317,13 @@ def _main(): help="Whether to enable cuda profile on server. " "The --mlc-model-lib path should be provided when enabling this option.", ) + parser.add_argument( + "--multi-round", + default=False, + action="store_true", + help="Whether to chat like mulit round conversion with history log each request. " + "Only enabled when benchmarked with fixed concurrent request mode." + "The --num-concurrent-requests should be provided when enabling this option.", + ) main(parser.parse_args()) diff --git a/python/mlc_llm/bench/request_processor.py b/python/mlc_llm/bench/request_processor.py index 4b69dade09..7c6748c3fa 100644 --- a/python/mlc_llm/bench/request_processor.py +++ b/python/mlc_llm/bench/request_processor.py @@ -15,6 +15,7 @@ from mlc_llm.bench.api_endpoint import APIEndPoint from mlc_llm.bench.request_record import RequestRecord +from mlc_llm.protocol.openai_api_protocol import ChatCompletionMessage from mlc_llm.support import logging logging.enable_logging() @@ -249,12 +250,13 @@ def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: class FixedConcurrentRequestExecutor(Executor): # pylint: disable=too-few-public-methods """The benchmark executor of fixing the number of concurrent requests.""" - def __init__( + def __init__( # pylint: disable=too-many-arguments self, f_create_api_endpoint: Callable[[], APIEndPoint], num_processes: Optional[int], disable_tqdm: bool, num_concurrent_requests: int, + multi_round: bool, ) -> None: if num_processes is None: # We assign each process at most 32 concurrent requests to send @@ -262,6 +264,7 @@ def __init__( num_processes = min((num_concurrent_requests + 31) // 32, 10) super().__init__(f_create_api_endpoint, num_processes, disable_tqdm) self.num_concurrent_requests = num_concurrent_requests + self.multi_round = multi_round def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: partitions: List[List[RequestRecord]] = [ @@ -281,6 +284,7 @@ def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: partition, self.num_concurrent_requests // self.num_processes + int(i < self.num_concurrent_requests % self.num_processes), + self.multi_round, ) for i, partition in enumerate(partitions) ] @@ -297,21 +301,26 @@ def _process_task( f_create_api_endpoint: Callable[[], APIEndPoint], request_records: List[RequestRecord], num_concurrent_requests: int, + multi_round: bool, ) -> List[RequestRecord]: if len(request_records) == 0: return [] + chat_history: List[List[ChatCompletionMessage]] = [ + [] for _ in range(num_concurrent_requests) + ] async def process_task_impl( f_create_api_endpoint: Callable[[], APIEndPoint], request_records: List[RequestRecord], num_concurrent_requests: int, + multi_round: bool, ) -> List[RequestRecord]: api_endpoint = f_create_api_endpoint() updated_request_records: List[RequestRecord] = [None for _ in request_records] async with api_endpoint: num_sent_request = 0 - async def _task() -> None: + async def _task(i: int) -> None: nonlocal num_sent_request while True: if num_sent_request == len(request_records): @@ -320,9 +329,22 @@ async def _task() -> None: num_sent_request += 1 request = request_records[idx] + if multi_round: + request.chat_cmpl.messages = ( + chat_history[i] + request.chat_cmpl.messages + ) + updated_request_records[idx] = await api_endpoint(request) - tasks = [asyncio.create_task(_task()) for _ in range(num_concurrent_requests)] + if multi_round: + chat_history[i] = updated_request_records[idx].chat_cmpl.messages + [ + ChatCompletionMessage( + content=updated_request_records[idx].output_str, + role="assistant", + ) + ] + + tasks = [asyncio.create_task(_task(i)) for i in range(num_concurrent_requests)] await asyncio.gather(*tasks) return updated_request_records @@ -332,6 +354,7 @@ async def _task() -> None: f_create_api_endpoint, request_records, num_concurrent_requests, + multi_round, ) ) @@ -491,6 +514,7 @@ def create_pipelines( args.num_process_workers, args.disable_tqdm, num_concurrent_requests, + args.multi_round, ), cuda_profile_url=cuda_profile_url, ),