-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[V1] Use FlashInfer Sampling Kernel for Top-P & Top-K Sampling (#11394)
Signed-off-by: Woosuk Kwon <[email protected]>
- Loading branch information
1 parent
0c0c201
commit 371d04d
Showing
6 changed files
with
355 additions
and
190 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from typing import List, Set, Tuple | ||
|
||
import torch | ||
|
||
from vllm.model_executor.layers.utils import ( | ||
apply_penalties as _apply_penalties) | ||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad | ||
|
||
|
||
def apply_min_token_penalties(logits: torch.Tensor, | ||
output_token_ids: List[List[int]], | ||
stop_token_ids: List[Set[int]], | ||
min_tokens: List[int]) -> None: | ||
""" | ||
Applies minimum token penalty by setting the logits of the stop tokens | ||
to -inf. | ||
""" | ||
min_tokens_logits_to_penalize: List[Tuple[int, int]] = [] | ||
for index, min_token in enumerate(min_tokens): | ||
if (len(output_token_ids[index]) < min_token): | ||
for stop_token_id in stop_token_ids[index]: | ||
min_tokens_logits_to_penalize.append((index, stop_token_id)) | ||
if min_tokens_logits_to_penalize: | ||
logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf") | ||
|
||
|
||
def apply_penalties(logits: torch.Tensor, prompt_token_ids: torch.Tensor, | ||
presence_penalties: torch.Tensor, | ||
frequency_penalties: torch.Tensor, | ||
repetition_penalties: torch.Tensor, | ||
output_token_ids: List[List[int]]) -> torch.Tensor: | ||
""" | ||
Applies presence, frequency and repetition penalties to the logits. | ||
""" | ||
_, vocab_size = logits.shape | ||
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size, | ||
logits.device) | ||
return _apply_penalties(logits, prompt_token_ids, output_tokens_t, | ||
presence_penalties, frequency_penalties, | ||
repetition_penalties) | ||
|
||
|
||
def _convert_to_tensors(output_token_ids: List[List[int]], vocab_size: int, | ||
device: torch.device) -> torch.Tensor: | ||
""" | ||
Convert the different list data structures to tensors. | ||
""" | ||
output_tokens_tensor = make_tensor_with_pad( | ||
output_token_ids, | ||
# Use the value of vocab_size as a pad since we don't have a | ||
# token_id of this value. | ||
pad=vocab_size, | ||
device="cpu", | ||
dtype=torch.int64, | ||
pin_memory=is_pin_memory_available(), | ||
) | ||
return output_tokens_tensor.to(device, non_blocking=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
from typing import Dict | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from vllm import envs | ||
from vllm.logger import init_logger | ||
from vllm.platforms import current_platform | ||
|
||
logger = init_logger(__name__) | ||
|
||
try: | ||
import flashinfer.sampling | ||
is_flashinfer_available = True | ||
except ImportError: | ||
is_flashinfer_available = False | ||
|
||
|
||
class TopKTopPSampler(nn.Module): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
if current_platform.is_cuda: | ||
if is_flashinfer_available: | ||
if envs.VLLM_USE_FLASHINFER_SAMPLER is not False: | ||
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for | ||
# sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by | ||
# default it is unused). For backward compatibility, we set | ||
# `VLLM_USE_FLASHINFER_SAMPLER` as None by default and | ||
# interpret it differently in V0 and V1 samplers: In V0, | ||
# None means False, while in V1, None means True. This is | ||
# why we use the condition | ||
# `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here. | ||
logger.info("Using FlashInfer for top-p & top-k sampling.") | ||
self.forward = self.forward_cuda | ||
else: | ||
logger.warning( | ||
"FlashInfer is available, but it is not enabled. " | ||
"Falling back to the PyTorch-native implementation of " | ||
"top-p & top-k sampling. For the best performance, " | ||
"please set VLLM_USE_FLASHINFER_SAMPLER=1.") | ||
self.forward = self.forward_native | ||
else: | ||
logger.warning( | ||
"FlashInfer is not available. Falling back to the PyTorch-" | ||
"native implementation of top-p & top-k sampling. For the " | ||
"best performance, please install FalshInfer.") | ||
self.forward = self.forward_native | ||
else: | ||
self.forward = self.forward_native | ||
|
||
def forward_native( | ||
self, | ||
logits: torch.Tensor, | ||
generators: Dict[int, torch.Generator], | ||
no_top_k: bool, | ||
k: torch.Tensor, | ||
no_top_p: bool, | ||
p: torch.Tensor, | ||
) -> torch.Tensor: | ||
"""PyTorch-native implementation of top-k and top-p sampling.""" | ||
logits = apply_top_k_top_p(logits, no_top_k, k, no_top_p, p) | ||
probs = logits.softmax(dim=-1, dtype=torch.float32) | ||
return random_sample(probs, generators) | ||
|
||
def forward_cuda( | ||
self, | ||
logits: torch.Tensor, | ||
generators: Dict[int, torch.Generator], | ||
no_top_k: bool, | ||
k: torch.Tensor, | ||
no_top_p: bool, | ||
p: torch.Tensor, | ||
) -> torch.Tensor: | ||
"""More optimized implementation for top-k and top-p sampling.""" | ||
probs = logits.softmax(dim=-1, dtype=torch.float32) | ||
if no_top_k and no_top_p: | ||
# We prefer `random_sample` over `flashinfer_sample` when sorting is | ||
# not needed. This is because `random_sample` does not require | ||
# CPU-GPU synchronization while `flashinfer_sample` does. | ||
return random_sample(probs, generators) | ||
return flashinfer_sample(probs, no_top_k, k, no_top_p, p, generators) | ||
|
||
|
||
def apply_top_k_top_p( | ||
logits: torch.Tensor, | ||
no_top_k: bool, | ||
k: torch.Tensor, | ||
no_top_p: bool, | ||
p: torch.Tensor, | ||
) -> torch.Tensor: | ||
"""Apply top-k and top-p masks to the logits. | ||
This function sorts the logits tensor, which can be slow for large batches. | ||
""" | ||
if no_top_k and no_top_p: | ||
return logits | ||
logits_sort, logits_idx = logits.sort(dim=-1, descending=False) | ||
|
||
if not no_top_k: | ||
# Apply top-k. | ||
top_k_mask = logits_sort.size(1) - k.to(torch.long) | ||
# Get all the top_k values. | ||
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) | ||
top_k_mask = logits_sort < top_k_mask | ||
logits_sort.masked_fill_(top_k_mask, -float("inf")) | ||
|
||
if not no_top_p: | ||
# Apply top-p. | ||
probs_sort = logits_sort.softmax(dim=-1) | ||
probs_sum = probs_sort.cumsum(dim=-1) | ||
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) | ||
# at least one | ||
top_p_mask[:, -1] = False | ||
logits_sort.masked_fill_(top_p_mask, -float("inf")) | ||
|
||
# Re-sort the probabilities. | ||
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) | ||
return logits | ||
|
||
|
||
def random_sample( | ||
probs: torch.Tensor, | ||
generators: Dict[int, torch.Generator], | ||
) -> torch.Tensor: | ||
"""Randomly sample from the probabilities. | ||
We use this function instead of torch.multinomial because torch.multinomial | ||
causes CPU-GPU synchronization. | ||
""" | ||
q = torch.empty_like(probs) | ||
# NOTE(woosuk): To batch-process the requests without their own seeds, | ||
# which is the common case, we first assume that every request does | ||
# not have its own seed. Then, we overwrite the values for the requests | ||
# that have their own seeds. | ||
if len(generators) != probs.shape[0]: | ||
q.exponential_() | ||
if generators: | ||
# TODO(woosuk): This can be slow because we handle each request | ||
# one by one. Optimize this. | ||
for i, generator in generators.items(): | ||
q[i].exponential_(generator=generator) | ||
return probs.div_(q).argmax(dim=-1).view(-1) | ||
|
||
|
||
def flashinfer_sample( | ||
probs: torch.Tensor, | ||
no_top_k: bool, | ||
k: torch.Tensor, | ||
no_top_p: bool, | ||
p: torch.Tensor, | ||
generators: Dict[int, torch.Generator], | ||
) -> torch.Tensor: | ||
"""Sample from the probabilities using FlashInfer. | ||
Statistically, this function is equivalent to the `random_sample` function. | ||
However, this function is faster because it avoids sorting the logits tensor | ||
via rejection sampling. | ||
NOTE: The outputs of this function do not necessarily match the outputs of | ||
the `random_sample` function. It only guarantees that the outputs are | ||
statistically equivalent. | ||
NOTE: This function includes CPU-GPU synchronization, while `random_sample` | ||
does not. Call this function at the end of the forward pass to minimize | ||
the synchronization overhead. | ||
""" | ||
assert not (no_top_k and no_top_p) | ||
max_top_k_round = 32 | ||
batch_size = probs.shape[0] | ||
uniform_samples = torch.empty((max_top_k_round, batch_size), | ||
device=probs.device) | ||
if len(generators) != batch_size: | ||
uniform_samples.uniform_() | ||
if generators: | ||
for i, generator in generators.items(): | ||
uniform_samples[:, i].uniform_(generator=generator) | ||
|
||
if no_top_k: | ||
# Top-p only. | ||
next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs( | ||
probs, uniform_samples, p, deterministic=True) | ||
elif no_top_p: | ||
# Top-k only. | ||
next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs( | ||
probs, uniform_samples, k, deterministic=True) | ||
else: | ||
# Both top-k and top-p. | ||
next_token_ids, success = ( | ||
flashinfer.sampling.top_k_top_p_sampling_from_probs( | ||
probs, uniform_samples, k, p, deterministic=True)) | ||
|
||
# NOTE: CPU-GPU synchronization happens here. | ||
if not success.all(): | ||
if not no_top_k: | ||
probs = flashinfer.sampling.top_k_renorm_prob(probs, k) | ||
if not no_top_p: | ||
probs = flashinfer.sampling.top_p_renorm_prob(probs, p) | ||
next_token_ids = flashinfer.sampling.sampling_from_probs( | ||
probs, uniform_samples[0], deterministic=True) | ||
return next_token_ids.view(-1) |
Oops, something went wrong.