From 80ca1e6a3a28a0373dc00c5b4fe956c16de952fa Mon Sep 17 00:00:00 2001 From: sroy745 <142070531+sroy745@users.noreply.github.com> Date: Mon, 1 Jul 2024 00:33:05 -0700 Subject: [PATCH] [Speculative Decoding 2/2 ] Integrate typical acceptance sampler into Spec Decode Worker (#5348) --- .../test_typical_acceptance_sampler.py | 96 +++++++---- .../e2e/test_multistep_correctness.py | 54 +++++- tests/spec_decode/test_dynamic_spec_decode.py | 12 +- tests/spec_decode/test_metrics.py | 94 +++++------ tests/spec_decode/test_spec_decode_worker.py | 154 ++++++++++-------- tests/spec_decode/test_utils.py | 22 +++ vllm/config.py | 75 ++++++++- vllm/engine/arg_utils.py | 42 ++++- vllm/engine/metrics.py | 2 +- .../layers/rejection_sampler.py | 18 +- .../layers/spec_decode_base_sampler.py | 15 +- .../layers/typical_acceptance_sampler.py | 22 ++- vllm/spec_decode/metrics.py | 24 +-- vllm/spec_decode/spec_decode_worker.py | 62 ++++--- 14 files changed, 482 insertions(+), 210 deletions(-) diff --git a/tests/samplers/test_typical_acceptance_sampler.py b/tests/samplers/test_typical_acceptance_sampler.py index 87cf37bc926bc..4f6290795b2ce 100644 --- a/tests/samplers/test_typical_acceptance_sampler.py +++ b/tests/samplers/test_typical_acceptance_sampler.py @@ -52,6 +52,19 @@ def get_draft_token_ids(batch_size: int, k: int, vocab_size: int, return draft_token_ids +def get_acceptance_sampler( + posterior_threshold: float = 0.03, + posterior_alpha: float = 0.9, + disable_bonus_tokens: bool = False, + strict_mode: bool = False, +) -> TypicalAcceptanceSampler: + """ + Initializes and returns a TypicalAcceptanceSampler. + """ + return TypicalAcceptanceSampler(posterior_threshold, posterior_alpha, + disable_bonus_tokens, strict_mode) + + @pytest.mark.parametrize("k", list(range(1, 6))) @pytest.mark.parametrize("vocab_size", [30_000, 50_000]) @pytest.mark.parametrize("batch_size", list(range(1, 32))) @@ -64,7 +77,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, different combinations of k, vocab_size, batch_size and num devices. """ torch.set_default_device(device) - typical_acceptance_sampler = TypicalAcceptanceSampler() + typical_acceptance_sampler = get_acceptance_sampler() typical_acceptance_sampler.init_gpu_tensors(rank=0) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) bonus_token_ids = torch.randint(low=0, @@ -76,7 +89,10 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, size=(batch_size, k), dtype=torch.int64) # Verify that sampling succeeds for all cases. - typical_acceptance_sampler(target_probs, bonus_token_ids, draft_token_ids) + typical_acceptance_sampler(target_probs, + bonus_token_ids, + draft_probs=None, + draft_token_ids=draft_token_ids) @pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"]) @@ -94,7 +110,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str, batch_size = 5 vocab_size = 30_000 torch.set_default_device(device) - typical_acceptance_sampler = TypicalAcceptanceSampler(strict_mode=True) + typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True) typical_acceptance_sampler.init_gpu_tensors(rank=0) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) bonus_token_ids = torch.randint(low=0, @@ -125,8 +141,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str, oob_token_ids[0][0] = rogue_token_id with pytest.raises(AssertionError): - typical_acceptance_sampler(target_probs, bonus_token_ids, - draft_token_ids) + typical_acceptance_sampler(target_probs, + bonus_token_ids, + draft_probs=None, + draft_token_ids=draft_token_ids) @pytest.mark.parametrize("seed", list(range(10))) @@ -151,7 +169,7 @@ def test_uniform_target_distribution_accepts_all_tokens( batch_size = 5 vocab_size = 30_000 torch.set_default_device(device) - typical_acceptance_sampler = TypicalAcceptanceSampler( + typical_acceptance_sampler = get_acceptance_sampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) typical_acceptance_sampler.init_gpu_tensors(rank=0) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) @@ -163,9 +181,11 @@ def test_uniform_target_distribution_accepts_all_tokens( high=vocab_size, size=(batch_size, 1), dtype=torch.int64) - output_token_ids = typical_acceptance_sampler(target_probs, - bonus_token_ids, - draft_token_ids) + output_token_ids = typical_acceptance_sampler( + target_probs, + bonus_token_ids, + draft_probs=None, + draft_token_ids=draft_token_ids) # We are using a uniform target probability distribution. # For a uniform distribution the entropy is very high and it # should lead to all draft tokens being accepted. Verify that. @@ -203,7 +223,7 @@ def test_temperature_zero_target_distribution(seed: int, vocab_size = 30_000 torch.set_default_device(device) - typical_acceptance_sampler = TypicalAcceptanceSampler( + typical_acceptance_sampler = get_acceptance_sampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) typical_acceptance_sampler.init_gpu_tensors(rank=0) # Simulate temperature 0 probability distribution for target probabilities @@ -224,9 +244,11 @@ def test_temperature_zero_target_distribution(seed: int, # 1.0 tokens in the target distribution we will reject all of them and # fallback to the greedy sampling for selecting 1 token for each sequence. # Verify the same. - output_token_ids = typical_acceptance_sampler(target_probs, - bonus_token_ids, - draft_token_ids) + output_token_ids = typical_acceptance_sampler( + target_probs, + bonus_token_ids, + draft_probs=None, + draft_token_ids=draft_token_ids) assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[1] == (k + 1) assert torch.all(output_token_ids[:, -1] == -1) @@ -261,7 +283,7 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool, batch_size = 4 vocab_size = 30_000 torch.set_default_device(device) - typical_acceptance_sampler = TypicalAcceptanceSampler( + typical_acceptance_sampler = get_acceptance_sampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) typical_acceptance_sampler.init_gpu_tensors(rank=0) # For sequences 0 and 2 set the distribution to a temperature @@ -277,9 +299,11 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool, high=vocab_size, size=(batch_size, 1), dtype=torch.int64) - output_token_ids = typical_acceptance_sampler(target_probs, - bonus_token_ids, - draft_token_ids) + output_token_ids = typical_acceptance_sampler( + target_probs, + bonus_token_ids, + draft_probs=None, + draft_token_ids=draft_token_ids) # verify the shape of output_token_ids assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[1] == (k + 1) @@ -326,7 +350,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, batch_size = 1 vocab_size = 30_000 torch.set_default_device(device) - typical_acceptance_sampler = TypicalAcceptanceSampler( + typical_acceptance_sampler = get_acceptance_sampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) typical_acceptance_sampler.init_gpu_tensors(rank=0) # Create a temperature zero target probability distribution and ensure @@ -339,9 +363,11 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, high=vocab_size, size=(batch_size, 1), dtype=torch.int64) - output_token_ids = typical_acceptance_sampler(target_probs, - bonus_token_ids, - draft_token_ids) + output_token_ids = typical_acceptance_sampler( + target_probs, + bonus_token_ids, + draft_probs=None, + draft_token_ids=draft_token_ids) assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[1] == (k + 1) assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids) @@ -357,9 +383,11 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, batch_size, k, vocab_size, zero_temperature_token_ids) draft_token_ids = torch.cat( (draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1) - output_token_ids = typical_acceptance_sampler(target_probs, - bonus_token_ids, - draft_token_ids) + output_token_ids = typical_acceptance_sampler( + target_probs, + bonus_token_ids, + draft_probs=None, + draft_token_ids=draft_token_ids) assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[1] == (k + 1) assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2]) @@ -384,7 +412,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, batch_size = 1 vocab_size = 30_000 torch.set_default_device(device) - typical_acceptance_sampler = TypicalAcceptanceSampler( + typical_acceptance_sampler = get_acceptance_sampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) typical_acceptance_sampler.init_gpu_tensors(rank=0) # Simulate temperature 0 probability distribution for target @@ -402,9 +430,11 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, high=vocab_size, size=(batch_size, 1), dtype=torch.int64) - output_token_ids = typical_acceptance_sampler(target_probs, - bonus_token_ids, - draft_token_ids) + output_token_ids = typical_acceptance_sampler( + target_probs, + bonus_token_ids, + draft_probs=None, + draft_token_ids=draft_token_ids) assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[1] == (k + 1) assert torch.all(output_token_ids[:, 1:-1] == -1) @@ -418,9 +448,11 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, posterior_threshold=0.0, posterior_alpha=0.0) typical_acceptance_sampler.init_gpu_tensors(rank=0) - output_token_ids = typical_acceptance_sampler(target_probs, - bonus_token_ids, - draft_token_ids) + output_token_ids = typical_acceptance_sampler( + target_probs, + bonus_token_ids, + draft_probs=None, + draft_token_ids=draft_token_ids) assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[1] == (k + 1) assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids) @@ -451,7 +483,7 @@ def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool, batch_size = 5 vocab_size = 30_000 torch.set_default_device(device) - typical_acceptance_sampler = TypicalAcceptanceSampler( + typical_acceptance_sampler = get_acceptance_sampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) typical_acceptance_sampler.init_gpu_tensors(rank=0) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index 94d71fb012727..94cc36f22875a 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -11,9 +11,15 @@ numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy equality. This gives us good coverage of temp=0. +At temp=0, the TypicalAcceptanceSampler ensures that only the tokens with the +highest probability in the target distribution are accepted. Therefore, we can +expect greedy equality for the TypicalAcceptanceSampler at temp=0. + For temp>0, we rely on unit tests on the rejection sampler to verify that the output distribution is the same with spec decode vs. no spec decode (this would -be prohibitively expensive to run with a real model). +be prohibitively expensive to run with a real model). Similarly, for the +TypicalAcceptance sampler also, we rely on unit tests to validate temp>0 +test cases. NOTE: Speculative decoding's distribution equality requires that the measured distributions of the target model and proposal model be deterministic given the @@ -611,3 +617,49 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int, batch_size, max_output_len=output_len, force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-160m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": k, + "spec_decoding_acceptance_method": "typical_acceptance_sampler" + } + # Try a range of common k. + for k in [1, 2, 3] + ]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_typical_acceptance_sampling(baseline_llm_generator, + test_llm_generator, batch_size: int, + output_len: int): + """Verify that speculative decoding produces exact equality to without spec + decode with TypicalAcceptanceSampler as the draft token acceptance + sampling method. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) diff --git a/tests/spec_decode/test_dynamic_spec_decode.py b/tests/spec_decode/test_dynamic_spec_decode.py index bb6d1c23a0039..29ed96999cb4c 100644 --- a/tests/spec_decode/test_dynamic_spec_decode.py +++ b/tests/spec_decode/test_dynamic_spec_decode.py @@ -3,33 +3,35 @@ import pytest import torch -from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.sequence import ExecuteModelRequest from vllm.spec_decode.metrics import AsyncMetricsCollector from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker from vllm.spec_decode.top1_proposer import Top1Proposer +from .test_utils import mock_spec_decode_sampler from .utils import create_batch, mock_worker @pytest.mark.parametrize('queue_size', [4]) @pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('k', [1]) +@pytest.mark.parametrize("acceptance_sampler_method", + ["rejection_sampler", "typical_acceptance_sampler"]) @torch.inference_mode() -def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int): +def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int, + acceptance_sampler_method: str): """Verify that speculative tokens are disabled when the batch size exceeds the threshold. """ disable_by_batch_size = 3 - draft_worker = mock_worker(cls=MultiStepWorker) target_worker = mock_worker() - rejection_sampler = MagicMock(spec=RejectionSampler) metrics_collector = MagicMock(spec=AsyncMetricsCollector) worker = SpecDecodeWorker(proposer_worker=draft_worker, scorer_worker=target_worker, - rejection_sampler=rejection_sampler, + spec_decode_sampler=mock_spec_decode_sampler( + acceptance_sampler_method), metrics_collector=metrics_collector, disable_by_batch_size=disable_by_batch_size) diff --git a/tests/spec_decode/test_metrics.py b/tests/spec_decode/test_metrics.py index 312878804b86e..2918fabddc900 100644 --- a/tests/spec_decode/test_metrics.py +++ b/tests/spec_decode/test_metrics.py @@ -10,16 +10,16 @@ def test_initial_call_returns_none(): """Expect first call to get metrics to return None. """ - rej_sampler = MagicMock() - rej_sampler.num_accepted_tokens = torch.tensor(0, - dtype=torch.long, - device='cuda') - rej_sampler.num_emitted_tokens = torch.tensor(0, - dtype=torch.long, - device='cuda') - rej_sampler.num_draft_tokens = 0 - - collector = AsyncMetricsCollector(rej_sampler) + spec_decode_sampler = MagicMock() + spec_decode_sampler.num_accepted_tokens = torch.tensor(0, + dtype=torch.long, + device='cuda') + spec_decode_sampler.num_emitted_tokens = torch.tensor(0, + dtype=torch.long, + device='cuda') + spec_decode_sampler.num_draft_tokens = 0 + + collector = AsyncMetricsCollector(spec_decode_sampler) collector.init_gpu_tensors(rank=0) maybe_metrics = collector.maybe_collect_rejsample_metrics(k=5) assert maybe_metrics is None @@ -28,14 +28,14 @@ def test_initial_call_returns_none(): def test_second_call_returns_metrics(): """Expect second call to not return None. """ - rej_sampler = MagicMock() - rej_sampler.num_accepted_tokens = torch.tensor(0, - dtype=torch.long, - device='cuda') - rej_sampler.num_emitted_tokens = torch.tensor(0, - dtype=torch.long, - device='cuda') - rej_sampler.num_draft_tokens = 0 + spec_decode_sampler = MagicMock() + spec_decode_sampler.num_accepted_tokens = torch.tensor(0, + dtype=torch.long, + device='cuda') + spec_decode_sampler.num_emitted_tokens = torch.tensor(0, + dtype=torch.long, + device='cuda') + spec_decode_sampler.num_draft_tokens = 0 collect_interval_s = 5.0 timer = MagicMock() @@ -43,7 +43,7 @@ def test_second_call_returns_metrics(): 0.0, collect_interval_s + 0.1, collect_interval_s + 0.2 ] - collector = AsyncMetricsCollector(rejection_sampler=rej_sampler, + collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler, timer=timer, collect_interval_s=collect_interval_s) collector.init_gpu_tensors(rank=0) @@ -56,16 +56,16 @@ def test_second_call_returns_metrics(): def test_nonzero_rank_noop(rank): """Verify nonzero ranks don't collect metrics. """ - rej_sampler = MagicMock() - rej_sampler.num_accepted_tokens = torch.tensor(0, - dtype=torch.long, - device='cuda') - rej_sampler.num_emitted_tokens = torch.tensor(0, - dtype=torch.long, - device='cuda') - rej_sampler.num_draft_tokens = 0 - - collector = AsyncMetricsCollector(rej_sampler) + spec_decode_sampler = MagicMock() + spec_decode_sampler.num_accepted_tokens = torch.tensor(0, + dtype=torch.long, + device='cuda') + spec_decode_sampler.num_emitted_tokens = torch.tensor(0, + dtype=torch.long, + device='cuda') + spec_decode_sampler.num_draft_tokens = 0 + + collector = AsyncMetricsCollector(spec_decode_sampler) collector.init_gpu_tensors(rank=rank) _ = collector.maybe_collect_rejsample_metrics(k=5) metrics = collector.maybe_collect_rejsample_metrics(k=5) @@ -75,14 +75,14 @@ def test_nonzero_rank_noop(rank): def test_noop_until_time(): """Verify metrics aren't collected until enough time passes. """ - rej_sampler = MagicMock() - rej_sampler.num_accepted_tokens = torch.tensor(0, - dtype=torch.long, - device='cuda') - rej_sampler.num_emitted_tokens = torch.tensor(0, - dtype=torch.long, - device='cuda') - rej_sampler.num_draft_tokens = 0 + spec_decode_sampler = MagicMock() + spec_decode_sampler.num_accepted_tokens = torch.tensor(0, + dtype=torch.long, + device='cuda') + spec_decode_sampler.num_emitted_tokens = torch.tensor(0, + dtype=torch.long, + device='cuda') + spec_decode_sampler.num_draft_tokens = 0 collect_interval_s = 5.0 timer = MagicMock() @@ -91,7 +91,7 @@ def test_noop_until_time(): collect_interval_s + 0.1, collect_interval_s + 0.1 ] - collector = AsyncMetricsCollector(rejection_sampler=rej_sampler, + collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler, timer=timer, collect_interval_s=collect_interval_s) collector.init_gpu_tensors(rank=0) @@ -122,14 +122,14 @@ def test_initial_metrics_has_correct_values(has_data: bool): max_num_emitted_tokens = AsyncMetricsCollector.get_max_num_emitted_tokens( num_draft_tokens, k) - rej_sampler = MagicMock() - rej_sampler.num_accepted_tokens = torch.tensor(num_accepted_tokens, - dtype=torch.long, - device='cuda') - rej_sampler.num_emitted_tokens = torch.tensor(num_emitted_tokens, - dtype=torch.long, - device='cuda') - rej_sampler.num_draft_tokens = num_draft_tokens + spec_decode_sampler = MagicMock() + spec_decode_sampler.num_accepted_tokens = torch.tensor(num_accepted_tokens, + dtype=torch.long, + device='cuda') + spec_decode_sampler.num_emitted_tokens = torch.tensor(num_emitted_tokens, + dtype=torch.long, + device='cuda') + spec_decode_sampler.num_draft_tokens = num_draft_tokens collect_interval_s = 5.0 timer = MagicMock() @@ -137,7 +137,7 @@ def test_initial_metrics_has_correct_values(has_data: bool): 0.0, collect_interval_s + 0.1, collect_interval_s + 0.2 ] - collector = AsyncMetricsCollector(rejection_sampler=rej_sampler, + collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler, timer=timer, collect_interval_s=collect_interval_s) collector.init_gpu_tensors(rank=0) diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index a20c793c9bfd7..527e7eddd7e33 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -6,7 +6,6 @@ import pytest import torch -from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.utils import set_random_seed from vllm.sequence import ExecuteModelRequest, SamplerOutput, SequenceOutput from vllm.spec_decode.interfaces import SpeculativeProposals @@ -16,23 +15,26 @@ from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker, split_num_cache_blocks_evenly) +from .test_utils import mock_spec_decode_sampler from .utils import create_batch, create_sampler_output_list, mock_worker @pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('batch_size', [1, 2, 32]) +@pytest.mark.parametrize("acceptance_sampler_method", + ["rejection_sampler", "typical_acceptance_sampler"]) @torch.inference_mode() -def test_correctly_calls_draft_model(k: int, batch_size: int): +def test_correctly_calls_draft_model(k: int, batch_size: int, + acceptance_sampler_method: str): """Verify SpecDecodeWorker calls the draft worker with correct inputs. Everything else is mocked out. """ draft_worker = mock_worker(cls=MultiStepWorker) target_worker = mock_worker() - rejection_sampler = MagicMock(spec=RejectionSampler) metrics_collector = MagicMock(spec=AsyncMetricsCollector) - worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, - metrics_collector) - + worker = SpecDecodeWorker( + draft_worker, target_worker, + mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) exception_secret = 'artificial stop' draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) @@ -53,15 +55,16 @@ def test_correctly_calls_draft_model(k: int, batch_size: int): @pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('batch_size', [1, 2, 32]) +@pytest.mark.parametrize("acceptance_sampler_method", + ["rejection_sampler", "typical_acceptance_sampler"]) @torch.inference_mode() -def test_correctly_calls_target_model(k: int, batch_size: int): +def test_correctly_calls_target_model(k: int, batch_size: int, + acceptance_sampler_method: str): """Verify SpecDecodeWorker calls the target model with correct inputs. Everything else is mocked out. """ draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False) target_worker = mock_worker(use_spec=False) - rejection_sampler = MagicMock(spec=RejectionSampler) - rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) draft_worker.device = 'cuda' @@ -69,8 +72,9 @@ def test_correctly_calls_target_model(k: int, batch_size: int): set_random_seed(1) - worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, - metrics_collector) + worker = SpecDecodeWorker( + draft_worker, target_worker, + mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) worker.init_device() vocab_size = 32_000 @@ -133,8 +137,11 @@ def test_correctly_calls_target_model(k: int, batch_size: int): @pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('batch_size', [1, 2, 32]) +@pytest.mark.parametrize("acceptance_sampler_method", + ["rejection_sampler", "typical_acceptance_sampler"]) @torch.inference_mode() -def test_correctly_calls_rejection_sampler(k: int, batch_size: int): +def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int, + acceptance_sampler_method: str): """Verify SpecDecodeWorker calls the rejection sampler with correct inputs. Everything else is mocked out. """ @@ -144,15 +151,14 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): vocab_size=vocab_size, use_spec=False) target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) - rejection_sampler = MagicMock(spec=RejectionSampler) - rejection_sampler.token_id_dtype = torch.int64 + spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method) metrics_collector = MagicMock(spec=AsyncMetricsCollector) draft_worker.device = 'cuda' target_worker.device = 'cuda' set_random_seed(1) - worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, + worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, metrics_collector) worker.init_device() @@ -199,15 +205,16 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): target_worker.execute_model.return_value = [target_output[0]] exception_secret = 'artificial stop' - rejection_sampler.side_effect = ValueError(exception_secret) + + spec_decode_sampler.side_effect = ValueError(exception_secret) with pytest.raises(ValueError, match=exception_secret): worker.execute_model(execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)) - assert len(rejection_sampler.call_args_list) == 1 - _, kwargs = rejection_sampler.call_args_list[0] + assert len(spec_decode_sampler.call_args_list) == 1 + _, kwargs = spec_decode_sampler.call_args_list[0] actual = SimpleNamespace(**kwargs) assert torch.equal(actual.bonus_token_ids, @@ -221,8 +228,11 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): @pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('batch_size', [1, 2, 32]) +@pytest.mark.parametrize("acceptance_sampler_method", + ["rejection_sampler", "typical_acceptance_sampler"]) @torch.inference_mode() -def test_correctly_formats_output(k: int, batch_size: int): +def test_correctly_formats_output(k: int, batch_size: int, + acceptance_sampler_method: str): """Verify SpecDecodeWorker formats sampler output correctly. Everything else is mocked out. """ @@ -232,15 +242,13 @@ def test_correctly_formats_output(k: int, batch_size: int): vocab_size=vocab_size, use_spec=False) target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) - rejection_sampler = MagicMock(spec=RejectionSampler) - rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) draft_worker.device = 'cuda' target_worker.device = 'cuda' set_random_seed(1) - - worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, + spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method) + worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, metrics_collector) worker.init_device() @@ -286,24 +294,23 @@ def test_correctly_formats_output(k: int, batch_size: int): target_worker.execute_model.return_value = [target_output[0]] - rejection_sampler_output = torch.randint(low=0, - high=vocab_size, - size=(batch_size, k + 1), - dtype=torch.int64, - device='cuda') + spec_decode_sampler_output = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k + 1), + dtype=torch.int64, + device='cuda') for i in range(batch_size): minimum_accepted_tokens = 1 - rejection_sampler_output[i][ + spec_decode_sampler_output[i][ -random.randint(minimum_accepted_tokens, k + 1):] = -1 - rejection_sampler.return_value = rejection_sampler_output - + spec_decode_sampler.return_value = spec_decode_sampler_output output = worker.execute_model(execute_model_req=ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)) expected_output = create_sampler_output_list( - token_ids=rejection_sampler_output.transpose(0, 1), + token_ids=spec_decode_sampler_output.transpose(0, 1), probs=[None for _ in range(k + 1)], logprobs=[None for _ in range(k + 1)]) @@ -350,8 +357,11 @@ def test_correctly_formats_output(k: int, batch_size: int): @pytest.mark.parametrize('k', [1, 2]) @pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('returns_metrics', [True, False]) +@pytest.mark.parametrize("acceptance_sampler_method", + ["rejection_sampler", "typical_acceptance_sampler"]) @torch.inference_mode() -def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): +def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool, + acceptance_sampler_method: str): """Verify SpecDecodeWorker collects metrics. """ vocab_size = 32_000 @@ -360,15 +370,14 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): vocab_size=vocab_size, use_spec=False) target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) - rejection_sampler = MagicMock(spec=RejectionSampler) - rejection_sampler.token_id_dtype = torch.int64 + spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method) metrics_collector = MagicMock(spec=AsyncMetricsCollector) draft_worker.device = 'cuda' target_worker.device = 'cuda' set_random_seed(1) - worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, + worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, metrics_collector) worker.init_device() @@ -414,17 +423,16 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): target_worker.execute_model.return_value = [target_output[0]] - rejection_sampler_output = torch.randint(low=0, - high=vocab_size, - size=(batch_size, k + 1), - dtype=torch.int64, - device='cuda') + spec_decode_sampler_output = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k + 1), + dtype=torch.int64, + device='cuda') for i in range(batch_size): minimum_accepted_tokens = 1 - rejection_sampler_output[i][ + spec_decode_sampler_output[i][ -random.randint(minimum_accepted_tokens, k + 1):] = -1 - - rejection_sampler.return_value = rejection_sampler_output + spec_decode_sampler.return_value = spec_decode_sampler_output mock_rejsample_metrics = MagicMock( spec=SpecDecodeWorkerMetrics) if returns_metrics else None @@ -445,15 +453,16 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): @pytest.mark.parametrize('k', [0]) @pytest.mark.parametrize('batch_size', [1, 2, 32]) +@pytest.mark.parametrize("acceptance_sampler_method", + ["rejection_sampler", "typical_acceptance_sampler"]) @torch.inference_mode() -def test_k_equals_zero(k: int, batch_size: int): +def test_k_equals_zero(k: int, batch_size: int, + acceptance_sampler_method: str): """Verify that the SpecDecodeWorker calls the draft and target workers when k is zero. This happens during prefill. """ draft_worker = mock_worker(cls=MultiStepWorker) target_worker = mock_worker() - rejection_sampler = MagicMock(spec=RejectionSampler) - rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) sampler_output = MagicMock(spec=SamplerOutput) @@ -465,8 +474,9 @@ def test_k_equals_zero(k: int, batch_size: int): set_random_seed(1) - worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, - metrics_collector) + worker = SpecDecodeWorker( + draft_worker, target_worker, + mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) seq_group_metadata_list, _, _ = create_batch(batch_size, k, @@ -487,16 +497,17 @@ def test_k_equals_zero(k: int, batch_size: int): @pytest.mark.parametrize('k', [0, 5]) @pytest.mark.parametrize('batch_size', [0]) +@pytest.mark.parametrize("acceptance_sampler_method", + ["rejection_sampler", "typical_acceptance_sampler"]) @torch.inference_mode() -def test_empty_input_batch(k: int, batch_size: int): +def test_empty_input_batch(k: int, batch_size: int, + acceptance_sampler_method: str): """Verify that the SpecDecodeWorker calls the draft and target workers when the input batch is empty. This can happen if the engine communicates to the workers information without scheduling a batch. """ draft_worker = mock_worker(cls=MultiStepWorker) target_worker = mock_worker() - rejection_sampler = MagicMock(spec=RejectionSampler) - rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) sampler_output = MagicMock(spec=SamplerOutput) @@ -508,8 +519,9 @@ def test_empty_input_batch(k: int, batch_size: int): set_random_seed(1) - worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, - metrics_collector) + worker = SpecDecodeWorker( + draft_worker, target_worker, + mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) seq_group_metadata_list, _, _ = create_batch(batch_size, k, @@ -528,18 +540,19 @@ def test_empty_input_batch(k: int, batch_size: int): target_worker.execute_model.assert_called_once_with(execute_model_req) +@pytest.mark.parametrize("acceptance_sampler_method", + ["rejection_sampler", "typical_acceptance_sampler"]) @pytest.mark.skip_global_cleanup -def test_init_device(): +def test_init_device(acceptance_sampler_method: str): """Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as well as other GPU initialization. """ draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False) target_worker = mock_worker(use_spec=False) - rejection_sampler = MagicMock(spec=RejectionSampler) - rejection_sampler.token_id_dtype = torch.int64 + spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method) metrics_collector = MagicMock(spec=AsyncMetricsCollector) - worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, + worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, metrics_collector) worker.init_device() @@ -549,22 +562,23 @@ def test_init_device(): target_worker.init_device.assert_called_once() metrics_collector.init_gpu_tensors.assert_called_once() - rejection_sampler.init_gpu_tensors.assert_called_once() + spec_decode_sampler.init_gpu_tensors.assert_called_once() +@pytest.mark.parametrize("acceptance_sampler_method", + ["rejection_sampler", "typical_acceptance_sampler"]) @torch.inference_mode() -def test_initialize_cache(): +def test_initialize_cache(acceptance_sampler_method): """Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer workers. """ draft_worker = mock_worker(cls=MultiStepWorker) target_worker = mock_worker() - rejection_sampler = MagicMock(spec=RejectionSampler) - rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) - worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, - metrics_collector) + worker = SpecDecodeWorker( + draft_worker, target_worker, + mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023} worker.initialize_cache(**kwargs) @@ -577,19 +591,20 @@ def test_initialize_cache(): @pytest.mark.parametrize('available_cpu_blocks', [500]) @pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096]) @pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096]) +@pytest.mark.parametrize("acceptance_sampler_method", + ["rejection_sampler", "typical_acceptance_sampler"]) @pytest.mark.skip_global_cleanup def test_determine_num_available_blocks(available_gpu_blocks: int, available_cpu_blocks: int, target_cache_block_size_bytes: int, - draft_kv_size_bytes: int): + draft_kv_size_bytes: int, + acceptance_sampler_method: str): """Verify SpecDecodeWorker correctly profiles num available GPU blocks. Specifically, it should run profiling in the scorer worker, and then evenly split the blocks between proposer and scorer worker. """ draft_worker = mock_worker(cls=MultiStepWorker) target_worker = mock_worker() - rejection_sampler = MagicMock(spec=RejectionSampler) - rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) target_worker.determine_num_available_blocks.return_value = ( @@ -598,8 +613,9 @@ def test_determine_num_available_blocks(available_gpu_blocks: int, target_cache_block_size_bytes) draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes - worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, - metrics_collector) + worker = SpecDecodeWorker( + draft_worker, target_worker, + mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks() diff --git a/tests/spec_decode/test_utils.py b/tests/spec_decode/test_utils.py index bccbf9a6aaaeb..18dbdd5bc952f 100644 --- a/tests/spec_decode/test_utils.py +++ b/tests/spec_decode/test_utils.py @@ -1,7 +1,11 @@ from unittest.mock import MagicMock import pytest +import torch +from vllm.model_executor.layers.rejection_sampler import RejectionSampler +from vllm.model_executor.layers.typical_acceptance_sampler import ( + TypicalAcceptanceSampler) from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids from vllm.spec_decode.util import split_batch_by_proposal_len @@ -109,3 +113,21 @@ def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata): assert filtered_groups == [] assert indices == [] + + +def mock_spec_decode_sampler(acceptance_sampler_method): + """ + Returns either a RejectionSampler or TypicalAcceptanceSampler + object depending on whether acceptance_sampler_method is + 'rejection_sampler' or 'typical_acceptance_sampler' respectively. + """ + if acceptance_sampler_method == "rejection_sampler": + sampler = MagicMock(spec=RejectionSampler) + sampler.token_id_dtype = torch.int64 + return sampler + elif acceptance_sampler_method == "typical_acceptance_sampler": + sampler = MagicMock(spec=TypicalAcceptanceSampler) + sampler.token_id_dtype = torch.int64 + return sampler + else: + raise ValueError(f"Invalid sampler name {acceptance_sampler_method}") diff --git a/vllm/config.py b/vllm/config.py index 3551e8f6fa031..9854f175065a2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -753,7 +753,6 @@ def __init__(self, self.chunked_prefill_enabled = enable_chunked_prefill self.embedding_mode = embedding_mode self.preemption_mode = preemption_mode - self._verify_args() def _verify_args(self) -> None: @@ -834,6 +833,9 @@ def maybe_create_spec_config( speculative_disable_by_batch_size: Optional[int], ngram_prompt_lookup_max: Optional[int], ngram_prompt_lookup_min: Optional[int], + draft_token_acceptance_method: str, + typical_acceptance_sampler_posterior_threshold: Optional[float], + typical_acceptance_sampler_posterior_alpha: Optional[float], ) -> Optional["SpeculativeConfig"]: """Create a SpeculativeConfig if possible, else return None. @@ -870,7 +872,20 @@ def maybe_create_spec_config( window, if provided. ngram_prompt_lookup_min (Optional[int]): Min size of ngram token window, if provided. - + draft_token_acceptance_method (str): The method to use for + accepting draft tokens. This can take two possible + values 'rejection_sampler' and 'typical_acceptance_sampler' + for RejectionSampler and TypicalAcceptanceSampler + respectively. + typical_acceptance_sampler_posterior_threshold (Optional[float]): + A threshold value that sets a lower bound on the posterior + probability of a token in the target model for it to be + accepted. This threshold is used only when we use the + TypicalAcceptanceSampler for token acceptance. + typical_acceptance_sampler_posterior_alpha (Optional[float]): + A scaling factor for the entropy-based threshold in the + TypicalAcceptanceSampler. + Returns: Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if the necessary conditions are met, else None. @@ -984,6 +999,11 @@ def maybe_create_spec_config( "speculative_model unless the draft model config contains an " "n_predict parameter.") + if typical_acceptance_sampler_posterior_threshold is None: + typical_acceptance_sampler_posterior_threshold = 0.09 + if typical_acceptance_sampler_posterior_alpha is None: + typical_acceptance_sampler_posterior_alpha = 0.3 + return SpeculativeConfig( draft_model_config, draft_parallel_config, @@ -991,6 +1011,11 @@ def maybe_create_spec_config( speculative_disable_by_batch_size, ngram_prompt_lookup_max, ngram_prompt_lookup_min, + draft_token_acceptance_method=draft_token_acceptance_method, + typical_acceptance_sampler_posterior_threshold=\ + typical_acceptance_sampler_posterior_threshold, + typical_acceptance_sampler_posterior_alpha=\ + typical_acceptance_sampler_posterior_alpha, ) @staticmethod @@ -1072,6 +1097,9 @@ def __init__( speculative_disable_by_batch_size: Optional[int], ngram_prompt_lookup_max: Optional[int], ngram_prompt_lookup_min: Optional[int], + draft_token_acceptance_method: str, + typical_acceptance_sampler_posterior_threshold: float, + typical_acceptance_sampler_posterior_alpha: float, ): """Create a SpeculativeConfig object. @@ -1085,6 +1113,19 @@ def __init__( enqueue requests is larger than this value. ngram_prompt_lookup_max: Max size of ngram token window. ngram_prompt_lookup_min: Min size of ngram token window. + draft_token_acceptance_method (str): The method to use for + accepting draft tokens. This can take two possible + values 'rejection_sampler' and 'typical_acceptance_sampler' + for RejectionSampler and TypicalAcceptanceSampler + respectively. + typical_acceptance_sampler_posterior_threshold (Optional[float]): + A threshold value that sets a lower bound on the posterior + probability of a token in the target model for it to be + accepted. This threshold is used only when we use the + TypicalAcceptanceSampler for token acceptance. + typical_acceptance_sampler_posterior_alpha (Optional[float]): + A scaling factor for the entropy-based threshold in the + TypicalAcceptanceSampler. """ self.draft_model_config = draft_model_config self.draft_parallel_config = draft_parallel_config @@ -1093,6 +1134,11 @@ def __init__( speculative_disable_by_batch_size self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0 self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0 + self.draft_token_acceptance_method = draft_token_acceptance_method + self.typical_acceptance_sampler_posterior_threshold = \ + typical_acceptance_sampler_posterior_threshold + self.typical_acceptance_sampler_posterior_alpha = \ + typical_acceptance_sampler_posterior_alpha self._verify_args() @@ -1104,6 +1150,31 @@ def _verify_args(self) -> None: if self.draft_model_config: self.draft_model_config.verify_with_parallel_config( self.draft_parallel_config) + # Validate and set draft token acceptance related settings. + + if (self.draft_token_acceptance_method is None): + raise ValueError("draft_token_acceptance_method is not set. " + "Expected values are rejection_sampler or " + "typical_acceptance_sampler.") + + if (self.draft_token_acceptance_method != 'rejection_sampler' + and self.draft_token_acceptance_method != + 'typical_acceptance_sampler'): + raise ValueError( + "Expected draft_token_acceptance_method to be either " + "rejection_sampler or typical_acceptance_sampler. Instead it " + f"is {self.draft_token_acceptance_method}") + + if (self.typical_acceptance_sampler_posterior_threshold < 0 + or self.typical_acceptance_sampler_posterior_alpha < 0): + raise ValueError( + "Expected typical_acceptance_sampler_posterior_threshold " + "and typical_acceptance_sampler_posterior_alpha to be > 0. " + "Instead found " + f"typical_acceptance_sampler_posterior_threshold = " + f"{self.typical_acceptance_sampler_posterior_threshold} and " + f"typical_acceptance_sampler_posterior_alpha = " + f"{self.typical_acceptance_sampler_posterior_alpha}") @property def num_lookahead_slots(self) -> int: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f9d089091ffc2..d4044adfce615 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -100,7 +100,9 @@ class EngineArgs: speculative_disable_by_batch_size: Optional[int] = None ngram_prompt_lookup_max: Optional[int] = None ngram_prompt_lookup_min: Optional[int] = None - + spec_decoding_acceptance_method: str = 'rejection_sampler' + typical_acceptance_sampler_posterior_threshold: Optional[float] = None + typical_acceptance_sampler_posterior_alpha: Optional[float] = None qlora_adapter_name_or_path: Optional[str] = None otlp_traces_endpoint: Optional[str] = None @@ -577,6 +579,38 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help='Min size of window for ngram prompt lookup in speculative ' 'decoding.') + parser.add_argument( + '--spec-decoding-acceptance-method', + type=str, + default=EngineArgs.spec_decoding_acceptance_method, + choices=['rejection_sampler', 'typical_acceptance_sampler'], + help='Specify the acceptance method to use during draft token ' + 'verification in speculative decoding. Two types of acceptance ' + 'routines are supported: ' + '1) RejectionSampler which does not allow changing the ' + 'acceptance rate of draft tokens, ' + '2) TypicalAcceptanceSampler which is configurable, allowing for ' + 'a higher acceptance rate at the cost of lower quality, ' + 'and vice versa.') + + parser.add_argument( + '--typical-acceptance-sampler-posterior-threshold', + type=float, + default=EngineArgs.typical_acceptance_sampler_posterior_threshold, + help='Set the lower bound threshold for the posterior ' + 'probability of a token to be accepted. This threshold is ' + 'used by the TypicalAcceptanceSampler to make sampling decisions ' + 'during speculative decoding. Defaults to 0.09') + + parser.add_argument( + '--typical-acceptance-sampler-posterior-alpha', + type=float, + default=EngineArgs.typical_acceptance_sampler_posterior_alpha, + help='A scaling factor for the entropy-based threshold for token ' + 'acceptance in the TypicalAcceptanceSampler. Typically defaults ' + 'to sqrt of --typical-acceptance-sampler-posterior-threshold ' + 'i.e. 0.3') + parser.add_argument('--model-loader-extra-config', type=nullable_str, default=EngineArgs.model_loader_extra_config, @@ -737,6 +771,12 @@ def create_engine_config(self, ) -> EngineConfig: use_v2_block_manager=self.use_v2_block_manager, ngram_prompt_lookup_max=self.ngram_prompt_lookup_max, ngram_prompt_lookup_min=self.ngram_prompt_lookup_min, + draft_token_acceptance_method=\ + self.spec_decoding_acceptance_method, + typical_acceptance_sampler_posterior_threshold=self. + typical_acceptance_sampler_posterior_threshold, + typical_acceptance_sampler_posterior_alpha=self. + typical_acceptance_sampler_posterior_alpha, ) scheduler_config = SchedulerConfig( diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 2c1210c90c632..77de42bc0ed5d 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -457,4 +457,4 @@ def log(self, stats: Stats): class RayPrometheusStatLogger(PrometheusStatLogger): """RayPrometheusStatLogger uses Ray metrics instead.""" - _metrics_cls = RayMetrics + _metrics_cls = RayMetrics \ No newline at end of file diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index fe9b2fac1117e..e189610461a70 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -3,13 +3,12 @@ import torch import torch.jit -import torch.nn as nn from vllm.model_executor.layers.spec_decode_base_sampler import ( SpecDecodeBaseSampler) -class RejectionSampler(SpecDecodeBaseSampler, nn.Module): +class RejectionSampler(SpecDecodeBaseSampler): """Apply modified rejection sampling as described in "Accelerating Large Language Model Decoding with Speculative Sampling" https://arxiv.org/pdf/2302.01318.pdf. @@ -28,8 +27,8 @@ def __init__(self, during sampling. This catches correctness issues but adds nontrivial latency. """ - SpecDecodeBaseSampler.__init__(self, disable_bonus_tokens, strict_mode) - nn.Module.__init__(self) + super().__init__(disable_bonus_tokens=disable_bonus_tokens, + strict_mode=strict_mode) def forward( self, @@ -78,11 +77,12 @@ def forward( self._raise_if_incorrect_input(target_probs, bonus_token_ids, draft_probs, draft_token_ids) - accepted, recovered_token_ids = self._batch_modified_rejection_sampling( - target_probs, - draft_probs, - draft_token_ids, - ) + accepted, recovered_token_ids = ( + self._batch_modified_rejection_sampling( + target_probs, + draft_probs, + draft_token_ids, + )) output_token_ids = self._create_output( accepted, diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py index 9856a7e7ddea0..692024056495c 100644 --- a/vllm/model_executor/layers/spec_decode_base_sampler.py +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -1,9 +1,12 @@ +from abc import abstractmethod from typing import Optional import torch +import torch.jit +import torch.nn as nn -class SpecDecodeBaseSampler(): +class SpecDecodeBaseSampler(nn.Module): """Base class for samplers used for Speculative Decoding verification step. """ @@ -51,6 +54,16 @@ def probs_dtype(self): def token_id_dtype(self): return torch.int64 + @abstractmethod + def forward( + self, + target_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError + def _create_output( self, accepted: torch.Tensor, # [batch_size, k] diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index f12d6a03b4d16..9bf3c84a161c5 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -1,12 +1,11 @@ import torch import torch.jit -import torch.nn as nn from vllm.model_executor.layers.spec_decode_base_sampler import ( SpecDecodeBaseSampler) -class TypicalAcceptanceSampler(SpecDecodeBaseSampler, nn.Module): +class TypicalAcceptanceSampler(SpecDecodeBaseSampler): """Apply typical acceptance sampling as described in section 3.3.1 in "MEDUSA: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads" @@ -15,10 +14,10 @@ class TypicalAcceptanceSampler(SpecDecodeBaseSampler, nn.Module): def __init__( self, + posterior_threshold: float, + posterior_alpha: float, disable_bonus_tokens: bool = False, strict_mode: bool = False, - posterior_threshold: float = 0.09, - posterior_alpha: float = 0.3, ): """Create a Typical Acceptance Sampler. @@ -31,23 +30,20 @@ def __init__( nontrivial latency. posterior_threshold : A threshold value that sets a lower bound on the posterior probability of a token in target model for it - to be accepted. Default is 0.09 + to be accepted. posterior_alpha : A scaling factor for the entropy-based - threshold in typical acceptance sampling. Typically defaults to - sqrt of posterior_threshold and is set to 0.3. + threshold in typical acceptance sampling. """ - SpecDecodeBaseSampler.__init__( - self, - disable_bonus_tokens=disable_bonus_tokens, - strict_mode=strict_mode) - nn.Module.__init__(self) self._posterior_threshold = posterior_threshold self._posterior_alpha = posterior_alpha + super().__init__(disable_bonus_tokens=disable_bonus_tokens, + strict_mode=strict_mode) def forward( self, target_probs: torch.Tensor, bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, draft_token_ids: torch.Tensor, ) -> torch.Tensor: """Sample token ids using typical acceptance sampling. This accepts @@ -69,6 +65,8 @@ def forward( speculative tokens in a sequence are accepted. shape = [batch_size, num_bonus_tokens] + draft_probs: This parameter is unused by the acceptance sampler. + draft_token_ids: The token ids that were sampled from the draft probabilities. shape = [batch_size, num_speculative_tokens] diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py index ab1d96c558de7..2c4ae0b22744b 100644 --- a/vllm/spec_decode/metrics.py +++ b/vllm/spec_decode/metrics.py @@ -4,7 +4,8 @@ import torch -from vllm.model_executor.layers.rejection_sampler import RejectionSampler +from vllm.model_executor.layers.spec_decode_base_sampler import ( + SpecDecodeBaseSampler) from vllm.utils import is_pin_memory_available @@ -46,15 +47,15 @@ class SpecDecodeWorkerMetrics: class AsyncMetricsCollector: - """Class which copies rejection sampler metrics from the device to CPU on a - non-default Torch stream. + """Class which copies rejection/typical-acceptance sampler metrics + from the device to CPU on a non-default Torch stream. """ def __init__(self, - rejection_sampler: RejectionSampler, + spec_decode_sampler: SpecDecodeBaseSampler, timer: Optional[Timer] = None, collect_interval_s: float = 5.0): - self._rejection_sampler = rejection_sampler + self.spec_decode_sampler = spec_decode_sampler self._timer = time.time if timer is None else timer self._rank: Optional[int] = None @@ -95,7 +96,7 @@ def maybe_collect_rejsample_metrics( return None def _should_collect_rejsample_metrics(self, now: float) -> bool: - """Return whether or not this iteration should print rejection sampling + """Return whether or not this iteration should print sampling metrics. """ if self._rank != 0: @@ -107,8 +108,8 @@ def _should_collect_rejsample_metrics(self, now: float) -> bool: return True def _copy_rejsample_metrics_async(self) -> torch.cuda.Event: - """Copy rejection sampling metrics (number of accepted tokens, etc) to - CPU asynchronously. + """Copy rejection/typical-acceptance sampling metrics + (number of accepted tokens, etc) to CPU asynchronously. Returns a CUDA event recording when the copy is complete. """ @@ -117,13 +118,14 @@ def _copy_rejsample_metrics_async(self) -> torch.cuda.Event: with torch.cuda.stream(self._copy_stream): self._aggregate_num_accepted_tokens.copy_( - self._rejection_sampler.num_accepted_tokens, non_blocking=True) + self.spec_decode_sampler.num_accepted_tokens, + non_blocking=True) self._aggregate_num_emitted_tokens.copy_( - self._rejection_sampler.num_emitted_tokens, non_blocking=True) + self.spec_decode_sampler.num_emitted_tokens, non_blocking=True) # Number of draft tokens is calculated on CPU, so no copy is # required. self._aggregate_num_draft_tokens = ( - self._rejection_sampler.num_draft_tokens) + self.spec_decode_sampler.num_draft_tokens) aggregate_metrics_ready = torch.cuda.Event() aggregate_metrics_ready.record(self._copy_stream) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index f1e64cae8fc5b..ca470bee21c91 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -7,6 +7,10 @@ from vllm.distributed.communication_op import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler +from vllm.model_executor.layers.spec_decode_base_sampler import ( + SpecDecodeBaseSampler) +from vllm.model_executor.layers.typical_acceptance_sampler import ( + TypicalAcceptanceSampler) from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest, HiddenStates, SamplerOutput, SequenceGroupMetadata, get_all_seq_ids) @@ -56,7 +60,12 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": draft_worker_kwargs=draft_worker_kwargs, disable_by_batch_size=speculative_config. speculative_disable_by_batch_size, - ) + draft_token_acceptance_method=speculative_config. + draft_token_acceptance_method, + typical_acceptance_sampler_posterior_threshold=speculative_config. + typical_acceptance_sampler_posterior_threshold, + typical_acceptance_sampler_posterior_alpha=speculative_config. + typical_acceptance_sampler_posterior_alpha) return spec_decode_worker @@ -78,8 +87,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): welcome!). * Only top-1 proposal and scoring are implemented. Tree-attention is left as future work. - * Only lossless rejection sampling is supported. Contributions adding lossy - verification routines are welcome (e.g. Medusa's typical acceptance). * All sequences in a batch must have the same proposal length, or zero. This can be improved by having per-sequence speculation in the future. * The scoring forward pass is done without an MQA kernel, which is @@ -95,6 +102,9 @@ def create_worker( scorer_worker: Worker, draft_worker_kwargs: Dict[str, Any], disable_by_batch_size: Optional[int], + draft_token_acceptance_method: str, + typical_acceptance_sampler_posterior_threshold: float, + typical_acceptance_sampler_posterior_alpha: float, ) -> "SpecDecodeWorker": ngram_prompt_lookup_max = ( @@ -127,17 +137,30 @@ def create_worker( logger.info("Configuring SpecDecodeWorker with proposer=%s", type(proposer_worker)) + spec_decode_sampler: SpecDecodeBaseSampler = None + if draft_token_acceptance_method == "rejection_sampler": + spec_decode_sampler = RejectionSampler( + disable_bonus_tokens=disable_bonus_tokens, ) + elif draft_token_acceptance_method == "typical_acceptance_sampler": + spec_decode_sampler = TypicalAcceptanceSampler( + disable_bonus_tokens=disable_bonus_tokens, + posterior_threshold=\ + typical_acceptance_sampler_posterior_threshold, + posterior_alpha=typical_acceptance_sampler_posterior_alpha, + ) + logger.info("Configuring SpecDecodeWorker with sampler=%s", + type(spec_decode_sampler)) + return SpecDecodeWorker(proposer_worker, scorer_worker, disable_by_batch_size=disable_by_batch_size, - rejection_sampler=RejectionSampler( - disable_bonus_tokens=disable_bonus_tokens)) + spec_decode_sampler=spec_decode_sampler) def __init__( self, proposer_worker: ProposerWorkerBase, scorer_worker: WorkerBase, - rejection_sampler: RejectionSampler, + spec_decode_sampler: SpecDecodeBaseSampler, metrics_collector: Optional[AsyncMetricsCollector] = None, disable_by_batch_size: Optional[int] = None, ): @@ -150,8 +173,12 @@ def __init__( scorer_worker: A worker that produces probabilities of speculative tokens according to some base model. Typically a vanilla vLLM Worker. - rejection_sampler: A Torch module used to perform modified rejection - sampling for speculative decoding. + spec_decode_sampler: A Torch module used to perform acceptance + sampling of the draft tokens in the verification step of + speculative decoding. Currently we support two different + types of sampler namely RejectionSampler and + TypicalAcceptanceSampler. 'spec_decode_sampler' is either an + instance of RejectionSampler or TypicalAcceptanceSampler. disable_by_batch_size: If the batch size is larger than this, disable speculative decoding for new incoming requests. metrics_collector: Helper class for collecting metrics; can be set @@ -160,15 +187,12 @@ def __init__( self.proposer_worker = proposer_worker self.scorer_worker = scorer_worker self.disable_by_batch_size = disable_by_batch_size or float("inf") - self.rejection_sampler = rejection_sampler - + self.spec_decode_sampler = spec_decode_sampler self._metrics = AsyncMetricsCollector( - rejection_sampler + self.spec_decode_sampler ) if metrics_collector is None else metrics_collector - - self.probs_dtype = self.rejection_sampler.probs_dtype - self.token_id_dtype = self.rejection_sampler.token_id_dtype - + self.probs_dtype = self.spec_decode_sampler.probs_dtype + self.token_id_dtype = self.spec_decode_sampler.token_id_dtype # Lazy initiazliation. self.scorer: SpeculativeScorer @@ -189,7 +213,8 @@ def init_device(self) -> None: self.proposer_worker.load_model() self._metrics.init_gpu_tensors(self.rank) - self.rejection_sampler.init_gpu_tensors(self.rank) + self.spec_decode_sampler.init_gpu_tensors(self.rank) + self.scorer = BatchExpansionTop1Scorer( scorer_worker=self.scorer_worker, device=self.device, @@ -203,7 +228,7 @@ def load_model(self, *args, **kwargs): def _configure_model_sampler_for_spec_decode(self): """Configure model sampler to emit GPU tensors. This allows spec decode to keep data on device without transferring to CPU and serializing, - which significantly reduces overhead of rejection sampling. + which significantly reduces overhead of sampling during verification. NOTE(cade): This breaks abstraction boundaries pretty badly. The better design is to have the "move to CPU and serialize" sampling decision be @@ -481,7 +506,7 @@ def _verify_tokens( # Get proposed tokens. proposal_token_ids = proposals.proposal_token_ids[spec_indices] - accepted_token_ids = self.rejection_sampler( + accepted_token_ids = self.spec_decode_sampler( target_probs=proposal_verifier_probs, bonus_token_ids=bonus_token_ids, draft_probs=proposal_probs, @@ -496,7 +521,6 @@ def _verify_tokens( accepted_token_ids = torch.cat( [accepted_token_ids, non_spec_token_ids]) logprobs = proposal_scores.logprobs - # Rearrange so that results are in the order of the original seq group # metadata. accepted_token_ids[original_indices] = accepted_token_ids.clone()