Skip to content

Commit

Permalink
Performance optimized interleaved mode JetStream server
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZijunZhou committed Jul 26, 2024
1 parent af1b918 commit 20ba07c
Showing 1 changed file with 60 additions and 25 deletions.
85 changes: 60 additions & 25 deletions jetstream/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ class Driver:
# This can be a list because we can pass it as an arg to generate and
# detokenize threads. It is a list of tokens to be detokenized.
_detokenize_backlogs: list[queue.Queue[engine_api.ResultTokens]] = []
_prefill_detokenize_backlogs: list[queue.Queue[engine_api.ResultTokens]] = []
_generate_slots: list[queue.Queue[int]] = []
_active_requests: list[queue.Queue[tuple[int, ActiveRequest]]] = []

Expand Down Expand Up @@ -316,6 +317,12 @@ def __init__(
queue.Queue(8)
for _ in self._generate_engines
]
self._prefill_detokenize_backlogs = [
# We don't let detokenization accumulate more than 8 steps to avoid
# synchronization issues.
queue.Queue(8)
for _ in self._prefill_engines
]

# A queue of integers representing available 'slots' in the decode
# operation. I.e. potentially available rows in the batch and/or microbatch.
Expand Down Expand Up @@ -376,12 +383,23 @@ def __init__(
)
for idx in range(len(self._generate_engines))
]
self.prefill_detokenize_threads = [
JetThread(
target=functools.partial(
self._prefill_detokenize_thread,
idx,
),
name=f"prefill-detokenize-{idx}",
)
for idx in range(len(self._prefill_engines))
]
self._all_threads = list(
itertools.chain(
self._prefill_threads,
self._transfer_threads,
self._generate_threads,
self.detokenize_threads,
self.prefill_detokenize_threads,
)
)
self.live = True
Expand Down Expand Up @@ -514,12 +532,12 @@ def _prefill_thread(self, idx: int):
padded_tokens=padded_tokens,
true_length=true_length,
)

first_token.copy_to_host_async()
request.prefill_result = prefill_result

# put first token to detokenize queue
request.complete = np.zeros((prefill_engine.samples_per_slot,), np.bool_)
my_detokenize_backlog = self._detokenize_backlogs[idx]
my_detokenize_backlog = self._prefill_detokenize_backlogs[idx]
my_detokenize_backlog.put(
(first_token, request, request_start_time), block=True
)
Expand Down Expand Up @@ -736,34 +754,13 @@ def _detokenize_thread(self, idx: int):
if data is None:
break
start_detokenize_time = time.time()
# prefill first token
if isinstance(data[0], engine_api.ResultTokens):
request_first_token, request, request_start_time = data
request_first_token = request_first_token.convert_to_numpy()

results, complete = token_utils.process_result_tokens(
tokenizer=tokenizer,
slot=0, # always 0 as prefill only run 1 sample
slot_max_length=request.max_tokens,
result_tokens=request_first_token,
is_client_side_tokenization=request.is_client_side_tokenization,
complete=request.complete,
)
request.complete = complete
# Return some output samples.
request.enqueue_samples(results)

first_token_return_time = time.perf_counter()
logging.info(
"TTFT duration: %fms",
(first_token_return_time - request_start_time) * 1000,
)
# generate step tokens
elif isinstance(data[1], engine_api.ResultTokens):
if isinstance(data[1], engine_api.ResultTokens):
# We want to detokenize them.
generate_timestep_added, result_tokens = data
# Disable attribute error because pytype doesn't know this
# is a result tokens, and we can't annotate the tuple.
result_tokens = jax.block_until_ready(result_tokens)
result_tokens = result_tokens.convert_to_numpy()

for slot, request in my_live_requests.items():
Expand Down Expand Up @@ -795,6 +792,44 @@ def _detokenize_thread(self, idx: int):
slot, active_request = data
my_live_requests[slot] = active_request

def _prefill_detokenize_thread(self, idx: int):
"""Detokenize the prefill token and returns it to the user."""
# We don't need to keep a my_live_requests list since the request is
# passed to tranfer then to generate. We could directly detokenize the
# first token from prefill and return it.
my_detokenize_backlog = self._prefill_detokenize_backlogs[idx]
my_prefill_engine = self._prefill_engines[idx]

metadata = my_prefill_engine.get_tokenizer()
tokenizer = my_prefill_engine.build_tokenizer(metadata)
while self.live:
data = my_detokenize_backlog.get(block=True)
if data is None:
break
# prefill first token
if isinstance(data[0], engine_api.ResultTokens):
request_first_token, request, request_start_time = data
request_first_token = jax.block_until_ready(request_first_token)
request_first_token = request_first_token.convert_to_numpy()

results, complete = token_utils.process_result_tokens(
tokenizer=tokenizer,
slot=0, # always 0 as prefill only run 1 sample
slot_max_length=request.max_tokens,
result_tokens=request_first_token,
is_client_side_tokenization=request.is_client_side_tokenization,
complete=request.complete,
)
request.complete = complete
# Return some output samples.
request.enqueue_samples(results)

first_token_return_time = time.perf_counter()
logging.info(
"TTFT duration: %fms",
(first_token_return_time - request_start_time) * 1000,
)


class LLMOrchestrator(jetstream_pb2_grpc.OrchestratorServicer):
"""Coordinates a set of prefill and generate slices for LLM decoding."""
Expand Down

0 comments on commit 20ba07c

Please sign in to comment.