Skip to content

Commit

Permalink
[Bugfix][SpecDecode] apply sampling parameters to target probabilitie…
Browse files Browse the repository at this point in the history
…s for consistency in rejection sampling. (#10198)

Signed-off-by: jeongin601 <[email protected]>
Signed-off-by: jeong_in.bae <[email protected]>
  • Loading branch information
jeongin601 authored Nov 27, 2024
1 parent 0a4d968 commit 1bf905d
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 14 deletions.
2 changes: 1 addition & 1 deletion tests/spec_decode/e2e/test_mlp_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}])
@pytest.mark.parametrize("output_len", [64])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("temperature", [0.1, 1.0])
@pytest.mark.parametrize("temperature", [1.0])
@pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
Expand Down
8 changes: 8 additions & 0 deletions tests/spec_decode/test_batch_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,14 @@ def test_create_single_target_seq_group_metadata(k: int):
)

assert output.request_id == input_seq_group_metadata.request_id
assert output.sampling_params.repetition_penalty == \
input_seq_group_metadata.sampling_params.repetition_penalty
assert output.sampling_params.temperature == \
input_seq_group_metadata.sampling_params.temperature
assert output.sampling_params.top_p == \
input_seq_group_metadata.sampling_params.top_p
assert output.sampling_params.top_k == \
input_seq_group_metadata.sampling_params.top_k
assert len(output.seq_data) == 1
assert output.seq_data[target_seq_id].get_prompt_token_ids() == tuple(
prompt_tokens)
Expand Down
14 changes: 1 addition & 13 deletions vllm/spec_decode/batch_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,28 +307,16 @@ def _create_target_seq_group_metadata(
token_ids_to_score = self._get_token_ids_to_score(
proposal_token_ids[batch_index])

# Use simpler sampling parameters apart from for final token
# (in particular don't do seeded sampling) since those sampled tokens
# aren't used.
# We don't replace the sampling_params in the greedy case because
# this also controls whether the probs get modified in the sampler
# (see use of _modify_greedy_probs_inplace there).
sampling_params = input_seq_group_metadata.sampling_params
non_bonus_sampling_params = DEFAULT_SIMPLE_SAMPLING_PARAMS \
if sampling_params.temperature else sampling_params

target_seq_group_metadata_list: List[SequenceGroupMetadata] = []
last_index = len(token_ids_to_score) - 1
for i, token_ids in enumerate(token_ids_to_score):
target_sampling_params = sampling_params if i == last_index \
else non_bonus_sampling_params
target_seq_group_metadata_list.append(
self._create_single_target_seq_group_metadata(
input_seq_group_metadata,
input_seq_id,
next(target_seq_ids_iter),
token_ids,
sampling_params=target_sampling_params,
sampling_params=sampling_params,
))

return target_seq_group_metadata_list
Expand Down

0 comments on commit 1bf905d

Please sign in to comment.