-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performance optimized interleaved mode JetStream server #122
base: main
Are you sure you want to change the base?
Conversation
JoeZijunZhou
commented
Jul 26, 2024
•
edited
Loading
edited
- Optimized TPU duty cycle (largest gap < 4ms)
- Optimized TTFT: dispatch prefill tasks ASAP w/o unnecessary blocking in CPU, keep backpressure to enforce insert ASAP, return first token ASAP.
- Optimized TPOT: properly enforce generate and detokenize task in sequential w/o unnecessary blocking in CPU.
- Optimized output token throughput: properly prioritize prefill and balancing TTFT and decode in high throughput situation.
- Tested with llama2-70b JetStream MaxText server on v5e-8 VM
Optimized TTFT and Optimized output token throughput are conflicted with each. Can we expose some parameter to tuning the two part? |
@@ -316,6 +317,12 @@ def __init__( | |||
queue.Queue(8) | |||
for _ in self._generate_engines | |||
] | |||
self._prefill_detokenize_backlogs = [ | |||
# We don't let detokenization accumulate more than 8 steps to avoid |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate more on why there is synchronization issue after 8 steps?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Set it to 8 as the detokenize thread. Too large or too small will cause performance issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, is this PR ready to submit?
@@ -795,6 +792,44 @@ def _detokenize_thread(self, idx: int): | |||
slot, active_request = data | |||
my_live_requests[slot] = active_request | |||
|
|||
def _prefill_detokenize_thread(self, idx: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I though we already had prefill detokenize thread. Do current jetstream (before this pr) always return prefill token (fist token) after first decode step?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only had detokenize thread that combined prefill detokenize and decode detokenize. The problem is that we have jax.block_until_ready()
blocking the thread waiting for the prefill token or decode token copy to host async, so putting them in 1 thread would make the TTFT slow. JetStream returns prefill token in prefill thread (after prefill step generating the first token).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good, thanks for sharing insights!
Currently, prioritize prefills in interleaved mode, and apply correct JAX blocking for copy to host async to reduce wasted wait time. 1 more optimization to do is to ensure the result returns immediately when the return channel has the result (from orchestrator). |