Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
Signed-off-by: NickLucche <[email protected]>
  • Loading branch information
NickLucche committed Nov 7, 2024
1 parent dfc52a5 commit 85367e2
Show file tree
Hide file tree
Showing 9 changed files with 6 additions and 94 deletions.
2 changes: 1 addition & 1 deletion tests/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def check_logprobs_close(
# If the seq 0 token's logprobs are not `None`,
# the seq 1 token's logprobs must not be `None`
assert logprobs_elem_1 is not None, fail_msg
# Logprobs check: top-k token choices must be the same
# Logprobs check: top-k token choices must be the same
assert (set(logprobs_elem_0.keys()) == set(
logprobs_elem_1.keys())), fail_msg
else:
Expand Down
3 changes: 0 additions & 3 deletions tests/spec_decode/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,6 @@ def _check_logprobs_when_output_disabled(
spec_pos_logprob) = next(iter(spec_pos_logprobs.items()))
assert spec_pos_logprob.rank == -1
assert spec_pos_logprob.logprob == 0.0
# FIXME shouldnt have a tensor here?
if not isinstance(spec_pos_logprob_token_id, int):
spec_pos_logprob_token_id = spec_pos_logprob_token_id.item()
assert spec_pos_logprob_token_id in baseline_pos_logprobs


Expand Down
10 changes: 2 additions & 8 deletions tests/spec_decode/e2e/test_multistep_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [2, 32])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
@fork_new_process_for_each_test
def test_spec_decode_e2e_with_detokenization(test_llm_generator,
Expand Down Expand Up @@ -143,9 +143,6 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
# Note that one is equal to the draft model, another isn't.
{
"model_name": "JackFram/llama-68m",
# "enable_chunked_prefill": True,
# "max_num_batched_tokens": 4,
# "max_num_seqs": 4
},
{
"model_name": "JackFram/llama-160m",
Expand All @@ -163,9 +160,6 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"enable_chunked_prefill": True,
# TODO test with #prompt>>k
# "max_num_batched_tokens": 12,
# "max_num_seqs": 12,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
"disable_logprobs_during_spec_decoding": False
Expand All @@ -177,7 +171,7 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
# Use long output len for the small model test.
10,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("seed", [1])
@fork_new_process_for_each_test
def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/output_processor/single_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def single_step_process_prompt_logprob(
seq_group,
prompt_logprobs,
position_offset=len(seq_group.prompt_logprobs))
# The second chunk should get this appended so it doesnt add None no more!!

seq_group.prompt_logprobs.extend(prompt_logprobs)


Expand Down
1 change: 0 additions & 1 deletion vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,7 +1197,6 @@ def __post_init__(self):
if self.seq_group_metadata_list is not None:
assert len(self.seq_group_metadata_list) == len(self.hidden_states)
self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
print("HIDDEN STATE DIM", self.hidden_states.shape)

@property
def seq_ids(self) -> List[int]:
Expand Down
61 changes: 1 addition & 60 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ def _run_speculative_decoding_step(
stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots,
scoring_timer.elapsed_time_ms,
verification_timer.elapsed_time_ms)
# TODO since so far we only had decodes here, no one bothered to add prompt_logprobs

return self._create_output_sampler_list(
execute_model_req.seq_group_metadata_list,
accepted_token_ids,
Expand Down Expand Up @@ -886,65 +886,7 @@ def _create_output_sampler_list(
# i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]] while
# terminal chunks will only have one generated token at time 0.
sampler_output_list: List[SamplerOutput] = []
def create_prompt_logprobs_for_prefill_seq(prompt_logprobs):
# we do this because output needs to be per-step
from vllm.sequence import Logprob
from vllm.model_executor.layers.sampler import _get_ranks
topk = seq_group_metadata_list[0].sampling_params.prompt_logprobs
# NkV->NxK, NxK, topk_token_ids: for each row, index that was selected (token id \in 0..V-1)
topk_probs, topk_token_ids = prompt_logprobs.topk(topk, axis=1)
# when the actual token is already in the top K, we only return K
# otherwise we return K+1 (actual token is always included)
seq_id_to_plogprob = dict()
num_prompt_tokens = 0
for i, seq_meta in enumerate(seq_group_metadata_list):
if not seq_meta.is_prompt:
break
seq_data = list(seq_meta.seq_data.values())[0]
# only get the tokens in this chunk!
prompt_token_ids = seq_data.get_prompt_token_ids()
# TODO all prompt logprobs output are shifted by one=>thats because at step i you know about next token
prompt_token_ids = prompt_token_ids[seq_data._num_computed_tokens+1:
seq_data._num_computed_tokens+seq_meta.token_chunk_size+1]

# this can be smaller, like final chunk [1024, 338]->only logprob is 338
assert len(prompt_token_ids) <= seq_meta.token_chunk_size
# for the first token of the prompt we have no logprob=> we dont care we add None in postproc
is_first_chunk = seq_data._num_computed_tokens == 0
# iterate over prompt tokens of this request
logprobs_per_seq: List[Dict[int, Logprob]] = []
# for j in range(seq_meta.token_chunk_size-is_first_chunk):
for actual_prompt_token in prompt_token_ids:
step_logprobs: Dict[int, Logprob] = {}
# idx = j+is_first_chunk
# actual_prompt_token = prompt_token_ids[idx]
step_logprobs[actual_prompt_token] = Logprob(
# NOTE ASSUMING PROMPT LOGPROBS HERE HAS VALUE EVEN FOR FIRST ONE
logprob=prompt_logprobs[num_prompt_tokens, actual_prompt_token],
# TODO cache/re-use
# FIXME oom
# rank=_get_ranks(prompt_logprobs[num_prompt_tokens].reshape(-1,1), torch.tensor(actual_prompt_token).reshape(-1,1)),
rank = 1
)

# add the other topk tokens
for tok_id, lprob in zip(topk_token_ids[num_prompt_tokens], topk_probs[num_prompt_tokens]):
step_logprobs[tok_id.item()] = Logprob(
logprob=lprob,
# rank=_get_ranks(prompt_logprobs[num_prompt_tokens].reshape(-1,1), torch.tensor(tok_id).reshape(-1,1)),
rank=22
)
logprobs_per_seq.append(step_logprobs)
# cum count that keeps track of processed
num_prompt_tokens += len(prompt_token_ids)
# TODO or just use i?
seq_id_to_plogprob[seq_meta.request_id] = logprobs_per_seq
return seq_id_to_plogprob

# if prompt_logprobs is not None:
# seq_id_to_plogprob = {sg.request_id:o.prompt_logprobs for sg, o in zip(seq_group_metadata_list, prompt_logprobs) if sg.is_prompt}


# Prefills are not multi-step (return at most 1 token), in order to
# avoid padding or repetition to fit decodes, we separate them.
for i, sg in enumerate(seq_group_metadata_list):
Expand Down Expand Up @@ -1016,7 +958,6 @@ def create_prompt_logprobs_for_prefill_seq(prompt_logprobs):

# Each sequence may have a different num_logprobs; retrieve it.
num_logprobs = num_logprobs_per_seq[sequence_index]
# NOTE this was meant for decodes only
step_output_token_ids.append(
create_sequence_group_output(
token_id=accepted_token_ids_by_step[step_index]
Expand Down
1 change: 0 additions & 1 deletion vllm/spec_decode/top1_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def get_spec_proposals(
# in batch size list
hidden_states = execute_model_req.previous_hidden_states
if hidden_states is not None:
# NOTE check this out, only the hidden states for decodes are passed in!
hidden_states.prune(nonzero_proposal_len_seqs)
nonzero_execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=nonzero_proposal_len_seqs,
Expand Down
15 changes: 1 addition & 14 deletions vllm/transformers_utils/detokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,26 +41,15 @@ def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup,
seq = seq_group.get_seqs()[0]
# Only prompt, without the generated token.
all_token_ids = seq.get_token_ids()
# TODO you do NOT always have a generated token here non-terminal chunk
# these are actually hte tokens of the entire prompt so it should be ok
# Request 1 prompt tokens: (1, 450, 6673, 310, 278, 3303, 3900, 338)
# Request 1 chunk tokens: (1, 450)
# SEQ GROUP 1
# all_token_ids: [1, 450, 6673, 310, 278, 3303, 3900, 338]
# prompt_token_ids: [1, 450, 6673, 310, 278, 3303, 3900] 338 it's still prompt though isnt it??
prompt_token_ids = all_token_ids[:-1]
tokenizer = self.get_tokenizer_for_seq(seq)
# TODO it appears smt is off for terminal chunks, when prefll only route is chosen..?
prefix_offset = 0
read_offset = 0
next_iter_prefix_offset = 0
next_iter_read_offset = 0
next_iter_tokens: List[str] = []
prev_tokens = None
# NOTE when len(prompt_logprobs) < len(all_token_ids) this is a chunk! OR for chunk BUT the first, position_offset>0
print("TOKENS RECEIVED", all_token_ids)
print("POSITION OFFSET FOR CHUNK", position_offset)
print("Seq data comp tokens", seq.data._num_computed_tokens, 'out tok', seq.data.get_output_token_ids())

for token_position_in_logprob, prompt_logprobs_for_token in enumerate(
prompt_logprobs):

Expand All @@ -73,8 +62,6 @@ def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup,
for token_id, sample_logprob in prompt_logprobs_for_token.items():
if (sample_logprob.decoded_token is None
and token_id != VLLM_INVALID_TOKEN_ID):
# TODO uuuuhh I see, we're kinda expected to append the sequence of plogs chunks together so we can read the actual pos here..
# but like how I thought this was done in single_step?
prompt_token_ids_with_token = (
prompt_token_ids[:token_position] + [token_id])
(new_tokens, new_text, new_prefix_offset,
Expand Down
5 changes: 0 additions & 5 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,11 +332,6 @@ def execute_model(
num_steps=num_steps,
**kwargs,
)

# import pickle
# with open('data_original_cp.pkl', 'ab') as f:
# new_data = {sg.request_id:o.prompt_logprobs for sg, o in zip(execute_model_req.seq_group_metadata_list, output[0].outputs) if sg.is_prompt}
# pickle.dump(new_data, f)

model_execute_time = time.perf_counter() - start_time
if not get_pp_group().is_last_rank:
Expand Down

0 comments on commit 85367e2

Please sign in to comment.