Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance optimized interleaved mode JetStream server #122

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

JoeZijunZhou
Copy link
Collaborator

@JoeZijunZhou JoeZijunZhou commented Jul 26, 2024

  • Optimized TPU duty cycle (largest gap < 4ms)
  • Optimized TTFT: dispatch prefill tasks ASAP w/o unnecessary blocking in CPU, keep backpressure to enforce insert ASAP, return first token ASAP.
  • Optimized TPOT: properly enforce generate and detokenize task in sequential w/o unnecessary blocking in CPU.
  • Optimized output token throughput: properly prioritize prefill and balancing TTFT and decode in high throughput situation.
  • Tested with llama2-70b JetStream MaxText server on v5e-8 VM

@JoeZijunZhou JoeZijunZhou marked this pull request as ready for review July 26, 2024 10:53
@JoeZijunZhou JoeZijunZhou requested a review from vipannalla as a code owner July 26, 2024 10:53
@FanhaiLu1
Copy link
Contributor

  • Optimized TPU duty cycle (largest gap < 4ms)
  • Optimized TTFT: dispatch prefill tasks ASAP w/o unnecessary blocking in CPU, keep backpressure to enforce insert ASAP, return first token ASAP.
  • Optimized TPOT: properly enforce generate and detokenize task in sequential w/o unnecessary blocking in CPU.
  • Optimized output token throughput: properly prioritize prefill and balancing TTFT and decode in high throughput situation.
  • Tested with llama2-70b JetStream MaxText server on v5e-8 VM

Optimized TTFT and Optimized output token throughput are conflicted with each. Can we expose some parameter to tuning the two part?

@@ -316,6 +317,12 @@ def __init__(
queue.Queue(8)
for _ in self._generate_engines
]
self._prefill_detokenize_backlogs = [
# We don't let detokenization accumulate more than 8 steps to avoid
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate more on why there is synchronization issue after 8 steps?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Set it to 8 as the detokenize thread. Too large or too small will cause performance issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, is this PR ready to submit?

@@ -795,6 +792,44 @@ def _detokenize_thread(self, idx: int):
slot, active_request = data
my_live_requests[slot] = active_request

def _prefill_detokenize_thread(self, idx: int):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I though we already had prefill detokenize thread. Do current jetstream (before this pr) always return prefill token (fist token) after first decode step?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only had detokenize thread that combined prefill detokenize and decode detokenize. The problem is that we have jax.block_until_ready() blocking the thread waiting for the prefill token or decode token copy to host async, so putting them in 1 thread would make the TTFT slow. JetStream returns prefill token in prefill thread (after prefill step generating the first token).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good, thanks for sharing insights!

@JoeZijunZhou
Copy link
Collaborator Author

  • Optimized TPU duty cycle (largest gap < 4ms)
  • Optimized TTFT: dispatch prefill tasks ASAP w/o unnecessary blocking in CPU, keep backpressure to enforce insert ASAP, return first token ASAP.
  • Optimized TPOT: properly enforce generate and detokenize task in sequential w/o unnecessary blocking in CPU.
  • Optimized output token throughput: properly prioritize prefill and balancing TTFT and decode in high throughput situation.
  • Tested with llama2-70b JetStream MaxText server on v5e-8 VM

Optimized TTFT and Optimized output token throughput are conflicted with each. Can we expose some parameter to tuning the two part?

Currently, prioritize prefills in interleaved mode, and apply correct JAX blocking for copy to host async to reduce wasted wait time. 1 more optimization to do is to ensure the result returns immediately when the return channel has the result (from orchestrator).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants