Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sroy745 committed Dec 2, 2024
1 parent ca0313a commit e19f99b
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 21 deletions.
176 changes: 176 additions & 0 deletions tests/v1/sample/test_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import pytest
import torch
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import Sampler
from typing import List, Set, Tuple
import numpy as np

VOCAB_SIZE = 1024
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
MAX_NUM_PROMPT_TOKENS = 64


def _create_fake_logits(
batch_size: int, vocab_size:int
) -> torch.Tensor:
fake_logits = torch.full((batch_size, vocab_size),
1e-2,
dtype=torch.float)
return fake_logits

def _create_default_sampling_metadata(
num_output_tokens: int, batch_size: int,
vocab_size: int,
) -> SamplingMetadata:
output_token_ids:List[List[int]] = []
prompt_token_ids:List[List[int]] = []
for _ in range(batch_size):
output_token_ids.append(
np.random.randint(0, vocab_size, size=num_output_tokens).tolist())
prompt_token_ids.append(
np.random.randint(0, vocab_size,
size=np.random.randint(1, MAX_NUM_PROMPT_TOKENS)).tolist())
fake_sampling_metadata = SamplingMetadata(
temperature=torch.full((batch_size,), 0.0),
all_greedy=True,
all_random=False,
top_p=torch.empty(batch_size,),
top_k=torch.empty(batch_size,),
no_top_p=True,
no_top_k=True,
generators={},
max_num_logprobs=VOCAB_SIZE,
prompt_token_ids=prompt_token_ids,
output_token_ids=output_token_ids,
frequency_penalties=[0.0 for _ in range(batch_size)],
presence_penalties=[0.0 for _ in range(batch_size)],
repetition_penalties=[1.0 for _ in range(batch_size)],
min_tokens=[],
stop_token_ids=[],
)
return fake_sampling_metadata

def _create_min_token_penalty_dataset(
num_output_tokens: int,
batch_size: int,
vocab_size: int,
batch_indices_for_min_token_penalty:List[int]
) -> Tuple[List[int], List[Set[int]]]:
stop_token_ids:List[Set[int]] = []
min_tokens: List[int]=[]
for index in range(batch_size):
if index in batch_indices_for_min_token_penalty:
min_tokens.append(
np.random.randint(num_output_tokens + 1, 2 * num_output_tokens))
stop_token_ids.append(
set(np.random.randint(0, vocab_size - 1) for _ in range(
np.random.randint(0, vocab_size))))

else:
min_tokens.append(np.random.randint(0, num_output_tokens))
stop_token_ids.append(set())
return (min_tokens, stop_token_ids)

def _create_weighted_output_token_list(
batch_size: int,
vocab_size: int
) -> Tuple[List[List[int]], List[List[int]]]:
output_token_ids : List[List[int]] = []
sorted_token_ids_in_output : List[List[int]] = []
for _ in range(batch_size):
distinct_token_ids = np.random.choice(vocab_size, size=np.random.randint(1, 10), replace=False).tolist()

Check failure on line 83 in tests/v1/sample/test_sampler.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

tests/v1/sample/test_sampler.py:83:81: E501 Line too long (112 > 80)
sorted_token_ids_in_output.append(distinct_token_ids)
output_token_ids_for_batch = []
for index, token_id in enumerate(distinct_token_ids):
output_token_ids_for_batch.extend([token_id for _ in range(index+1) ])

Check failure on line 87 in tests/v1/sample/test_sampler.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

tests/v1/sample/test_sampler.py:87:81: E501 Line too long (82 > 80)
output_token_ids.append(output_token_ids_for_batch)
return (output_token_ids, sorted_token_ids_in_output)


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
def test_sampler_min_tokens_penalty(device: str, batch_size: int):
"""
Tests that if the number of output tokens is less than
SamplingParams.min_tokens then we will set the logits for
the stop token ids to -inf.
"""
torch.set_default_device(device)
NUM_OUTPUT_TOKENS = 20
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata= _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE)
batch_indices_for_min_token_penalty = np.random.randint(
0, batch_size - 1, size=np.random.randint(0, batch_size)).tolist()
min_tokens, stop_token_ids = _create_min_token_penalty_dataset(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, batch_indices_for_min_token_penalty)

Check failure on line 108 in tests/v1/sample/test_sampler.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

tests/v1/sample/test_sampler.py:108:81: E501 Line too long (87 > 80)
sampling_metadata.min_tokens = min_tokens
sampling_metadata.stop_token_ids = stop_token_ids
sampler = Sampler()
sampler_output = sampler(fake_logits, sampling_metadata)
for batch_idx in range(batch_size):
for vocab in range(VOCAB_SIZE):
logprob_index = torch.where(
sampler_output.logprob_token_ids[batch_idx] == vocab)[0].item()
if vocab in stop_token_ids[batch_idx]:
assert sampler_output.logprobs[batch_idx][logprob_index] == -float("inf")

