From 1bf905ddaa969e6458fe0d15a1db80318f39fade Mon Sep 17 00:00:00 2001 From: jeongin601 <78595701+jeongin601@users.noreply.github.com> Date: Wed, 27 Nov 2024 14:07:30 +0900 Subject: [PATCH] [Bugfix][SpecDecode] apply sampling parameters to target probabilities for consistency in rejection sampling. (#10198) Signed-off-by: jeongin601 <0200angela@gmail.com> Signed-off-by: jeong_in.bae --- tests/spec_decode/e2e/test_mlp_correctness.py | 2 +- tests/spec_decode/test_batch_expansion.py | 8 ++++++++ vllm/spec_decode/batch_expansion.py | 14 +------------- 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py index 5ecc0d4e95719..183ff2f5db274 100644 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -203,7 +203,7 @@ def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}]) @pytest.mark.parametrize("output_len", [64]) @pytest.mark.parametrize("batch_size", [1, 32]) -@pytest.mark.parametrize("temperature", [0.1, 1.0]) +@pytest.mark.parametrize("temperature", [1.0]) @pytest.mark.parametrize("seed", [1]) def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, diff --git a/tests/spec_decode/test_batch_expansion.py b/tests/spec_decode/test_batch_expansion.py index 0d6aaa449d856..3504fcf43e361 100644 --- a/tests/spec_decode/test_batch_expansion.py +++ b/tests/spec_decode/test_batch_expansion.py @@ -90,6 +90,14 @@ def test_create_single_target_seq_group_metadata(k: int): ) assert output.request_id == input_seq_group_metadata.request_id + assert output.sampling_params.repetition_penalty == \ + input_seq_group_metadata.sampling_params.repetition_penalty + assert output.sampling_params.temperature == \ + input_seq_group_metadata.sampling_params.temperature + assert output.sampling_params.top_p == \ + input_seq_group_metadata.sampling_params.top_p + assert output.sampling_params.top_k == \ + input_seq_group_metadata.sampling_params.top_k assert len(output.seq_data) == 1 assert output.seq_data[target_seq_id].get_prompt_token_ids() == tuple( prompt_tokens) diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 25ef27b8378f0..01b9cdad963da 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -307,28 +307,16 @@ def _create_target_seq_group_metadata( token_ids_to_score = self._get_token_ids_to_score( proposal_token_ids[batch_index]) - # Use simpler sampling parameters apart from for final token - # (in particular don't do seeded sampling) since those sampled tokens - # aren't used. - # We don't replace the sampling_params in the greedy case because - # this also controls whether the probs get modified in the sampler - # (see use of _modify_greedy_probs_inplace there). sampling_params = input_seq_group_metadata.sampling_params - non_bonus_sampling_params = DEFAULT_SIMPLE_SAMPLING_PARAMS \ - if sampling_params.temperature else sampling_params - target_seq_group_metadata_list: List[SequenceGroupMetadata] = [] - last_index = len(token_ids_to_score) - 1 for i, token_ids in enumerate(token_ids_to_score): - target_sampling_params = sampling_params if i == last_index \ - else non_bonus_sampling_params target_seq_group_metadata_list.append( self._create_single_target_seq_group_metadata( input_seq_group_metadata, input_seq_id, next(target_seq_ids_iter), token_ids, - sampling_params=target_sampling_params, + sampling_params=sampling_params, )) return target_seq_group_metadata_list