Skip to content

Commit

Permalink
Minor
Browse files Browse the repository at this point in the history
Signed-off-by: Woosuk Kwon <[email protected]>
  • Loading branch information
WoosukKwon committed Dec 21, 2024
1 parent a781c11 commit 6bc8e01
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ def __init__(
self.repetition_penalties_cpu_tensor.numpy()
self.repetition_penalties_reqs: Set[str] = set()

self.prompt_tokens_tensor: Optional[torch.Tensor] = None
self.min_tokens: List[int] = [0] * max_num_reqs
self.stop_token_ids: List[Set[int]] = [
set() for _ in range(max_num_reqs)
Expand Down Expand Up @@ -244,6 +243,8 @@ def remove_request(self, req_id: str) -> Optional[int]:
self.frequency_penalties_reqs.discard(req_id)
self.presence_penalties_reqs.discard(req_id)
self.repetition_penalties_reqs.discard(req_id)
self.min_tokens[req_index] = 0
self.stop_token_ids[req_index].clear()
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
self.prompt_logprob_reqs.discard(req_id)
Expand All @@ -259,6 +260,9 @@ def clear(self) -> None:
self.frequency_penalties_reqs.clear()
self.presence_penalties_reqs.clear()
self.repetition_penalties_reqs.clear()
self.min_tokens = [0] * self.max_num_reqs
for stop_token_ids in self.stop_token_ids:
stop_token_ids.clear()
self.generators.clear()
self.num_logprobs.clear()
self.prompt_logprob_reqs.clear()
Expand Down Expand Up @@ -323,6 +327,7 @@ def make_sampling_metadata(
req_id_output_token_ids: Dict[str, List[int]],
skip_copy: bool = False,
) -> SamplingMetadata:
prompt_tokens_tensor: Optional[torch.Tensor] = None
if not skip_copy:
self.temperature[:self.num_reqs].copy_(
self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True)
Expand All @@ -346,9 +351,8 @@ def make_sampling_metadata(
# The prompt tokens are used only for applying penalties during
# the sampling process. Hence copy these tensors only when
# there are requests which need penalties to be applied.
self.prompt_tokens_tensor = \
self._construct_prompt_tokens_tensor(
self.vocab_size, device=self.device)
prompt_tokens_tensor = self._construct_prompt_tokens_tensor(
self.vocab_size, device=self.device)

output_token_ids: List[List[int]] = []

Expand All @@ -373,8 +377,7 @@ def make_sampling_metadata(
no_top_k=self.no_top_k,
generators=self.generators,
max_num_logprobs=self.max_num_logprobs,
prompt_token_ids=self.prompt_tokens_tensor[:self.num_reqs] \
if self.prompt_tokens_tensor is not None else None,
prompt_token_ids=prompt_tokens_tensor,
frequency_penalties=self.frequency_penalties[:self.num_reqs],
presence_penalties=self.presence_penalties[:self.num_reqs],
repetition_penalties=self.repetition_penalties[:self.num_reqs],
Expand Down

0 comments on commit 6bc8e01

Please sign in to comment.