Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
Signed-off-by: andoorve <[email protected]>
  • Loading branch information
andoorve committed Nov 19, 2024
1 parent 9a5e8a6 commit 3c266a2
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
6 changes: 3 additions & 3 deletions tests/spec_decode/e2e/test_integration_dist_tp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,9 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, seed: int):
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, seed: int):
"""Verify spec decode works well with smaller tp for draft models.
"""
run_equality_correctness_test_tp(model,
Expand Down
4 changes: 1 addition & 3 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,9 +715,7 @@ def _run_speculative_decoding_step(
idx for idx in non_spec_indices
if execute_model_req.seq_group_metadata_list[idx].is_prompt
]
broadcast_dict = dict(
run_spec_proposer=bool(non_spec_indices)
)
broadcast_dict = dict(run_spec_proposer=bool(non_spec_indices))
broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)
if len(non_spec_indices):
all_hidden_states = proposal_scores.hidden_states
Expand Down

0 comments on commit 3c266a2

Please sign in to comment.