-
-
Notifications
You must be signed in to change notification settings - Fork 5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core] Reduce TTFT with concurrent partial prefills #10235
Open
joerunde
wants to merge
59
commits into
vllm-project:main
Choose a base branch
from
opendatahub-io:prefill-slots
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
59 commits
Select commit
Hold shift + click to select a range
f97eacf
:bug: fix multi-chunked-prefill sampler bug
joerunde b50a6b8
🚧 add num_prefill_slots arg
prashantgupta24 7f23c04
:sparkles: start to write prefill slot logic
joerunde d271cc9
🎨 format
prashantgupta24 b2cb96f
:sparkles: update num tokens for prefill slots
joerunde c349ac0
♻️ add schedule_chunked_prefill logic
prashantgupta24 e20518d
♻️ change function name
prashantgupta24 6ba0e34
:sparkles: reserve incoming prefill slots
joerunde a7491cc
🎨 fix some typos
prashantgupta24 1ee6fea
:zap: finish awesome scheduler
joerunde 517915a
:bug: fix the deadlocks
joerunde ed298c3
:memo: Add more docstrings
joerunde 90e0c07
:bug: fix deadlock
joerunde 1c92ac2
:construction: WIP scheduler tests
joerunde de95f62
:bug: fix prefix caching
joerunde 41e20ca
:test_tube: add prefix caching test
joerunde 4dc7310
✅ add second test iteration
prashantgupta24 8e3118e
✅ add llm engine test
prashantgupta24 b6ebec8
♻️ quicker budget check
prashantgupta24 7e93668
🎨 rename to max_num_partial_prefills
prashantgupta24 557bfe3
🎨 more renaming to max_num_partial_prefills + docstring updates
prashantgupta24 d3e94df
🎨 rename big to long
prashantgupta24 849baf6
♻️ add cli args for partial_prefill configs
prashantgupta24 beaf086
🎨 fix request word typo
prashantgupta24 672a50c
🎨 more docstring changes
prashantgupta24 a2751ff
🎨 forgot to add the new args to config
prashantgupta24 dff757d
🐛 fix range bug on partial_prefill_budget_lookup_list
prashantgupta24 86ffa04
🎨 add docstring to test function
prashantgupta24 3d39942
:construction: WIP move metadata to dataclass
joerunde dbb9ae8
🎨 wrap up PartialPrefillMetadata
prashantgupta24 4bac8ed
♻️ add some utility functions within partial_prefill_metadata
prashantgupta24 c44ca1f
🎨 change to long_prefill_token_threshold
prashantgupta24 38bad7a
🔥 remove commented code
prashantgupta24 0f3efa1
🐛 fix the big bug! (Thanks Joe)
prashantgupta24 3daf35f
:memo: docstings galore
joerunde 241853a
🎨 fix typo
prashantgupta24 07b6d72
⏪ revert logging change
prashantgupta24 c4bdf37
✅ remove value error from test
prashantgupta24 7c8b400
✅ remove value error from test
prashantgupta24 21796fc
🎨 fix typo
prashantgupta24 d993861
✅ make test comprehensive
prashantgupta24 946d297
🎨 fix unused vars in test
prashantgupta24 5535515
🎨 some more comments
prashantgupta24 ba91ddf
🎨 fix merge conflict
prashantgupta24 bccf86f
🎨 fmt
prashantgupta24 75848c9
♻️ merge with main
prashantgupta24 1c80379
Merge branch 'main' into prefill-slots
prashantgupta24 4f1c322
🎨 fix fmt
prashantgupta24 cb8fc93
⏪ revert quick budget check
prashantgupta24 8a8a07f
🎨 fmt
prashantgupta24 90a53ab
♻️ merge with main
prashantgupta24 29a7ccd
Merge remote-tracking branch 'upstream/main' into prefill-slots
prashantgupta24 752ce1b
🎨 fmt
prashantgupta24 edc204e
Merge remote-tracking branch 'upstream/main' into prefill-slots
joerunde 0206173
Merge remote-tracking branch 'upstream/main' into prefill-slots
joerunde 80b72ef
Merge remote-tracking branch 'upstream/main' into prefill-slots
joerunde 03525f2
:bug: fix index out of range
joerunde d5f5eb6
:recycle: naming updates
joerunde cb5361a
:bug: fix long prefill threshold init
joerunde File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,9 @@ | |
|
||
from vllm.config import CacheConfig, SchedulerConfig | ||
from vllm.core.scheduler import Scheduler | ||
from vllm.engine.arg_utils import EngineArgs | ||
from vllm.engine.llm_engine import LLMEngine | ||
from vllm.sampling_params import SamplingParams | ||
from vllm.sequence import Logprob, SequenceGroup | ||
|
||
from .utils import create_dummy_prompt | ||
|
@@ -14,7 +17,7 @@ def get_sequence_groups(scheduler_output): | |
return [s.seq_group for s in scheduler_output.scheduled_seq_groups] | ||
|
||
|
||
def append_new_token(seq_group, token_id: int): | ||
def append_new_token(seq_group: SequenceGroup, token_id: int): | ||
for seq in seq_group.get_seqs(): | ||
seq.append_token_id(token_id, {token_id: Logprob(token_id)}) | ||
|
||
|
@@ -121,6 +124,214 @@ def test_chunk(): | |
assert out.num_batched_tokens == 57 | ||
|
||
|
||
def test_concurrent_chunking(): | ||
"""Verify prefills are chunked properly when | ||
--max-num-partial-prefills is > 1""" | ||
block_size = 4 | ||
max_seqs = 60 | ||
max_model_len = 2000 | ||
max_num_batched_tokens = 64 | ||
scheduler_config = SchedulerConfig( | ||
"generate", | ||
max_num_batched_tokens, | ||
max_seqs, | ||
max_model_len, | ||
enable_chunked_prefill=True, | ||
max_num_partial_prefills=2, # Up to 2 partial prefills at a time | ||
) | ||
cache_config = CacheConfig(block_size, 1.0, 1, "auto") | ||
cache_config.num_cpu_blocks = 32 | ||
cache_config.num_gpu_blocks = 32 | ||
scheduler = Scheduler(scheduler_config, cache_config, None) | ||
running: List[SequenceGroup] = [] | ||
|
||
# Add seq groups to scheduler. | ||
for i in range(2): | ||
_, seq_group = create_dummy_prompt(str(i), | ||
prompt_length=60, | ||
block_size=block_size) | ||
scheduler.add_seq_group(seq_group) | ||
running.append(seq_group) | ||
|
||
# Verify both requests are chunked with half of max_num_batched_tokens each | ||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) | ||
assert set(get_sequence_groups(out)) == set(running) | ||
assert seq_group_meta[0].token_chunk_size == 32 | ||
assert seq_group_meta[1].token_chunk_size == 32 | ||
assert out.num_prefill_groups == 2 | ||
assert out.num_batched_tokens == 64 | ||
|
||
# After one iteration, both should have 60 - 32 = 28 tokens left to prefill | ||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) | ||
assert set(get_sequence_groups(out)) == set(running) | ||
assert seq_group_meta[0].token_chunk_size == 28 | ||
assert seq_group_meta[1].token_chunk_size == 28 | ||
assert out.num_prefill_groups == 2 | ||
assert out.num_batched_tokens == 56 | ||
|
||
|
||
def test_concurrent_chunking_large_requests(): | ||
"""Verify large prefill requests are run one at a time""" | ||
block_size = 4 | ||
max_seqs = 60 | ||
max_model_len = 2000 | ||
max_num_batched_tokens = 64 | ||
scheduler_config = SchedulerConfig( | ||
"generate", | ||
max_num_batched_tokens, | ||
max_seqs, | ||
max_model_len, | ||
enable_chunked_prefill=True, | ||
max_num_partial_prefills=2, # Up to 2 partial prefills at a time | ||
) | ||
cache_config = CacheConfig(block_size, 1.0, 1, "auto") | ||
cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests | ||
cache_config.num_gpu_blocks = 3200 | ||
scheduler = Scheduler(scheduler_config, cache_config, None) | ||
|
||
# Add seq groups to scheduler. | ||
for i in range(2): | ||
_, seq_group = create_dummy_prompt( | ||
str(i), | ||
prompt_length=1200, # Very large prompt | ||
block_size=block_size) | ||
scheduler.add_seq_group(seq_group) | ||
|
||
# Verify only a single request is chunked, and it gets all 64 tokens | ||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) | ||
assert len(get_sequence_groups(out)) == 1 | ||
assert seq_group_meta[0].token_chunk_size == 64 | ||
assert out.num_prefill_groups == 1 | ||
assert out.num_batched_tokens == 64 | ||
|
||
|
||
def test_short_prompts_jump_long_prompts_in_queue(): | ||
"""Verify large prefill requests are punted behind smaller ones if | ||
another large prefill request is already running""" | ||
block_size = 4 | ||
max_seqs = 60 | ||
max_model_len = 2000 | ||
max_num_batched_tokens = 64 | ||
scheduler_config = SchedulerConfig( | ||
"generate", | ||
max_num_batched_tokens, | ||
max_seqs, | ||
max_model_len, | ||
enable_chunked_prefill=True, | ||
max_num_partial_prefills=2, # Up to 2 partial prefills at a time | ||
) | ||
cache_config = CacheConfig(block_size, 1.0, 1, "auto") | ||
cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests | ||
cache_config.num_gpu_blocks = 3200 | ||
scheduler = Scheduler(scheduler_config, cache_config, None) | ||
running: List[SequenceGroup] = [] | ||
|
||
# Add 2 large seq groups to scheduler. | ||
for i in range(2): | ||
_, seq_group = create_dummy_prompt( | ||
str(i), | ||
prompt_length=1200, # Very large prompt | ||
block_size=block_size) | ||
scheduler.add_seq_group(seq_group) | ||
running.append(seq_group) | ||
assert seq_group.is_prefill() | ||
|
||
# Add 2 small seq groups behind them | ||
for i in range(2): | ||
_, seq_group = create_dummy_prompt( | ||
str(i + 2), | ||
prompt_length=40, # Very small prompt | ||
block_size=block_size) | ||
scheduler.add_seq_group(seq_group) | ||
running.append(seq_group) | ||
assert seq_group.is_prefill() | ||
|
||
# Verify one large req and 1 small req chunked | ||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) | ||
assert seq_group_meta[0].token_chunk_size == 32 # large req gets 32 tokens | ||
assert seq_group_meta[1].token_chunk_size == 32 # small req gets 32 tokens | ||
|
||
# all 4 are prefilling | ||
assert running[0].is_prefill() | ||
assert running[1].is_prefill() | ||
assert running[2].is_prefill() | ||
assert running[3].is_prefill() | ||
|
||
assert out.num_prefill_groups == 2 | ||
assert out.num_batched_tokens == 64 | ||
|
||
# in the second iteration, | ||
# the first small request had only 8 tokens left | ||
# so it went to decode | ||
# The other small req is scheduled | ||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) | ||
# the new small req got 64 - (32+8) tokens | ||
assert (seq_group_meta[0].token_chunk_size == 24) | ||
assert seq_group_meta[1].token_chunk_size == 32 # large req still got 32 | ||
# the other small request had only 8 tokens left | ||
assert seq_group_meta[2].token_chunk_size == 8 # 40-32 | ||
|
||
# notice the small request got to decode now | ||
# this is because of max_num_partial_prefills logic | ||
assert running[0].is_prefill() | ||
assert running[1].is_prefill() | ||
assert not running[2].is_prefill() | ||
assert running[3].is_prefill() | ||
|
||
assert out.num_prefill_groups == 3 | ||
assert out.num_batched_tokens == 64 | ||
# the small seq group has a new token appended. | ||
append_new_token(running[2], 1) | ||
|
||
# in the third iteration, | ||
# the first small request has entered decode | ||
# and other small req had 16 tokens left | ||
# so it went to decode | ||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) | ||
assert seq_group_meta[0].token_chunk_size == 32 # large still got 32 | ||
# small req prefilled 40-24=16 tokens | ||
assert (seq_group_meta[1].token_chunk_size == 16) | ||
assert seq_group_meta[2].token_chunk_size == 1 # decode | ||
assert out.num_prefill_groups == 2 | ||
assert out.num_batched_tokens == 49 # (32+16+1 decode) | ||
|
||
# both small requests have now reached decode | ||
assert running[0].is_prefill() | ||
assert running[1].is_prefill() | ||
assert not running[2].is_prefill() | ||
assert not running[3].is_prefill() | ||
|
||
# the small seq group has a new token appended. | ||
append_new_token(running[2], 1) | ||
|
||
# in the fourth iteration, both small requests are decoding | ||
# so large request gets all the budget | ||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) | ||
# large req gets 63 tokens (minus 1 for decode) | ||
assert seq_group_meta[0].token_chunk_size == 63 | ||
assert seq_group_meta[1].token_chunk_size == 1 # decode | ||
assert out.num_prefill_groups == 1 | ||
assert out.num_batched_tokens == 64 | ||
|
||
assert running[0].is_prefill() | ||
assert running[1].is_prefill() | ||
assert not running[2].is_prefill() | ||
assert not running[3].is_prefill() | ||
|
||
# both the small seq groups have a new token appended | ||
append_new_token(running[2], 1) | ||
append_new_token(running[3], 1) | ||
|
||
# in the fifth iteration, large request gets all the budget | ||
# while both small requests are decoding | ||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) | ||
assert seq_group_meta[0].token_chunk_size == 62 | ||
assert seq_group_meta[1].token_chunk_size == 1 # decode | ||
assert seq_group_meta[2].token_chunk_size == 1 # decode | ||
assert out.num_prefill_groups == 1 | ||
assert out.num_batched_tokens == 64 | ||
|
||
|
||
def test_complex(): | ||
block_size = 4 | ||
max_seqs = 60 | ||
|
@@ -506,7 +717,7 @@ def test_chunked_prefill_max_seqs(): | |
assert not running[1].is_prefill() | ||
|
||
|
||
def test_perfix_caching(): | ||
def test_prefix_caching(): | ||
"""Verify allocating full blocks when prefix caching is enabled.""" | ||
block_size = 4 | ||
max_seqs = 10 | ||
|
@@ -546,3 +757,86 @@ def test_perfix_caching(): | |
assert seq_group_meta[1].token_chunk_size == 12 | ||
assert out.num_prefill_groups == 2 | ||
assert out.num_batched_tokens == 62 | ||
|
||
|
||
def test_prefix_caching_with_concurrent_partial_prefills(): | ||
"""Verify allocating full blocks when prefix caching is enabled with | ||
--max-num-partial-prefills > 1.""" | ||
block_size = 4 | ||
max_seqs = 10 | ||
max_model_len = 8000 | ||
max_num_batched_tokens = 60 # With two slots, each slot will get 30 tokens | ||
scheduler_config = SchedulerConfig("generate", | ||
max_num_batched_tokens, | ||
max_seqs, | ||
max_model_len, | ||
enable_chunked_prefill=True, | ||
max_num_partial_prefills=2) | ||
cache_config = CacheConfig(block_size, | ||
1.0, | ||
1, | ||
"auto", | ||
enable_prefix_caching=True) | ||
cache_config.num_cpu_blocks = 0 | ||
cache_config.num_gpu_blocks = 32 | ||
scheduler = Scheduler(scheduler_config, cache_config, None) | ||
running: List[SequenceGroup] = [] | ||
|
||
# Add seq groups to scheduler. | ||
for i in range(2): | ||
_, seq_group = create_dummy_prompt(str(i), | ||
block_size=block_size, | ||
prompt_length=50) | ||
scheduler.add_seq_group(seq_group) | ||
running.append(seq_group) | ||
|
||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) | ||
assert set(get_sequence_groups(out)) == set(running) | ||
# To partially prefill both sequences, both can chunk up to 30 tokens | ||
# But the next lowest multiple of the block size (4) is 28 | ||
assert seq_group_meta[0].token_chunk_size == 28 | ||
assert seq_group_meta[1].token_chunk_size == 28 | ||
assert out.num_prefill_groups == 2 | ||
assert out.num_batched_tokens == 56 | ||
|
||
# On the next iteration, both sequences should finish prefill | ||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) | ||
assert set(get_sequence_groups(out)) == set(running) | ||
# Both sequences have 50 - 28 = 22 tokens left to prefill. | ||
# This is not a multiple of the block size, but we don't care since we don't | ||
# cache the final partial block of prefix sequences | ||
assert seq_group_meta[0].token_chunk_size == 22 | ||
assert seq_group_meta[1].token_chunk_size == 22 | ||
assert out.num_prefill_groups == 2 | ||
assert out.num_batched_tokens == 44 | ||
|
||
|
||
@pytest.mark.parametrize("model", ["facebook/opt-125m"]) | ||
@pytest.mark.parametrize("max_num_partial_prefills", [2, 4, 8]) | ||
def test_chunked_prefill_with_actual_engine(model: str, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @rickyyx here's what we tried to do to test that the sampler doesn't throw any assertions- we put multiple prompts into an engine and manually step it forward with them all partially prefilled |
||
max_num_partial_prefills: int): | ||
"""Make sure the model can actually sample with concurrent | ||
partial prefills | ||
""" | ||
|
||
prompt = "hello" * 40 | ||
|
||
engine_args = EngineArgs( | ||
model=model, | ||
max_num_partial_prefills=max_num_partial_prefills, | ||
max_num_batched_tokens=40, | ||
max_num_seqs=8, | ||
enable_chunked_prefill=True, | ||
gpu_memory_utilization=0.8, | ||
) | ||
|
||
engine = LLMEngine.from_engine_args(engine_args) | ||
sampling_params = SamplingParams(temperature=0) | ||
|
||
for req_num in range(max_num_partial_prefills): | ||
engine.add_request(f"{req_num}", prompt, sampling_params) | ||
# first step | ||
request_outputs = engine.step() | ||
# means all are prefilling | ||
assert len(request_outputs) == 0 | ||
assert len(engine.scheduler[0].running) == max_num_partial_prefills |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if this is a bug, but at this stage, request#3 should be decoding, but it didn't get any budget. Request#2 got budget for
1
decode token, and request#0 got the remaining budget for prefilling63
tokens. Is that expected?Based on this comment,
vllm should have prioritized decode requests and given both request#2 and request#3
1
budget, and request#062
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does happen in the next iteration though