Skip to content

Commit

Permalink
[Speculative Decoding 2/2 ] Integrate typical acceptance sampler into…
Browse files Browse the repository at this point in the history
… Spec Decode Worker (#5348)
  • Loading branch information
sroy745 authored Jul 1, 2024
1 parent 614aa51 commit 80ca1e6
Show file tree
Hide file tree
Showing 14 changed files with 482 additions and 210 deletions.
96 changes: 64 additions & 32 deletions tests/samplers/test_typical_acceptance_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,19 @@ def get_draft_token_ids(batch_size: int, k: int, vocab_size: int,
return draft_token_ids


def get_acceptance_sampler(
posterior_threshold: float = 0.03,
posterior_alpha: float = 0.9,
disable_bonus_tokens: bool = False,
strict_mode: bool = False,
) -> TypicalAcceptanceSampler:
"""
Initializes and returns a TypicalAcceptanceSampler.
"""
return TypicalAcceptanceSampler(posterior_threshold, posterior_alpha,
disable_bonus_tokens, strict_mode)


@pytest.mark.parametrize("k", list(range(1, 6)))
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", list(range(1, 32)))
Expand All @@ -64,7 +77,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
different combinations of k, vocab_size, batch_size and num devices.
"""
torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler()
typical_acceptance_sampler = get_acceptance_sampler()
typical_acceptance_sampler.init_gpu_tensors(rank=0)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
Expand All @@ -76,7 +89,10 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
size=(batch_size, k),
dtype=torch.int64)
# Verify that sampling succeeds for all cases.
typical_acceptance_sampler(target_probs, bonus_token_ids, draft_token_ids)
typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)


@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
Expand All @@ -94,7 +110,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler(strict_mode=True)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
Expand Down Expand Up @@ -125,8 +141,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
oob_token_ids[0][0] = rogue_token_id

with pytest.raises(AssertionError):
typical_acceptance_sampler(target_probs, bonus_token_ids,
draft_token_ids)
typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)


@pytest.mark.parametrize("seed", list(range(10)))
Expand All @@ -151,7 +169,7 @@ def test_uniform_target_distribution_accepts_all_tokens(
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler(
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
Expand All @@ -163,9 +181,11 @@ def test_uniform_target_distribution_accepts_all_tokens(
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_token_ids)
output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
# We are using a uniform target probability distribution.
# For a uniform distribution the entropy is very high and it
# should lead to all draft tokens being accepted. Verify that.
Expand Down Expand Up @@ -203,7 +223,7 @@ def test_temperature_zero_target_distribution(seed: int,
vocab_size = 30_000
torch.set_default_device(device)

typical_acceptance_sampler = TypicalAcceptanceSampler(
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
# Simulate temperature 0 probability distribution for target probabilities
Expand All @@ -224,9 +244,11 @@ def test_temperature_zero_target_distribution(seed: int,
# 1.0 tokens in the target distribution we will reject all of them and
# fallback to the greedy sampling for selecting 1 token for each sequence.
# Verify the same.
output_token_ids = typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_token_ids)
output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, -1] == -1)
Expand Down Expand Up @@ -261,7 +283,7 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
batch_size = 4
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler(
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
# For sequences 0 and 2 set the distribution to a temperature
Expand All @@ -277,9 +299,11 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_token_ids)
output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
# verify the shape of output_token_ids
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
Expand Down Expand Up @@ -326,7 +350,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
batch_size = 1
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler(
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
# Create a temperature zero target probability distribution and ensure
Expand All @@ -339,9 +363,11 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_token_ids)
output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
Expand All @@ -357,9 +383,11 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
batch_size, k, vocab_size, zero_temperature_token_ids)
draft_token_ids = torch.cat(
(draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1)
output_token_ids = typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_token_ids)
output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2])
Expand All @@ -384,7 +412,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
batch_size = 1
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler(
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
# Simulate temperature 0 probability distribution for target
Expand All @@ -402,9 +430,11 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_token_ids)
output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, 1:-1] == -1)
Expand All @@ -418,9 +448,11 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
posterior_threshold=0.0,
posterior_alpha=0.0)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
output_token_ids = typical_acceptance_sampler(target_probs,
bonus_token_ids,
draft_token_ids)
output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
Expand Down Expand Up @@ -451,7 +483,7 @@ def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = TypicalAcceptanceSampler(
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
Expand Down
54 changes: 53 additions & 1 deletion tests/spec_decode/e2e/test_multistep_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,15 @@
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality. This gives us good coverage of temp=0.
At temp=0, the TypicalAcceptanceSampler ensures that only the tokens with the
highest probability in the target distribution are accepted. Therefore, we can
expect greedy equality for the TypicalAcceptanceSampler at temp=0.
For temp>0, we rely on unit tests on the rejection sampler to verify that the
output distribution is the same with spec decode vs. no spec decode (this would
be prohibitively expensive to run with a real model).
be prohibitively expensive to run with a real model). Similarly, for the
TypicalAcceptance sampler also, we rely on unit tests to validate temp>0
test cases.
NOTE: Speculative decoding's distribution equality requires that the measured
distributions of the target model and proposal model be deterministic given the
Expand Down Expand Up @@ -611,3 +617,49 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
batch_size,
max_output_len=output_len,
force_output_len=True)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-160m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"spec_decoding_acceptance_method": "typical_acceptance_sampler"
}
# Try a range of common k.
for k in [1, 2, 3]
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_typical_acceptance_sampling(baseline_llm_generator,
test_llm_generator, batch_size: int,
output_len: int):
"""Verify that speculative decoding produces exact equality to without spec
decode with TypicalAcceptanceSampler as the draft token acceptance
sampling method.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
12 changes: 7 additions & 5 deletions tests/spec_decode/test_dynamic_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,35 @@
import pytest
import torch

from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
from vllm.spec_decode.top1_proposer import Top1Proposer

from .test_utils import mock_spec_decode_sampler
from .utils import create_batch, mock_worker


@pytest.mark.parametrize('queue_size', [4])
@pytest.mark.parametrize('batch_size', [1])
@pytest.mark.parametrize('k', [1])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode()
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int,
acceptance_sampler_method: str):
"""Verify that speculative tokens are disabled when the batch size
exceeds the threshold.
"""
disable_by_batch_size = 3

draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(proposer_worker=draft_worker,
scorer_worker=target_worker,
rejection_sampler=rejection_sampler,
spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
metrics_collector=metrics_collector,
disable_by_batch_size=disable_by_batch_size)

Expand Down
Loading

0 comments on commit 80ca1e6

Please sign in to comment.