Check failure on line 118 in tests/v1/sample/test_sampler.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

tests/v1/sample/test_sampler.py:118:81: E501 Line too long (89 > 80)
else:
assert sampler_output.logprobs[batch_idx][logprob_index] != -float("inf")

Check failure on line 120 in tests/v1/sample/test_sampler.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

tests/v1/sample/test_sampler.py:120:81: E501 Line too long (89 > 80)

@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
def test_sampler_presence_penalty(device: str, batch_size: int):
torch.set_default_device(device)
NUM_OUTPUT_TOKENS = 20
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata= _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE)
output_token_ids = sampling_metadata.output_token_ids
sampling_metadata.presence_penalties = [2.0 for _ in range(batch_size)]
sampler = Sampler()
sampler_output = sampler(fake_logits, sampling_metadata)
for batch_idx in range(batch_size):
logprob_for_output_token = sampler_output.logprobs[batch_idx][VOCAB_SIZE - 1]

Check failure on line 135 in tests/v1/sample/test_sampler.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

tests/v1/sample/test_sampler.py:135:81: E501 Line too long (85 > 80)
logprob_for_non_output_token = sampler_output.logprobs[batch_idx][0]
assert logprob_for_non_output_token > logprob_for_output_token
for vocab in range(VOCAB_SIZE):
logprob_index = torch.where(
sampler_output.logprob_token_ids[batch_idx] == vocab)[0].item()
if vocab in output_token_ids[batch_idx]:
assert torch.isclose(
sampler_output.logprobs[batch_idx][logprob_index],
logprob_for_output_token)
else:
assert torch.isclose(
sampler_output.logprobs[batch_idx][logprob_index],
logprob_for_non_output_token)

@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
def test_sampler_frequency_penalty(device: str, batch_size: int):
"""
Test to verify that if fre
"""
torch.set_default_device(device)
NUM_OUTPUT_TOKENS = 20
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata= _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE)
sampling_metadata.frequency_penalties = [2.0 for _ in range(batch_size)]
output_token_ids, sorted_token_ids_in_output = \
_create_weighted_output_token_list(batch_size, VOCAB_SIZE)
sampling_metadata.output_token_ids=output_token_ids
sampler = Sampler()
sampler_output = sampler(fake_logits, sampling_metadata)
for batch_idx in range(batch_size):
logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx]
token_ids_in_output = sorted_token_ids_in_output[batch_idx]
assert not torch.isin(
logprobs_token_ids[ : -len(token_ids_in_output)],
torch.tensor(token_ids_in_output)).any(), "Some values in the tensor are in the list"

Check failure on line 172 in tests/v1/sample/test_sampler.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

tests/v1/sample/test_sampler.py:172:81: E501 Line too long (97 > 80)
assert logprobs_token_ids[-len(token_ids_in_output):].tolist() == token_ids_in_output, \

Check failure on line 173 in tests/v1/sample/test_sampler.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

tests/v1/sample/test_sampler.py:173:81: E501 Line too long (96 > 80)
"The tensor values are not in the same order as the list!"


21 changes: 0 additions & 21 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,23 +618,6 @@ class CachedRequestState:
def num_tokens(self) -> int:
return len(self.prompt_token_ids) + len(self.output_token_ids)

@property
def stop_token_ids(self) -> Optional[List[int]]:
return self.sampling_params.stop_token_ids

@property
def prompt_tokens_mask(self) -> int:
return len(self.prompt_token_ids) + len(self.output_token_ids)

@property
def output_tokens_mask(self) -> int:
return len(self.prompt_token_ids) + len(self.output_token_ids)

@property
def output_tokens_bin_counts(self) -> int:
return len(self.prompt_token_ids) + len(self.output_token_ids)


class InputBatch:

def __init__(
Expand Down Expand Up @@ -702,9 +685,6 @@ def __init__(
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
self.top_k_reqs: Set[str] = set()

self.prompt_masks = Dict[int, torch.Tensor]
self.output_masks = Dict[int, torch.Tensor]

# req_index -> generator
self.generators: Dict[int, torch.Generator] = {}

Expand Down Expand Up @@ -821,7 +801,6 @@ def condense(self, empty_req_indices: List[int]) -> None:
last_req_index]
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]

generator = self.generators.pop(last_req_index, None)
if generator is not None:
self.generators[empty_index] = generator
Expand Down

0 comments on commit e19f99b

Please sign in to comment.