Skip to content

Commit

Permalink
A few tweaks to the JetStream code for better observability and throu…
Browse files Browse the repository at this point in the history
…ghput.

+ Added custom GC config on the serve side, by defult Python does too much GC as we allocate a lot of objects.
+ Tweaked log level in orchestrator to WARNING so important messages don't hide in server logs.
+ Added slow TTFT detection and text logging on both server and client side (benchmark_serving as the client).
+ Fixed timestamp recording on the server side.
+ Added prefill based throttling on the client side.
+ Added concurrent active request throttling on the client side.
  • Loading branch information
fenghuizhang committed Dec 29, 2024
1 parent 9ca4421 commit bda2e22
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 15 deletions.
87 changes: 84 additions & 3 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
import asyncio
from dataclasses import dataclass, field
from datetime import datetime
import gc
import json
import random
import time
Expand Down Expand Up @@ -107,6 +108,42 @@ def str2bool(v: str) -> bool:
raise ValueError(f"Invalid value '{v}'!")


class AtomicCounter:
"""An atomic counter class for counting and quota management with asycio,
not thread safe.
"""

def __init__(self, init_value: int, block_on_zero_seconds=0.002):
"""
Args:
init_value: Initial value for the counter.
block_on_zero_seconds: if greater than 0, the counter will spin when
value hits 0, hence can be used for quota management.
"""
self._init_value = init_value
self._value = init_value
self._block_on_zero_seconds = block_on_zero_seconds
self._lock = asyncio.Lock()

async def inc(self):
async with self._lock:
self._value += 1

async def dec(self):
while True:
async with self._lock:
if self._value > 0 or self._block_on_zero_seconds <= 0.0:
self._value -= 1
return
await asyncio.sleep(self._block_on_zero_seconds)

def approximate_value(self):
return self._value

def approximate_delta(self):
return self._init_value - self._value


