From 371d04d39bf056e4cc56100c83d4812b7cb230e4 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 27 Dec 2024 09:32:38 +0900 Subject: [PATCH] [V1] Use FlashInfer Sampling Kernel for Top-P & Top-K Sampling (#11394) Signed-off-by: Woosuk Kwon --- tests/v1/sample/test_sampler.py | 54 +++--- vllm/envs.py | 5 +- vllm/v1/sample/ops/__init__.py | 0 vllm/v1/sample/ops/penalties.py | 57 ++++++ vllm/v1/sample/ops/topk_topp_sampler.py | 201 +++++++++++++++++++++ vllm/v1/sample/sampler.py | 228 ++++++++---------------- 6 files changed, 355 insertions(+), 190 deletions(-) create mode 100644 vllm/v1/sample/ops/__init__.py create mode 100644 vllm/v1/sample/ops/penalties.py create mode 100644 vllm/v1/sample/ops/topk_topp_sampler.py diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index d8d055805cbea..5ebf72927cfd6 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -68,7 +68,7 @@ def _create_default_sampling_metadata( no_top_p=True, no_top_k=True, generators={}, - max_num_logprobs=VOCAB_SIZE, + max_num_logprobs=0, prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids, vocab_size, device), output_token_ids=output_token_ids, @@ -169,20 +169,14 @@ def test_sampler_min_tokens_penalty(device: str, batch_size: int): sampling_metadata.min_tokens = min_tokens sampling_metadata.stop_token_ids = stop_token_ids sampler = Sampler() - sampler_output = sampler(fake_logits, sampling_metadata) + logits = sampler.apply_penalties(fake_logits, sampling_metadata) + logits = logits.cpu() for batch_idx in range(batch_size): - for vocab in range(VOCAB_SIZE): - # Verify that the logprobs for stop token ids is set - # to -inf. - logprob_index = torch.where( - sampler_output.logprob_token_ids[batch_idx] == - vocab)[0].item() - if vocab in stop_token_ids[batch_idx]: - assert sampler_output.logprobs[batch_idx][ - logprob_index] == -float("inf") + for token_id in range(VOCAB_SIZE): + if token_id in stop_token_ids[batch_idx]: + assert logits[batch_idx][token_id] == -float("inf") else: - assert sampler_output.logprobs[batch_idx][ - logprob_index] != -float("inf") + assert logits[batch_idx][token_id] != -float("inf") @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -205,18 +199,14 @@ def test_sampler_presence_penalty(device: str, batch_size: int, batch_size, presence_penalty, torch.device(device)) sampling_metadata.no_penalties = False sampler = Sampler() - sampler_output = sampler(fake_logits, sampling_metadata) + logits = sampler.apply_penalties(fake_logits, sampling_metadata) + logits = logits.cpu() for batch_idx in range(batch_size): - # The logprobs in the SamplerOutput are arranged in descending order. - # Since all tokens initially have the same logprobs, the non-penalized - # tokens will appear at the beginning, while the penalized tokens - # will appear at the end of the list. - penalized_token_id = sampler_output.logprob_token_ids[batch_idx][ - VOCAB_SIZE - 1] - penalized_log_prod = sampler_output.logprobs[batch_idx][VOCAB_SIZE - 1] - non_penalized_token_id = sampler_output.logprob_token_ids[batch_idx][0] - non_penalized_log_prod = sampler_output.logprobs[batch_idx][0] - assert non_penalized_log_prod > penalized_log_prod + # Since all tokens initially have the same logits, the non-penalized + # token ID will be the one with the highest logit value, while the + # penalized token ID will be the one with the lowest logit value. + non_penalized_token_id = logits[batch_idx].argmax().item() + penalized_token_id = logits[batch_idx].argmin().item() if presence_penalty > 0: # If `presence_penalty` is set to a value greater than 0, it # indicates a preference for new tokens over those already @@ -256,11 +246,11 @@ def test_sampler_frequency_penalty(device: str, batch_size: int, sampling_metadata.output_token_ids = output_token_ids sampling_metadata.no_penalties = False sampler = Sampler() - sampler_output = sampler(fake_logits, sampling_metadata) + logits = sampler.apply_penalties(fake_logits, sampling_metadata) + logits = logits.cpu() for batch_idx in range(batch_size): - logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx] - non_penalized_token_id = logprobs_token_ids[0] - penalized_token_id = logprobs_token_ids[VOCAB_SIZE - 1] + non_penalized_token_id = logits[batch_idx].argmax().item() + penalized_token_id = logits[batch_idx].argmin().item() distinct_sorted_token_ids_in_output = \ sorted_token_ids_in_output[batch_idx] most_frequent_token_id = distinct_sorted_token_ids_in_output[ @@ -305,11 +295,11 @@ def test_sampler_repetition_penalty(device: str, batch_size: int, batch_size, repetition_penalty, torch.device(device)) sampling_metadata.no_penalties = False sampler = Sampler() - sampler_output = sampler(fake_logits, sampling_metadata) + logits = sampler.apply_penalties(fake_logits, sampling_metadata) + logits = logits.cpu() for batch_idx in range(batch_size): - logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx] - non_penalized_token_id = logprobs_token_ids[0] - penalized_token_id = logprobs_token_ids[VOCAB_SIZE - 1] + non_penalized_token_id = logits[batch_idx].argmax().item() + penalized_token_id = logits[batch_idx].argmin().item() prompt_tokens = sampling_metadata.prompt_token_ids[ batch_idx][:].tolist() output_tokens = sampling_metadata.output_token_ids[batch_idx] diff --git a/vllm/envs.py b/vllm/envs.py index 18870c1c6b51a..c4a568c680db0 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -30,7 +30,7 @@ VLLM_LOGGING_CONFIG_PATH: Optional[str] = None VLLM_TRACE_FUNCTION: int = 0 VLLM_ATTENTION_BACKEND: Optional[str] = None - VLLM_USE_FLASHINFER_SAMPLER: bool = False + VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False VLLM_PP_LAYER_PARTITION: Optional[str] = None @@ -277,7 +277,8 @@ def get_default_config_root(): # If set, vllm will use flashinfer sampler "VLLM_USE_FLASHINFER_SAMPLER": - lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_SAMPLER", "0"))), + lambda: bool(int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"])) + if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ else None, # If set, vllm will force flashinfer to use tensor cores; # otherwise will use heuristic based on model architecture. diff --git a/vllm/v1/sample/ops/__init__.py b/vllm/v1/sample/ops/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/v1/sample/ops/penalties.py b/vllm/v1/sample/ops/penalties.py new file mode 100644 index 0000000000000..91ebaf9269f32 --- /dev/null +++ b/vllm/v1/sample/ops/penalties.py @@ -0,0 +1,57 @@ +from typing import List, Set, Tuple + +import torch + +from vllm.model_executor.layers.utils import ( + apply_penalties as _apply_penalties) +from vllm.utils import is_pin_memory_available, make_tensor_with_pad + + +def apply_min_token_penalties(logits: torch.Tensor, + output_token_ids: List[List[int]], + stop_token_ids: List[Set[int]], + min_tokens: List[int]) -> None: + """ + Applies minimum token penalty by setting the logits of the stop tokens + to -inf. + """ + min_tokens_logits_to_penalize: List[Tuple[int, int]] = [] + for index, min_token in enumerate(min_tokens): + if (len(output_token_ids[index]) < min_token): + for stop_token_id in stop_token_ids[index]: + min_tokens_logits_to_penalize.append((index, stop_token_id)) + if min_tokens_logits_to_penalize: + logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf") + + +def apply_penalties(logits: torch.Tensor, prompt_token_ids: torch.Tensor, + presence_penalties: torch.Tensor, + frequency_penalties: torch.Tensor, + repetition_penalties: torch.Tensor, + output_token_ids: List[List[int]]) -> torch.Tensor: + """ + Applies presence, frequency and repetition penalties to the logits. + """ + _, vocab_size = logits.shape + output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size, + logits.device) + return _apply_penalties(logits, prompt_token_ids, output_tokens_t, + presence_penalties, frequency_penalties, + repetition_penalties) + + +def _convert_to_tensors(output_token_ids: List[List[int]], vocab_size: int, + device: torch.device) -> torch.Tensor: + """ + Convert the different list data structures to tensors. + """ + output_tokens_tensor = make_tensor_with_pad( + output_token_ids, + # Use the value of vocab_size as a pad since we don't have a + # token_id of this value. + pad=vocab_size, + device="cpu", + dtype=torch.int64, + pin_memory=is_pin_memory_available(), + ) + return output_tokens_tensor.to(device, non_blocking=True) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py new file mode 100644 index 0000000000000..c088c3c129ca5 --- /dev/null +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -0,0 +1,201 @@ +from typing import Dict + +import torch +import torch.nn as nn + +from vllm import envs +from vllm.logger import init_logger +from vllm.platforms import current_platform + +logger = init_logger(__name__) + +try: + import flashinfer.sampling + is_flashinfer_available = True +except ImportError: + is_flashinfer_available = False + + +class TopKTopPSampler(nn.Module): + + def __init__(self): + super().__init__() + if current_platform.is_cuda: + if is_flashinfer_available: + if envs.VLLM_USE_FLASHINFER_SAMPLER is not False: + # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for + # sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by + # default it is unused). For backward compatibility, we set + # `VLLM_USE_FLASHINFER_SAMPLER` as None by default and + # interpret it differently in V0 and V1 samplers: In V0, + # None means False, while in V1, None means True. This is + # why we use the condition + # `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here. + logger.info("Using FlashInfer for top-p & top-k sampling.") + self.forward = self.forward_cuda + else: + logger.warning( + "FlashInfer is available, but it is not enabled. " + "Falling back to the PyTorch-native implementation of " + "top-p & top-k sampling. For the best performance, " + "please set VLLM_USE_FLASHINFER_SAMPLER=1.") + self.forward = self.forward_native + else: + logger.warning( + "FlashInfer is not available. Falling back to the PyTorch-" + "native implementation of top-p & top-k sampling. For the " + "best performance, please install FalshInfer.") + self.forward = self.forward_native + else: + self.forward = self.forward_native + + def forward_native( + self, + logits: torch.Tensor, + generators: Dict[int, torch.Generator], + no_top_k: bool, + k: torch.Tensor, + no_top_p: bool, + p: torch.Tensor, + ) -> torch.Tensor: + """PyTorch-native implementation of top-k and top-p sampling.""" + logits = apply_top_k_top_p(logits, no_top_k, k, no_top_p, p) + probs = logits.softmax(dim=-1, dtype=torch.float32) + return random_sample(probs, generators) + + def forward_cuda( + self, + logits: torch.Tensor, + generators: Dict[int, torch.Generator], + no_top_k: bool, + k: torch.Tensor, + no_top_p: bool, + p: torch.Tensor, + ) -> torch.Tensor: + """More optimized implementation for top-k and top-p sampling.""" + probs = logits.softmax(dim=-1, dtype=torch.float32) + if no_top_k and no_top_p: + # We prefer `random_sample` over `flashinfer_sample` when sorting is + # not needed. This is because `random_sample` does not require + # CPU-GPU synchronization while `flashinfer_sample` does. + return random_sample(probs, generators) + return flashinfer_sample(probs, no_top_k, k, no_top_p, p, generators) + + +def apply_top_k_top_p( + logits: torch.Tensor, + no_top_k: bool, + k: torch.Tensor, + no_top_p: bool, + p: torch.Tensor, +) -> torch.Tensor: + """Apply top-k and top-p masks to the logits. + + This function sorts the logits tensor, which can be slow for large batches. + """ + if no_top_k and no_top_p: + return logits + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) + + if not no_top_k: + # Apply top-k. + top_k_mask = logits_sort.size(1) - k.to(torch.long) + # Get all the top_k values. + top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) + top_k_mask = logits_sort < top_k_mask + logits_sort.masked_fill_(top_k_mask, -float("inf")) + + if not no_top_p: + # Apply top-p. + probs_sort = logits_sort.softmax(dim=-1) + probs_sum = probs_sort.cumsum(dim=-1) + top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) + # at least one + top_p_mask[:, -1] = False + logits_sort.masked_fill_(top_p_mask, -float("inf")) + + # Re-sort the probabilities. + logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) + return logits + + +def random_sample( + probs: torch.Tensor, + generators: Dict[int, torch.Generator], +) -> torch.Tensor: + """Randomly sample from the probabilities. + + We use this function instead of torch.multinomial because torch.multinomial + causes CPU-GPU synchronization. + """ + q = torch.empty_like(probs) + # NOTE(woosuk): To batch-process the requests without their own seeds, + # which is the common case, we first assume that every request does + # not have its own seed. Then, we overwrite the values for the requests + # that have their own seeds. + if len(generators) != probs.shape[0]: + q.exponential_() + if generators: + # TODO(woosuk): This can be slow because we handle each request + # one by one. Optimize this. + for i, generator in generators.items(): + q[i].exponential_(generator=generator) + return probs.div_(q).argmax(dim=-1).view(-1) + + +def flashinfer_sample( + probs: torch.Tensor, + no_top_k: bool, + k: torch.Tensor, + no_top_p: bool, + p: torch.Tensor, + generators: Dict[int, torch.Generator], +) -> torch.Tensor: + """Sample from the probabilities using FlashInfer. + + Statistically, this function is equivalent to the `random_sample` function. + However, this function is faster because it avoids sorting the logits tensor + via rejection sampling. + + NOTE: The outputs of this function do not necessarily match the outputs of + the `random_sample` function. It only guarantees that the outputs are + statistically equivalent. + + NOTE: This function includes CPU-GPU synchronization, while `random_sample` + does not. Call this function at the end of the forward pass to minimize + the synchronization overhead. + """ + assert not (no_top_k and no_top_p) + max_top_k_round = 32 + batch_size = probs.shape[0] + uniform_samples = torch.empty((max_top_k_round, batch_size), + device=probs.device) + if len(generators) != batch_size: + uniform_samples.uniform_() + if generators: + for i, generator in generators.items(): + uniform_samples[:, i].uniform_(generator=generator) + + if no_top_k: + # Top-p only. + next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs( + probs, uniform_samples, p, deterministic=True) + elif no_top_p: + # Top-k only. + next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs( + probs, uniform_samples, k, deterministic=True) + else: + # Both top-k and top-p. + next_token_ids, success = ( + flashinfer.sampling.top_k_top_p_sampling_from_probs( + probs, uniform_samples, k, p, deterministic=True)) + + # NOTE: CPU-GPU synchronization happens here. + if not success.all(): + if not no_top_k: + probs = flashinfer.sampling.top_k_renorm_prob(probs, k) + if not no_top_p: + probs = flashinfer.sampling.top_p_renorm_prob(probs, p) + next_token_ids = flashinfer.sampling.sampling_from_probs( + probs, uniform_samples[0], deterministic=True) + return next_token_ids.view(-1) diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 82470fb2610f8..1e38453a0ff28 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -1,53 +1,55 @@ """A layer that samples the next tokens from the model's outputs.""" -from typing import Dict, List, Set, Tuple +from typing import Tuple import torch import torch.nn as nn -from vllm.model_executor.layers.utils import apply_penalties -from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.v1.outputs import SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.ops.penalties import (apply_min_token_penalties, + apply_penalties) +from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler _SAMPLING_EPS = 1e-5 class Sampler(nn.Module): + def __init__(self): + super().__init__() + self.topk_topp_sampler = TopKTopPSampler() + def forward( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> SamplerOutput: - _apply_min_token_penalties(logits, sampling_metadata.output_token_ids, - sampling_metadata.stop_token_ids, - sampling_metadata.min_tokens) - if not sampling_metadata.no_penalties: - assert sampling_metadata.prompt_token_ids is not None - _apply_penalties(logits, sampling_metadata.prompt_token_ids, - sampling_metadata.presence_penalties, - sampling_metadata.frequency_penalties, - sampling_metadata.repetition_penalties, - sampling_metadata.output_token_ids) - logits = self.apply_temperature(logits, sampling_metadata.temperature) - logits = self.apply_top_k_top_p(logits, sampling_metadata) - probs = self.get_probs(logits) - sampled = self.sample(probs, sampling_metadata) - # Use int32 to reduce the tensor size. - sampled = sampled.to(torch.int32) - - if sampling_metadata.max_num_logprobs > 0: - logprobs = self.get_logprobs(logits) - # FIXME: Mask the sampled token_id, get topk logprobs, - # and concatenate the topk with the sampled token_id. - topk_logprobs, topk_indices = torch.topk( - logprobs, sampling_metadata.max_num_logprobs, dim=-1) - # Use int32 to reduce the tensor size. - topk_indices = topk_indices.to(torch.int32) + needs_logprobs = sampling_metadata.max_num_logprobs > 0 + if needs_logprobs: + # NOTE(woosuk): Use the original logits (before any penalties or + # temperature scaling) for the top-k logprobs. + # This is different from the V0 sampler, which uses the logits that + # is used for sampling (after penalties and temperature scaling). + # NOTE: We compute logprobs first because the below ops may + # modify the logits tensor in-place (and we don't want to clone + # the logits tensor for memory efficiency). + topk_logprobs, topk_indices = self.get_topk_logprobs( + logits, sampling_metadata) else: topk_logprobs = None topk_indices = None + # Use float32 for the logits. + logits = logits.to(torch.float32) + # Apply penalties (e.g., min_tokens, freq_penalties). + logits = self.apply_penalties(logits, sampling_metadata) + # Apply temperature. + logits = self.apply_temperature(logits, sampling_metadata.temperature) + # Sample the next token. + sampled = self.sample(logits, sampling_metadata) + # Use int32 to reduce the tensor size. + sampled = sampled.to(torch.int32) + # NOTE: CPU-GPU synchronization happens here. sampler_output = SamplerOutput( sampled_token_ids=sampled.tolist(), @@ -63,71 +65,37 @@ def apply_temperature( logits: torch.Tensor, temp: torch.Tensor, ) -> torch.Tensor: - # Use float32 to apply temperature scaling. - logits = logits.to(torch.float32) # Avoid division by zero. temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp) # Use in-place division to avoid creating a new tensor. logits.div_(temp.unsqueeze(dim=1)) return logits - def apply_top_k_top_p( + def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor: + return logits.argmax(dim=-1).view(-1) + + def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: - return _apply_top_k_top_p( + assert not (sampling_metadata.all_greedy + and sampling_metadata.all_random) + if sampling_metadata.all_greedy: + return self.greedy_sample(logits) + + random_sampled = self.topk_topp_sampler( logits, + sampling_metadata.generators, sampling_metadata.no_top_k, sampling_metadata.top_k, sampling_metadata.no_top_p, sampling_metadata.top_p, ) - - def get_probs(self, logits: torch.Tensor) -> torch.Tensor: - return torch.softmax(logits, dim=-1, dtype=torch.float32) - - def get_logprobs(self, logits: torch.Tensor) -> torch.Tensor: - return torch.log_softmax(logits, dim=-1, dtype=torch.float32) - - def greedy_sample(self, probs: torch.Tensor) -> torch.Tensor: - return probs.argmax(dim=-1).view(-1) - - def random_sample( - self, - probs: torch.Tensor, - generators: Dict[int, torch.Generator], - ) -> torch.Tensor: - q = torch.empty_like(probs) - # NOTE(woosuk): To batch-process the requests without their own seeds, - # which is the common case, we first assume that every request does - # not have its own seed. Then, we overwrite the values for the requests - # that have their own seeds. - if len(generators) != probs.shape[0]: - # This might still be done here unnecessarily if there are greedies - q.exponential_() - if generators: - # TODO(woosuk): This can be slow because we handle each request - # one by one. Optimize this. - for i, generator in generators.items(): - q[i].exponential_(generator=generator) - return probs.div_(q).argmax(dim=-1).view(-1) - - def sample( - self, - probs: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> torch.Tensor: - assert not (sampling_metadata.all_greedy - and sampling_metadata.all_random) - if sampling_metadata.all_greedy: - return self.greedy_sample(probs) if sampling_metadata.all_random: - return self.random_sample(probs, sampling_metadata.generators) + return random_sampled - greedy_sampled = self.greedy_sample(probs) - random_sampled = self.random_sample(probs, - sampling_metadata.generators) + greedy_sampled = self.greedy_sample(logits) sampled = torch.where( sampling_metadata.temperature < _SAMPLING_EPS, greedy_sampled, @@ -135,86 +103,34 @@ def sample( ) return sampled + def get_topk_logprobs( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Tuple[torch.Tensor, torch.Tensor]: + logprobs = logits.log_softmax(dim=-1, dtype=torch.float32) + # FIXME: Mask the sampled token_id, get topk logprobs, + # and concatenate the topk with the sampled token_id. + topk_logprobs, topk_indices = torch.topk( + logprobs, sampling_metadata.max_num_logprobs, dim=-1) + # Use int32 to reduce the tensor size. + topk_indices = topk_indices.to(torch.int32) + return topk_logprobs, topk_indices -# TODO(woosuk): Optimize this with a custom kernel. -def _apply_top_k_top_p( - logits: torch.Tensor, - no_top_k: bool, - k: torch.Tensor, - no_top_p: bool, - p: torch.Tensor, -) -> torch.Tensor: - if no_top_k and no_top_p: + def apply_penalties( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + apply_min_token_penalties(logits, sampling_metadata.output_token_ids, + sampling_metadata.stop_token_ids, + sampling_metadata.min_tokens) + if not sampling_metadata.no_penalties: + assert sampling_metadata.prompt_token_ids is not None + logits = apply_penalties(logits, + sampling_metadata.prompt_token_ids, + sampling_metadata.presence_penalties, + sampling_metadata.frequency_penalties, + sampling_metadata.repetition_penalties, + sampling_metadata.output_token_ids) return logits - logits_sort, logits_idx = logits.sort(dim=-1, descending=False) - - if not no_top_k: - # Apply top-k. - top_k_mask = logits_sort.size(1) - k.to(torch.long) - # Get all the top_k values. - top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) - top_k_mask = logits_sort < top_k_mask - logits_sort.masked_fill_(top_k_mask, -float("inf")) - - if not no_top_p: - # Apply top-p. - probs_sort = logits_sort.softmax(dim=-1) - probs_sum = probs_sort.cumsum(dim=-1) - top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) - # at least one - top_p_mask[:, -1] = False - logits_sort.masked_fill_(top_p_mask, -float("inf")) - - # Re-sort the probabilities. - logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) - return logits - - -def _apply_min_token_penalties(logits: torch.Tensor, - output_token_ids: List[List[int]], - stop_token_ids: List[Set[int]], - min_tokens: List[int]): - """ - Applies minimum token penalty by setting the logits of the stop tokens - to -inf. - """ - min_tokens_logits_to_penalize: List[Tuple[int, int]] = [] - for index, min_token in enumerate(min_tokens): - if (len(output_token_ids[index]) < min_token): - for stop_token_id in stop_token_ids[index]: - min_tokens_logits_to_penalize.append((index, stop_token_id)) - if min_tokens_logits_to_penalize: - logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf") - - -def _apply_penalties(logits: torch.Tensor, prompt_token_ids: torch.Tensor, - presence_penalties: torch.Tensor, - frequency_penalties: torch.Tensor, - repetition_penalties: torch.Tensor, - output_token_ids: List[List[int]]): - """ - Applies presence, frequency and repetition penalties to the logits. - """ - _, vocab_size = logits.shape - output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size, - logits.device) - return apply_penalties(logits, prompt_token_ids, output_tokens_t, - presence_penalties, frequency_penalties, - repetition_penalties) - - -def _convert_to_tensors(output_token_ids: List[List[int]], vocab_size: int, - device: torch.device) -> torch.Tensor: - """ - Convert the different list data structures to tensors. - """ - output_tokens_tensor = make_tensor_with_pad( - output_token_ids, - # Use the value of vocab_size as a pad since we don't have a - # token_id of this value. - pad=vocab_size, - device="cpu", - dtype=torch.int64, - pin_memory=is_pin_memory_available(), - ) - return output_tokens_tensor.to(device, non_blocking=True)