-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performance optimized interleaved mode JetStream server #122
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -201,6 +201,7 @@ class Driver: | |
# This can be a list because we can pass it as an arg to generate and | ||
# detokenize threads. It is a list of tokens to be detokenized. | ||
_detokenize_backlogs: list[queue.Queue[engine_api.ResultTokens]] = [] | ||
_prefill_detokenize_backlogs: list[queue.Queue[engine_api.ResultTokens]] = [] | ||
_generate_slots: list[queue.Queue[int]] = [] | ||
_active_requests: list[queue.Queue[tuple[int, ActiveRequest]]] = [] | ||
|
||
|
@@ -316,6 +317,12 @@ def __init__( | |
queue.Queue(8) | ||
for _ in self._generate_engines | ||
] | ||
self._prefill_detokenize_backlogs = [ | ||
# We don't let detokenization accumulate more than 8 steps to avoid | ||
# synchronization issues. | ||
queue.Queue(8) | ||
for _ in self._prefill_engines | ||
] | ||
|
||
# A queue of integers representing available 'slots' in the decode | ||
# operation. I.e. potentially available rows in the batch and/or microbatch. | ||
|
@@ -376,12 +383,23 @@ def __init__( | |
) | ||
for idx in range(len(self._generate_engines)) | ||
] | ||
self.prefill_detokenize_threads = [ | ||
JetThread( | ||
target=functools.partial( | ||
self._prefill_detokenize_thread, | ||
idx, | ||
), | ||
name=f"prefill-detokenize-{idx}", | ||
) | ||
for idx in range(len(self._prefill_engines)) | ||
] | ||
self._all_threads = list( | ||
itertools.chain( | ||
self._prefill_threads, | ||
self._transfer_threads, | ||
self._generate_threads, | ||
self.detokenize_threads, | ||
self.prefill_detokenize_threads, | ||
) | ||
) | ||
self.live = True | ||
|
@@ -514,12 +532,12 @@ def _prefill_thread(self, idx: int): | |
padded_tokens=padded_tokens, | ||
true_length=true_length, | ||
) | ||
|
||
first_token.copy_to_host_async() | ||
request.prefill_result = prefill_result | ||
|
||
# put first token to detokenize queue | ||
request.complete = np.zeros((prefill_engine.samples_per_slot,), np.bool_) | ||
my_detokenize_backlog = self._detokenize_backlogs[idx] | ||
my_detokenize_backlog = self._prefill_detokenize_backlogs[idx] | ||
my_detokenize_backlog.put( | ||
(first_token, request, request_start_time), block=True | ||
) | ||
|
@@ -736,34 +754,13 @@ def _detokenize_thread(self, idx: int): | |
if data is None: | ||
break | ||
start_detokenize_time = time.time() | ||
# prefill first token | ||
if isinstance(data[0], engine_api.ResultTokens): | ||
request_first_token, request, request_start_time = data | ||
request_first_token = request_first_token.convert_to_numpy() | ||
|
||
results, complete = token_utils.process_result_tokens( | ||
tokenizer=tokenizer, | ||
slot=0, # always 0 as prefill only run 1 sample | ||
slot_max_length=request.max_tokens, | ||
result_tokens=request_first_token, | ||
is_client_side_tokenization=request.is_client_side_tokenization, | ||
complete=request.complete, | ||
) | ||
request.complete = complete | ||
# Return some output samples. | ||
request.enqueue_samples(results) | ||
|
||
first_token_return_time = time.perf_counter() | ||
logging.info( | ||
"TTFT duration: %fms", | ||
(first_token_return_time - request_start_time) * 1000, | ||
) | ||
# generate step tokens | ||
elif isinstance(data[1], engine_api.ResultTokens): | ||
if isinstance(data[1], engine_api.ResultTokens): | ||
# We want to detokenize them. | ||
generate_timestep_added, result_tokens = data | ||
# Disable attribute error because pytype doesn't know this | ||
# is a result tokens, and we can't annotate the tuple. | ||
result_tokens = jax.block_until_ready(result_tokens) | ||
result_tokens = result_tokens.convert_to_numpy() | ||
|
||
for slot, request in my_live_requests.items(): | ||
|
@@ -795,6 +792,44 @@ def _detokenize_thread(self, idx: int): | |
slot, active_request = data | ||
my_live_requests[slot] = active_request | ||
|
||
def _prefill_detokenize_thread(self, idx: int): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I though we already had prefill detokenize thread. Do current jetstream (before this pr) always return prefill token (fist token) after first decode step? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We only had detokenize thread that combined prefill detokenize and decode detokenize. The problem is that we have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sounds good, thanks for sharing insights! |
||
"""Detokenize the prefill token and returns it to the user.""" | ||
# We don't need to keep a my_live_requests list since the request is | ||
# passed to tranfer then to generate. We could directly detokenize the | ||
# first token from prefill and return it. | ||
my_detokenize_backlog = self._prefill_detokenize_backlogs[idx] | ||
my_prefill_engine = self._prefill_engines[idx] | ||
|
||
metadata = my_prefill_engine.get_tokenizer() | ||
tokenizer = my_prefill_engine.build_tokenizer(metadata) | ||
while self.live: | ||
data = my_detokenize_backlog.get(block=True) | ||
if data is None: | ||
break | ||
# prefill first token | ||
if isinstance(data[0], engine_api.ResultTokens): | ||
request_first_token, request, request_start_time = data | ||
request_first_token = jax.block_until_ready(request_first_token) | ||
request_first_token = request_first_token.convert_to_numpy() | ||
|
||
results, complete = token_utils.process_result_tokens( | ||
tokenizer=tokenizer, | ||
slot=0, # always 0 as prefill only run 1 sample | ||
slot_max_length=request.max_tokens, | ||
result_tokens=request_first_token, | ||
is_client_side_tokenization=request.is_client_side_tokenization, | ||
complete=request.complete, | ||
) | ||
request.complete = complete | ||
# Return some output samples. | ||
request.enqueue_samples(results) | ||
|
||
first_token_return_time = time.perf_counter() | ||
logging.info( | ||
"TTFT duration: %fms", | ||
(first_token_return_time - request_start_time) * 1000, | ||
) | ||
|
||
|
||
class LLMOrchestrator(jetstream_pb2_grpc.OrchestratorServicer): | ||
"""Coordinates a set of prefill and generate slices for LLM decoding.""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate more on why there is synchronization issue after 8 steps?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Set it to 8 as the detokenize thread. Too large or too small will cause performance issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, is this PR ready to submit?