From 231acc8b79c878e0f044ec6a615630f14afde047 Mon Sep 17 00:00:00 2001 From: fenghuizhang Date: Thu, 26 Dec 2024 11:13:49 -0800 Subject: [PATCH] A few tweaks to the JetStream code for better observability and throughput. + Added custom GC config on the serve side, by defult Python does too much GC as we allocate a lot of objects. + Tweaked log level in orchestrator to WARNING so important messages don't hide in server logs. + Added slow TTFT detection and text logging on both server and client side (benchmark_serving as the client). + Fixed timestamp recording on the server side. + Added prefill based throttling on the client side. + Added concurrent active request throttling on the client side. --- benchmarks/benchmark_serving.py | 85 +++++++++++++++++++++++++++++++-- jetstream/core/config_lib.py | 5 ++ jetstream/core/orchestrator.py | 44 ++++++++++++----- jetstream/core/server_lib.py | 15 +++++- 4 files changed, 134 insertions(+), 15 deletions(-) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 97628372..15a663ce 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -62,6 +62,7 @@ import asyncio from dataclasses import dataclass, field from datetime import datetime +import gc import json import random import time @@ -107,6 +108,40 @@ def str2bool(v: str) -> bool: raise ValueError(f"Invalid value '{v}'!") +class AsyncCounter: + """An counter class for counting and quota management with asycio, + not thread safe. It's safe with asyncio as value changes are done + outside of await statements. + """ + + def __init__(self, init_value: int, block_on_zero_seconds=0.002): + """ + Args: + init_value: Initial value for the counter. + block_on_zero_seconds: if greater than 0, the counter will spin when + value hits 0, hence can be used for quota management. + """ + self._init_value = init_value + self._value = init_value + self._block_on_zero_seconds = block_on_zero_seconds + + async def inc(self): + self._value += 1 + + async def dec(self): + while True: + if self._value > 0 or self._block_on_zero_seconds <= 0.0: + self._value -= 1 + return + await asyncio.sleep(self._block_on_zero_seconds) + + def value(self): + return self._value + + def delta(self): + return self._init_value - self._value + + @dataclass class BenchmarkMetrics: """Data class to store benchmark metrics.""" @@ -378,6 +413,7 @@ def calculate_metrics( completed = 0 per_token_latencies = [] ttfts = [] + output_sizes = [] for i in range(len(outputs)): if outputs[i].success: output_len = len( @@ -385,6 +421,7 @@ def calculate_metrics( if tokenizer != "test" else ["Ċ", "Ō", "Ɵ"] ) + output_sizes.append(output_len) total_output += output_len total_input += input_requests[i].prompt_len if output_len == 0: @@ -397,6 +434,10 @@ def calculate_metrics( ttfts.append(outputs[i].ttft) completed += 1 + print("Mean output size:", float(np.mean(output_sizes))) + print("Median output size:", float(np.median(output_sizes))) + print("P99 output size:", float(np.percentile(output_sizes, 99))) + metrics = BenchmarkMetrics( completed=completed, total_input=total_input, @@ -416,21 +457,32 @@ def calculate_metrics( async def grpc_async_request( - api_url: str, request: Any + api_url: str, + request: Any, + prefill_quota: AsyncCounter, + active_req_quota: AsyncCounter, ) -> tuple[list[str], float, float]: """Send grpc synchronous request since the current grpc server is sync.""" options = [("grpc.keepalive_timeout_ms", 10000)] async with grpc.aio.insecure_channel(api_url, options=options) as channel: stub = jetstream_pb2_grpc.OrchestratorStub(channel) - print("Making request") ttft = 0 token_list = [] request_start_time = time.perf_counter() response = stub.Decode(request) async for resp in response: if ttft == 0: + await prefill_quota.inc() + ttft = time.perf_counter() - request_start_time + if ttft > 2.0: + print( + datetime.now(), + f"slow TTFT {ttft:.2f}", + prefill_quota.value(), + ) token_list.extend(resp.stream_content.samples[0].token_ids) + await active_req_quota.inc() latency = time.perf_counter() - request_start_time return token_list, ttft, latency @@ -439,9 +491,12 @@ async def send_request( api_url: str, tokenizer: Any, input_request: InputRequest, + prefill_quota: AsyncCounter, + active_req_quota: AsyncCounter, pbar: tqdm, ) -> RequestFuncOutput: """Send the request to JetStream server.""" + # Tokenization on client side following MLPerf standard. token_ids = tokenizer.encode(input_request.prompt) request = jetstream_pb2.DecodeRequest( @@ -449,12 +504,15 @@ async def send_request( token_ids=token_ids ), max_tokens=input_request.output_len, + metadata=jetstream_pb2.DecodeRequest.Metadata( + start_time=time.perf_counter() + ), ) output = RequestFuncOutput() output.input_request = input_request output.prompt_len = input_request.prompt_len generated_token_list, ttft, latency = await grpc_async_request( - api_url, request + api_url, request, prefill_quota, active_req_quota ) output.ttft = ttft output.latency = latency @@ -463,6 +521,12 @@ async def send_request( output.generated_text = tokenizer.decode(generated_token_list) output.success = True if pbar: + pbar.postfix = ( + f"#reqs: {active_req_quota.delta()}/" + f"{active_req_quota.value()}; " + f"#prefill: {prefill_quota.delta()}/" + f"{prefill_quota.value()}" + ) pbar.update(1) return output @@ -473,6 +537,8 @@ async def benchmark( input_requests: list[InputRequest], request_rate: float, disable_tqdm: bool, + prefill_quota: AsyncCounter, + active_req_quota: AsyncCounter, ): """Benchmark the online serving performance.""" pbar = None if disable_tqdm else tqdm(total=len(input_requests)) @@ -482,12 +548,17 @@ async def benchmark( benchmark_start_time = time.perf_counter() tasks = [] async for request in get_request(input_requests, request_rate): + await prefill_quota.dec() + await active_req_quota.dec() + tasks.append( asyncio.create_task( send_request( api_url=api_url, tokenizer=tokenizer, input_request=request, + prefill_quota=prefill_quota, + active_req_quota=active_req_quota, pbar=pbar, ) ) @@ -579,6 +650,9 @@ def main(args: argparse.Namespace): tokenizer_id = args.tokenizer use_hf_tokenizer = args.use_hf_tokenizer + prefill_quota = AsyncCounter(init_value=3) + active_req_quota = AsyncCounter(init_value=450) + api_url = f"{args.server}:{args.port}" tokenizer = get_tokenizer(model_id, tokenizer_id, use_hf_tokenizer) @@ -621,6 +695,8 @@ def main(args: argparse.Namespace): input_requests=warmup_requests, request_rate=args.request_rate, disable_tqdm=args.disable_tqdm, + prefill_quota=prefill_quota, + active_req_quota=active_req_quota, ) ) print(f"{args.warmup_mode} warmup completed.") @@ -636,6 +712,8 @@ def main(args: argparse.Namespace): input_requests=input_requests, request_rate=args.request_rate, disable_tqdm=args.disable_tqdm, + prefill_quota=prefill_quota, + active_req_quota=active_req_quota, ) ) @@ -836,4 +914,5 @@ def main(args: argparse.Namespace): ) parsed_args = parser.parse_args() + gc.disable() main(parsed_args) diff --git a/jetstream/core/config_lib.py b/jetstream/core/config_lib.py index f3022d01..0bae35b8 100644 --- a/jetstream/core/config_lib.py +++ b/jetstream/core/config_lib.py @@ -39,6 +39,11 @@ class ServerConfig: generate_engine_create_fns: Tuple[CreateEngineFn, ...] = () interleaved_engine_create_fns: Tuple[CreateEngineFn, ...] = () is_ray_backend: bool = False + # Parameters for customized gc config, increase the numbers here will + # potentially increase memory usage. + gc_gen0_allocs: int = 60000 # default is 700, too frequent sometimes. + gc_gen1_multipler: int = 2 # Make gen1 gc runs less frequent + gc_gen2_multipler: int = 3 # Make gen2 gc runs less frequent @dataclasses.dataclass diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 15fc36dd..dfd2594d 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -74,6 +74,7 @@ to debug hangs due to bugs in threads (it is easier to debug with live logs). """ +from datetime import datetime import dataclasses import functools import itertools @@ -98,10 +99,10 @@ import numpy as np root = logging.getLogger() -root.setLevel(logging.INFO) +root.setLevel(logging.WARNING) handler = logging.StreamHandler(sys.stdout) -handler.setLevel(logging.INFO) +handler.setLevel(logging.WARNING) formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) @@ -113,18 +114,25 @@ class ActiveRequestMetadata: """Inference request metadata.""" - start_time: Optional[float] = None + start_time: float = 0.0 - prefill_enqueue_time: Optional[float] = None - prefill_dequeue_time: Optional[float] = None + prefill_enqueue_time: float = 0.0 + prefill_dequeue_time: float = 0.0 - transfer_enqueue_time: Optional[float] = None - transfer_dequeue_time: Optional[float] = None + transfer_enqueue_time: float = 0.0 + transfer_dequeue_time: float = 0.0 - generate_enqueue_time: Optional[float] = None - generate_dequeue_time: Optional[float] = None + generate_enqueue_time: float = 0.0 + generate_dequeue_time: float = 0.0 - complete_time: Optional[float] = None + complete_time: float = 0.0 + + def stats(self) -> str: + return ( + f"{self.prefill_enqueue_time - self.start_time:.2f};" + f"{self.prefill_dequeue_time - self.prefill_enqueue_time:.2f};" + f"{time.perf_counter() - self.prefill_dequeue_time:.2f}" + ) @dataclasses.dataclass @@ -245,7 +253,7 @@ def __init__( if generate_params is None: generate_params = [] - logging.info( + logging.warning( "Initialising driver with %d prefill engines and %d generate engines.", len(prefill_engines), len(generate_engines), @@ -476,6 +484,9 @@ def get_total_concurrent_requests(self) -> int: ) return total_max_concurrent_decodes + def prefill_backlog_size(self): + return self._prefill_backlog.qsize() + def place_request_on_prefill_queue(self, request: ActiveRequest): """Used to place new requests for prefilling and generation.""" # Don't block so we can fail and shed load when the queue is full. @@ -980,6 +991,8 @@ async def Decode( # pylint: disable=invalid-overridden-method context: Optional[grpc.aio.ServicerContext] = None, ) -> AsyncIterator[jetstream_pb2.DecodeResponse]: """Decode.""" + request_start_time = time.perf_counter() + ttft = 0 if context is None: logging.warning( "LLM orchestrator is being used in offline test mode, and will not" @@ -1031,6 +1044,15 @@ async def Decode( # pylint: disable=invalid-overridden-method buffered_response_list = [] async for response in active_request.return_channel: response = cast(list[ReturnSample], response) + if ttft == 0: + ttft = time.perf_counter() - request_start_time + if ttft > 2.0: + print( + datetime.now(), + f"Slow TTFT: {ttft:.2f}s," + f" stats={active_request.metadata.stats()}," + f" prefill_qsize={self._driver.prefill_backlog_size()}", + ) if is_client_side_tokenization: # If is_client_side_tokenization, the client should request with token # ids, and the JetStream server will return token ids as response. diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index b323286a..3f7ef79b 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -19,6 +19,7 @@ import asyncio from concurrent import futures +import gc import logging import os import signal @@ -218,8 +219,20 @@ def run( # to make sure we can fully saturate the model. Set default minimum to 64. threads = threads or max(driver.get_total_concurrent_requests(), 64) jetstream_server = JetStreamServer(driver, threads, port, credentials) - logging.info("Starting server on port %d with %d threads", port, threads) + # Tweak gc config. + # Force a gen 2 collection here. + gc.collect(generation=2) + # Freeze objects currently tracked and ignore them in future gc runs. + gc.freeze() + allocs, gen1, gen2 = gc.get_threshold() + allocs = config.gc_gen0_allocs + gen1 = gen1 * config.gc_gen1_multipler + gen2 = gen2 * config.gc_gen2_multipler + gc.set_threshold(allocs, gen1, gen2) + print("GC tweaked (allocs, gen1, gen2): ", allocs, gen1, gen2) + + logging.info("Starting server on port %d with %d threads", port, threads) jetstream_server.start() if metrics_collector: