Skip to content

Commit

Permalink
[V1] Use FlashInfer Sampling Kernel for Top-P & Top-K Sampling (#11394)
Browse files Browse the repository at this point in the history
Signed-off-by: Woosuk Kwon <[email protected]>
  • Loading branch information
WoosukKwon authored Dec 27, 2024
1 parent 0c0c201 commit 371d04d
Show file tree
Hide file tree
Showing 6 changed files with 355 additions and 190 deletions.
54 changes: 22 additions & 32 deletions tests/v1/sample/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _create_default_sampling_metadata(
no_top_p=True,
no_top_k=True,
generators={},
max_num_logprobs=VOCAB_SIZE,
max_num_logprobs=0,
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
vocab_size, device),
output_token_ids=output_token_ids,
Expand Down Expand Up @@ -169,20 +169,14 @@ def test_sampler_min_tokens_penalty(device: str, batch_size: int):
sampling_metadata.min_tokens = min_tokens
sampling_metadata.stop_token_ids = stop_token_ids
sampler = Sampler()
sampler_output = sampler(fake_logits, sampling_metadata)
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size):
for vocab in range(VOCAB_SIZE):
# Verify that the logprobs for stop token ids is set
# to -inf.
logprob_index = torch.where(
sampler_output.logprob_token_ids[batch_idx] ==
vocab)[0].item()
if vocab in stop_token_ids[batch_idx]:
assert sampler_output.logprobs[batch_idx][
logprob_index] == -float("inf")
for token_id in range(VOCAB_SIZE):
if token_id in stop_token_ids[batch_idx]:
assert logits[batch_idx][token_id] == -float("inf")
else:
assert sampler_output.logprobs[batch_idx][
logprob_index] != -float("inf")
assert logits[batch_idx][token_id] != -float("inf")


@pytest.mark.parametrize("device", CUDA_DEVICES)
Expand All @@ -205,18 +199,14 @@ def test_sampler_presence_penalty(device: str, batch_size: int,
batch_size, presence_penalty, torch.device(device))
sampling_metadata.no_penalties = False
sampler = Sampler()
sampler_output = sampler(fake_logits, sampling_metadata)
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size):
# The logprobs in the SamplerOutput are arranged in descending order.
# Since all tokens initially have the same logprobs, the non-penalized
# tokens will appear at the beginning, while the penalized tokens
# will appear at the end of the list.
penalized_token_id = sampler_output.logprob_token_ids[batch_idx][
VOCAB_SIZE - 1]
penalized_log_prod = sampler_output.logprobs[batch_idx][VOCAB_SIZE - 1]
non_penalized_token_id = sampler_output.logprob_token_ids[batch_idx][0]
non_penalized_log_prod = sampler_output.logprobs[batch_idx][0]
assert non_penalized_log_prod > penalized_log_prod
# Since all tokens initially have the same logits, the non-penalized
# token ID will be the one with the highest logit value, while the
# penalized token ID will be the one with the lowest logit value.
non_penalized_token_id = logits[batch_idx].argmax().item()
penalized_token_id = logits[batch_idx].argmin().item()
if presence_penalty > 0:
# If `presence_penalty` is set to a value greater than 0, it
# indicates a preference for new tokens over those already
Expand Down Expand Up @@ -256,11 +246,11 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
sampling_metadata.output_token_ids = output_token_ids
sampling_metadata.no_penalties = False
sampler = Sampler()
sampler_output = sampler(fake_logits, sampling_metadata)
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size):
logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx]
non_penalized_token_id = logprobs_token_ids[0]
penalized_token_id = logprobs_token_ids[VOCAB_SIZE - 1]
non_penalized_token_id = logits[batch_idx].argmax().item()
penalized_token_id = logits[batch_idx].argmin().item()
distinct_sorted_token_ids_in_output = \
sorted_token_ids_in_output[batch_idx]
most_frequent_token_id = distinct_sorted_token_ids_in_output[
Expand Down Expand Up @@ -305,11 +295,11 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
batch_size, repetition_penalty, torch.device(device))
sampling_metadata.no_penalties = False
sampler = Sampler()
sampler_output = sampler(fake_logits, sampling_metadata)
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size):
logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx]
non_penalized_token_id = logprobs_token_ids[0]
penalized_token_id = logprobs_token_ids[VOCAB_SIZE - 1]
non_penalized_token_id = logits[batch_idx].argmax().item()
penalized_token_id = logits[batch_idx].argmin().item()
prompt_tokens = sampling_metadata.prompt_token_ids[
batch_idx][:].tolist()
output_tokens = sampling_metadata.output_token_ids[batch_idx]
Expand Down
5 changes: 3 additions & 2 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_USE_FLASHINFER_SAMPLER: bool = False
VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None
VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False
VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False
VLLM_PP_LAYER_PARTITION: Optional[str] = None
Expand Down Expand Up @@ -277,7 +277,8 @@ def get_default_config_root():

# If set, vllm will use flashinfer sampler
"VLLM_USE_FLASHINFER_SAMPLER":
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_SAMPLER", "0"))),
lambda: bool(int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"]))
if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ else None,

# If set, vllm will force flashinfer to use tensor cores;
# otherwise will use heuristic based on model architecture.
Expand Down
Empty file added vllm/v1/sample/ops/__init__.py
Empty file.
57 changes: 57 additions & 0 deletions vllm/v1/sample/ops/penalties.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import List, Set, Tuple

import torch

from vllm.model_executor.layers.utils import (
apply_penalties as _apply_penalties)
from vllm.utils import is_pin_memory_available, make_tensor_with_pad


def apply_min_token_penalties(logits: torch.Tensor,
output_token_ids: List[List[int]],
stop_token_ids: List[Set[int]],
min_tokens: List[int]) -> None:
"""
Applies minimum token penalty by setting the logits of the stop tokens
to -inf.
"""
min_tokens_logits_to_penalize: List[Tuple[int, int]] = []
for index, min_token in enumerate(min_tokens):
if (len(output_token_ids[index]) < min_token):
for stop_token_id in stop_token_ids[index]:
min_tokens_logits_to_penalize.append((index, stop_token_id))
if min_tokens_logits_to_penalize:
logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf")


def apply_penalties(logits: torch.Tensor, prompt_token_ids: torch.Tensor,
presence_penalties: torch.Tensor,
frequency_penalties: torch.Tensor,
repetition_penalties: torch.Tensor,
output_token_ids: List[List[int]]) -> torch.Tensor:
"""
Applies presence, frequency and repetition penalties to the logits.
"""
_, vocab_size = logits.shape
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size,
logits.device)
return _apply_penalties(logits, prompt_token_ids, output_tokens_t,
presence_penalties, frequency_penalties,
repetition_penalties)


def _convert_to_tensors(output_token_ids: List[List[int]], vocab_size: int,
device: torch.device) -> torch.Tensor:
"""
Convert the different list data structures to tensors.
"""
output_tokens_tensor = make_tensor_with_pad(
output_token_ids,
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
pad=vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=is_pin_memory_available(),
)
return output_tokens_tensor.to(device, non_blocking=True)
201 changes: 201 additions & 0 deletions vllm/v1/sample/ops/topk_topp_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
from typing import Dict

import torch
import torch.nn as nn

from vllm import envs
from vllm.logger import init_logger
from vllm.platforms import current_platform

logger = init_logger(__name__)

try:
import flashinfer.sampling
is_flashinfer_available = True
except ImportError:
is_flashinfer_available = False


class TopKTopPSampler(nn.Module):

def __init__(self):
super().__init__()
if current_platform.is_cuda:
if is_flashinfer_available:
if envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
# sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
# default it is unused). For backward compatibility, we set
# `VLLM_USE_FLASHINFER_SAMPLER` as None by default and
# interpret it differently in V0 and V1 samplers: In V0,
# None means False, while in V1, None means True. This is
# why we use the condition
# `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here.
logger.info("Using FlashInfer for top-p & top-k sampling.")
self.forward = self.forward_cuda
else:
logger.warning(
"FlashInfer is available, but it is not enabled. "
"Falling back to the PyTorch-native implementation of "
"top-p & top-k sampling. For the best performance, "
"please set VLLM_USE_FLASHINFER_SAMPLER=1.")
self.forward = self.forward_native
else:
logger.warning(
"FlashInfer is not available. Falling back to the PyTorch-"
"native implementation of top-p & top-k sampling. For the "
"best performance, please install FalshInfer.")
self.forward = self.forward_native
else:
self.forward = self.forward_native

def forward_native(
self,
logits: torch.Tensor,
generators: Dict[int, torch.Generator],
no_top_k: bool,
k: torch.Tensor,
no_top_p: bool,
p: torch.Tensor,
) -> torch.Tensor:
"""PyTorch-native implementation of top-k and top-p sampling."""
logits = apply_top_k_top_p(logits, no_top_k, k, no_top_p, p)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)

