diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index a61ec63a365b5..c529cd21f384b 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -139,3 +139,41 @@ def test_engine_core(monkeypatch): engine_core.abort_requests([req2.request_id, req0.request_id]) assert len(engine_core.scheduler.waiting) == 0 assert len(engine_core.scheduler.running) == 0 + + +def test_engine_core_advanced_sampling(monkeypatch): + """ + A basic end-to-end test to verify that the engine functions correctly + when additional sampling parameters, such as min_tokens and + presence_penalty, are set. + """ + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + """Setup the EngineCore.""" + engine_args = EngineArgs(model=MODEL_NAME) + vllm_config = engine_args.create_engine_config( + usage_context=UsageContext.UNKNOWN_CONTEXT) + executor_class = AsyncLLM._get_executor_cls(vllm_config) + + engine_core = EngineCore(vllm_config=vllm_config, + executor_class=executor_class, + usage_context=UsageContext.UNKNOWN_CONTEXT) + """Test basic request lifecycle.""" + # First request. + request: EngineCoreRequest = make_request() + request.sampling_params = SamplingParams( + min_tokens=4, + presence_penalty=1.0, + frequency_penalty=1.0, + repetition_penalty=0.1, + stop_token_ids=[1001, 1002], + ) + engine_core.add_request(request) + assert len(engine_core.scheduler.waiting) == 1 + assert len(engine_core.scheduler.running) == 0 + # Loop through until they are all done. + while len(engine_core.step()) > 0: + pass + + assert len(engine_core.scheduler.waiting) == 0 + assert len(engine_core.scheduler.running) == 0 diff --git a/tests/v1/sample/__init__.py b/tests/v1/sample/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py new file mode 100644 index 0000000000000..d8d055805cbea --- /dev/null +++ b/tests/v1/sample/test_sampler.py @@ -0,0 +1,331 @@ +from typing import List, Set, Tuple + +import numpy as np +import pytest +import torch + +from vllm.utils import make_tensor_with_pad +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.sampler import Sampler + +VOCAB_SIZE = 1024 +NUM_OUTPUT_TOKENS = 20 +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] +MAX_NUM_PROMPT_TOKENS = 64 + + +def _create_fake_logits(batch_size: int, vocab_size: int) -> torch.Tensor: + fake_logits = torch.full((batch_size, vocab_size), 1e-2, dtype=torch.float) + return fake_logits + + +def _create_penalty_tensor(batch_size: int, penalty_value: float, + device: torch.device) -> torch.Tensor: + return torch.full((batch_size, ), + fill_value=penalty_value, + dtype=torch.float, + device=device) + + +def _create_prompt_tokens_tensor( + prompt_token_ids: List[List[int]], + vocab_size: int, + device: torch.device, +) -> torch.Tensor: + return make_tensor_with_pad( + prompt_token_ids, + pad=vocab_size, + device=device, + dtype=torch.int64, + pin_memory=False, + ) + + +def _create_default_sampling_metadata( + num_output_tokens: int, + batch_size: int, + vocab_size: int, + device: torch.device, +) -> SamplingMetadata: + output_token_ids: List[List[int]] = [] + prompt_token_ids: List[List[int]] = [] + for _ in range(batch_size): + output_token_ids.append( + np.random.randint(0, vocab_size, size=num_output_tokens).tolist()) + prompt_token_ids.append( + np.random.randint(0, + vocab_size, + size=np.random.randint( + 1, MAX_NUM_PROMPT_TOKENS)).tolist()) + fake_sampling_metadata = SamplingMetadata( + temperature=torch.full((batch_size, ), 0.0), + all_greedy=True, + all_random=False, + top_p=torch.empty(batch_size, ), + top_k=torch.empty(batch_size, ), + no_top_p=True, + no_top_k=True, + generators={}, + max_num_logprobs=VOCAB_SIZE, + prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids, + vocab_size, device), + output_token_ids=output_token_ids, + frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device), + presence_penalties=_create_penalty_tensor(batch_size, 0.0, device), + repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device), + no_penalties=True, + min_tokens=[], + stop_token_ids=[], + ) + return fake_sampling_metadata + + +def _generate_min_token_penalties_and_stop_tokens( + num_output_tokens: int, batch_size: int, vocab_size: int, + batch_indices_for_min_token_penalty: List[int] +) -> Tuple[List[int], List[Set[int]]]: + """ + Generates and returns a list of minimum token penalties (`min_tokens`) + and a corresponding list of stop token IDs (`stop_token_ids`) for each + batch. + + If a batch index is included in `batch_indices_for_min_token_penalty`, + a higher `min_tokens` value is assigned (within a randomized range), + and a random set of stop token IDs is created. Otherwise, a lower + `min_tokens` value is assigned, and the stop token IDs set is empty. + """ + stop_token_ids: List[Set[int]] = [] + min_tokens: List[int] = [] + for index in range(batch_size): + if index in batch_indices_for_min_token_penalty: + min_tokens.append( + np.random.randint(num_output_tokens + 1, + 2 * num_output_tokens)) + stop_token_ids.append( + set( + np.random.randint(0, vocab_size - 1) + for _ in range(np.random.randint(0, vocab_size)))) + + else: + min_tokens.append(np.random.randint(0, num_output_tokens)) + stop_token_ids.append(set()) + return (min_tokens, stop_token_ids) + + +def _create_weighted_output_token_list( + batch_size: int, + vocab_size: int) -> Tuple[List[List[int]], List[List[int]]]: + """ + Creates an output token list where each token occurs a distinct + number of times. + + For each batch, a random subset of token IDs is selected from the + vocabulary. The selected tokens are then added to the output token + list, each with a different frequency. + + Returns: + Tuple[List[List[int]], List[List[int]]]: + - The first element is the output token list, where each sublist + corresponds to a batch and contains tokens with weighted + frequencies. + - The second element is a list of distinct token IDs for each + batch, ordered by their frequency in the corresponding output + list. + """ + output_token_ids: List[List[int]] = [] + sorted_token_ids_in_output: List[List[int]] = [] + for _ in range(batch_size): + distinct_token_ids = np.random.choice(vocab_size, + size=np.random.randint(1, 10), + replace=False).tolist() + sorted_token_ids_in_output.append(distinct_token_ids) + output_token_ids_for_batch = [] + for index, token_id in enumerate(distinct_token_ids): + output_token_ids_for_batch.extend( + [token_id for _ in range(index + 1)]) + output_token_ids.append(output_token_ids_for_batch) + return (output_token_ids, sorted_token_ids_in_output) + + +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("batch_size", [1, 2, 32]) +def test_sampler_min_tokens_penalty(device: str, batch_size: int): + """ + Tests that if the number of output tokens is less than + SamplingParams.min_tokens then we will set the logits for + the stop token ids to -inf. + """ + torch.set_default_device(device) + fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) + sampling_metadata = _create_default_sampling_metadata( + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) + batch_indices_for_min_token_penalty = np.random.randint( + 0, batch_size - 1, size=np.random.randint(0, batch_size)).tolist() + min_tokens, stop_token_ids = _generate_min_token_penalties_and_stop_tokens( + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, + batch_indices_for_min_token_penalty) + sampling_metadata.min_tokens = min_tokens + sampling_metadata.stop_token_ids = stop_token_ids + sampler = Sampler() + sampler_output = sampler(fake_logits, sampling_metadata) + for batch_idx in range(batch_size): + for vocab in range(VOCAB_SIZE): + # Verify that the logprobs for stop token ids is set + # to -inf. + logprob_index = torch.where( + sampler_output.logprob_token_ids[batch_idx] == + vocab)[0].item() + if vocab in stop_token_ids[batch_idx]: + assert sampler_output.logprobs[batch_idx][ + logprob_index] == -float("inf") + else: + assert sampler_output.logprobs[batch_idx][ + logprob_index] != -float("inf") + + +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("batch_size", [1, 2, 32]) +@pytest.mark.parametrize("presence_penalty", [-2.0, 2.0]) +def test_sampler_presence_penalty(device: str, batch_size: int, + presence_penalty: float): + """ + Test to verify that if presence penalty is enabled then tokens + are penalized as per their presence in the existing output. + """ + torch.set_default_device(device) + # Create fake logits where each token is assigned the same + # logit value. + fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) + sampling_metadata = _create_default_sampling_metadata( + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) + output_token_ids = sampling_metadata.output_token_ids + sampling_metadata.presence_penalties = _create_penalty_tensor( + batch_size, presence_penalty, torch.device(device)) + sampling_metadata.no_penalties = False + sampler = Sampler() + sampler_output = sampler(fake_logits, sampling_metadata) + for batch_idx in range(batch_size): + # The logprobs in the SamplerOutput are arranged in descending order. + # Since all tokens initially have the same logprobs, the non-penalized + # tokens will appear at the beginning, while the penalized tokens + # will appear at the end of the list. + penalized_token_id = sampler_output.logprob_token_ids[batch_idx][ + VOCAB_SIZE - 1] + penalized_log_prod = sampler_output.logprobs[batch_idx][VOCAB_SIZE - 1] + non_penalized_token_id = sampler_output.logprob_token_ids[batch_idx][0] + non_penalized_log_prod = sampler_output.logprobs[batch_idx][0] + assert non_penalized_log_prod > penalized_log_prod + if presence_penalty > 0: + # If `presence_penalty` is set to a value greater than 0, it + # indicates a preference for new tokens over those already + # present in the output. + # Verify that the penalized token ID exists in the output, while the + # non-penalized token ID does not. + assert penalized_token_id in output_token_ids[batch_idx] + assert non_penalized_token_id not in output_token_ids[batch_idx] + elif presence_penalty < 0: + # If `presence_penalty` is set to a value less than 0, it indicates + # a preference for existing tokens over new ones. Verify that the + # non-penalized token ID exists in the output, while the penalized + # token ID does not. + assert non_penalized_token_id in output_token_ids[batch_idx] + assert penalized_token_id not in output_token_ids[batch_idx] + + +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("batch_size", [1, 2, 32]) +@pytest.mark.parametrize("frequency_penalty", [-2.0, 2.0]) +def test_sampler_frequency_penalty(device: str, batch_size: int, + frequency_penalty: float): + """ + Test to verify that if frequency penalty is enabled then tokens are + penalized as per their frequency of occurrence. + """ + torch.set_default_device(device) + # Create fake logits where each token is assigned the same + # logit value. + fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) + sampling_metadata = _create_default_sampling_metadata( + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) + sampling_metadata.frequency_penalties = _create_penalty_tensor( + batch_size, frequency_penalty, torch.device(device)) + output_token_ids, sorted_token_ids_in_output = \ + _create_weighted_output_token_list(batch_size, VOCAB_SIZE) + sampling_metadata.output_token_ids = output_token_ids + sampling_metadata.no_penalties = False + sampler = Sampler() + sampler_output = sampler(fake_logits, sampling_metadata) + for batch_idx in range(batch_size): + logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx] + non_penalized_token_id = logprobs_token_ids[0] + penalized_token_id = logprobs_token_ids[VOCAB_SIZE - 1] + distinct_sorted_token_ids_in_output = \ + sorted_token_ids_in_output[batch_idx] + most_frequent_token_id = distinct_sorted_token_ids_in_output[ + len(distinct_sorted_token_ids_in_output) - 1] + if frequency_penalty > 0: + # If `frequency_penalty` is set to > 0, it indicates + # a preference for new tokens over existing ones. Verify that the + # non-penalized token ID is not present in the output, while the + # most penalized token is the one that occurs most frequently in + # the output. + assert non_penalized_token_id \ + not in distinct_sorted_token_ids_in_output + assert penalized_token_id == most_frequent_token_id + elif frequency_penalty < 0: + # If `frequency_penalty` is set to < 0, it indicates + # a preference for existing tokens over new ones. Verify that the + # non-penalized token ID is the one that occurs most frequently + # in the output, while the penalized token ID is one that has not + # yet appeared. + assert non_penalized_token_id == most_frequent_token_id + assert penalized_token_id \ + not in distinct_sorted_token_ids_in_output + + +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("batch_size", [1, 2, 32]) +@pytest.mark.parametrize("repetition_penalty", [0.1, 1.9]) +def test_sampler_repetition_penalty(device: str, batch_size: int, + repetition_penalty: float): + """ + Test to verify that when the repetition penalty is enabled, tokens + are penalized based on their presence in the prompt or the existing + output. + """ + torch.set_default_device(device) + # Create fake logits where each token is assigned the same + # logit value. + fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) + sampling_metadata = _create_default_sampling_metadata( + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) + sampling_metadata.repetition_penalties = _create_penalty_tensor( + batch_size, repetition_penalty, torch.device(device)) + sampling_metadata.no_penalties = False + sampler = Sampler() + sampler_output = sampler(fake_logits, sampling_metadata) + for batch_idx in range(batch_size): + logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx] + non_penalized_token_id = logprobs_token_ids[0] + penalized_token_id = logprobs_token_ids[VOCAB_SIZE - 1] + prompt_tokens = sampling_metadata.prompt_token_ids[ + batch_idx][:].tolist() + output_tokens = sampling_metadata.output_token_ids[batch_idx] + if repetition_penalty > 1.0: + # If `repetition_penalty` > 1.0, verify that the non-penalized + # token ID has not been seen before, while the penalized token ID + # exists either in the prompt or the output. + assert (non_penalized_token_id not in prompt_tokens and \ + non_penalized_token_id not in output_tokens) + assert (penalized_token_id in prompt_tokens or \ + penalized_token_id in output_tokens) + elif repetition_penalty < 1.0: + # If `repetition_penalty` < 1.0, verify that the penalized + # token ID has not been seen before, while the non-penalized + # token ID exists either in the prompt or the output. + assert (penalized_token_id not in prompt_tokens and \ + penalized_token_id not in output_tokens) + assert (non_penalized_token_id in prompt_tokens or \ + non_penalized_token_id in output_tokens) diff --git a/tests/v1/worker/__init__.py b/tests/v1/worker/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py new file mode 100644 index 0000000000000..694ce81ff6e22 --- /dev/null +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -0,0 +1,224 @@ +from typing import Dict, List, Set, Tuple + +import numpy as np +import pytest +import torch + +from vllm.sampling_params import SamplingParams +from vllm.utils import is_pin_memory_available, make_tensor_with_pad +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch + +VOCAB_SIZE = 1024 +NUM_OUTPUT_TOKENS = 20 +MAX_PROMPT_SIZE = 100 +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] +MAX_NUM_PROMPT_TOKENS = 64 + + +def _remove_requests( + input_batch: InputBatch, batch_size: int, + reqs: List[CachedRequestState]) -> Tuple[Set[str], List[int]]: + """ + Remove some requests randomly from the batch and returns a Tuple + of 1) set of request removed 2) indices of the requests removed + ordered in descending order + """ + + num_reqs_to_remove = np.random.randint(0, batch_size) + req_indices_to_remove: Set[int] = set() + for _ in range(num_reqs_to_remove): + req_index_to_remove = np.random.randint(0, batch_size) + req_indices_to_remove.add(req_index_to_remove) + + req_indices_to_remove_list = list(req_indices_to_remove) + req_indices_to_remove_list.sort(reverse=True) + req_ids_to_remove: Set[str] = set() + for index in req_indices_to_remove: + input_batch.remove_request(reqs[index].req_id) + req_ids_to_remove.add(reqs[index].req_id) + return (req_ids_to_remove, req_indices_to_remove_list) + + +def _construct_expected_sampling_metadata( + reqs: List[CachedRequestState], req_ids_retained: Set[int], + req_id_index_in_input_batch: Dict[str, int], + device: torch.device) -> SamplingMetadata: + """ + Constructs and returns the expected SamplingMetadata for this + batch. + """ + num_reqs = len(req_ids_retained) + output_token_ids: List[List[int]] = [list() for _ in range(num_reqs)] + prompt_token_ids: List[List[int]] = [list() for _ in range(num_reqs)] + presence_penalties = [0.0 for _ in range(num_reqs)] + frequency_penalties = [0.0 for _ in range(num_reqs)] + repetition_penalties = [1.0 for _ in range(num_reqs)] + top_k = [0 for _ in range(num_reqs)] + top_p = [0.0 for _ in range(num_reqs)] + temperature = [0.0 for _ in range(num_reqs)] + stop_token_ids: List[Set[int]] = [set() for _ in range(num_reqs)] + min_tokens = [0 for _ in range(num_reqs)] + for req in reqs: + if req.req_id not in req_ids_retained: + continue + index_in_input_batch = req_id_index_in_input_batch[req.req_id] + output_token_ids[index_in_input_batch] = req.output_token_ids + prompt_token_ids[index_in_input_batch] = req.prompt_token_ids + presence_penalties[ + index_in_input_batch] = req.sampling_params.presence_penalty + frequency_penalties[ + index_in_input_batch] = req.sampling_params.frequency_penalty + repetition_penalties[ + index_in_input_batch] = req.sampling_params.repetition_penalty + top_k[index_in_input_batch] = req.sampling_params.top_k + top_p[index_in_input_batch] = req.sampling_params.top_p + temperature[index_in_input_batch] = req.sampling_params.temperature + stop_token_ids[ + index_in_input_batch] = req.sampling_params.all_stop_token_ids + min_tokens[index_in_input_batch] = req.sampling_params.min_tokens + + + return SamplingMetadata( + temperature=torch.tensor(temperature, dtype=torch.float, device=device), + all_greedy=False, + all_random=True, + top_p=torch.tensor(top_p, dtype=torch.float, device=device), + top_k=torch.tensor(top_k, dtype=torch.int, device=device), + no_top_p=all(x == 1.0 for x in top_p), + no_top_k=all(x == 0 for x in top_k), + generators={}, + max_num_logprobs=0, + prompt_token_ids= make_tensor_with_pad( + prompt_token_ids, + pad=VOCAB_SIZE, + device=torch.device(device), + dtype=torch.int64, + ), + frequency_penalties=torch.tensor( + frequency_penalties, dtype=torch.float, + device=device), + presence_penalties=torch.tensor( + presence_penalties, dtype=torch.float, + device=device), + repetition_penalties=torch.tensor( + repetition_penalties, dtype=torch.float, + device=device), + output_token_ids=output_token_ids, + min_tokens=min_tokens, + stop_token_ids=stop_token_ids, + no_penalties=(all(x ==0 for x in presence_penalties) and \ + all(x ==0 for x in frequency_penalties) and \ + all(x ==1 for x in repetition_penalties)) + ) + + +def _create_sampling_params(): + return SamplingParams(top_k=np.random.randint(1, 10), + top_p=np.random.uniform(0.0, 1.0), + presence_penalty=np.random.uniform(-2.0, 2.0), + repetition_penalty=np.random.uniform(0.0, 2.0), + frequency_penalty=np.random.uniform(-2.0, 2.0), + min_tokens=np.random.randint(1, 10), + stop_token_ids=[ + np.random.randint(0, VOCAB_SIZE) + for _ in range(np.random.randint(10)) + ]) + + +def _construct_cached_request_state(req_id_suffix: int): + prompt_token_ids = [ + np.random.randint(0, VOCAB_SIZE) + for _ in range(np.random.randint(0, MAX_PROMPT_SIZE)) + ] + output_token_ids = [ + np.random.randint(0, VOCAB_SIZE) + for _ in range(np.random.randint(0, NUM_OUTPUT_TOKENS)) + ] + return CachedRequestState(req_id=f"req_id_{req_id_suffix}", + prompt_token_ids=prompt_token_ids, + prompt=None, + sampling_params=_create_sampling_params(), + mm_inputs=[], + mm_positions=[], + block_ids=[], + generator=None, + num_computed_tokens=len(output_token_ids), + output_token_ids=output_token_ids) + + +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("batch_size", [1, 2, 32, 64]) +def test_sampling_metadata_in_input_batch(device: str, batch_size: int): + """ + Tests the logic for managing sampling metadata in the InputBatch. + + This test involves adding a set of requests to the InputBatch, + followed by removing a subset of them. Afterward, the batch is compacted, + and the `make_sampling_metadata` method is invoked on the batch. The + output of `make_sampling_metadata` is then compared against the expected + results to ensure correctness. + """ + input_batch: InputBatch = InputBatch(max_num_reqs=batch_size, + max_model_len=1024, + max_num_blocks_per_req=10, + device=torch.device(device), + pin_memory=is_pin_memory_available(), + vocab_size=1024) + reqs: List[CachedRequestState] = [] + req_id_reqs = {} + req_id_output_token_ids = {} + # Add requests + for req_index in range(batch_size): + req: CachedRequestState = _construct_cached_request_state(req_index) + input_batch.add_request(req, req_index) + reqs.append(req) + req_id_reqs[req.req_id] = req + req_id_output_token_ids[req.req_id] = req.output_token_ids + + # Remove some requests + req_ids_to_remove, req_indices_to_remove = _remove_requests( + input_batch, batch_size, reqs) + req_ids_retained = set(req_id_reqs.keys()) - req_ids_to_remove + + # Compact the input batch + input_batch.condense(req_indices_to_remove) + + # Generate the sampling metadata + sampling_metadata = input_batch.make_sampling_metadata( + req_id_output_token_ids, skip_copy=False) + + # Create expected output. + expected_sampling_metadata = _construct_expected_sampling_metadata( + reqs, + req_ids_retained, + input_batch.req_id_to_index, + device=torch.device(device)) + + # Assert the actual and expected output. + assert torch.allclose(expected_sampling_metadata.temperature, + sampling_metadata.temperature) + assert torch.allclose(expected_sampling_metadata.top_p, + sampling_metadata.top_p) + assert torch.allclose(expected_sampling_metadata.top_k, + sampling_metadata.top_k) + assert torch.allclose(expected_sampling_metadata.frequency_penalties, + sampling_metadata.frequency_penalties) + assert torch.allclose(expected_sampling_metadata.presence_penalties, + sampling_metadata.presence_penalties) + assert torch.allclose(expected_sampling_metadata.repetition_penalties, + sampling_metadata.repetition_penalties) + assert torch.allclose(expected_sampling_metadata.prompt_token_ids, + sampling_metadata.prompt_token_ids) + assert (expected_sampling_metadata.output_token_ids == + sampling_metadata.output_token_ids) + assert ( + expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens) + assert (expected_sampling_metadata.stop_token_ids == + sampling_metadata.stop_token_ids) + assert (expected_sampling_metadata.no_penalties == + sampling_metadata.no_penalties) + assert (expected_sampling_metadata.no_top_p == sampling_metadata.no_top_p) + assert (expected_sampling_metadata.no_top_k == sampling_metadata.no_top_k) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index c10efefea5471..c2d12c466ba45 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -11,6 +11,7 @@ import torch.nn as nn import vllm.envs as envs +from vllm.model_executor.layers.utils import apply_penalties from vllm.model_executor.sampling_metadata import (SamplingMetadata, SamplingTensors, SequenceGroupToSample) @@ -258,11 +259,11 @@ def forward( # Apply presence and frequency penalties. if do_penalties: - logits = _apply_penalties(logits, sampling_tensors.prompt_tokens, - sampling_tensors.output_tokens, - sampling_tensors.presence_penalties, - sampling_tensors.frequency_penalties, - sampling_tensors.repetition_penalties) + logits = apply_penalties(logits, sampling_tensors.prompt_tokens, + sampling_tensors.output_tokens, + sampling_tensors.presence_penalties, + sampling_tensors.frequency_penalties, + sampling_tensors.repetition_penalties) # Use float32 to apply temperature scaling. # Use in-place division to avoid creating a new tensor. @@ -336,23 +337,6 @@ def _should_modify_greedy_probs_inplace(self) -> bool: return self.should_modify_greedy_probs_inplace -def _get_bin_counts_and_mask( - tokens: torch.Tensor, - vocab_size: int, - num_seqs: int, -) -> Tuple[torch.Tensor, torch.Tensor]: - # Compute the bin counts for the tokens. - # vocab_size + 1 for padding. - bin_counts = torch.zeros((num_seqs, vocab_size + 1), - dtype=torch.long, - device=tokens.device) - bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens)) - bin_counts = bin_counts[:, :vocab_size] - mask = bin_counts > 0 - - return bin_counts, mask - - def _apply_min_tokens_penalty( logits: torch.Tensor, sampling_metadata: SamplingMetadata, @@ -400,29 +384,6 @@ def _apply_min_tokens_penalty( return logits -def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, - output_tokens_tensor: torch.Tensor, - presence_penalties: torch.Tensor, - frequency_penalties: torch.Tensor, - repetition_penalties: torch.Tensor) -> torch.Tensor: - num_seqs, vocab_size = logits.shape - _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size, - num_seqs) - output_bin_counts, output_mask = _get_bin_counts_and_mask( - output_tokens_tensor, vocab_size, num_seqs) - - repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size) - repetition_penalties[~(prompt_mask | output_mask)] = 1.0 - logits = torch.where(logits > 0, logits / repetition_penalties, - logits * repetition_penalties) - - # We follow the definition in OpenAI API. - # Refer to https://platform.openai.com/docs/api-reference/parameter-details - logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts - logits -= presence_penalties.unsqueeze_(dim=1) * output_mask - return logits - - def _apply_top_k_top_p( logits: torch.Tensor, p: torch.Tensor, diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py new file mode 100644 index 0000000000000..f6f34cd49d953 --- /dev/null +++ b/vllm/model_executor/layers/utils.py @@ -0,0 +1,57 @@ +"""Utility methods for model layers.""" +from typing import Tuple + +import torch + + +def get_token_bin_counts_and_mask( + tokens: torch.Tensor, + vocab_size: int, + num_seqs: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + # Compute the bin counts for the tokens. + # vocab_size + 1 for padding. + bin_counts = torch.zeros((num_seqs, vocab_size + 1), + dtype=torch.long, + device=tokens.device) + bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens)) + bin_counts = bin_counts[:, :vocab_size] + mask = bin_counts > 0 + + return bin_counts, mask + + +def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, + output_tokens_tensor: torch.Tensor, + presence_penalties: torch.Tensor, + frequency_penalties: torch.Tensor, + repetition_penalties: torch.Tensor) -> torch.Tensor: + """ + Applies penalties in place to the logits tensor + logits : The input logits tensor of shape [num_seqs, vocab_size] + prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts + are padded to the maximum prompt length within the batch using + `vocab_size` as the padding value. The value `vocab_size` is used + for padding because it does not correspond to any valid token ID + in the vocabulary. + output_tokens_tensor: The output tokens tensor. + presence_penalties: The presence penalties of shape (num_seqs, ) + frequency_penalties: The frequency penalties of shape (num_seqs, ) + repetition_penalties: The repetition penalties of shape (num_seqs, ) + """ + num_seqs, vocab_size = logits.shape + _, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor, + vocab_size, num_seqs) + output_bin_counts, output_mask = get_token_bin_counts_and_mask( + output_tokens_tensor, vocab_size, num_seqs) + repetition_penalties = repetition_penalties.unsqueeze_(dim=1).repeat( + 1, vocab_size) + logits[logits > 0] /= torch.where(prompt_mask | output_mask, + repetition_penalties, 1.0)[logits > 0] + logits[logits <= 0] *= torch.where(prompt_mask | output_mask, + repetition_penalties, 1.0)[logits <= 0] + # We follow the definition in OpenAI API. + # Refer to https://platform.openai.com/docs/api-reference/parameter-details + logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts + logits -= presence_penalties.unsqueeze_(dim=1) * output_mask + return logits diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 9ef36f2e6b212..d60f7eb5d76f9 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict +from typing import Dict, List, Optional, Set import torch @@ -19,3 +19,13 @@ class SamplingMetadata: generators: Dict[int, torch.Generator] max_num_logprobs: int + + no_penalties: bool + prompt_token_ids: Optional[torch.Tensor] + frequency_penalties: torch.Tensor + presence_penalties: torch.Tensor + repetition_penalties: torch.Tensor + + output_token_ids: List[List[int]] + min_tokens: List[int] + stop_token_ids: List[Set[int]] diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index d1a755be01ff7..82470fb2610f8 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -1,9 +1,11 @@ """A layer that samples the next tokens from the model's outputs.""" -from typing import Dict +from typing import Dict, List, Set, Tuple import torch import torch.nn as nn +from vllm.model_executor.layers.utils import apply_penalties +from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.v1.outputs import SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata @@ -17,9 +19,18 @@ def forward( logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> SamplerOutput: + _apply_min_token_penalties(logits, sampling_metadata.output_token_ids, + sampling_metadata.stop_token_ids, + sampling_metadata.min_tokens) + if not sampling_metadata.no_penalties: + assert sampling_metadata.prompt_token_ids is not None + _apply_penalties(logits, sampling_metadata.prompt_token_ids, + sampling_metadata.presence_penalties, + sampling_metadata.frequency_penalties, + sampling_metadata.repetition_penalties, + sampling_metadata.output_token_ids) logits = self.apply_temperature(logits, sampling_metadata.temperature) logits = self.apply_top_k_top_p(logits, sampling_metadata) - probs = self.get_probs(logits) sampled = self.sample(probs, sampling_metadata) # Use int32 to reduce the tensor size. @@ -157,3 +168,53 @@ def _apply_top_k_top_p( # Re-sort the probabilities. logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) return logits + + +def _apply_min_token_penalties(logits: torch.Tensor, + output_token_ids: List[List[int]], + stop_token_ids: List[Set[int]], + min_tokens: List[int]): + """ + Applies minimum token penalty by setting the logits of the stop tokens + to -inf. + """ + min_tokens_logits_to_penalize: List[Tuple[int, int]] = [] + for index, min_token in enumerate(min_tokens): + if (len(output_token_ids[index]) < min_token): + for stop_token_id in stop_token_ids[index]: + min_tokens_logits_to_penalize.append((index, stop_token_id)) + if min_tokens_logits_to_penalize: + logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf") + + +def _apply_penalties(logits: torch.Tensor, prompt_token_ids: torch.Tensor, + presence_penalties: torch.Tensor, + frequency_penalties: torch.Tensor, + repetition_penalties: torch.Tensor, + output_token_ids: List[List[int]]): + """ + Applies presence, frequency and repetition penalties to the logits. + """ + _, vocab_size = logits.shape + output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size, + logits.device) + return apply_penalties(logits, prompt_token_ids, output_tokens_t, + presence_penalties, frequency_penalties, + repetition_penalties) + + +def _convert_to_tensors(output_token_ids: List[List[int]], vocab_size: int, + device: torch.device) -> torch.Tensor: + """ + Convert the different list data structures to tensors. + """ + output_tokens_tensor = make_tensor_with_pad( + output_token_ids, + # Use the value of vocab_size as a pad since we don't have a + # token_id of this value. + pad=vocab_size, + device="cpu", + dtype=torch.int64, + pin_memory=is_pin_memory_available(), + ) + return output_tokens_tensor.to(device, non_blocking=True) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 5c113c74778df..6c4d300ec6efe 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -43,12 +43,14 @@ def __init__( max_num_blocks_per_req: int, device: torch.device, pin_memory: bool, + vocab_size: int, ): self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len self.max_num_blocks_per_req = max_num_blocks_per_req self.device = device self.pin_memory = pin_memory + self.vocab_size = vocab_size self.req_ids: List[Optional[str]] = [None] * max_num_reqs self.req_id_to_index: Dict[str, int] = {} @@ -63,6 +65,7 @@ def __init__( ) self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) + self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) # Attention-related. self.block_table = torch.zeros( @@ -110,6 +113,50 @@ def __init__( self.top_k_cpu = self.top_k_cpu_tensor.numpy() self.top_k_reqs: Set[str] = set() + # Frequency penalty related data structures + self.frequency_penalties = torch.empty((max_num_reqs, ), + dtype=torch.float, + device=device) + self.frequency_penalties_cpu_tensor = torch.empty( + (max_num_reqs, ), + dtype=torch.float, + device="cpu", + pin_memory=pin_memory) + self.frequency_penalties_cpu = \ + self.frequency_penalties_cpu_tensor.numpy() + self.frequency_penalties_reqs: Set[str] = set() + + # Presence penalty related data structures + self.presence_penalties = torch.empty((max_num_reqs, ), + dtype=torch.float, + device=device) + self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float, + device="cpu", + pin_memory=pin_memory) + self.presence_penalties_cpu = \ + self.presence_penalties_cpu_tensor.numpy() + self.presence_penalties_reqs: Set[str] = set() + + # Repetition penalty related data structures + self.repetition_penalties = torch.empty((max_num_reqs, ), + dtype=torch.float, + device=device) + self.repetition_penalties_cpu_tensor = torch.empty( + (max_num_reqs, ), + dtype=torch.float, + device="cpu", + pin_memory=pin_memory) + self.repetition_penalties_cpu = \ + self.repetition_penalties_cpu_tensor.numpy() + self.repetition_penalties_reqs: Set[str] = set() + + self.min_tokens: List[int] = [0] * max_num_reqs + self.stop_token_ids: List[Set[int]] = [ + set() for _ in range(max_num_reqs) + ] + self.prompt_token_ids: Optional[torch.Tensor] = None + # req_index -> generator # NOTE(woosuk): The indices of the requests that do not have their own # generator should not be included in the dictionary. @@ -133,6 +180,7 @@ def add_request( # Copy the prompt token ids and output token ids. num_prompt_tokens = len(request.prompt_token_ids) + self.num_prompt_tokens[req_index] = num_prompt_tokens self.token_ids_cpu[ req_index, :num_prompt_tokens] = request.prompt_token_ids start_idx = num_prompt_tokens @@ -157,6 +205,20 @@ def add_request( self.top_k_cpu[req_index] = sampling_params.top_k if sampling_params.top_k > 0: self.top_k_reqs.add(req_id) + self.frequency_penalties_cpu[req_index] = \ + sampling_params.frequency_penalty + if sampling_params.frequency_penalty != 0.0: + self.frequency_penalties_reqs.add(req_id) + self.presence_penalties_cpu[req_index] = \ + sampling_params.presence_penalty + if sampling_params.presence_penalty != 0.0: + self.presence_penalties_reqs.add(req_id) + self.repetition_penalties_cpu[req_index] = \ + sampling_params.repetition_penalty + if sampling_params.repetition_penalty != 1.0: + self.repetition_penalties_reqs.add(req_id) + self.min_tokens[req_index] = sampling_params.min_tokens + self.stop_token_ids[req_index] = sampling_params.all_stop_token_ids # NOTE(woosuk): self.generators should not include the requests that # do not have their own generator. @@ -179,6 +241,9 @@ def remove_request(self, req_id: str) -> Optional[int]: self.random_reqs.discard(req_id) self.top_p_reqs.discard(req_id) self.top_k_reqs.discard(req_id) + self.frequency_penalties_reqs.discard(req_id) + self.presence_penalties_reqs.discard(req_id) + self.repetition_penalties_reqs.discard(req_id) self.generators.pop(req_index, None) self.num_logprobs.pop(req_id, None) self.prompt_logprob_reqs.discard(req_id) @@ -191,6 +256,9 @@ def clear(self) -> None: self.random_reqs.clear() self.top_p_reqs.clear() self.top_k_reqs.clear() + self.frequency_penalties_reqs.clear() + self.presence_penalties_reqs.clear() + self.repetition_penalties_reqs.clear() self.generators.clear() self.num_logprobs.clear() self.prompt_logprob_reqs.clear() @@ -224,6 +292,8 @@ def condense(self, empty_req_indices: List[int]) -> None: # block_table_cpu. self.token_ids_cpu[empty_index] = self.token_ids_cpu[ last_req_index] + self.num_prompt_tokens[empty_index] = \ + self.num_prompt_tokens[last_req_index] self.num_computed_tokens_cpu[ empty_index] = self.num_computed_tokens_cpu[last_req_index] self.block_table_cpu[empty_index] = self.block_table_cpu[ @@ -232,6 +302,15 @@ def condense(self, empty_req_indices: List[int]) -> None: last_req_index] self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] + self.frequency_penalties_cpu[empty_index] = \ + self.frequency_penalties_cpu[last_req_index] + self.presence_penalties_cpu[empty_index] = \ + self.presence_penalties_cpu[last_req_index] + self.repetition_penalties_cpu[empty_index] = \ + self.repetition_penalties_cpu[last_req_index] + self.min_tokens[empty_index] = self.min_tokens[last_req_index] + self.stop_token_ids[empty_index] = \ + self.stop_token_ids[last_req_index] generator = self.generators.pop(last_req_index, None) if generator is not None: self.generators[empty_index] = generator @@ -241,6 +320,7 @@ def condense(self, empty_req_indices: List[int]) -> None: def make_sampling_metadata( self, + req_id_output_token_ids: Dict[str, List[int]], skip_copy: bool = False, ) -> SamplingMetadata: if not skip_copy: @@ -250,6 +330,37 @@ def make_sampling_metadata( self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) self.top_k[:self.num_reqs].copy_( self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) + if not self.no_penalties: + # Since syncing these tensors is expensive only copy them + # if necessary i.e. if there are requests which require + # penalties to be applied during sampling. + self.frequency_penalties[:self.num_reqs].copy_( + self.frequency_penalties_cpu_tensor[:self.num_reqs], + non_blocking=True) + self.presence_penalties[:self.num_reqs].copy_( + self.presence_penalties_cpu_tensor[:self.num_reqs], + non_blocking=True) + self.repetition_penalties[:self.num_reqs].copy_( + self.repetition_penalties_cpu_tensor[:self.num_reqs], + non_blocking=True) + # The prompt tokens are used only for applying penalties during + # the sampling process. Hence copy these tensors only when + # there are requests which need penalties to be applied. + self.prompt_token_ids = self._make_prompt_token_ids_tensor() + + output_token_ids: List[List[int]] = [] + + for req_id in self.req_ids[:self.num_reqs]: + assert req_id is not None + # Currently we create a tensor for output_token_ids from scratch + # at each step. However, for the penalties computation what we + # need is stats about the token ids present in the output. This + # stats can be maintained incrementally instead of computing it + # from scratch at each step. + # TODO - Replace this with incremental update to output token + # statistics. + output_token_ids.append(req_id_output_token_ids[req_id]) + return SamplingMetadata( temperature=self.temperature[:self.num_reqs], all_greedy=self.all_greedy, @@ -260,8 +371,33 @@ def make_sampling_metadata( no_top_k=self.no_top_k, generators=self.generators, max_num_logprobs=self.max_num_logprobs, + prompt_token_ids=self.prompt_token_ids, + frequency_penalties=self.frequency_penalties[:self.num_reqs], + presence_penalties=self.presence_penalties[:self.num_reqs], + repetition_penalties=self.repetition_penalties[:self.num_reqs], + output_token_ids=output_token_ids, + min_tokens=self.min_tokens[:self.num_reqs], + stop_token_ids=self.stop_token_ids[:self.num_reqs], + no_penalties=self.no_penalties, ) + def _make_prompt_token_ids_tensor(self) -> torch.Tensor: + max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max() + prompt_token_ids_cpu_tensor = torch.empty( + (self.num_reqs, max_prompt_len), + device="cpu", + dtype=torch.int64, + pin_memory=self.pin_memory) + prompt_token_ids = prompt_token_ids_cpu_tensor.numpy() + prompt_token_ids[:] = ( + self.token_ids_cpu[:self.num_reqs, :max_prompt_len]) + # Use the value of vocab_size as a pad since we don't have a + # token_id of this value. + for i in range(self.num_reqs): + prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size + return prompt_token_ids_cpu_tensor.to(device=self.device, + non_blocking=True) + @property def num_reqs(self) -> int: return len(self.req_id_to_index) @@ -282,6 +418,12 @@ def no_top_p(self) -> bool: def no_top_k(self) -> bool: return len(self.top_k_reqs) == 0 + @property + def no_penalties(self) -> bool: + return (len(self.presence_penalties_reqs) == 0 + and len(self.frequency_penalties_reqs) == 0 + and len(self.repetition_penalties_reqs) == 0) + @property def max_num_logprobs(self) -> int: return max(self.num_logprobs.values()) if self.num_logprobs else 0 diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ace62d8978bea..509771b7e2e5a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -105,6 +105,7 @@ def __init__( max_num_blocks_per_req=self.max_num_blocks_per_req, device=self.device, pin_memory=self.pin_memory, + vocab_size=model_config.get_vocab_size(), ) self.use_cuda_graph = (self.vllm_config.compilation_config.level @@ -383,7 +384,12 @@ def _prepare_sampling( or scheduler_output.scheduled_resumed_reqs): skip_copy = False # Create the sampling metadata. - sampling_metadata = self.input_batch.make_sampling_metadata(skip_copy) + req_id_output_token_ids: Dict[str, List[int]] = \ + {req_id: req.output_token_ids \ + for req_id, req in self.requests.items()} + + sampling_metadata = self.input_batch.make_sampling_metadata( + req_id_output_token_ids, skip_copy) return sampling_metadata def _execute_encoder(self, scheduler_output: "SchedulerOutput"):