@dataclass
class BenchmarkMetrics:
"""Data class to store benchmark metrics."""
Expand Down Expand Up @@ -378,13 +415,15 @@ def calculate_metrics(
completed = 0
per_token_latencies = []
ttfts = []
output_sizes = []
for i in range(len(outputs)):
if outputs[i].success:
output_len = len(
outputs[i].generated_token_list
if tokenizer != "test"
else ["Ċ", "Ō", "Ɵ"]
)
output_sizes.append(output_len)
total_output += output_len
total_input += input_requests[i].prompt_len
if output_len == 0:
Expand All @@ -397,6 +436,10 @@ def calculate_metrics(
ttfts.append(outputs[i].ttft)
completed += 1

print("Mean output size:", float(np.mean(output_sizes)))
print("Median output size:", float(np.median(output_sizes)))
print("P99 output size:", float(np.percentile(output_sizes, 99)))

metrics = BenchmarkMetrics(
completed=completed,
total_input=total_input,
Expand All @@ -416,21 +459,32 @@ def calculate_metrics(


async def grpc_async_request(
api_url: str, request: Any
api_url: str,
request: Any,
prefill_quota: AtomicCounter,
active_req_quota: AtomicCounter,
) -> tuple[list[str], float, float]:
"""Send grpc synchronous request since the current grpc server is sync."""
options = [("grpc.keepalive_timeout_ms", 10000)]
async with grpc.aio.insecure_channel(api_url, options=options) as channel:
stub = jetstream_pb2_grpc.OrchestratorStub(channel)
print("Making request")
ttft = 0
token_list = []
request_start_time = time.perf_counter()
response = stub.Decode(request)
async for resp in response:
if ttft == 0:
await prefill_quota.inc()

ttft = time.perf_counter() - request_start_time
if ttft > 2.0:
print(
datetime.now(),
f"slow TTFT {ttft:.2f}",
prefill_quota.approximate_value(),
)
token_list.extend(resp.stream_content.samples[0].token_ids)
await active_req_quota.inc()
latency = time.perf_counter() - request_start_time
return token_list, ttft, latency

Expand All @@ -439,22 +493,28 @@ async def send_request(
api_url: str,
tokenizer: Any,
input_request: InputRequest,
prefill_quota: AtomicCounter,
active_req_quota: AtomicCounter,
pbar: tqdm,
) -> RequestFuncOutput:
"""Send the request to JetStream server."""

# Tokenization on client side following MLPerf standard.
token_ids = tokenizer.encode(input_request.prompt)
request = jetstream_pb2.DecodeRequest(
token_content=jetstream_pb2.DecodeRequest.TokenContent(
token_ids=token_ids
),
max_tokens=input_request.output_len,
metadata=jetstream_pb2.DecodeRequest.Metadata(
start_time=time.perf_counter()
),
)
output = RequestFuncOutput()
output.input_request = input_request
output.prompt_len = input_request.prompt_len
generated_token_list, ttft, latency = await grpc_async_request(
api_url, request
api_url, request, prefill_quota, active_req_quota
)
output.ttft = ttft
output.latency = latency
Expand All @@ -463,6 +523,12 @@ async def send_request(
output.generated_text = tokenizer.decode(generated_token_list)
output.success = True
if pbar:
pbar.postfix = (
f"#reqs: {active_req_quota.approximate_delta()}/"
f"{active_req_quota.approximate_value()}; "
f"#prefill: {prefill_quota.approximate_delta()}/"
f"{prefill_quota.approximate_value()}"
)
pbar.update(1)
return output

Expand All @@ -473,6 +539,8 @@ async def benchmark(
input_requests: list[InputRequest],
request_rate: float,
disable_tqdm: bool,
prefill_quota: AtomicCounter,
active_req_quota: AtomicCounter,
):
"""Benchmark the online serving performance."""
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
Expand All @@ -482,12 +550,17 @@ async def benchmark(
benchmark_start_time = time.perf_counter()
tasks = []
async for request in get_request(input_requests, request_rate):
await prefill_quota.dec()
await active_req_quota.dec()

tasks.append(
asyncio.create_task(
send_request(
api_url=api_url,
tokenizer=tokenizer,
input_request=request,
prefill_quota=prefill_quota,
active_req_quota=active_req_quota,
pbar=pbar,
)
)
Expand Down Expand Up @@ -579,6 +652,9 @@ def main(args: argparse.Namespace):
tokenizer_id = args.tokenizer
use_hf_tokenizer = args.use_hf_tokenizer

prefill_quota = AtomicCounter(init_value=3)
active_req_quota = AtomicCounter(init_value=450)

api_url = f"{args.server}:{args.port}"

tokenizer = get_tokenizer(model_id, tokenizer_id, use_hf_tokenizer)
Expand Down Expand Up @@ -621,6 +697,8 @@ def main(args: argparse.Namespace):
input_requests=warmup_requests,
request_rate=args.request_rate,
disable_tqdm=args.disable_tqdm,
prefill_quota=prefill_quota,
active_req_quota=active_req_quota,
)
)
print(f"{args.warmup_mode} warmup completed.")
Expand All @@ -636,6 +714,8 @@ def main(args: argparse.Namespace):
input_requests=input_requests,
request_rate=args.request_rate,
disable_tqdm=args.disable_tqdm,
prefill_quota=prefill_quota,
active_req_quota=active_req_quota,
)
)

Expand Down Expand Up @@ -836,4 +916,5 @@ def main(args: argparse.Namespace):
)

parsed_args = parser.parse_args()
gc.disable()
main(parsed_args)
5 changes: 5 additions & 0 deletions jetstream/core/config_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ class ServerConfig:
generate_engine_create_fns: Tuple[CreateEngineFn, ...] = ()
interleaved_engine_create_fns: Tuple[CreateEngineFn, ...] = ()
is_ray_backend: bool = False
# Parameters for customized gc config, increase the numbers here will
# potentially increase memory usage.
gc_gen0_allocs: int = 60000 # default is 700, too frequent sometimes.
gc_gen1_multipler: int = 2 # Make gen1 gc runs less frequent
gc_gen2_multipler: int = 3 # Make gen2 gc runs less frequent


@dataclasses.dataclass
Expand Down
44 changes: 33 additions & 11 deletions jetstream/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
to debug hangs due to bugs in threads (it is easier to debug with live logs).
"""

from datetime import datetime
import dataclasses
import functools
import itertools
Expand All @@ -98,10 +99,10 @@
import numpy as np

root = logging.getLogger()
root.setLevel(logging.INFO)
root.setLevel(logging.WARNING)

handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.INFO)
handler.setLevel(logging.WARNING)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
Expand All @@ -113,18 +114,25 @@
class ActiveRequestMetadata:
"""Inference request metadata."""

start_time: Optional[float] = None
start_time: float = 0.0

prefill_enqueue_time: Optional[float] = None
prefill_dequeue_time: Optional[float] = None
prefill_enqueue_time: float = 0.0
prefill_dequeue_time: float = 0.0

transfer_enqueue_time: Optional[float] = None
transfer_dequeue_time: Optional[float] = None
transfer_enqueue_time: float = 0.0
transfer_dequeue_time: float = 0.0

generate_enqueue_time: Optional[float] = None
generate_dequeue_time: Optional[float] = None
generate_enqueue_time: float = 0.0
generate_dequeue_time: float = 0.0

complete_time: Optional[float] = None
complete_time: float = 0.0

def stats(self) -> str:
return (
f"{self.prefill_enqueue_time - self.start_time:.2f};"
f"{self.prefill_dequeue_time - self.prefill_enqueue_time:.2f};"
f"{time.perf_counter() - self.prefill_dequeue_time:.2f}"
)


@dataclasses.dataclass
Expand Down Expand Up @@ -245,7 +253,7 @@ def __init__(
if generate_params is None:
generate_params = []

logging.info(
logging.warning(
"Initialising driver with %d prefill engines and %d generate engines.",
len(prefill_engines),
len(generate_engines),
Expand Down Expand Up @@ -476,6 +484,9 @@ def get_total_concurrent_requests(self) -> int:
)
return total_max_concurrent_decodes

def prefill_backlog_size(self):
return self._prefill_backlog.qsize()

def place_request_on_prefill_queue(self, request: ActiveRequest):
"""Used to place new requests for prefilling and generation."""
# Don't block so we can fail and shed load when the queue is full.
Expand Down Expand Up @@ -980,6 +991,8 @@ async def Decode( # pylint: disable=invalid-overridden-method
context: Optional[grpc.aio.ServicerContext] = None,
) -> AsyncIterator[jetstream_pb2.DecodeResponse]:
"""Decode."""
request_start_time = time.perf_counter()
ttft = 0
if context is None:
logging.warning(
"LLM orchestrator is being used in offline test mode, and will not"
Expand Down Expand Up @@ -1031,6 +1044,15 @@ async def Decode( # pylint: disable=invalid-overridden-method
buffered_response_list = []
async for response in active_request.return_channel:
response = cast(list[ReturnSample], response)
if ttft == 0:
ttft = time.perf_counter() - request_start_time
if ttft > 2.0:
print(
datetime.now(),
f"Slow TTFT: {ttft:.2f}s,"
f" stats={active_request.metadata.stats()},"
f" prefill_qsize={self._driver.prefill_backlog_size()}",
)
if is_client_side_tokenization:
# If is_client_side_tokenization, the client should request with token
# ids, and the JetStream server will return token ids as response.
Expand Down
15 changes: 14 additions & 1 deletion jetstream/core/server_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import asyncio
from concurrent import futures
import gc
import logging
import os
import signal
Expand Down Expand Up @@ -218,8 +219,20 @@ def run(
# to make sure we can fully saturate the model. Set default minimum to 64.
threads = threads or max(driver.get_total_concurrent_requests(), 64)
jetstream_server = JetStreamServer(driver, threads, port, credentials)
logging.info("Starting server on port %d with %d threads", port, threads)

# Tweak gc config.
# Force a gen 2 collection here.
gc.collect(generation=2)
# Freeze objects currently tracked and ignore them in future gc runs.
gc.freeze()
allocs, gen1, gen2 = gc.get_threshold()
allocs = config.gc_gen0_allocs
gen1 = gen1 * config.gc_gen1_multipler
gen2 = gen2 * config.gc_gen2_multipler
gc.set_threshold(allocs, gen1, gen2)
print("GC tweaked (allocs, gen1, gen2): ", allocs, gen1, gen2)

logging.info("Starting server on port %d with %d threads", port, threads)
jetstream_server.start()

if metrics_collector:
Expand Down

0 comments on commit bda2e22

Please sign in to comment.