Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A few tweaks to the JetStream code for better observability and throu… #158

Merged
merged 1 commit into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 82 additions & 3 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
import asyncio
from dataclasses import dataclass, field
from datetime import datetime
import gc
import json
import random
import time
Expand Down Expand Up @@ -107,6 +108,40 @@ def str2bool(v: str) -> bool:
raise ValueError(f"Invalid value '{v}'!")


class AsyncCounter:
"""An counter class for counting and quota management with asycio,
not thread safe. It's safe with asyncio as value changes are done
outside of await statements.
"""

def __init__(self, init_value: int, block_on_zero_seconds=0.002):
"""
Args:
init_value: Initial value for the counter.
block_on_zero_seconds: if greater than 0, the counter will spin when
value hits 0, hence can be used for quota management.
"""
self._init_value = init_value
self._value = init_value
self._block_on_zero_seconds = block_on_zero_seconds

async def inc(self):
self._value += 1

async def dec(self):
while True:
if self._value > 0 or self._block_on_zero_seconds <= 0.0:
self._value -= 1
return
await asyncio.sleep(self._block_on_zero_seconds)

def value(self):
return self._value

def delta(self):
return self._init_value - self._value


@dataclass
class BenchmarkMetrics:
"""Data class to store benchmark metrics."""
Expand Down Expand Up @@ -378,13 +413,15 @@ def calculate_metrics(
completed = 0
per_token_latencies = []
ttfts = []
output_sizes = []
for i in range(len(outputs)):
if outputs[i].success:
output_len = len(
outputs[i].generated_token_list
if tokenizer != "test"
else ["Ċ", "Ō", "Ɵ"]
)
output_sizes.append(output_len)
total_output += output_len
total_input += input_requests[i].prompt_len
if output_len == 0:
Expand All @@ -397,6 +434,10 @@ def calculate_metrics(
ttfts.append(outputs[i].ttft)
completed += 1

print("Mean output size:", float(np.mean(output_sizes)))
print("Median output size:", float(np.median(output_sizes)))
print("P99 output size:", float(np.percentile(output_sizes, 99)))

metrics = BenchmarkMetrics(
completed=completed,
total_input=total_input,
Expand All @@ -416,21 +457,32 @@ def calculate_metrics(


async def grpc_async_request(
api_url: str, request: Any
api_url: str,
request: Any,
prefill_quota: AsyncCounter,
active_req_quota: AsyncCounter,
) -> tuple[list[str], float, float]:
"""Send grpc synchronous request since the current grpc server is sync."""
options = [("grpc.keepalive_timeout_ms", 10000)]
async with grpc.aio.insecure_channel(api_url, options=options) as channel:
stub = jetstream_pb2_grpc.OrchestratorStub(channel)
print("Making request")
ttft = 0
token_list = []
request_start_time = time.perf_counter()
response = stub.Decode(request)
async for resp in response:
if ttft == 0:
await prefill_quota.inc()

ttft = time.perf_counter() - request_start_time
if ttft > 2.0:
print(
datetime.now(),
f"slow TTFT {ttft:.2f}",
prefill_quota.value(),
)
token_list.extend(resp.stream_content.samples[0].token_ids)
await active_req_quota.inc()
latency = time.perf_counter() - request_start_time
return token_list, ttft, latency

Expand All @@ -439,22 +491,28 @@ async def send_request(
api_url: str,
tokenizer: Any,
input_request: InputRequest,
prefill_quota: AsyncCounter,
active_req_quota: AsyncCounter,
pbar: tqdm,
) -> RequestFuncOutput:
"""Send the request to JetStream server."""

# Tokenization on client side following MLPerf standard.
token_ids = tokenizer.encode(input_request.prompt)
request = jetstream_pb2.DecodeRequest(
token_content=jetstream_pb2.DecodeRequest.TokenContent(
token_ids=token_ids
),
max_tokens=input_request.output_len,
metadata=jetstream_pb2.DecodeRequest.Metadata(
start_time=time.perf_counter()
),
)
output = RequestFuncOutput()
output.input_request = input_request
output.prompt_len = input_request.prompt_len
generated_token_list, ttft, latency = await grpc_async_request(
api_url, request
api_url, request, prefill_quota, active_req_quota
)
output.ttft = ttft
output.latency = latency
Expand All @@ -463,6 +521,12 @@ async def send_request(
output.generated_text = tokenizer.decode(generated_token_list)
output.success = True
if pbar:
pbar.postfix = (
f"#reqs: {active_req_quota.delta()}/"
f"{active_req_quota.value()}; "
f"#prefill: {prefill_quota.delta()}/"
f"{prefill_quota.value()}"
)
pbar.update(1)
return output

Expand All @@ -473,6 +537,8 @@ async def benchmark(
input_requests: list[InputRequest],
request_rate: float,
disable_tqdm: bool,
prefill_quota: AsyncCounter,
active_req_quota: AsyncCounter,
):
"""Benchmark the online serving performance."""
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
Expand All @@ -482,12 +548,17 @@ async def benchmark(
benchmark_start_time = time.perf_counter()
tasks = []
async for request in get_request(input_requests, request_rate):
await prefill_quota.dec()
await active_req_quota.dec()

tasks.append(
asyncio.create_task(
send_request(
api_url=api_url,
tokenizer=tokenizer,
input_request=request,
prefill_quota=prefill_quota,
active_req_quota=active_req_quota,
pbar=pbar,
)
)
Expand Down Expand Up @@ -579,6 +650,9 @@ def main(args: argparse.Namespace):
tokenizer_id = args.tokenizer
use_hf_tokenizer = args.use_hf_tokenizer

prefill_quota = AsyncCounter(init_value=3)
active_req_quota = AsyncCounter(init_value=450)

api_url = f"{args.server}:{args.port}"

tokenizer = get_tokenizer(model_id, tokenizer_id, use_hf_tokenizer)
Expand Down Expand Up @@ -621,6 +695,8 @@ def main(args: argparse.Namespace):
input_requests=warmup_requests,
request_rate=args.request_rate,
disable_tqdm=args.disable_tqdm,
prefill_quota=prefill_quota,
active_req_quota=active_req_quota,
)
)
print(f"{args.warmup_mode} warmup completed.")
Expand All @@ -636,6 +712,8 @@ def main(args: argparse.Namespace):
input_requests=input_requests,
request_rate=args.request_rate,
disable_tqdm=args.disable_tqdm,
prefill_quota=prefill_quota,
active_req_quota=active_req_quota,
)
)

Expand Down Expand Up @@ -836,4 +914,5 @@ def main(args: argparse.Namespace):
)

parsed_args = parser.parse_args()
gc.disable()
main(parsed_args)
5 changes: 5 additions & 0 deletions jetstream/core/config_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ class ServerConfig:
generate_engine_create_fns: Tuple[CreateEngineFn, ...] = ()
interleaved_engine_create_fns: Tuple[CreateEngineFn, ...] = ()
is_ray_backend: bool = False
# Parameters for customized gc config, increase the numbers here will
# potentially increase memory usage.
gc_gen0_allocs: int = 60000 # default is 700, too frequent sometimes.
gc_gen1_multipler: int = 2 # Make gen1 gc runs less frequent
gc_gen2_multipler: int = 3 # Make gen2 gc runs less frequent


@dataclasses.dataclass
Expand Down
44 changes: 33 additions & 11 deletions jetstream/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
to debug hangs due to bugs in threads (it is easier to debug with live logs).
"""

from datetime import datetime
import dataclasses
import functools
import itertools
Expand All @@ -98,10 +99,10 @@
import numpy as np

root = logging.getLogger()
root.setLevel(logging.INFO)
root.setLevel(logging.WARNING)

handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.INFO)
handler.setLevel(logging.WARNING)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
Expand All @@ -113,18 +114,25 @@
class ActiveRequestMetadata:
"""Inference request metadata."""

start_time: Optional[float] = None
start_time: float = 0.0

prefill_enqueue_time: Optional[float] = None
prefill_dequeue_time: Optional[float] = None
prefill_enqueue_time: float = 0.0
prefill_dequeue_time: float = 0.0

transfer_enqueue_time: Optional[float] = None
transfer_dequeue_time: Optional[float] = None
transfer_enqueue_time: float = 0.0
transfer_dequeue_time: float = 0.0

generate_enqueue_time: Optional[float] = None
generate_dequeue_time: Optional[float] = None
generate_enqueue_time: float = 0.0
generate_dequeue_time: float = 0.0

complete_time: Optional[float] = None
complete_time: float = 0.0

def stats(self) -> str:
return (
f"{self.prefill_enqueue_time - self.start_time:.2f};"
f"{self.prefill_dequeue_time - self.prefill_enqueue_time:.2f};"
f"{time.perf_counter() - self.prefill_dequeue_time:.2f}"
)


@dataclasses.dataclass
Expand Down Expand Up @@ -245,7 +253,7 @@ def __init__(
if generate_params is None:
generate_params = []

logging.info(
logging.warning(
"Initialising driver with %d prefill engines and %d generate engines.",
len(prefill_engines),
len(generate_engines),
Expand Down Expand Up @@ -476,6 +484,9 @@ def get_total_concurrent_requests(self) -> int:
)
return total_max_concurrent_decodes

def prefill_backlog_size(self):
return self._prefill_backlog.qsize()

def place_request_on_prefill_queue(self, request: ActiveRequest):
"""Used to place new requests for prefilling and generation."""
# Don't block so we can fail and shed load when the queue is full.
Expand Down Expand Up @@ -980,6 +991,8 @@ async def Decode( # pylint: disable=invalid-overridden-method
context: Optional[grpc.aio.ServicerContext] = None,
) -> AsyncIterator[jetstream_pb2.DecodeResponse]:
"""Decode."""
request_start_time = time.perf_counter()
ttft = 0
if context is None:
logging.warning(
"LLM orchestrator is being used in offline test mode, and will not"
Expand Down Expand Up @@ -1031,6 +1044,15 @@ async def Decode( # pylint: disable=invalid-overridden-method
buffered_response_list = []
async for response in active_request.return_channel:
response = cast(list[ReturnSample], response)
if ttft == 0:
ttft = time.perf_counter() - request_start_time
if ttft > 2.0:
print(
datetime.now(),
f"Slow TTFT: {ttft:.2f}s,"
f" stats={active_request.metadata.stats()},"
f" prefill_qsize={self._driver.prefill_backlog_size()}",
)
if is_client_side_tokenization:
# If is_client_side_tokenization, the client should request with token
# ids, and the JetStream server will return token ids as response.
Expand Down
15 changes: 14 additions & 1 deletion jetstream/core/server_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import asyncio
from concurrent import futures
import gc
import logging
import os
import signal
Expand Down Expand Up @@ -218,8 +219,20 @@ def run(
# to make sure we can fully saturate the model. Set default minimum to 64.
threads = threads or max(driver.get_total_concurrent_requests(), 64)
jetstream_server = JetStreamServer(driver, threads, port, credentials)
logging.info("Starting server on port %d with %d threads", port, threads)

# Tweak gc config.
# Force a gen 2 collection here.
gc.collect(generation=2)
# Freeze objects currently tracked and ignore them in future gc runs.
gc.freeze()
allocs, gen1, gen2 = gc.get_threshold()
allocs = config.gc_gen0_allocs
gen1 = gen1 * config.gc_gen1_multipler
gen2 = gen2 * config.gc_gen2_multipler
gc.set_threshold(allocs, gen1, gen2)
print("GC tweaked (allocs, gen1, gen2): ", allocs, gen1, gen2)

logging.info("Starting server on port %d with %d threads", port, threads)
jetstream_server.start()

if metrics_collector:
Expand Down