def forward_cuda(
self,
logits: torch.Tensor,
generators: Dict[int, torch.Generator],
no_top_k: bool,
k: torch.Tensor,
no_top_p: bool,
p: torch.Tensor,
) -> torch.Tensor:
"""More optimized implementation for top-k and top-p sampling."""
probs = logits.softmax(dim=-1, dtype=torch.float32)
if no_top_k and no_top_p:
# We prefer `random_sample` over `flashinfer_sample` when sorting is
# not needed. This is because `random_sample` does not require
# CPU-GPU synchronization while `flashinfer_sample` does.
return random_sample(probs, generators)
return flashinfer_sample(probs, no_top_k, k, no_top_p, p, generators)


def apply_top_k_top_p(
logits: torch.Tensor,
no_top_k: bool,
k: torch.Tensor,
no_top_p: bool,
p: torch.Tensor,
) -> torch.Tensor:
"""Apply top-k and top-p masks to the logits.
This function sorts the logits tensor, which can be slow for large batches.
"""
if no_top_k and no_top_p:
return logits
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)

if not no_top_k:
# Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long)
# Get all the top_k values.
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
top_k_mask = logits_sort < top_k_mask
logits_sort.masked_fill_(top_k_mask, -float("inf"))

if not no_top_p:
# Apply top-p.
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
# at least one
top_p_mask[:, -1] = False
logits_sort.masked_fill_(top_p_mask, -float("inf"))

# Re-sort the probabilities.
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
return logits


def random_sample(
probs: torch.Tensor,
generators: Dict[int, torch.Generator],
) -> torch.Tensor:
"""Randomly sample from the probabilities.
We use this function instead of torch.multinomial because torch.multinomial
causes CPU-GPU synchronization.
"""
q = torch.empty_like(probs)
# NOTE(woosuk): To batch-process the requests without their own seeds,
# which is the common case, we first assume that every request does
# not have its own seed. Then, we overwrite the values for the requests
# that have their own seeds.
if len(generators) != probs.shape[0]:
q.exponential_()
if generators:
# TODO(woosuk): This can be slow because we handle each request
# one by one. Optimize this.
for i, generator in generators.items():
q[i].exponential_(generator=generator)
return probs.div_(q).argmax(dim=-1).view(-1)


def flashinfer_sample(
probs: torch.Tensor,
no_top_k: bool,
k: torch.Tensor,
no_top_p: bool,
p: torch.Tensor,
generators: Dict[int, torch.Generator],
) -> torch.Tensor:
"""Sample from the probabilities using FlashInfer.
Statistically, this function is equivalent to the `random_sample` function.
However, this function is faster because it avoids sorting the logits tensor
via rejection sampling.
NOTE: The outputs of this function do not necessarily match the outputs of
the `random_sample` function. It only guarantees that the outputs are
statistically equivalent.
NOTE: This function includes CPU-GPU synchronization, while `random_sample`
does not. Call this function at the end of the forward pass to minimize
the synchronization overhead.
"""
assert not (no_top_k and no_top_p)
max_top_k_round = 32
batch_size = probs.shape[0]
uniform_samples = torch.empty((max_top_k_round, batch_size),
device=probs.device)
if len(generators) != batch_size:
uniform_samples.uniform_()
if generators:
for i, generator in generators.items():
uniform_samples[:, i].uniform_(generator=generator)

if no_top_k:
# Top-p only.
next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs(
probs, uniform_samples, p, deterministic=True)
elif no_top_p:
# Top-k only.
next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs(
probs, uniform_samples, k, deterministic=True)
else:
# Both top-k and top-p.
next_token_ids, success = (
flashinfer.sampling.top_k_top_p_sampling_from_probs(
probs, uniform_samples, k, p, deterministic=True))

# NOTE: CPU-GPU synchronization happens here.
if not success.all():
if not no_top_k:
probs = flashinfer.sampling.top_k_renorm_prob(probs, k)
if not no_top_p:
probs = flashinfer.sampling.top_p_renorm_prob(probs, p)
next_token_ids = flashinfer.sampling.sampling_from_probs(
probs, uniform_samples[0], deterministic=True)
return next_token_ids.view(-1)
Loading

0 comments on commit 371d04d

Please sign in to comment.