From b71c956debf045a9a1545ebfe06961ca5163d91c Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 11 Sep 2024 20:31:51 -0700 Subject: [PATCH 01/98] [TPU] Use Ray for default distributed backend (#8389) --- vllm/config.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/vllm/config.py b/vllm/config.py index 26e4b169587e1..8fc8ae6b7dfc5 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -869,6 +869,13 @@ def __init__( f"distributed executor backend " f"'{self.distributed_executor_backend}'.") + if current_platform.is_tpu() and self.world_size > 1: + if self.distributed_executor_backend is None: + self.distributed_executor_backend = "ray" + if self.distributed_executor_backend != "ray": + raise ValueError( + "TPU backend only supports Ray for distributed inference.") + if self.distributed_executor_backend is None and self.world_size > 1: # We use multiprocessing by default if world_size fits on the # current node and we aren't in a ray placement group. From b6c75e1cf27681ec92629930c03b616c7c9b9929 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Wed, 11 Sep 2024 23:35:33 -0400 Subject: [PATCH 02/98] Fix the AMD weight loading tests (#8390) --- vllm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 8fc8ae6b7dfc5..9684cea813134 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -883,7 +883,7 @@ def __init__( from vllm.executor import ray_utils backend = "mp" ray_found = ray_utils.ray_is_available() - if (torch.cuda.is_available() + if (current_platform.is_cuda() and cuda_device_count_stateless() < self.world_size): if not ray_found: raise ValueError("Unable to load Ray which is " From 5a60699c452c0b9b8086a978d8572c257c2c3cc4 Mon Sep 17 00:00:00 2001 From: tomeras91 <57313761+tomeras91@users.noreply.github.com> Date: Thu, 12 Sep 2024 06:55:30 +0300 Subject: [PATCH 03/98] [Bugfix]: Fix the logic for deciding if tool parsing is used (#8366) --- vllm/entrypoints/openai/serving_chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index a81d2aa989aaf..8ac4caffb37f0 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -607,7 +607,7 @@ async def chat_completion_full_generator( # if auto tools are not enabled, and a named tool choice using # outlines is not being used - if not (self.enable_auto_tools + if (not self.enable_auto_tools or not self.tool_parser) and not isinstance( request.tool_choice, ChatCompletionNamedToolChoiceParam): From 1bf2dd9df025feb82e27f90f534a3bf829ae75e9 Mon Sep 17 00:00:00 2001 From: Blueyo0 <30562758+blueyo0@users.noreply.github.com> Date: Thu, 12 Sep 2024 12:53:12 +0800 Subject: [PATCH 04/98] [Gemma2] add bitsandbytes support for Gemma2 (#8338) --- vllm/model_executor/models/gemma2.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 90449ec51ef0b..f9d9f9e7567c8 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -312,6 +312,14 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA): # Gemma does not apply LoRA to the embedding layer. embedding_modules = {} embedding_padding_modules = [] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } def __init__( self, From 295c4730a85ce419e5b46e256240d69ad1cce619 Mon Sep 17 00:00:00 2001 From: Kevin Lin <42618777+kevin314@users.noreply.github.com> Date: Thu, 12 Sep 2024 00:45:24 -0500 Subject: [PATCH 05/98] [Misc] Raise error when using encoder/decoder model with cpu backend (#8355) --- vllm/utils.py | 4 ++++ vllm/worker/cpu_model_runner.py | 6 +++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/vllm/utils.py b/vllm/utils.py index a22081ebe8df0..aba243071b69a 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -82,6 +82,9 @@ "currently supported with encoder/" "decoder models.") +STR_NOT_IMPL_ENC_DEC_CPU = ("CPU is not currently supported with " + "encoder/decoder models.") + # Efficiently import all enc/dec error strings # rather than having to import all of the above STR_NOT_IMPL_ENC_DEC_ERR_STRS = { @@ -97,6 +100,7 @@ "STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH": STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, "STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND, "STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER, + "STR_NOT_IMPL_ENC_DEC_CPU": STR_NOT_IMPL_ENC_DEC_CPU } # Constants related to forcing the attention backend selection diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 7205b1a7beb8d..7b2caf4973589 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -15,7 +15,7 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs) from vllm.sequence import IntermediateTensors, SequenceGroupMetadata -from vllm.utils import make_tensor_with_pad +from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS, make_tensor_with_pad from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, _add_attn_metadata_broadcastable_dict, @@ -121,6 +121,10 @@ def __init__( # Lazy initialization. self.model: nn.Module # Set after init_Model + if self.model_config.is_encoder_decoder_model: + raise NotImplementedError( + STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU']) + def load_model(self) -> None: self.model = get_model(model_config=self.model_config, load_config=self.load_config, From 42ffba11ad4597289b5ae609900a74a153fbd067 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 11 Sep 2024 23:13:14 -0700 Subject: [PATCH 06/98] [Misc] Use RoPE cache for MRoPE (#8396) --- vllm/model_executor/layers/rotary_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 7fa6c5e7fcde4..d4e9ed87ed54f 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -990,7 +990,7 @@ def get_rope( base, is_neox_style, dtype, short_factor, long_factor, **extra_kwargs) elif scaling_type == "mrope": - return MRotaryEmbedding( + rotary_emb = MRotaryEmbedding( head_size, rotary_dim, max_position, From 7de49aa86c7f169eb0962b6db29ad53fff519ffb Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 12 Sep 2024 00:11:55 -0700 Subject: [PATCH 07/98] [torch.compile] hide slicing under custom op for inductor (#8384) --- tests/compile/test_full_graph.py | 4 +- vllm/attention/backends/flash_attn.py | 105 +++++++++++++++++--------- 2 files changed, 74 insertions(+), 35 deletions(-) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index d5b59db8c7887..0a6e781e18834 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -16,5 +16,7 @@ def test_full_graph(model): "The future of AI is", ] sampling_params = SamplingParams(temperature=0) - llm = LLM(model="meta-llama/Meta-Llama-3-8B") + llm = LLM(model="meta-llama/Meta-Llama-3-8B", + enforce_eager=True, + load_format="dummy") llm.generate(prompts, sampling_params) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 69faa6d343eda..ec9cbde7467d6 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -122,6 +122,40 @@ def _( return torch.empty_like(decode_query) +@torch.library.custom_op("vllm::reshape_and_cache_flash", + mutates_args=["kv_cache"]) +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + """Inductor cannot deal with inplace operations on views. + See https://github.com/pytorch/pytorch/issues/131192 + and https://github.com/pytorch/pytorch/issues/130174 + This is a workaround to hide the view operation from the inductor. + """ + return torch.ops._C_cache_ops.reshape_and_cache_flash( + key, value, kv_cache[0], kv_cache[1], slot_mapping, kv_cache_dtype, + k_scale, v_scale) + + +@reshape_and_cache_flash.register_fake # type: ignore +def _( + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + pass + + class FlashAttentionBackend(AttentionBackend): @staticmethod @@ -653,11 +687,10 @@ def forward( # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. - ops.reshape_and_cache_flash( + torch.ops.vllm.reshape_and_cache_flash( key, value, - key_cache, - value_cache, + kv_cache, attn_metadata.slot_mapping.flatten(), self.kv_cache_dtype, k_scale, @@ -669,7 +702,6 @@ def forward( assert key.shape[0] == num_prefill_tokens + num_decode_tokens assert value.shape[0] == num_prefill_tokens + num_decode_tokens - output = torch.empty_like(query) # Query for decode. KV is not needed because it is already cached. decode_query = query[num_prefill_tokens:] # QKV for prefill. @@ -680,6 +712,9 @@ def forward( assert query.shape[0] == num_prefill_tokens assert decode_query.shape[0] == num_decode_tokens + prefill_output: Optional[torch.Tensor] = None + decode_output: Optional[torch.Tensor] = None + if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. if (kv_cache is None or prefill_meta.block_tables is None @@ -687,7 +722,7 @@ def forward( # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. - out = torch.ops.vllm.flash_attn_varlen_func( + prefill_output = torch.ops.vllm.flash_attn_varlen_func( q=query, k=key, v=value, @@ -701,42 +736,44 @@ def forward( alibi_slopes=self.alibi_slopes, softcap=self.logits_soft_cap, ) - assert output[:num_prefill_tokens].shape == out.shape - output[:num_prefill_tokens] = out else: # prefix-enabled attention assert prefill_meta.seq_lens is not None max_seq_len = max(prefill_meta.seq_lens) - output[: - num_prefill_tokens] = torch.ops.vllm.flash_attn_varlen_func( # noqa - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=prefill_meta.query_start_loc, - max_seqlen_q=prefill_meta.max_query_len, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_k=max_seq_len, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - block_table=prefill_meta.block_tables, - softcap=self.logits_soft_cap, - ) - - if decode_meta := attn_metadata.decode_metadata: - # Decoding run. - output[ - num_prefill_tokens:] = torch.ops.vllm.flash_attn_with_kvcache( - decode_query.unsqueeze(1), - key_cache, - value_cache, - block_table=decode_meta.block_tables, - cache_seqlens=decode_meta.seq_lens_tensor, + prefill_output = torch.ops.vllm.flash_attn_varlen_func( # noqa + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=prefill_meta.query_start_loc, + max_seqlen_q=prefill_meta.max_query_len, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_k=max_seq_len, softmax_scale=self.scale, causal=True, alibi_slopes=self.alibi_slopes, + block_table=prefill_meta.block_tables, softcap=self.logits_soft_cap, - ).squeeze(1) + ) - # Reshape the output tensor. + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + decode_output = torch.ops.vllm.flash_attn_with_kvcache( + decode_query.unsqueeze(1), + key_cache, + value_cache, + block_table=decode_meta.block_tables, + cache_seqlens=decode_meta.seq_lens_tensor, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + softcap=self.logits_soft_cap, + ).squeeze(1) + + if prefill_output is None: + assert decode_output is not None + return decode_output.view(num_decode_tokens, hidden_size) + if decode_output is None: + assert prefill_output is not None + return prefill_output.view(num_prefill_tokens, hidden_size) + output = torch.cat([prefill_output, decode_output], dim=0) return output.view(num_tokens, hidden_size) From 520ca380aef75f34cd2f5a146d30849b483e3be4 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Thu, 12 Sep 2024 09:28:37 -0700 Subject: [PATCH 08/98] [Hotfix][VLM] Fixing max position embeddings for Pixtral (#8399) --- vllm/transformers_utils/config.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 5ad6f6802d046..29a1ae1850500 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -206,6 +206,8 @@ def recurse_elems(elem: Any): config_dict["tie_word_embeddings"] = config_dict.get( "tie_embeddings", False) config_dict["max_seq_len"] = config_dict.get("max_seq_len", 128_000) + config_dict["max_position_embeddings"] = config_dict.get( + "max_position_embeddings", 128_000) if config_dict.get("moe") is not None: config_dict["architectures"] = ["MixtralForCausalLM"] From e56bf2774158dca80637a1b8309bbc4d308774b1 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 13 Sep 2024 01:10:35 +0800 Subject: [PATCH 09/98] [Bugfix] Fix InternVL2 inference with various num_patches (#8375) Co-authored-by: DarkLight1337 --- tests/models/test_internvl.py | 35 ++++++++++++++++++++++++++ vllm/model_executor/models/internvl.py | 7 +++--- 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/tests/models/test_internvl.py b/tests/models/test_internvl.py index fa3369dc53345..881068b3afe41 100644 --- a/tests/models/test_internvl.py +++ b/tests/models/test_internvl.py @@ -331,6 +331,41 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, ) +@pytest.mark.parametrize("model", ["OpenGVLab/InternVL2-2B"]) +@pytest.mark.parametrize("size_factors", [[0.5, 1.0]]) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +@torch.inference_mode() +def test_different_num_patches(hf_runner, vllm_runner, image_assets, model, + size_factors, dtype: str, max_tokens: int, + num_logprobs: int) -> None: + images = [asset.pil_image.resize((896, 896)) for asset in image_assets] + + inputs_batching = [( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + + inputs_multi_images = [ + ([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors], + [[rescale_image_size(image, factor) for image in images] + for factor in size_factors]) + ] + for inputs in [inputs_batching, inputs_multi_images]: + run_test( + hf_runner, + vllm_runner, + inputs, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + mm_limit=2, + tensor_parallel_size=1, + ) + + @pytest.mark.parametrize( "models", [("OpenGVLab/InternVL2-2B", "OpenGVLab/InternVL2-2B-AWQ")]) @pytest.mark.parametrize( diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 81819578a4d8c..507d7014714a2 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -270,6 +270,7 @@ def input_mapper_for_internvl(ctx: InputContext, data: object): # Add an N dimension for number of images per prompt (currently 1). data = data.unsqueeze(0) elif is_list_of(data, Image.Image): + # we can't stack here because the images may have different num_patches data = [ image_to_pixel_values(img, image_size, @@ -277,7 +278,6 @@ def input_mapper_for_internvl(ctx: InputContext, data: object): max_num, use_thumbnail=use_thumbnail) for img in data ] - data = torch.stack(data) model_config = ctx.model_config tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=True) @@ -449,11 +449,12 @@ def _parse_and_validate_image_input( if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") - + # We need to flatten (B, N, P) to (B*N*P), + # so we call flatten_bn twice. return InternVLImagePixelInputs( type="pixel_values", data=self._validate_pixel_values( - flatten_bn(pixel_values, concat=True).flatten(0, 1)), + flatten_bn(flatten_bn(pixel_values), concat=True)), ) raise AssertionError("This line should be unreachable.") From c6202daeedb22cd675942c37ae5e194549803c89 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Thu, 12 Sep 2024 11:10:54 -0600 Subject: [PATCH 10/98] [Model] Support multiple images for qwen-vl (#8247) Signed-off-by: Alex-Brooks Co-authored-by: Cyrus Leung Co-authored-by: DarkLight1337 --- docs/source/models/supported_models.rst | 2 +- ...e_inference_vision_language_multi_image.py | 84 +++-- tests/models/test_qwen.py | 308 ++++++++++++++++-- vllm/model_executor/models/qwen.py | 14 +- 4 files changed, 343 insertions(+), 65 deletions(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index be81c38833400..faac2b97722b7 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -254,7 +254,7 @@ Multimodal Language Models - * - :code:`QWenLMHeadModel` - Qwen-VL - - Image\ :sup:`E` + - Image\ :sup:`E+` - :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc. - * - :code:`Qwen2VLForConditionalGeneration` diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index ed7e886d57806..454872c628373 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -19,7 +19,39 @@ ] -def load_phi3v(question, image_urls: List[str]): +def load_qwenvl_chat(question: str, image_urls: List[str]): + model_name = "Qwen/Qwen-VL-Chat" + llm = LLM( + model=model_name, + trust_remote_code=True, + max_num_seqs=5, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + placeholders = "".join(f"Picture {i}: \n" + for i, _ in enumerate(image_urls, start=1)) + + # This model does not have a chat_template attribute on its tokenizer, + # so we need to explicitly pass it. We use ChatML since it's used in the + # generation utils of the model: + # https://huggingface.co/Qwen/Qwen-VL-Chat/blob/main/qwen_generation_utils.py#L265 + tokenizer = AutoTokenizer.from_pretrained(model_name, + trust_remote_code=True) + + # Copied from: https://huggingface.co/docs/transformers/main/en/chat_templating + chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" # noqa: E501 + + messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}] + prompt = tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True, + chat_template=chat_template) + + stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"] + stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] + return llm, prompt, stop_token_ids, None, chat_template + + +def load_phi3v(question: str, image_urls: List[str]): llm = LLM( model="microsoft/Phi-3.5-vision-instruct", trust_remote_code=True, @@ -30,10 +62,10 @@ def load_phi3v(question, image_urls: List[str]): for i, _ in enumerate(image_urls, start=1)) prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n" stop_token_ids = None - return llm, prompt, stop_token_ids, None + return llm, prompt, stop_token_ids, None, None -def load_internvl(question, image_urls: List[str]): +def load_internvl(question: str, image_urls: List[str]): model_name = "OpenGVLab/InternVL2-2B" llm = LLM( @@ -61,7 +93,7 @@ def load_internvl(question, image_urls: List[str]): stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"] stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] - return llm, prompt, stop_token_ids, None + return llm, prompt, stop_token_ids, None, None def load_qwen2_vl(question, image_urls: List[str]): @@ -111,18 +143,19 @@ def load_qwen2_vl(question, image_urls: List[str]): else: image_data, _ = process_vision_info(messages) - return llm, prompt, stop_token_ids, image_data + return llm, prompt, stop_token_ids, image_data, None model_example_map = { "phi3_v": load_phi3v, "internvl_chat": load_internvl, "qwen2_vl": load_qwen2_vl, + "qwen_vl_chat": load_qwenvl_chat, } def run_generate(model, question: str, image_urls: List[str]): - llm, prompt, stop_token_ids, image_data = model_example_map[model]( + llm, prompt, stop_token_ids, image_data, _ = model_example_map[model]( question, image_urls) if image_data is None: image_data = [fetch_image(url) for url in image_urls] @@ -146,29 +179,32 @@ def run_generate(model, question: str, image_urls: List[str]): def run_chat(model: str, question: str, image_urls: List[str]): - llm, _, stop_token_ids, _ = model_example_map[model](question, image_urls) + llm, _, stop_token_ids, _, chat_template = model_example_map[model]( + question, image_urls) sampling_params = SamplingParams(temperature=0.0, max_tokens=128, stop_token_ids=stop_token_ids) - - outputs = llm.chat([{ - "role": - "user", - "content": [ - { - "type": "text", - "text": question, - }, - *({ - "type": "image_url", - "image_url": { - "url": image_url + outputs = llm.chat( + [{ + "role": + "user", + "content": [ + { + "type": "text", + "text": question, }, - } for image_url in image_urls), - ], - }], - sampling_params=sampling_params) + *({ + "type": "image_url", + "image_url": { + "url": image_url + }, + } for image_url in image_urls), + ], + }], + sampling_params=sampling_params, + chat_template=chat_template, + ) for o in outputs: generated_text = o.outputs[0].text diff --git a/tests/models/test_qwen.py b/tests/models/test_qwen.py index 05f5cbf8c3435..5e7f1de99d6c3 100644 --- a/tests/models/test_qwen.py +++ b/tests/models/test_qwen.py @@ -1,11 +1,17 @@ import pathlib -from typing import List, Optional, Type +from typing import Dict, List, Optional, Tuple, Type, Union import pytest +import torch +from PIL.Image import Image -from vllm.multimodal.utils import rescale_image_size +from vllm.config import ModelConfig +from vllm.inputs import InputContext, LLMInputs +from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.utils import cached_get_tokenizer, rescale_image_size -from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets +from ..conftest import (IMAGE_ASSETS, HfRunner, ImageAsset, PromptImageInput, + VllmRunner, _ImageAssets) from .utils import check_logprobs_close pytestmark = pytest.mark.vlm @@ -23,19 +29,205 @@ "Picture 1: \nWhat is the season?: ", }) +HF_MULTIIMAGE_IMAGE_PROMPT = "Picture 1: \nPicture 2: \nCan you compare these images?\n" # noqa: E501 +HF_MULTIIMAGE_IMAGE_PROMPT = "Picture 1: \nPicture 2: \nDescribe the two images in detail.\n" # noqa: E501 +### Multimodal preprocessing tests +SAMPLE_IMAGE = IMAGE_ASSETS[0].pil_image +# These values are specific to Qwen-VL/Chat; we can get these from the model +# config also, but they are hardcoded here to keep the parameterize/fixtures +# easy to read. +IMG_START_ID = 151857 +IMG_END_ID = 151858 +IMG_PAD_ID = 151859 +TOKS_PER_IMG = 256 +VIS_ENC_DIM = 4096 +IMG_SIZE = 448 + + +def build_model_context(model_name: str, + tokenizer_name: Optional[str] = None, + trust_remote_code: bool = False): + """Creates an InputContext for a given model. + + Args: + model_name: Name of the model being considered. + tokenizer_name: Name of the tokenizer being considered. + trust_remote_code: Whether or not to allow loading remote code. + + Returns: + InputContext for the model being considered. + """ + if tokenizer_name is None: + tokenizer_name = model_name + model_config = ModelConfig( + model_name, + tokenizer_name, + tokenizer_mode="auto", + trust_remote_code=trust_remote_code, + dtype="float32", + seed=0, + ) + return InputContext(model_config) + + +@pytest.fixture() +def input_mapper_for_qwen(): + # Lazy import to avoid initializing CUDA during test collection + from vllm.model_executor.models.qwen import input_mapper_for_qwen + return input_mapper_for_qwen + + +@pytest.fixture() +def input_processor_for_qwen(): + # Lazy import to avoid initializing CUDA during test collection + from vllm.model_executor.models.qwen import input_processor_for_qwen + return input_processor_for_qwen + + +@pytest.fixture() +def qwen_vl_context() -> InputContext: + """Get an InputContext for Qwen-VL.""" + return build_model_context(model_name="Qwen/Qwen-VL", + trust_remote_code=True) + + +# Happy path tests for single/multi-image scenarios for the multimodal +# input processor and mapper, respectively +@pytest.mark.parametrize("num_images", [1, 2]) +def test_input_processor_valid_mm_data(input_processor_for_qwen, + qwen_vl_context: InputContext, + num_images: int): + """Happy cases for image inputs to Qwen's multimodal input processor.""" + prompt = "".join( + [f"Picture {num}: \n" for num in range(1, num_images + 1)]) + inputs = LLMInputs( + prompt=prompt, + # When processing multimodal data for a multimodal model, the qwen + # input processor will overwrite the provided prompt_token_ids with + # the image prompts + prompt_token_ids=None, + multi_modal_data={"image": torch.rand(num_images, TOKS_PER_IMG, 4096)}, + ) + proc_inputs = input_processor_for_qwen(qwen_vl_context, inputs) + assert isinstance(proc_inputs, dict) + + # Each image should have one start / stop and a fixed context of 256 + proc_tokens = proc_inputs["prompt_token_ids"] + assert proc_tokens.count(IMG_START_ID) == num_images + assert proc_tokens.count(IMG_END_ID) == num_images + assert proc_tokens.count(IMG_PAD_ID) == num_images * TOKS_PER_IMG + + +@pytest.mark.parametrize( + "img_data,expected_shape", + [ + # single / multi-image + (SAMPLE_IMAGE, (1, 3, IMG_SIZE, IMG_SIZE)), + (2 * [SAMPLE_IMAGE], (2, 3, IMG_SIZE, IMG_SIZE)), + # single / multi-image embeddings + (torch.rand( + (TOKS_PER_IMG, VIS_ENC_DIM)), (1, TOKS_PER_IMG, VIS_ENC_DIM)), + (torch.rand( + (1, TOKS_PER_IMG, VIS_ENC_DIM)), (1, TOKS_PER_IMG, VIS_ENC_DIM)), + (torch.rand( + (2, TOKS_PER_IMG, VIS_ENC_DIM)), (2, TOKS_PER_IMG, VIS_ENC_DIM)), + ]) +def test_input_mapper_valid_mm_data(input_mapper_for_qwen, + qwen_vl_context: InputContext, + img_data: Union[torch.Tensor, List[Image], + Image], + expected_shape: List[int]): + """Happy cases for image inputs to Qwen's multimodal input mapper.""" + mapped_img_data = input_mapper_for_qwen(qwen_vl_context, img_data) + # Ensure that we get the appropriately shaped pixel_values + # for images and image embeddings, respectively. + assert isinstance(mapped_img_data, MultiModalInputs) + assert "pixel_values" in mapped_img_data + assert mapped_img_data["pixel_values"].shape == expected_shape + + +# Sad path tests for the multimodal input processor and mapper, respectively +@pytest.mark.parametrize("mm_data", [ + { + "image": torch.rand((5)) + }, + { + "image": torch.rand((5, 5, 5, 5, 5)) + }, +]) +def test_input_processor_invalid_mm_data(input_processor_for_qwen, + qwen_vl_context: InputContext, + mm_data: Dict[str, torch.Tensor]): + """Test sad cases validated in Qwen's multimodal input processor.""" + tokenizer = cached_get_tokenizer(qwen_vl_context.model_config.tokenizer, + trust_remote_code=True) + prompt = "Picture 1: \n" + prompt_token_ids = tokenizer.encode(prompt) + inputs = LLMInputs(prompt=prompt, + prompt_token_ids=prompt_token_ids, + multi_modal_data=mm_data) + # Should fail since we have too many or too few dimensions for embeddings + with pytest.raises(ValueError): + input_processor_for_qwen(qwen_vl_context, inputs) + + +@pytest.mark.parametrize( + "img_data", + [ + # Wrong context length + torch.rand((1, TOKS_PER_IMG + 10, VIS_ENC_DIM)), + # Wrong visual encoder output size + torch.rand((1, TOKS_PER_IMG, VIS_ENC_DIM + 10)), + ]) +def test_input_mapper_invalid_mm_data( + input_mapper_for_qwen, + qwen_vl_context: InputContext, + img_data: Union[torch.Tensor, List[Image], Image], +): + """Sad cases validated in Qwen VL's multimodal input mapper.""" + with pytest.raises(ValueError): + input_mapper_for_qwen(qwen_vl_context, img_data) + + +### End-to-end generation tests +def get_prompt_with_path(tmp_path: pathlib.PosixPath, prompt: str, + assets: Union[_ImageAssets, List[ImageAsset]]) -> str: + """Given a temporary dir path, export one or more image assets into the + tempdir & replace its contents with the local path to the string so that + the HF version of Qwen-VL can resolve the path and load the image ni its + forward() call. + + Args: + tmp_path: Tempdir for test under consideration. + prompt: Prompt with image placeholders. + assets: List of image assets whose len equals the num placeholders. + """ + # Ensure that the number of placeholders matches the number of assets; + # If this is not true, the test is probably written incorrectly. + assert prompt.count("") == len(assets) + + # Replace the placeholders with local paths to the exported assets + for asset in assets: + image_tmp_path = tmp_path / f"{asset.name}.jpg" + asset.pil_image.save(image_tmp_path) + prompt = prompt.replace( + "", + f"{image_tmp_path}", + 1, + ) + return prompt + -### Tests for multimodal Qwen models def run_test( - tmp_path: pathlib.PosixPath, hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner], - image_assets: _ImageAssets, + inputs: List[Tuple[List[str], PromptImageInput]], model: str, *, - size_factors: List[float], dtype: str, max_tokens: int, num_logprobs: int, + mm_limit: int, tensor_parallel_size: int, distributed_executor_backend: Optional[str] = None, ): @@ -48,23 +240,6 @@ def run_test( Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. """ - images = [asset.pil_image for asset in image_assets] - - # Export the images to a tempdir and substitute it into the hf prompt; - # the contents between / will be ignored by VLLM, but the - # transformers implementation for the visual transformer parses this to - # reload it in the forward call; the contents are treated as a URL or a - # local path. - for idx, asset in enumerate(image_assets): - image_tmp_path = tmp_path / f"{asset.name}.jpg" - asset.pil_image.save(image_tmp_path) - HF_IMAGE_PROMPTS[idx] = HF_IMAGE_PROMPTS[idx].replace( - "", f"{image_tmp_path}") - - inputs_per_image = [( - [prompt for _ in size_factors], - [rescale_image_size(image, factor) for factor in size_factors], - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] # NOTE: take care of the order. run vLLM first, and then run HF. # vLLM needs a fresh new process without cuda initialization. @@ -72,11 +247,12 @@ def run_test( # will hurt multiprocessing backend with fork method (the default method). # max_model_len should be greater than image_feature_size - # Qwen encodes images into a fixed content size of 256 + # Qwen encodes each image into a fixed content size of 256 with vllm_runner(model, - max_model_len=300, + max_model_len=1024, max_num_seqs=1, dtype=dtype, + limit_mm_per_prompt={"image": mm_limit}, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, enforce_eager=True) as vllm_model: @@ -85,7 +261,7 @@ def run_test( max_tokens, num_logprobs=num_logprobs, images=images) - for prompts, images in inputs_per_image + for prompts, images in inputs ] with hf_runner(model, dtype=dtype) as hf_model: @@ -94,7 +270,7 @@ def run_test( max_tokens, num_logprobs=num_logprobs, images=images) - for prompts, images in inputs_per_image + for prompts, images in inputs ] for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, @@ -125,19 +301,81 @@ def run_test( @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [8]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_multimodal_models(tmp_path, hf_runner, vllm_runner, image_assets, - model, size_factors, dtype, max_tokens, - num_logprobs) -> None: +def test_multimodal_models_single_image(tmp_path: pathlib.PosixPath, + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, model: str, + size_factors: List[float], dtype: str, + max_tokens: int, + num_logprobs: int) -> None: + """Tests multimodal models with single image prompts.""" + images = [asset.pil_image for asset in image_assets] + + prompts = [ + get_prompt_with_path(tmp_path, prompt, [asset]) + for prompt, asset in zip(HF_IMAGE_PROMPTS, image_assets) + ] + + inputs = [( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) for image, prompt in zip(images, prompts)] + + run_test( + hf_runner, + vllm_runner, + inputs, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + mm_limit=1, + tensor_parallel_size=1, + ) + + +@pytest.mark.parametrize("model", multimodal_models) +@pytest.mark.parametrize( + "size_factors", + [ + # No image + [], + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.25, 0.5, 1.0], + ], +) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_multimodal_models_multi_image(tmp_path: pathlib.PosixPath, + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, model: str, + size_factors: List[float], dtype: str, + max_tokens: int, + num_logprobs: int) -> None: + """Tests multimodal models with multi-image prompts.""" + images = [asset.pil_image for asset in image_assets] + # Put all of the images into one prompt. + prompt = get_prompt_with_path(tmp_path, HF_MULTIIMAGE_IMAGE_PROMPT, + image_assets) + inputs = [([prompt for _ in size_factors], + [[rescale_image_size(image, factor) for image in images] + for factor in size_factors])] + run_test( - tmp_path, hf_runner, vllm_runner, - image_assets, + inputs, model, - size_factors=size_factors, dtype=dtype, max_tokens=max_tokens, num_logprobs=num_logprobs, + mm_limit=2, tensor_parallel_size=1, ) @@ -150,7 +388,7 @@ def test_multimodal_models(tmp_path, hf_runner, vllm_runner, image_assets, @pytest.mark.parametrize("num_logprobs", [5]) def test_text_only_qwen_model_can_be_loaded_and_run( vllm_runner: Type[VllmRunner], - example_prompts, + example_prompts: List[str], model: str, *, dtype: str, diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index a726ec10984c0..18bc6b303f485 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -47,6 +47,7 @@ from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, SequenceData) +from vllm.utils import is_list_of from .utils import flatten_bn, is_pp_missing_parameter, make_layers @@ -684,9 +685,12 @@ def input_processor_for_qwen(ctx: InputContext, raise ValueError( f"Expected img embeds to be have 3 dimensions, got {num_dims}") num_images = 1 if num_dims == 2 else image_data.shape[0] - else: - # TODO - handle multiple image inputs once the API is solidified + elif isinstance(image_data, Image.Image): num_images = 1 + elif is_list_of(image_data, Image.Image): + num_images = len(image_data) + else: + raise TypeError(f"Invalid image type: {type(image_data)}") if prompt is None: prompt = tokenizer.decode(prompt_token_ids) @@ -767,11 +771,11 @@ def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs: f"[# images, {MAX_QWEN_IMG_TOKENS}, {img_emb_size}], but " f"received shape [{data.shape}]") pixel_values = data - else: transform = build_normalization_transform(image_size) - # TODO - handle multiple image inputs once the API is solidified - transformed_images = [transform(data)] + if not isinstance(data, (list, tuple)): + data = [data] + transformed_images = [transform(datum) for datum in data] pixel_values = torch.stack(transformed_images, dim=0) return MultiModalInputs({"pixel_values": pixel_values}) From 8a23e933026bdb66b0b141c69454457428aa056d Mon Sep 17 00:00:00 2001 From: WANGWEI Date: Fri, 13 Sep 2024 01:47:42 +0800 Subject: [PATCH 11/98] [BugFix] lazy init _copy_stream to avoid torch init wrong gpu instance (#8403) --- vllm/worker/multi_step_model_runner.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 9a196c3dfcd1f..cd9b20083c1a6 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -230,12 +230,15 @@ def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs): self._base_model_runner: GPUModelRunnerBase = base_model_runner self.is_multi_step = self.scheduler_config.is_multi_step - # used to copy tensors from GPU to CPU asynchronously - self._copy_stream = torch.cuda.Stream() self.pinned_sampled_token_ids: Optional[torch.Tensor] = None self.pythonization_cache = PythonizationCache() + @functools.cached_property + def _copy_stream(self): + # used to copy tensors from GPU to CPU asynchronously + return torch.cuda.Stream() + def make_model_input_from_broadcasted_tensor_dict( self, tensor_dict: Dict[str, Any]) -> StatefulModelInput: model_input = (StatefulModelInput.from_broadcasted_tensor_dict( From 1f0c75afa95303fcb628861f040199090e82004d Mon Sep 17 00:00:00 2001 From: Luis Vega Date: Thu, 12 Sep 2024 11:10:11 -0700 Subject: [PATCH 12/98] [BugFix] Fix Duplicate Assignment in Hermes2ProToolParser (#8423) --- vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index bde9b47ce60d5..ad6f536838a88 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -33,7 +33,6 @@ def __init__(self, tokenizer: AnyTokenizer): self.current_tool_name_sent: bool = False self.prev_tool_call_arr: List[Dict] = [] self.current_tool_id: int = -1 - self.current_tool_name_sent = False self.streamed_args_for_tool: List[str] = [ ] # map what has been streamed for each tool so far to a list From f2e263b801743596f5dda0680e0bcb0fc3c05e26 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Thu, 12 Sep 2024 12:11:57 -0600 Subject: [PATCH 13/98] [Bugfix] Offline mode fix (#8376) Signed-off-by: Joe Runde --- .buildkite/test-pipeline.yaml | 1 + tests/entrypoints/offline_mode/__init__.py | 0 .../offline_mode/test_offline_mode.py | 77 +++++++++++++++++++ vllm/transformers_utils/config.py | 30 +++++++- 4 files changed, 106 insertions(+), 2 deletions(-) create mode 100644 tests/entrypoints/offline_mode/__init__.py create mode 100644 tests/entrypoints/offline_mode/test_offline_mode.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 5b8d6a8739f1b..25f18cc57793e 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -91,6 +91,7 @@ steps: - pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process - pytest -v -s entrypoints/openai - pytest -v -s entrypoints/test_chat_utils.py + - pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests - label: Distributed Tests (4 GPUs) # 10min diff --git a/tests/entrypoints/offline_mode/__init__.py b/tests/entrypoints/offline_mode/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/entrypoints/offline_mode/test_offline_mode.py b/tests/entrypoints/offline_mode/test_offline_mode.py new file mode 100644 index 0000000000000..0b6026a89c758 --- /dev/null +++ b/tests/entrypoints/offline_mode/test_offline_mode.py @@ -0,0 +1,77 @@ +"""Tests for HF_HUB_OFFLINE mode""" +import importlib +import sys +import weakref + +import pytest + +from vllm import LLM + +from ...conftest import cleanup + +MODEL_NAME = "facebook/opt-125m" + + +@pytest.fixture(scope="module") +def llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, + max_num_batched_tokens=4096, + tensor_parallel_size=1, + gpu_memory_utilization=0.10, + enforce_eager=True) + + with llm.deprecate_legacy_api(): + yield weakref.proxy(llm) + + del llm + + cleanup() + + +@pytest.mark.skip_global_cleanup +def test_offline_mode(llm: LLM, monkeypatch): + # we use the llm fixture to ensure the model files are in-cache + del llm + + # Set HF to offline mode and ensure we can still construct an LLM + try: + monkeypatch.setenv("HF_HUB_OFFLINE", "1") + # Need to re-import huggingface_hub and friends to setup offline mode + _re_import_modules() + # Cached model files should be used in offline mode + LLM(model=MODEL_NAME, + max_num_batched_tokens=4096, + tensor_parallel_size=1, + gpu_memory_utilization=0.10, + enforce_eager=True) + finally: + # Reset the environment after the test + # NB: Assuming tests are run in online mode + monkeypatch.delenv("HF_HUB_OFFLINE") + _re_import_modules() + pass + + +def _re_import_modules(): + hf_hub_module_names = [ + k for k in sys.modules if k.startswith("huggingface_hub") + ] + transformers_module_names = [ + k for k in sys.modules if k.startswith("transformers") + and not k.startswith("transformers_modules") + ] + + reload_exception = None + for module_name in hf_hub_module_names + transformers_module_names: + try: + importlib.reload(sys.modules[module_name]) + except Exception as e: + reload_exception = e + # Try to continue clean up so that other tests are less likely to + # be affected + + # Error this test if reloading a module failed + if reload_exception is not None: + raise reload_exception diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 29a1ae1850500..3c269bc10cdf8 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -4,7 +4,9 @@ from pathlib import Path from typing import Any, Dict, Optional, Type, Union -from huggingface_hub import file_exists, hf_hub_download +import huggingface_hub +from huggingface_hub import (file_exists, hf_hub_download, + try_to_load_from_cache) from transformers import GenerationConfig, PretrainedConfig from transformers.models.auto.image_processing_auto import ( get_image_processor_config) @@ -70,7 +72,22 @@ def file_or_path_exists(model: Union[str, Path], config_name, revision, if Path(model).exists(): return (Path(model) / config_name).is_file() - return file_exists(model, config_name, revision=revision, token=token) + # Offline mode support: Check if config file is cached already + cached_filepath = try_to_load_from_cache(repo_id=model, + filename=config_name, + revision=revision) + if isinstance(cached_filepath, str): + # The config file exists in cache- we can continue trying to load + return True + + # NB: file_exists will only check for the existence of the config file on + # hf_hub. This will fail in offline mode. + try: + return file_exists(model, config_name, revision=revision, token=token) + except huggingface_hub.errors.OfflineModeIsEnabled: + # Don't raise in offline mode, all we know is that we don't have this + # file cached. + return False def get_config( @@ -102,6 +119,15 @@ def get_config( token=kwargs.get("token")): config_format = ConfigFormat.MISTRAL else: + # If we're in offline mode and found no valid config format, then + # raise an offline mode error to indicate to the user that they + # don't have files cached and may need to go online. + # This is conveniently triggered by calling file_exists(). + file_exists(model, + HF_CONFIG_NAME, + revision=revision, + token=kwargs.get("token")) + raise ValueError(f"No supported config format found in {model}") if config_format == ConfigFormat.HF: From a6c0f3658da4f2f23460e3e15bfa7d70ac7e60c1 Mon Sep 17 00:00:00 2001 From: William Lin Date: Thu, 12 Sep 2024 11:16:22 -0700 Subject: [PATCH 14/98] [multi-step] add flashinfer backend (#7928) --- csrc/ops.h | 19 +- csrc/prepare_inputs/advance_step.cu | 225 ++++++++++++++++-- csrc/torch_bindings.cpp | 15 +- .../multi_step/test_correctness_async_llm.py | 12 +- vllm/_custom_ops.py | 38 ++- vllm/attention/backends/abstract.py | 4 +- vllm/attention/backends/flash_attn.py | 18 +- vllm/attention/backends/flashinfer.py | 87 ++++++- vllm/worker/multi_step_model_runner.py | 37 ++- 9 files changed, 371 insertions(+), 84 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index 05b89e183ca29..5333b22c536d6 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -54,10 +54,21 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input); void gelu_quick(torch::Tensor& out, torch::Tensor& input); -void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size, - torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, - torch::Tensor& input_positions, torch::Tensor& seq_lens, - torch::Tensor& slot_mapping, torch::Tensor& block_tables); +void advance_step_flashattn(int64_t num_seqs, int64_t num_queries, + int64_t block_size, torch::Tensor& input_tokens, + torch::Tensor& sampled_token_ids, + torch::Tensor& input_positions, + torch::Tensor& seq_lens, + torch::Tensor& slot_mapping, + torch::Tensor& block_tables); + +void advance_step_flashinfer( + int64_t num_seqs, int64_t num_queries, int64_t block_size, + torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, + torch::Tensor& input_positions, torch::Tensor& seq_lens, + torch::Tensor& slot_mapping, torch::Tensor& block_tables, + torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr, + torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds); #ifndef USE_ROCM torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, diff --git a/csrc/prepare_inputs/advance_step.cu b/csrc/prepare_inputs/advance_step.cu index 0e537ddd6c4cd..a9d08ca0dc14c 100644 --- a/csrc/prepare_inputs/advance_step.cu +++ b/csrc/prepare_inputs/advance_step.cu @@ -12,13 +12,11 @@ namespace prepare_inputs { // template -__global__ void advance_step_kernel(int num_seqs, int num_queries, - int block_size, long* input_tokens_ptr, - long const* sampled_token_ids_ptr, - long* input_positions_ptr, - int* seq_lens_ptr, long* slot_mapping_ptr, - int const* block_tables_ptr, - int64_t const block_tables_stride) { +__global__ void advance_step_flashattn_kernel( + int num_seqs, int num_queries, int block_size, long* input_tokens_ptr, + long const* sampled_token_ids_ptr, long* input_positions_ptr, + int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr, + int64_t const block_tables_stride) { int num_query_blocks = div_ceil(num_queries, num_threads); if (blockIdx.x >= num_query_blocks) { @@ -79,16 +77,91 @@ inline void verify_tensor(std::string const& name, torch::Tensor& t, } } -void advance_step(int num_seqs, int num_queries, int block_size, - torch::Tensor& input_tokens, // type: long - torch::Tensor& sampled_token_ids, // type: long - torch::Tensor& input_positions, // type: long - torch::Tensor& seq_lens, // type: int - torch::Tensor& slot_mapping, // type: long - torch::Tensor& block_tables) { // type: int +__global__ void advance_step_flashinfer_kernel( + int num_threads, int num_seqs, int num_queries, int block_size, + long* input_tokens_ptr, long const* sampled_token_ids_ptr, + long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr, + int const* block_tables_ptr, int64_t const block_tables_stride, + int* paged_kv_last_page_len_ptr, int* block_table_bound_ptr) { + int num_query_blocks = div_ceil(num_queries, num_threads); + + if (blockIdx.x < num_query_blocks) { + int cur_query_id = blockIdx.x * num_threads + threadIdx.x; + + if (cur_query_id < num_queries) { + // Update input_tokens + input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id]; + + int seq_len = seq_lens_ptr[cur_query_id]; + int next_seq_len = seq_len + 1; + int next_input_pos = next_seq_len - 1; + + // Update seq_lens + seq_lens_ptr[cur_query_id] = next_seq_len; + // Update input_positions + input_positions_ptr[cur_query_id] = next_input_pos; + + int const* seq_block_tables_ptr = + block_tables_ptr + block_tables_stride * cur_query_id; + + int block_index = next_input_pos / block_size; + int block_offset = next_input_pos % block_size; + + // Update paged_kv_last_page_len + paged_kv_last_page_len_ptr[cur_query_id] = block_offset + 1; + + int slot_num = + seq_block_tables_ptr[block_index] * block_size + block_offset; + // Update slot_mapping + slot_mapping_ptr[cur_query_id] = slot_num; + block_table_bound_ptr[cur_query_id] = div_ceil(next_seq_len, block_size); + } + } +} + +__global__ void advance_step_flashinfer_indptr_kernel( + int num_threads, int num_seqs, int num_queries, int* paged_kv_indptr_ptr, + int* block_table_bound_ptr) { + int idx = blockIdx.x * num_threads + threadIdx.x; + + // Update paged_kv_indptr + if (idx < num_queries) { + int sum = 0; + for (int i = 0; i <= idx; ++i) { + sum += block_table_bound_ptr[i]; + } + paged_kv_indptr_ptr[idx + 1] = sum; + } +} + +__global__ void advance_step_flashinfer_indices_kernel( + int num_threads, int num_seqs, int num_queries, int const* block_tables_ptr, + int64_t const block_tables_stride, int* paged_kv_indices_ptr, + int* paged_kv_indptr_ptr, int* block_table_bound_ptr) { + int idx = blockIdx.x * num_threads + threadIdx.x; + int row = idx / block_tables_stride; + int col = idx % block_tables_stride; + + if (row < num_queries && col < block_table_bound_ptr[row]) { + paged_kv_indices_ptr[paged_kv_indptr_ptr[row] + col] = + block_tables_ptr[row * block_tables_stride + col]; + } + // if cudagraph, fill padded seqs with the last valid seq's indptr + if (num_queries < row && row <= num_seqs) { + paged_kv_indptr_ptr[row] = paged_kv_indptr_ptr[num_queries]; + } +} + +void advance_step_flashattn(int num_seqs, int num_queries, int block_size, + torch::Tensor& input_tokens, // type: long + torch::Tensor& sampled_token_ids, // type: long + torch::Tensor& input_positions, // type: long + torch::Tensor& seq_lens, // type: int + torch::Tensor& slot_mapping, // type: long + torch::Tensor& block_tables) { // type: int if (logging) { - printf("advance_step:\n"); + printf("advance_step_flashattn:\n"); printf(" num_seqs = %d\n", num_seqs); printf(" num_queries = %d\n", num_queries); printf(" block_size = %d\n", block_size); @@ -108,24 +181,126 @@ void advance_step(int num_seqs, int num_queries, int block_size, int blocks; cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); - advance_step_kernel<<>>( - num_seqs, num_queries, block_size, + advance_step_flashattn_kernel + <<>>( + num_seqs, num_queries, block_size, + reinterpret_cast(input_tokens.data_ptr()), + reinterpret_cast(sampled_token_ids.data_ptr()), + reinterpret_cast(input_positions.data_ptr()), + reinterpret_cast(seq_lens.data_ptr()), + reinterpret_cast(slot_mapping.data_ptr()), + reinterpret_cast(block_tables.data_ptr()), + block_tables.stride(0)); +} + +void advance_step_flashinfer( + int num_seqs, int num_queries, int block_size, + torch::Tensor& input_tokens, // type: long + torch::Tensor& sampled_token_ids, // type: long + torch::Tensor& input_positions, // type: long + torch::Tensor& seq_lens, // type: int + torch::Tensor& slot_mapping, // type: long + torch::Tensor& block_tables, // type: int + torch::Tensor& paged_kv_indices, // type: int + torch::Tensor& paged_kv_indptr, // type: int + torch::Tensor& paged_kv_last_page_len, // type: int + torch::Tensor& block_table_bound) { // type: int + + if (logging) { + printf("advance_step_flashinfer:\n"); + printf(" num_seqs = %d\n", num_seqs); + printf(" num_queries = %d\n", num_queries); + printf(" block_size = %d\n", block_size); + printf(" block_tables.stride(0) = %d\n", block_tables.stride(0)); + } + // Verify all tensors + verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong); + // verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1, + // at::kLong); + verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong); + verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt); + verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong); + verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt); + + verify_tensor("paged_kv_indices", paged_kv_indices, -1, -1, at::kInt); + verify_tensor("paged_kv_indptr", paged_kv_indptr, num_seqs + 1, -1, at::kInt); + verify_tensor("paged_kv_last_page_len", paged_kv_last_page_len, num_seqs, -1, + at::kInt); + + verify_tensor("block_table_bound", block_table_bound, num_seqs, -1, at::kInt); + + int dev = sampled_token_ids.get_device(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); + + int blocks; + int threads; + cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); + cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev); + if (logging) { + printf("launching kernel with %d blocks\n", blocks); + } + + // TODO(will): support arbitrary block_tables stride + if ((blocks * threads) / block_tables.stride(0) < num_queries) { + TORCH_CHECK(false, + "multi-step: not enough threads to map block_table to" + "FlashInfer's paged_kv_indices on GPU. Try reducing the number " + "of seqs,", + " increasing the block size or take smaller steps.", + " num_queries = ", num_queries, + " block_tables.stride(0) = ", block_tables.stride(0), + " blocks = ", blocks, " max_threads = ", threads); + } + + advance_step_flashinfer_kernel<<>>( + threads, num_seqs, num_queries, block_size, reinterpret_cast(input_tokens.data_ptr()), reinterpret_cast(sampled_token_ids.data_ptr()), reinterpret_cast(input_positions.data_ptr()), reinterpret_cast(seq_lens.data_ptr()), reinterpret_cast(slot_mapping.data_ptr()), reinterpret_cast(block_tables.data_ptr()), - block_tables.stride(0)); + block_tables.stride(0), + reinterpret_cast(paged_kv_last_page_len.data_ptr()), + reinterpret_cast(block_table_bound.data_ptr())); + + advance_step_flashinfer_indptr_kernel<<>>( + threads, num_seqs, num_queries, + reinterpret_cast(paged_kv_indptr.data_ptr()), + reinterpret_cast(block_table_bound.data_ptr())); + + advance_step_flashinfer_indices_kernel<<>>( + threads, num_seqs, num_queries, + reinterpret_cast(block_tables.data_ptr()), + block_tables.stride(0), + reinterpret_cast(paged_kv_indices.data_ptr()), + reinterpret_cast(paged_kv_indptr.data_ptr()), + reinterpret_cast(block_table_bound.data_ptr())); } } // namespace prepare_inputs -void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size, - torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, - torch::Tensor& input_positions, torch::Tensor& seq_lens, - torch::Tensor& slot_mapping, torch::Tensor& block_tables) { - prepare_inputs::advance_step(num_seqs, num_queries, block_size, input_tokens, - sampled_token_ids, input_positions, seq_lens, - slot_mapping, block_tables); +void advance_step_flashattn(int64_t num_seqs, int64_t num_queries, + int64_t block_size, torch::Tensor& input_tokens, + torch::Tensor& sampled_token_ids, + torch::Tensor& input_positions, + torch::Tensor& seq_lens, + torch::Tensor& slot_mapping, + torch::Tensor& block_tables) { + prepare_inputs::advance_step_flashattn( + num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, + input_positions, seq_lens, slot_mapping, block_tables); +} + +void advance_step_flashinfer( + int64_t num_seqs, int64_t num_queries, int64_t block_size, + torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, + torch::Tensor& input_positions, torch::Tensor& seq_lens, + torch::Tensor& slot_mapping, torch::Tensor& block_tables, + torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr, + torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bound) { + prepare_inputs::advance_step_flashinfer( + num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, + input_positions, seq_lens, slot_mapping, block_tables, paged_kv_indices, + paged_kv_indptr, paged_kv_last_page_len, block_table_bound); } \ No newline at end of file diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 57103c0936f5b..51afeacfdc0ad 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -74,11 +74,22 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // prepare_inputs advance_step ops.def( - "advance_step(int num_seqs, int num_queries, int block_size, " + "advance_step_flashattn(int num_seqs, int num_queries, int block_size, " "Tensor! input_tokens, Tensor sampled_token_ids, " "Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping, " "Tensor block_tables) -> ()"); - ops.impl("advance_step", torch::kCUDA, &advance_step); + ops.impl("advance_step_flashattn", torch::kCUDA, &advance_step_flashattn); + + ops.def( + "advance_step_flashinfer(" + " int num_seqs, int num_queries, int block_size," + " Tensor! input_tokens, Tensor sampled_token_ids," + " Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping," + " Tensor block_tables, Tensor! paged_kv_indices," + " Tensor! paged_kv_indptr, Tensor! paged_kv_last_page_len," + " Tensor! block_table_bounds" + ") -> ()"); + ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer); // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. diff --git a/tests/multi_step/test_correctness_async_llm.py b/tests/multi_step/test_correctness_async_llm.py index 0cbe8371e235a..a75a671e57f74 100644 --- a/tests/multi_step/test_correctness_async_llm.py +++ b/tests/multi_step/test_correctness_async_llm.py @@ -1,9 +1,10 @@ # Test the AsyncLLMEngine with multi-step-decoding - from typing import List, Optional import pytest +from tests.kernels.utils import override_backend_env_variable + from ..models.utils import check_logprobs_close from ..utils import (completions_with_server_args, get_client_text_generations, get_client_text_logprob_generations) @@ -33,8 +34,9 @@ @pytest.mark.parametrize("eager_mode", [False, True]) @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @pytest.mark.parametrize("num_prompts", NUM_PROMPTS) -@pytest.mark.parametrize("num_logprobs", [None, 5]) -@pytest.mark.parametrize("is_async", [False, True]) +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("is_async", [True]) +@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"]) @pytest.mark.asyncio async def test_multi_step( example_prompts, @@ -46,6 +48,8 @@ async def test_multi_step( num_prompts: int, is_async: bool, num_logprobs: Optional[int], + attention_backend: str, + monkeypatch, ) -> None: """Test vLLM engine with multi-step scheduling in an OpenAI-protocol client/server environment. @@ -71,6 +75,8 @@ async def test_multi_step( completions endpoint; `None` -> no logprobs """ + override_backend_env_variable(monkeypatch, attention_backend) + prompts = example_prompts if len(prompts) < num_prompts: prompts = prompts * ((num_prompts // len(prompts)) + 1) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 7a9061526ef2c..efa02d36c4acd 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -161,16 +161,36 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) -def advance_step(num_seqs: int, num_queries: int, block_size: int, - input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor, - input_positions: torch.Tensor, seq_lens: torch.Tensor, - slot_mapping: torch.Tensor, - block_tables: torch.Tensor) -> None: +def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int, + input_tokens: torch.Tensor, + sampled_token_ids: torch.Tensor, + input_positions: torch.Tensor, + seq_lens: torch.Tensor, slot_mapping: torch.Tensor, + block_tables: torch.Tensor) -> None: """Advance a step on GPU for existing inputs for a multi-step runner""" - return torch.ops._C.advance_step(num_seqs, num_queries, block_size, - input_tokens, sampled_token_ids, - input_positions, seq_lens, slot_mapping, - block_tables) + return torch.ops._C.advance_step_flashattn(num_seqs, num_queries, + block_size, input_tokens, + sampled_token_ids, + input_positions, seq_lens, + slot_mapping, block_tables) + + +def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int, + input_tokens: torch.Tensor, + sampled_token_ids: torch.Tensor, + input_positions: torch.Tensor, + seq_lens: torch.Tensor, slot_mapping: torch.Tensor, + block_tables: torch.Tensor, + paged_kv_indices: torch.Tensor, + paged_kv_indptr: torch.Tensor, + paged_kv_last_page_len: torch.Tensor, + block_table_bound: torch.Tensor) -> None: + + return torch.ops._C.advance_step_flashinfer( + num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, + input_positions, seq_lens, slot_mapping, block_tables, + paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len, + block_table_bound) # quantization ops diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index ccfc6b254c1e7..adc8390e6f9ec 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -83,7 +83,9 @@ def copy_blocks( ) -> None: raise NotImplementedError - def advance_step(self, num_seqs: int, num_queries: int): + def advance_step(self, model_input: "ModelRunnerInputBase", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, num_seqs: int, num_queries: int) -> None: raise NotImplementedError diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index ec9cbde7467d6..bf883987bd80b 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -380,15 +380,15 @@ def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", self.seq_lens[i] += 1 self.max_decode_seq_len = max(self.seq_lens) - ops.advance_step(num_seqs=num_seqs, - num_queries=num_queries, - block_size=block_size, - input_tokens=model_input.input_tokens, - sampled_token_ids=sampled_token_ids, - input_positions=model_input.input_positions, - seq_lens=self.seq_lens_tensor, - slot_mapping=self.slot_mapping, - block_tables=self.block_tables) + ops.advance_step_flashattn(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables) class FlashAttentionMetadataBuilder( diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 7aec8203eb1e5..58d62e02e8733 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -30,7 +30,8 @@ make_tensor_with_pad) if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUBuilder + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) class FlashInferBackend(AttentionBackend): @@ -268,6 +269,10 @@ class FlashInferMetadata(AttentionMetadata): query_start_loc: Optional[torch.Tensor] = None block_tables: Optional[torch.Tensor] = None + # used for GPU in-place advance_step + seq_lens_tensor: Optional[torch.Tensor] = None + block_table_bound: Optional[torch.Tensor] = None + # An example for paged_kv_indices, paged_kv_indptr: # request 1, page indices [0, 5, 8] # request 2, page indices [1, 6, 7] @@ -318,6 +323,8 @@ def begin_forward(self): assert self.paged_kv_indices is not None assert self.paged_kv_indptr is not None assert self.paged_kv_last_page_len is not None + assert self.block_table_bound is not None + assert self.seq_lens_tensor is not None batch_size = self.query_start_loc.shape[0] - 1 assert batch_size >= 0 # We will use flash attention for profiling to @@ -327,6 +334,8 @@ def begin_forward(self): self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( self.device) + self.block_table_bound = self.block_table_bound.to(self.device) + self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) self.paged_kv_indices = self.paged_kv_indices.to(self.device) self.prefill_wrapper.end_forward() self.prefill_wrapper.begin_forward( @@ -335,14 +344,18 @@ def begin_forward(self): self.num_qo_heads, self.num_kv_heads, self.head_dim, self.page_size) else: - if not self.use_cuda_graph: - assert self.paged_kv_indices is not None - assert self.paged_kv_indptr is not None - assert self.paged_kv_last_page_len is not None - self.paged_kv_indices = self.paged_kv_indices.to(self.device) - self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) - self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( - self.device) + assert self.paged_kv_indices is not None + assert self.paged_kv_indptr is not None + assert self.paged_kv_last_page_len is not None + self.paged_kv_indices = self.paged_kv_indices.to(self.device) + self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) + self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( + self.device) + # handle model warmup path + if self.block_table_bound is not None: + self.block_table_bound = self.block_table_bound.to(self.device) + if self.seq_lens_tensor is not None: + self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) assert self.decode_wrapper is not None self.decode_wrapper.end_forward() @@ -391,6 +404,48 @@ def decode_metadata(self) -> Optional["FlashInferMetadata"]: return self + def advance_step( + self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + ): + """ + Update metadata in-place to advance one decode step. + """ + + assert num_seqs > 0 + assert num_queries > 0 + assert model_input.attn_metadata is not None + assert sampled_token_ids is not None + + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + assert self.use_cuda_graph + + model_input.input_tokens[:num_queries] = sampled_token_ids.flatten() + + # Update GPU tensors + ops.advance_step_flashinfer( + num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=model_input.input_tokens, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables, + paged_kv_indices=self.paged_kv_indices, + paged_kv_indptr=self.paged_kv_indptr, + paged_kv_last_page_len=self.paged_kv_last_page_len, + block_table_bound=self.block_table_bound) + class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): @@ -428,7 +483,7 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.paged_kv_indptr: List[int] = [0] # paged_kv_last_page_len is the length of the last page of each request self.paged_kv_last_page_len: List[int] = [] - + self.total_blocks = 0 self.is_profile_run: bool = False def _add_seq_group( @@ -499,6 +554,7 @@ def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int): # block_table_bound is 1 with 1 valid block. # If seq_len = 15, block_size = 16, # block_table_bound is 0 + 1 with 1 valid block. + self.total_blocks += len(block_table) block_table_bound = seq_len // self.block_size + 1 \ if seq_len % self.block_size != 0 \ else seq_len // self.block_size @@ -583,6 +639,10 @@ def build(self, seq_lens: List[int], query_lens: List[int], out=query_start_loc[1:]) if len(self.paged_kv_indptr) > 0: + # extend to the maximum number of blocks as returned by the + # scheduler + self.paged_kv_indices.extend( + [0] * (self.total_blocks - len(self.paged_kv_indices))) paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, device="cpu", dtype=torch.int) @@ -591,10 +651,15 @@ def build(self, seq_lens: List[int], query_lens: List[int], dtype=torch.int) paged_kv_last_page_len_tensor = torch.tensor( self.paged_kv_last_page_len, device="cpu", dtype=torch.int) + block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) - + 1, + device="cpu", + dtype=torch.int) else: paged_kv_indices_tensor = None paged_kv_indptr_tensor = None paged_kv_last_page_len_tensor = None + block_table_bound_tensor = None if self.runner.kv_cache_dtype.startswith("fp8"): kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( @@ -613,6 +678,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], paged_kv_indptr=paged_kv_indptr_tensor, paged_kv_indices=paged_kv_indices_tensor, paged_kv_last_page_len=paged_kv_last_page_len_tensor, + block_table_bound=block_table_bound_tensor, + seq_lens_tensor=seq_lens_tensor, num_qo_heads=self.runner.model_config.get_num_attention_heads( self.runner.parallel_config), num_kv_heads=self.runner.model_config.get_num_kv_heads( diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index cd9b20083c1a6..b900eb5a610ff 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -4,13 +4,6 @@ from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union) -try: - from vllm.attention.backends.flash_attn import FlashAttentionMetadata -except ModuleNotFoundError: - # vllm_flash_attn is not installed, use the identical ROCm FA metadata - from vllm.attention.backends.rocm_flash_attn import ( - ROCmFlashAttentionMetadata as FlashAttentionMetadata) - import torch from vllm.distributed import get_pp_group @@ -36,6 +29,8 @@ logger = init_logger(__name__) +MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "flashinfer"] + def seq_output_builder(): return SequenceOutput( @@ -489,27 +484,27 @@ def _update_sampling_metadata(self, sampling_metadata, num_seqs, def _advance_step(self, model_input: StatefulModelInput, out: SamplerOutput) -> StatefulModelInput: - frozen_model_input = model_input.frozen_model_input - assert frozen_model_input is not None - assert frozen_model_input.attn_metadata is not None + if self.attn_backend.get_name() not in MULTI_STEP_ATTENTION_BACKENDS: + raise ValueError( + f"Multi-step not supported for attention backend: " + f"{self.attn_backend.get_name()}. Set VLLM_ATTENTION_BACKEND " + f"to a value from {MULTI_STEP_ATTENTION_BACKENDS}.") + sampled_token_ids = model_input.cached_outputs[-1].sampled_token_ids num_seqs = model_input.num_seqs num_queries = model_input.num_queries - assert num_seqs > 0 - assert num_queries > 0 - assert num_seqs >= num_queries - + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None attn_metadata = frozen_model_input.attn_metadata - assert isinstance(attn_metadata, FlashAttentionMetadata) + assert attn_metadata is not None attn_metadata.advance_step( frozen_model_input, - model_input.cached_outputs[-1].sampled_token_ids, self.block_size, - num_seqs, num_queries) - - if frozen_model_input.seq_lens is not None: - for i in range(num_queries): - frozen_model_input.seq_lens[i] = attn_metadata.seq_lens[i] + sampled_token_ids, + self.block_size, + num_seqs, + num_queries, + ) return model_input From 551ce01078a655068e5ec3764d0a55ac744ea425 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 12 Sep 2024 20:02:00 +0100 Subject: [PATCH 15/98] [Core] Add engine option to return only deltas or final output (#7381) --- .buildkite/test-pipeline.yaml | 1 + tests/async_engine/test_async_llm_engine.py | 161 ++++++++++++++++-- vllm/engine/llm_engine.py | 24 +-- vllm/entrypoints/llm.py | 23 +-- vllm/entrypoints/openai/protocol.py | 7 +- vllm/entrypoints/openai/serving_chat.py | 125 ++++++++------ vllm/entrypoints/openai/serving_completion.py | 32 ++-- vllm/outputs.py | 79 ++++++--- vllm/sampling_params.py | 17 +- vllm/sequence.py | 39 ++++- 10 files changed, 371 insertions(+), 137 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 25f18cc57793e..d0732ec3fe2fb 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -50,6 +50,7 @@ steps: - tests/worker commands: - pytest -v -s async_engine # Async Engine + - NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py - pytest -v -s test_inputs.py - pytest -v -s multimodal - pytest -v -s test_utils.py # Utils diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index 3bf11fbcfb3b8..bab42942d311f 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -1,7 +1,10 @@ import asyncio +import os +import uuid from asyncio import CancelledError +from copy import copy from dataclasses import dataclass -from typing import Optional +from typing import List, Optional import pytest import pytest_asyncio @@ -11,6 +14,7 @@ from vllm.config import ParallelConfig from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine from vllm.outputs import RequestOutput as RealRequestOutput +from vllm.sampling_params import RequestOutputKind from ..conftest import cleanup from ..utils import wait_for_gpu_memory_to_clear @@ -122,8 +126,17 @@ def start_engine(): timeout_s=60, ) + num_scheduler_steps = int(os.getenv("NUM_SCHEDULER_STEPS", "1")) + print(f"Starting engine with num_scheduler_steps={num_scheduler_steps}") + return AsyncLLMEngine.from_engine_args( - AsyncEngineArgs(model="facebook/opt-125m", enforce_eager=True)) + AsyncEngineArgs(model="facebook/opt-125m", + enforce_eager=True, + num_scheduler_steps=num_scheduler_steps)) + + +def uid() -> str: + return str(uuid.uuid4()) @pytest_asyncio.fixture(scope="module") @@ -148,57 +161,177 @@ def should_do_global_cleanup_after_test(request) -> bool: @pytest.mark.asyncio(scope="module") async def test_asyncio_run(async_engine): + scheduler_config = await async_engine.get_scheduler_config() + num_scheduler_steps = scheduler_config.num_scheduler_steps + async def run(prompt: str): sampling_params = SamplingParams( temperature=0, max_tokens=32, + min_tokens=32, ) + output_count = 0 + final_output = None async for output in async_engine.generate(prompt, sampling_params, - request_id=prompt): + request_id=uid()): + output_count += 1 final_output = output - return final_output + return final_output, output_count results = await asyncio.gather( run("test0"), - run("test1"), + run("test0"), ) assert len(results) == 2 + first, second = results + + # remove nondeterministic fields for comparison + first[0].metrics = None + second[0].metrics = None + first[0].request_id = None + second[0].request_id = None + + assert str(first) == str(second) + + output_count = results[0][1] + if num_scheduler_steps == 1: + assert output_count == 32 + else: + assert 1 < output_count < 32 + + +@pytest.mark.asyncio(scope="module") +async def test_output_kinds(async_engine): + """Test that output_kind works as expected and that + results are equivalent across different kinds.""" + + scheduler_config = await async_engine.get_scheduler_config() + num_scheduler_steps = scheduler_config.num_scheduler_steps + + sampling_params = SamplingParams( + temperature=0, + max_tokens=32, + min_tokens=32, + ) + + async def run(prompt: str, kind: RequestOutputKind): + params = copy(sampling_params) + params.output_kind = kind + + output_count = 0 + final_output = None + async for output in async_engine.generate(prompt, + params, + request_id=uid()): + output_count += 1 + final_output = output + + assert final_output is not None + return (final_output.prompt_token_ids, + final_output.outputs[0].token_ids, + final_output.outputs[0].text, output_count) + + async def run_deltas(prompt: str): + params = copy(sampling_params) + params.output_kind = RequestOutputKind.DELTA + + prompt_tokens = None + output_tokens: List[int] = [] + output_text = "" + output_count = 0 + async for output in async_engine.generate(prompt, + params, + request_id=uid()): + token_ids = output.outputs[0].token_ids + text = output.outputs[0].text + + # Ensure we get prompt ids iff we haven't yet received output tokens + if output_tokens: + assert 1 <= len(token_ids) <= num_scheduler_steps + assert text + assert not output.prompt_token_ids + else: + assert output.prompt_token_ids + prompt_tokens = output.prompt_token_ids + + output_tokens.extend(token_ids) + output_text += text + + output_count += 1 + return prompt_tokens, output_tokens, output_text, output_count + + results = await asyncio.gather( + run("common input prompt", RequestOutputKind.CUMULATIVE), + run("common input prompt", RequestOutputKind.FINAL_ONLY), + run_deltas("common input prompt")) + + # Make sure outputs are the same + prompt_set = set(tuple(prompt_ids) for prompt_ids, _, _, _ in results) + assert len(prompt_set) == 1 + + text_set = set(text for _, _, text, _ in results) + assert len(text_set) == 1 + + tokens_set = set(tuple(ids) for _, ids, _, _ in results) + assert len(tokens_set) == 1 + + cumulative, final, deltas = results + + # output message counts + assert cumulative[3] == deltas[3] + + if num_scheduler_steps == 1: + assert cumulative[3] == 32 + else: + assert 1 < cumulative[3] < 32 + + assert final[3] == 1 @pytest.mark.asyncio(scope="module") async def test_cancellation(async_engine): + scheduler_config = await async_engine.get_scheduler_config() + num_scheduler_steps = scheduler_config.num_scheduler_steps + sampling_params = SamplingParams( temperature=0, - min_tokens=10, - max_tokens=10, + min_tokens=13, + max_tokens=13, ) + stop_at = 5 if num_scheduler_steps == 1 else 1 + + request_id = uid() + i = 0 with pytest.raises(CancelledError): async for output in async_engine.generate("test2", sampling_params, - request_id="test2"): + request_id=request_id): assert not output.finished i += 1 - if i == 5: - await async_engine.abort("test2") + if i == stop_at: + await async_engine.abort(request_id) - assert i == 5 + assert i == stop_at @pytest.mark.asyncio(scope="module") async def test_delayed_generator(async_engine): + scheduler_config = await async_engine.get_scheduler_config() + + if scheduler_config.num_scheduler_steps != 1: + pytest.skip("no need to test this one with multistep") + sampling_params = SamplingParams( temperature=0, min_tokens=10, max_tokens=10, ) - stream = async_engine.generate("test3", - sampling_params, - request_id="test3") + stream = async_engine.generate("test3", sampling_params, request_id=uid()) i = 0 final_output: Optional[RealRequestOutput] = None async for output in stream: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 92e46c7af5162..e07893b29ec38 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -39,7 +39,7 @@ RequestOutputFactory) from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, Sequence, SequenceGroup, SequenceGroupMetadata, SequenceStatus) @@ -225,9 +225,6 @@ def __init__( usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, input_registry: InputRegistry = INPUT_REGISTRY, - # To improve performance, only final requests outputs may be required. - # If this set to true, then no intermediate outputs will be returned. - step_return_finished_only: bool = False, ) -> None: logger.info( "Initializing an LLM engine (v%s) with config: " @@ -295,7 +292,6 @@ def __init__( self.observability_config = observability_config or ObservabilityConfig( ) self.log_stats = log_stats - self.step_return_finished_only = step_return_finished_only if not self.model_config.skip_tokenizer_init: self.tokenizer = self._init_tokenizer() @@ -1273,7 +1269,7 @@ def _process_model_outputs(self, ctx: The virtual engine context to work on request_id: If provided, then only this request is going to be processed - + """ now = time.time() @@ -1378,7 +1374,8 @@ def _process_model_outputs(self, seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) request_output = RequestOutputFactory.create(seq_group) - ctx.request_outputs.append(request_output) + if request_output: + ctx.request_outputs.append(request_output) # When we process a single request, we skip it for the next time, # and invoke the request output callback (if there was final output) @@ -1415,14 +1412,19 @@ def _process_model_outputs(self, seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) - if (seq_group.is_finished() - if self.step_return_finished_only else True): - request_output = RequestOutputFactory.create(seq_group) + request_output = RequestOutputFactory.create(seq_group) + if request_output: ctx.request_outputs.append(request_output) for seq_group in scheduler_outputs.ignored_seq_groups: + params = seq_group.sampling_params + if params is not None and params.output_kind == ( + RequestOutputKind.DELTA) and not seq_group.is_finished(): + continue + request_output = RequestOutputFactory.create(seq_group) - ctx.request_outputs.append(request_output) + if request_output: + ctx.request_outputs.append(request_output) # Immediately process request outputs here (if callback is given) if (ctx.request_outputs diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b1d9f386b6c3e..c01bffeb4289d 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -19,7 +19,7 @@ from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, get_cached_tokenizer) from vllm.transformers_utils.tokenizer_group import TokenizerGroup @@ -642,14 +642,12 @@ def _validate_and_add_requests( raise ValueError("The lengths of prompts and lora_request " "must be the same.") - if isinstance(params, list): - params = [ - self._add_guided_processor(param, guided_options) - if isinstance(param, SamplingParams) else param - for param in params - ] - elif isinstance(params, SamplingParams): - params = self._add_guided_processor(params, guided_options) + for sp in params if isinstance(params, list) else (params, ): + if isinstance(sp, SamplingParams): + self._add_guided_processor(sp, guided_options) + + # We only care about the final output + sp.output_kind = RequestOutputKind.FINAL_ONLY # Add requests to the engine. for i, request_inputs in enumerate(inputs): @@ -709,9 +707,6 @@ def _run_engine( f"output: {0:.2f} toks/s"), ) - # In the loop below, only finished outputs are used - self.llm_engine.step_return_finished_only = True - # Run the engine. outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] total_in_toks = 0 @@ -724,6 +719,7 @@ def _run_engine( if use_tqdm: if isinstance(output, RequestOutput): # Calculate tokens only for RequestOutput + assert output.prompt_token_ids is not None total_in_toks += len(output.prompt_token_ids) in_spd = total_in_toks / pbar.format_dict["elapsed"] total_out_toks += sum( @@ -735,9 +731,6 @@ def _run_engine( f"output: {out_spd:.2f} toks/s") pbar.update(1) - # Restore original behavior - self.llm_engine.step_return_finished_only = False - if use_tqdm: pbar.close() # Sort the outputs by request ID. diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 374196044b7e8..7e9f53b1816d1 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -12,7 +12,8 @@ from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.openai.logits_processors import get_logits_processors from vllm.pooling_params import PoolingParams -from vllm.sampling_params import LogitsProcessor, SamplingParams +from vllm.sampling_params import (LogitsProcessor, RequestOutputKind, + SamplingParams) from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import random_uuid @@ -316,6 +317,8 @@ def to_sampling_params( length_penalty=self.length_penalty, logits_processors=logits_processors, truncate_prompt_tokens=self.truncate_prompt_tokens, + output_kind=RequestOutputKind.DELTA if self.stream \ + else RequestOutputKind.FINAL_ONLY, ) @model_validator(mode="before") @@ -559,6 +562,8 @@ def to_sampling_params( length_penalty=self.length_penalty, logits_processors=logits_processors, truncate_prompt_tokens=self.truncate_prompt_tokens, + output_kind=RequestOutputKind.DELTA if self.stream \ + else RequestOutputKind.FINAL_ONLY, ) @model_validator(mode="before") diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 8ac4caffb37f0..58e42fb5363fb 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -246,8 +246,7 @@ async def create_chat_completion( def get_chat_request_role(self, request: ChatCompletionRequest) -> str: if request.add_generation_prompt: return self.response_role - else: - return request.messages[-1]["role"] + return request.messages[-1]["role"] async def chat_completion_stream_generator( self, @@ -264,15 +263,37 @@ async def chat_completion_stream_generator( # Send response for each token for each request.n (index) num_choices = 1 if request.n is None else request.n - previous_texts = [""] * num_choices previous_num_tokens = [0] * num_choices finish_reason_sent = [False] * num_choices + num_prompt_tokens = 0 + tool_parser: Optional[ToolParser] = self.tool_parser( tokenizer) if self.tool_parser else None + if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): + tool_choice_function_name = request.tool_choice.function.name + else: + tool_choice_function_name = None + + # Determine whether tools are in use with "auto" tool choice + tool_choice_auto = ( + not tool_choice_function_name + and self._should_stream_with_auto_tool_parsing(request)) + + all_previous_token_ids: Optional[List[List[int]]] + if tool_choice_auto: + # These are only required in "auto" tool choice case + previous_texts = [""] * num_choices + all_previous_token_ids = [[]] * num_choices + else: + previous_texts, all_previous_token_ids = None, None + try: async for res in result_generator: + if res.prompt_token_ids is not None: + num_prompt_tokens = len(res.prompt_token_ids) + # We need to do it here, because if there are exceptions in # the result_generator, it needs to be sent as the FIRST # response (by the try...catch). @@ -305,10 +326,10 @@ async def chat_completion_stream_generator( and request.stream_options.include_usage): # if continuous usage stats are requested, add it if request.stream_options.continuous_usage_stats: - prompt_tokens = len(res.prompt_token_ids) - usage = UsageInfo(prompt_tokens=prompt_tokens, - completion_tokens=0, - total_tokens=prompt_tokens) + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=0, + total_tokens=num_prompt_tokens) chunk.usage = usage # otherwise don't else: @@ -344,12 +365,10 @@ async def chat_completion_stream_generator( request.stream_options.include_usage): if (request.stream_options. continuous_usage_stats): - prompt_tokens = len( - res.prompt_token_ids) usage = UsageInfo( - prompt_tokens=prompt_tokens, + prompt_tokens=num_prompt_tokens, completion_tokens=0, - total_tokens=prompt_tokens) + total_tokens=num_prompt_tokens) chunk.usage = usage else: chunk.usage = None @@ -360,65 +379,66 @@ async def chat_completion_stream_generator( first_iteration = False for output in res.outputs: - i = output.index if finish_reason_sent[i]: continue - delta_token_ids = output.token_ids[previous_num_tokens[i]:] - out_logprobs = output.logprobs[ - previous_num_tokens[i]:] if output.logprobs else None - if request.logprobs and request.top_logprobs is not None: - assert out_logprobs is not None, ( + assert output.logprobs is not None, ( "Did not output logprobs") logprobs = self._create_chat_logprobs( - token_ids=delta_token_ids, - top_logprobs=out_logprobs, + token_ids=output.token_ids, + top_logprobs=output.logprobs, tokenizer=tokenizer, num_output_top_logprobs=request.top_logprobs, ) else: logprobs = None - delta_text = output.text[len(previous_texts[i]):] - delta_message: Optional[DeltaMessage] = None + delta_text = output.text + delta_message: Optional[DeltaMessage] # handle streaming deltas for tools with named tool_choice - if (request.tool_choice and type(request.tool_choice) is - ChatCompletionNamedToolChoiceParam): + if tool_choice_function_name: delta_message = DeltaMessage(tool_calls=[ DeltaToolCall(function=DeltaFunctionCall( - name=request.tool_choice.function.name, + name=tool_choice_function_name, arguments=delta_text), index=i) ]) # handle streaming deltas for tools with "auto" tool choice - elif (self._should_stream_with_auto_tool_parsing(request) - and tool_parser): + elif tool_choice_auto: + assert previous_texts is not None + assert all_previous_token_ids is not None + assert tool_parser is not None + #TODO optimize manipulation of these lists + previous_text = previous_texts[i] + previous_token_ids = all_previous_token_ids[i] + current_text = previous_text + delta_text + current_token_ids = previous_token_ids + list( + output.token_ids) + delta_message = ( tool_parser.extract_tool_calls_streaming( - previous_text=previous_texts[i], - current_text=output.text, + previous_text=previous_text, + current_text=current_text, delta_text=delta_text, - previous_token_ids= \ - output.token_ids[ - :-1 * len(delta_token_ids) - ], - current_token_ids=output.token_ids, - delta_token_ids=delta_token_ids - ) - ) + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=output.token_ids)) + + # update the previous values for the next iteration + previous_texts[i] = current_text + all_previous_token_ids[i] = current_token_ids # handle streaming just a content delta else: delta_message = DeltaMessage(content=delta_text) # set the previous values for the next iteration - previous_texts[i] = output.text - previous_num_tokens[i] = len(output.token_ids) + previous_num_tokens[i] += len(output.token_ids) # if the message delta is None (e.g. because it was a # "control token" for tool calls or the parser otherwise @@ -445,13 +465,12 @@ async def chat_completion_stream_generator( # handle usage stats if requested & if continuous if (request.stream_options and request.stream_options.include_usage): - if (request.stream_options.continuous_usage_stats): - prompt_tokens = len(res.prompt_token_ids) + if request.stream_options.continuous_usage_stats: completion_tokens = len(output.token_ids) usage = UsageInfo( - prompt_tokens=prompt_tokens, + prompt_tokens=num_prompt_tokens, completion_tokens=completion_tokens, - total_tokens=prompt_tokens + + total_tokens=num_prompt_tokens + completion_tokens, ) chunk.usage = usage @@ -482,7 +501,7 @@ async def chat_completion_stream_generator( tool_parser.prev_tool_call_arr[index].get( "arguments", {})) - # get what we've streamed so for for arguments + # get what we've streamed so far for arguments # for the current tool actual_call = tool_parser.streamed_args_for_tool[ index] @@ -500,7 +519,6 @@ async def chat_completion_stream_generator( ]) # Send the finish response for each request.n only once - prompt_tokens = len(res.prompt_token_ids) choice_data = ChatCompletionResponseStreamChoice( index=i, delta=delta_message, @@ -518,13 +536,12 @@ async def chat_completion_stream_generator( model=model_name) if (request.stream_options and request.stream_options.include_usage): - if (request.stream_options.continuous_usage_stats): - prompt_tokens = len(res.prompt_token_ids) + if request.stream_options.continuous_usage_stats: completion_tokens = len(output.token_ids) usage = UsageInfo( - prompt_tokens=prompt_tokens, + prompt_tokens=num_prompt_tokens, completion_tokens=completion_tokens, - total_tokens=prompt_tokens + + total_tokens=num_prompt_tokens + completion_tokens, ) chunk.usage = usage @@ -538,10 +555,11 @@ async def chat_completion_stream_generator( # is sent, send the usage if (request.stream_options and request.stream_options.include_usage): + completion_tokens = previous_num_tokens[i] final_usage = UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=previous_num_tokens[i], - total_tokens=prompt_tokens + previous_num_tokens[i], + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens, ) final_usage_chunk = ChatCompletionStreamResponse( @@ -680,6 +698,7 @@ async def chat_completion_full_generator( or "") choice.message.content = full_message + assert final_res.prompt_token_ids is not None num_prompt_tokens = len(final_res.prompt_token_ids) num_generated_tokens = sum( len(output.token_ids) for output in final_res.outputs) @@ -789,9 +808,9 @@ def _should_check_for_unstreamed_tool_arg_tokens( return bool( # if there is a delta message that includes tool calls which # include a function that has arguments - self.enable_auto_tools and self.tool_parser and delta_message + output.finish_reason is not None + and self.enable_auto_tools and self.tool_parser and delta_message and delta_message.tool_calls and delta_message.tool_calls[0] and delta_message.tool_calls[0].function and delta_message.tool_calls[0].function.arguments is not None - and output.finish_reason is not None ) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 34f1200753f8d..42142efb5f23e 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -223,9 +223,10 @@ async def completion_stream_generator( tokenizer: AnyTokenizer, ) -> AsyncGenerator[str, None]: num_choices = 1 if request.n is None else request.n - previous_texts = [""] * num_choices * num_prompts + previous_text_lens = [0] * num_choices * num_prompts previous_num_tokens = [0] * num_choices * num_prompts has_echoed = [False] * num_choices * num_prompts + num_prompt_tokens = [0] * num_prompts try: async for prompt_idx, res in result_generator: @@ -233,6 +234,10 @@ async def completion_stream_generator( prompt_logprobs = res.prompt_logprobs prompt_text = res.prompt + # Prompt details are excluded from later streamed outputs + if res.prompt_token_ids is not None: + num_prompt_tokens[prompt_idx] = len(res.prompt_token_ids) + delta_token_ids: GenericSequence[int] out_logprobs: Optional[GenericSequence[Optional[Dict[ int, Logprob]]]] @@ -244,6 +249,7 @@ async def completion_stream_generator( assert request.max_tokens is not None if request.echo and request.max_tokens == 0: + assert prompt_token_ids is not None assert prompt_text is not None # only return the prompt delta_text = prompt_text @@ -252,6 +258,7 @@ async def completion_stream_generator( has_echoed[i] = True elif (request.echo and request.max_tokens > 0 and not has_echoed[i]): + assert prompt_token_ids is not None assert prompt_text is not None assert prompt_logprobs is not None # echo the prompt and first token @@ -266,11 +273,9 @@ async def completion_stream_generator( has_echoed[i] = True else: # return just the delta - delta_text = output.text[len(previous_texts[i]):] - delta_token_ids = output.token_ids[ - previous_num_tokens[i]:] - out_logprobs = output.logprobs[previous_num_tokens[ - i]:] if output.logprobs else None + delta_text = output.text + delta_token_ids = output.token_ids + out_logprobs = output.logprobs if request.logprobs is not None: assert out_logprobs is not None, ( @@ -280,13 +285,13 @@ async def completion_stream_generator( top_logprobs=out_logprobs, num_output_top_logprobs=request.logprobs, tokenizer=tokenizer, - initial_text_offset=len(previous_texts[i]), + initial_text_offset=previous_text_lens[i], ) else: logprobs = None - previous_texts[i] = output.text - previous_num_tokens[i] = len(output.token_ids) + previous_text_lens[i] += len(output.text) + previous_num_tokens[i] += len(output.token_ids) finish_reason = output.finish_reason stop_reason = output.stop_reason @@ -307,8 +312,8 @@ async def completion_stream_generator( and request.stream_options.include_usage): if (request.stream_options.continuous_usage_stats or output.finish_reason is not None): - prompt_tokens = len(prompt_token_ids) - completion_tokens = len(output.token_ids) + prompt_tokens = num_prompt_tokens[prompt_idx] + completion_tokens = previous_num_tokens[i] usage = UsageInfo( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, @@ -356,6 +361,7 @@ def request_output_to_completion_response( for final_res in final_res_batch: prompt_token_ids = final_res.prompt_token_ids + assert prompt_token_ids is not None prompt_logprobs = final_res.prompt_logprobs prompt_text = final_res.prompt @@ -411,9 +417,9 @@ def request_output_to_completion_response( ) choices.append(choice_data) + num_generated_tokens += len(output.token_ids) + num_prompt_tokens += len(prompt_token_ids) - num_generated_tokens += sum( - len(output.token_ids) for output in final_res.outputs) usage = UsageInfo( prompt_tokens=num_prompt_tokens, diff --git a/vllm/outputs.py b/vllm/outputs.py index e091b576f5972..85ea9196b25df 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -5,6 +5,7 @@ from typing import Union from vllm.lora.request import LoRARequest +from vllm.sampling_params import RequestOutputKind from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, SequenceGroup, SequenceStatus) @@ -92,7 +93,7 @@ def __init__( self, request_id: str, prompt: Optional[str], - prompt_token_ids: List[int], + prompt_token_ids: Optional[List[int]], prompt_logprobs: Optional[PromptLogprobs], outputs: List[CompletionOutput], finished: bool, @@ -113,19 +114,26 @@ def __init__( self.encoder_prompt_token_ids = encoder_prompt_token_ids @classmethod - def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": - if seq_group.sampling_params is None: + def from_seq_group(cls, + seq_group: SequenceGroup) -> Optional["RequestOutput"]: + sampling_params = seq_group.sampling_params + if sampling_params is None: raise ValueError( "Sampling parameters are missing for a CompletionRequest.") + finished = seq_group.is_finished() + if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and ( + not finished): + return None + seqs = seq_group.get_seqs() if len(seqs) == 1: top_n_seqs = seqs else: # Get the top-n sequences. - n = seq_group.sampling_params.n - if seq_group.sampling_params.use_beam_search: + n = sampling_params.n + if sampling_params.use_beam_search: sorting_key = lambda seq: seq.get_beam_search_score( - seq_group.sampling_params.length_penalty) + sampling_params.length_penalty) else: sorting_key = lambda seq: seq.get_cumulative_logprob() sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) @@ -135,26 +143,49 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": # NOTE: We need omit logprobs here explicitly because the sequence # always has the logprobs of the sampled tokens even if the # logprobs are not requested. - include_logprobs = seq_group.sampling_params.logprobs is not None - text_buffer_length = seq_group.sampling_params.output_text_buffer_length - outputs = [ - CompletionOutput( - seqs.index(seq), - seq.get_output_text_to_return(text_buffer_length), - seq.data._output_token_ids, - seq.get_cumulative_logprob() if include_logprobs else None, - seq.output_logprobs if include_logprobs else None, - SequenceStatus.get_finished_reason(seq.status), - seq.stop_reason) for seq in top_n_seqs - ] + include_logprobs = sampling_params.logprobs is not None + text_buffer_length = sampling_params.output_text_buffer_length + delta = sampling_params.output_kind == RequestOutputKind.DELTA + + outputs = [] + include_prompt = True + for seq in top_n_seqs: + output_text = seq.get_output_text_to_return( + text_buffer_length, delta) + output_token_ids = seq.get_output_token_ids_to_return(delta) + output_logprobs = seq.output_logprobs if include_logprobs else None + + if delta: + # Slice logprobs delta if applicable + if output_logprobs: + output_logprobs = output_logprobs[-len(output_token_ids):] + # Don't include prompt if this is after the first output + # containing decode token ids + if include_prompt and seq.get_output_len() > len( + output_token_ids): + include_prompt = False + + outputs.append( + CompletionOutput( + seqs.index(seq), output_text, output_token_ids, + seq.get_cumulative_logprob() if include_logprobs else None, + output_logprobs, + SequenceStatus.get_finished_reason(seq.status), + seq.stop_reason)) # Every sequence in the sequence group should have the same prompt. - prompt = seq_group.prompt - prompt_token_ids = seq_group.prompt_token_ids - encoder_prompt = seq_group.encoder_prompt - encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids - prompt_logprobs = seq_group.prompt_logprobs - finished = seq_group.is_finished() + if include_prompt: + prompt = seq_group.prompt + prompt_token_ids = seq_group.prompt_token_ids + encoder_prompt = seq_group.encoder_prompt + encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids + prompt_logprobs = seq_group.prompt_logprobs + else: + prompt = None + prompt_token_ids = None + encoder_prompt = None + encoder_prompt_token_ids = None + prompt_logprobs = None finished_time = time.time() if finished else None seq_group.set_finished_time(finished_time) return cls(seq_group.request_id, diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index c83ed5cca6791..5edbc8e424e81 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -1,6 +1,6 @@ """Sampling parameters for text generation.""" import copy -from enum import IntEnum +from enum import Enum, IntEnum from functools import cached_property from typing import Any, Callable, Dict, List, Optional, Set, Union @@ -33,6 +33,15 @@ class SamplingType(IntEnum): to sample from.""" +class RequestOutputKind(Enum): + # Return entire output so far in every RequestOutput + CUMULATIVE = 0 + # Return only deltas in each RequestOutput + DELTA = 1 + # Do not return intermediate RequestOuputs + FINAL_ONLY = 2 + + class SamplingParams( msgspec.Struct, omit_defaults=True, # type: ignore[call-arg] @@ -147,6 +156,7 @@ class SamplingParams( logits_processors: Optional[Any] = None include_stop_str_in_output: bool = False truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None + output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE # The below fields are not supposed to be used as an input. # They are set in post_init. @@ -182,6 +192,7 @@ def from_optional( logits_processors: Optional[List[LogitsProcessor]] = None, truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None, + output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE, ) -> "SamplingParams": return SamplingParams( n=1 if n is None else n, @@ -213,6 +224,7 @@ def from_optional( spaces_between_special_tokens=spaces_between_special_tokens, logits_processors=logits_processors, truncate_prompt_tokens=truncate_prompt_tokens, + output_kind=output_kind, ) def __post_init__(self) -> None: @@ -317,6 +329,9 @@ def _verify_args(self) -> None: raise ValueError( "stop strings are only supported when detokenize is True. " "Set detokenize=True to use stop.") + if self.best_of != self.n and self.output_kind == ( + RequestOutputKind.DELTA): + raise ValueError("best_of must equal n to use output_kind=DELTA") def _verify_beam_search(self) -> None: if self.best_of == 1: diff --git a/vllm/sequence.py b/vllm/sequence.py index 135586831e680..98a8b73586062 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -5,8 +5,9 @@ from array import array from collections import defaultdict from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping, - Optional, Set, Tuple, Union, cast) +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional +from typing import Sequence as GenericSequence +from typing import Set, Tuple, Union, cast import msgspec import torch @@ -407,6 +408,10 @@ def __init__( self.status = SequenceStatus.WAITING self.stop_reason: Union[int, str, None] = None + # These are used to keep track of delta outputs + self._last_token_ids_offset: int = 0 + self._last_output_text_offset: int = 0 + # Used for incremental detokenization self.prefix_offset = 0 self.read_offset = 0 @@ -462,11 +467,35 @@ def prompt_adapter_id(self) -> int: return self.prompt_adapter_request.prompt_adapter_id \ if self.prompt_adapter_request else 0 - def get_output_text_to_return(self, buffer_length: int): + def get_output_text_to_return(self, buffer_length: int, + delta: bool) -> str: + """If delta is True, only new text since the last call to + this method is returned""" + # We return the full output text if the sequence is finished. truncate = buffer_length and not self.is_finished() - return self.output_text[:-buffer_length] if truncate else ( - self.output_text) + if not delta: + return self.output_text[:-buffer_length] if truncate else ( + self.output_text) + length = len(self.output_text) - buffer_length + last_offset = self._last_output_text_offset + if last_offset < length: + self._last_output_text_offset = length + return self.output_text[last_offset:length] + return "" + + def get_output_token_ids_to_return(self, + delta: bool) -> GenericSequence[int]: + """If delta is True, only new tokens since the last call to + this method are returned""" + if not delta: + return self.get_output_token_ids() + length = self.get_output_len() + last_offset = self._last_token_ids_offset + if last_offset < length: + self._last_token_ids_offset = length + return self.data._output_token_ids[last_offset:] + return () def hash_of_block(self, logical_idx: int) -> int: # TODO This can produce incorrect hash when block size > prompt size From 019877253be473bf0c12daaf2c29022150402052 Mon Sep 17 00:00:00 2001 From: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com> Date: Thu, 12 Sep 2024 17:01:50 -0400 Subject: [PATCH 16/98] [Bugfix] multi-step + flashinfer: ensure cuda graph compatible (#8427) --- vllm/attention/backends/flashinfer.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 58d62e02e8733..4054d337316fe 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -597,9 +597,19 @@ def build(self, seq_lens: List[int], query_lens: List[int], # The shape of graph_block_tables is # [max batch size, max context len // block size]. input_block_tables = self.runner.graph_block_tables[:batch_size] + max_blocks = input_block_tables.shape[1] for i, block_table in enumerate(self.block_tables): if block_table: - input_block_tables[i, :len(block_table)] = block_table + num_blocks = len(block_table) + if num_blocks <= max_blocks: + input_block_tables[i, :num_blocks] = block_table + else: + # It may be possible to have more blocks allocated due + # to lookahead slots of multi-step, however, they are + # not used anyway, so can be safely ignored. + input_block_tables[ + i, :max_blocks] = block_table[:max_blocks] + block_tables = torch.from_numpy(input_block_tables).to( device, non_blocking=True) From c16369455f9568b709d286be0857375a860842ab Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Thu, 12 Sep 2024 14:06:51 -0700 Subject: [PATCH 17/98] [Hotfix][Core][VLM] Disable chunked prefill by default and prefix caching for multimodal models (#8425) --- vllm/engine/arg_utils.py | 12 +++++++++++- vllm/model_executor/models/__init__.py | 4 ++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 6f58c39162087..b5eba9ca3727a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -843,6 +843,13 @@ def create_engine_config(self) -> EngineConfig: device_config = DeviceConfig(device=self.device) model_config = self.create_model_config() + if model_config.is_multimodal_model: + if self.enable_prefix_caching: + logger.warning( + "--enable-prefix-caching is currently not " + "supported for multimodal models and has been disabled.") + self.enable_prefix_caching = False + cache_config = CacheConfig( block_size=self.block_size if self.device != "neuron" else self.max_model_len, # neuron needs block_size = max_model_len @@ -874,7 +881,10 @@ def create_engine_config(self) -> EngineConfig: # If not explicitly set, enable chunked prefill by default for # long context (> 32K) models. This is to avoid OOM errors in the # initial memory profiling phase. - if use_long_context: + + # Chunked prefill is currently disabled for multimodal models by + # default. + if use_long_context and not model_config.is_multimodal_model: is_gpu = device_config.device_type == "cuda" use_sliding_window = (model_config.get_sliding_window() is not None) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 2c01eb380c375..250f75b639a5b 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -90,12 +90,12 @@ "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), - "UltravoxModel": ("ultravox", "UltravoxModel"), - "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), + "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), + "UltravoxModel": ("ultravox", "UltravoxModel"), } _CONDITIONAL_GENERATION_MODELS = { "BartModel": ("bart", "BartForConditionalGeneration"), From b61bd98f907180c70f65e21505b3af6d1cc2bf36 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Thu, 12 Sep 2024 15:05:35 -0700 Subject: [PATCH 18/98] [CI/Build] Disable multi-node test for InternVL2 (#8428) --- tests/distributed/test_pipeline_parallel.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index d2219eed988e1..9a02f468f0a93 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -32,9 +32,10 @@ (1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), (2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), (2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "ray"), - (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "ray"), - (1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "ray"), + # TODO: Enable internVL2 in a separate test if needed + # (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "ray"), + # (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "ray"), + # (1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "ray"), ], ) @fork_new_process_for_each_test From d31174a4e1ff7ac1efbdb5d89a24f0e477f95cc8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 13 Sep 2024 00:21:51 +0200 Subject: [PATCH 19/98] [Hotfix][Pixtral] Fix multiple images bugs (#8415) --- tests/conftest.py | 2 +- tests/models/fixtures/pixtral_chat.pickle | Bin 0 -> 20865 bytes .../fixtures/pixtral_chat_engine.pickle | Bin 0 -> 20858 bytes tests/models/test_pixtral.py | 188 ++++++++++++++---- vllm/model_executor/models/pixtral.py | 83 ++++---- 5 files changed, 196 insertions(+), 77 deletions(-) create mode 100644 tests/models/fixtures/pixtral_chat.pickle create mode 100644 tests/models/fixtures/pixtral_chat_engine.pickle diff --git a/tests/conftest.py b/tests/conftest.py index c850e60a9ca6c..620f8b4983517 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -658,8 +658,8 @@ def generate( outputs.append((req_sample_output_ids, req_sample_output_strs)) return outputs + @staticmethod def _final_steps_generate_w_logprobs( - self, req_outputs: List[RequestOutput], ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = [] diff --git a/tests/models/fixtures/pixtral_chat.pickle b/tests/models/fixtures/pixtral_chat.pickle new file mode 100644 index 0000000000000000000000000000000000000000..43d4c883c3a49c314e1b5f529f6505f36b511181 GIT binary patch literal 20865 zcmai6X>?r0mF``O7fD8zyn)MBykpsBaY!K0W;qza8z2Tu2pQCtT2fnWwP>|0oET`# zl1bu8>|innNP@^A3ke~FZ4NQM{E+(R8(ZeW|GWbUFID5c)#$w&^Uj&(ojcE)(&824_$L04L2)}%VqqOU7A&?i!2RW{YuE!uOPqAyb@WO_P9t_S|}KrYvw z?h$?c`Hoaux_Ju>a-|w){_Ip%^U?Eh_%Kv4jNmjI$=&tJkJ7t?4lF~C%umd+Pa znI5$DX`(zX4!Fo!24%IPj_sB9Es=Srf7Dk?>Lx;N8SPz!P5)bc>oQ6)OksX~7~t&mry9i&V= zRNqP_QB0L+&-C?XQ|lD+gBj(KsHaM_<$7@CtfF^aSMW2n?@t!wmTiX7(1VdbtgUi$ zwHR!D{l_+LZY3ZshKfhDw{ns@YNYE4K`?Tuwrs9X17LH4!oA6G z^@a8#nIo^trt+N%cJ4gaT>bq9Mk-*N0BTAF1Nsk#@6h&8MGW^70JdEEKRg_xL*qoq z#&tsU_V@PUzT#3maDJ2GUA1)+>ek?O=yR6E*TgD!Bw*Z%4VsSv3`A>E^J zZ1!Zm#i0&6drAY3?3b5+u{F@3BYV@;&Y=Fg7J?Eg^s5V1@a=Zuorxhbg(~cJd*Twu zm~5D5Kos5h9z;{C?s8F+EsFT(`_*hH{+!&+j~(Igcgb;!bvHjIh#ZQD3vJQk>?Te& zD3k#6rrDi6&SlqNQKX@ezCx`*`ZO0c^+jU~HB^cL*tNUT7-34riU|(b-*fVNNusCY zg7N~Knb#M%z_%@s_Nfl=Wfip>+*gK-J-xL^_E5+eu?muq&U~()yzhj|9h0(t&n`(= zk^>n;Hx7I*dxJMv6v~N;-mFp5;BL(RuAv@1GAao-qosO=7`bBVAdQC-!bc;52Tk75bR3F+N&ud%A7IUf#r+l$@_PkVO_dMj#Jl6pzPn__-cbCNgjGVjhh*mqmLz6*Zj zb+>$uKVDabS?#1#yj09r$?=NW<%bZ*mw_^dm2d@{APUZV>BZq8PZAY24>Fm0) zS#I?)Xk~4xuOQZdFJi(WAaR9kS=r0?cXp=RMQ=LQ4F(JObfK*)-6z0l(|M5+@JjA) zUnlyy*V%NLC9M8Gx>gO*oWlBb9RYJb(Va~Hz~3iHdafLAp>jMIA){-?wh}_E+cHEP zU|~?Z^aeo)TsNIvx13SOU$1o!8Wb*DYz1U%xNw;8vS``%k-x7<=8*CqVYr!$(41mq zke(~bUq(U+=16ms27#0#&5H9pW@NrF5a8Uk8%)%(=L_s0$TmFBP(8zSe1H4CqlBdj zuBZmO$%V-1^ch0|+lZKnbWdaRH{L>!ZdA;lj~Kt%i9>nB4@_=Hv2hj_R0!jdSYKmc zYLay%s!8T4;1fSjw+76e6#wKOxx-f8a*>ANTvv*={(L6aj~Q{xJ5_vr$66?=hy%2s z2=XMbW{NRLI&2ciQ_nhYfgpOQBY}RdQctCXr-6?*+&Dx~wL^`T28OQ5BngH}4wD8} zSs?1D!sKq0G;pfCe~?U~7Mlh>%!-O5lLka9l2y5)(!sW5|CN*erU{}Bo1eh07yKYZ#Pm=AmppoA!FhDePkS07sp3ed^b7B@@qhb@FWsV~=$hFU(ofqPuM;}bm` zDi|ZuSM)*!*av$p4tw-uS8dg9gGC)MZ@CVWj;r^BgZ0}#o$knp;f9XGZ{5v@pM2_V zg4CjtSbs-{s>%MsR$}@F141nks&JIakxdN^+rAFe57U@9blm6L>T zFFF?e>ZhO6%*d7SZG=;D>ZuM&&Mqco$a!l63PEiRE@;l60nu8d*O8Yrvg~he$e>U{ zLaZ|*uDruFHeVPRN~jREUP%jvzR_&`#uds z4d5uM#^srKgT)|!R$jBrUp6i?#y~=s)EN=qGoSqY^dVw~(_gkA{p_+|4H5+B-mUp~ z1=g_gEkQWs?p9zW<>n*{R0}XlaX59SF^3Cai$k>nyeNHO&-e7esKF)3i_-Y-PHH8B zC?qIKxr_-rf2yq=B7;}2jTT5U%2%%eie%{Ol_G%^rEmQ8)Ip+$GT?nUT2UGRQ3ssI zmM=;xj@_i~p^(EDrL{Ld-b!PW7G0FKT%6%M#UzVEHC3w6ZO^M$k5Y7r8@0at&a@%2 zs+_C2S_L*shNG36B|$(7`Y6xXW$ks#wFwZRGWIXCfT_37k$AMZT8&W>FcyQP-GEV? zr?`at*?ZRO0g-d~7^KXL#F?Q6jlk)rBGKao=`3j=$oe@jF^;jGlQhp2YoQ8FZ>mc~ z&b1h19X=L_D7w6&d1>STOrL%csQKab$~+m?&lOp!SJ*1a`NBZZPa?ERe3lnlh&#)? zc3KE86=DK>yY^|eyr_A3*ZE1pD2Je;<`}mc@%O*y?lv*urB>q-%-QU2&f)zRXgf$b zE>`s91_0IarLEjuYb*|V$HZ6r)o~>+YyuEYv^io>b6s9*>&T{JiSo4q6kx=i0h)fmOh>k^LvXdQXx@CihZD zWjOL5TiuakKVJZ-AiFE1^Lc8^wp=#M2d}VD)Dl8lJZBqg0iZ&~(Yf!tW7R!-Q<26( z_RtZ7LqSzGsdc~;MF;KgUHM3o&?ulfIOW5ESsml=vrt+JeF(gcdDK5GsKld6yyZ8q z2pSfb!u6IC*Q;J&fT%{ho>~Y#?mK^2Xk+MsXcd+O2~^qAyNnVy1bpU!M@Ov&1*rwJO80;7-{NafpcOeKRS z+iPXFHi^qowy!(!`g^oJlwkV0ycv6x#SndJ#(Ebs`ZIs_0`NkoB>EVzIAfDeap}+e zlR;51s6$^voT~9huj)+8p+Bsj@gyk}-p3oz~_vbqxIr-Y}N&=yZIP?GnA}QD;7-#)@6yBis^9@Vy2Mh>=;F*VnwN5!-K zImxgf5VFQ>pBD^^MQD(oD`%l(e-87fC9MX6q$6T3uIO&Jn3Bgpjzw0Gb3}X{AfxoA z4(&$0zQ8NjjiMIRtOmI4;eEj`?M1;armFtck3>tM0eV$Crg{0exBen;d>ntVe6{yx zyrtnvcmmLArqy-)GhXZsZ`U}l@i_0oYV2;+#JzjRc{>+)%NFSd*Bd1buEX}TRDIh# zdN<3aKzhT++bNezI)sasn@iB4BYri=%oTTR%0Z{vubj{}5=PClI3gum~g+*fm zxON=`6P)IvR`cAY$;(SQB`YjV2$6CV9H?C<8A3==)Vl3GbP0+k4s&5mYq$@ZLTQ^yh-rN3yTM7Sd`#2 z2}&CP{zutKNkXEG(Hj7X2geG+kzd&c06P+w6-x~awU{IEk|x2_gf@jTDseX*y-6^s z;=hMzT(ubXV85hEkQ@sj6C2U8O@dj{U{FEDs7-?T))q<_tx2$;Xsju38of#I>jN8; zfdDfjyh$*=`xTABMPN+==5(^|_`QORp@zWecpg;$fI-nBlxrnSv=~~1;HI8(Ctkhm zyvp2bz(pwAD){iF;vg|2KcZ{b=))yv>;prNg2S{58ulBYT82?t1yig!3JHA@uQc#h z!Tl9a0zY&S3WX|Z`_+C78__k!+%x zf7L43bLI7#fpV@gtpfjT3jsB%IN)JX{ZcAP1M;xo&495Mh4kHK06!3swYhr+$sF=G zja}tatYffkkH23K5U8(yI~)puzw@(vs1S+-0qUziAqX!Ljt&_Jve#V&cHv4cv{rz+ z${Kkc50ZBcTasYPFUoZchM#{)TS6}%rDG6!D_0~t zD-c`cW*XDrRA-#o!#gczeZ$V8tZyUUO%cFW|ln}HH_(6Z9%}^UX~8cU^rRFf^|M#}p@5 zarx((9Ii_Vxsu+&|6T3qaC-;5D=>4K!I22;>8Qh*Z?mh0O+Pj;TpIA3Od6Y9xH5-y z1$M{Iam17y^zyURe@13&C<=1UQh8S(a?KGMhKpiyR=(rLEe3@IiO5y#a&&drSVA3@ zaF+^h(FaGQ{0Mhf;|2la+81e0xZoHH*t`*D=|JKl&1;)UcZ%~1*{}WVHpM1m`0&9%th$3{{lG*=y|8G2TK^NfZu1%_6Xg@t?h7yQYUa z98c$mj?r_S;-^er5wp)xK;Y&4tvKGIkiYJt^RlSMLWsBefJJSzJkkE;yj_B(6?tpF zV2AOtwt0ml2)VyKjOSm81u)lrc%xuak1<8wZyE*cSYJ-IDC8a->nr+k&44K3V!{y@tPi} zSp8r`Wp%s>Fh0|s3>4@lK51xMU(kf@;2X%~3BLlFQ`MND@n z)JpJ9!SN5&x6;_S5;i`0r(lL$ph4(Lqjd_JLIwpgRFLpa!7Tfnyr_JqV746#ne$Eo z3nyRP_J^T>x(+8Ew@sniL4p8rMTw=la+ZCLl<~4Z-w^LwJKH$*txb~Zyf-=dJR zZYQX@W-0!*Cr3UPmas|jr@6f+Ij?B4Vw~ylak!zfQbIP{ycsB)Qi2x&OG4&c3RssZ z0+xJWo&(=eiU5CP7^v!(Y?&fpVba(@qWU9;)%EVQEV2-E>qztID8`?{AVp&Z=Qq-P z0YZ44uvqdU0P6%bw?7h&CI6KJgRPb?_%R&bqb*)399(~1l0b?yRSJ)=d+D*0=L&)# z=}NI84NcLMx=H>FTlDxg160(-N*r$MmmN;B%0kdDD#Ur|T(9~x75S@niG%i6KkDAQ ziDByjP5vA6kH%{+mVttB2O_qy9?kK95-hNjrSy67?k(NZJOW z-Gd7FkndX%DlkLx7kK>jd@C7J)33ik<>#7p`yNUX2C2K}nkDB84S54pE1*ABRV%Rm zB4nm;tRJp!%1=aQf&tX zR6$FuEw@I0zRRwR$foNI@(=}lZlAyAAwkqoLHW52Z~CIf(EQWay>Ge;DI*p`^H2Bc z9%sqnIEz90Fk4+2@~~T(sqq$u+~J;uy|Te()z9zgl>zngiM*Wj$U{Q}LhfY~`KcD7 zsjFLj6MpK>{E?>{GvqpbU2yqKU2ODj;%)is43JzO@9UfBUG5%e`1kx|7>Jp2DnX4+ z^qtj5xg(Fc?gI@&0qPGKz^B;o=P&>nWBdNZQ?JlS)RDli5$X0kXE?=e&d~-%$ zuyZG>ocre_jaTACue*0$WR-={IzT~ac2U(9cZZe}Iz%7=E^evLGEVRg*0Ze$&u%3g z%7C1$UwO(t_NEwPL8v1l)YZDTpLfQSwV@l6WQSIuJvE>V@%^B_Xi%ttJXdW&JR93F z?dBmeh5=kQ8;djz_BlC@jmx`PT#mIcqHU5gKJm4i&5nM^#KS!VXBxIW zKJt0dGGF~btjyu;eCecRCUfIS*U@fXrjo8B-BIL1VN4+F@bS2;{A_lyG|3P!9wjZ? zIx-w^Et8fl2pJpb8agfGIoIM%CZSS}K{*#so0pttAiz~E6w9OJlA?J|E`d#S%DN0a3~|L?yjI%Cc-~xQC!51FFCek)d}+@*+LO2!|58JdtttN zUtZQto7hS~6au%O3ptgFDsSTaZhC5hhN6yuA0N49!7YLeq6#+hx2|<5PQU3X4MZUc zQGolnD$TE)<-(e(cNfVXDuEO-mr^_F<8B5ln_3J?ITc>}F^kdY-uCt%{mL1wWDZ5Z zEeX6EgtNTl$L^?!=bqOvS_}A^S)n@Jj9S)M6iSGSHR@&j&e5)=+RR&m2GsoFor2?C zSnZkJG9M1*pr4VSo6Jk>&wI$BE6txKW;0(WFZEPNrQuuL(lO1kM7?r=sE4BCSBxvRpOH6MM9wj*a7V&eeb(0X&!5VC;)bF zI6F~O!Ws)QfYDNzjtedYzdyWvWLVP3xDxhn-0<9RzCB9=5#{a=+g3T!*H;7h2ZeAs dSkO2~JgPq^{|7y`q4EF# literal 0 HcmV?d00001 diff --git a/tests/models/fixtures/pixtral_chat_engine.pickle b/tests/models/fixtures/pixtral_chat_engine.pickle new file mode 100644 index 0000000000000000000000000000000000000000..19dbeaecc8dfffcddda1d66f83f24a1ed1ec16fc GIT binary patch literal 20858 zcmb_kdw5mVmA~)ghDQk%jC4e= zV(sS)ZAH-;AMM;u?WnC)tE1nHC}%_cs?}OS>(J_m5d~3H6jAX-X05%>-e<3HL%*4C z{Q8IW{Ib^XxA$86ti9Jhw-~uK`|W`G=dumc;eXD-n3Jk@-feK+k2vQ{bk3dOjBj;{ zQT)F*UE?es<D_T@!NEbB_OtrYFq4w26k3YqQ>k?n^6 z+?UO^r@KYIH@7_1mTuVqPSz+t>qcDB)wQZ6pT4{|-QAYnu&&~qY)4NnyKKWP8)jU) z0YK|2y1f5#ru*c(RsFzoT_l(4Ub*4y;JOj*={B(3zO;~CneN`OE)0w|^e%HI)^=9V zpLbTBDMS0tOTwRZ4m&&4IecB@{LV}sLE|DeXw+x-3Ic@z6r2l)*0=u<4A__kgNXp9 z;*4~zkjiwUtL)*NlaCKwnBs32xIXiCMq z4Hnk@;GJ!f{E`Y43TX1ZXKmPLg*_fUR8c9~(>0RUqv(A*6x>YRj1t*JMIn(wHS40b z!Exb6&o!zEi@F!eVV6Z z`1SwtV5Fj_x2FdOueE6#2i9cwCdeLz)H$hTxm0^LhXeD}-`W%>&DKa%5)ptz(1|QN zChgN8($SW}bW}RVw_4WRcW=HN+~T;Pj!L~IqQwX7aB@FCYkrY{s3a=-GKEe=^<-^x z0|tdMVxns`%m>2qEXRX)-<2R^sG(AHWwR?i;1^76WDyEv?J(}k(LZuZ9piaQlvUC71fG%f5`@p+J3wAg;a=C zp^)xYI5v8+Zi&XB4l{a61JCR?zEw()MbzMB_D$DVPTltn1SM2>uP#(^%3KR4$K3ZP zu!M7qD!?hc;t<6bWgkWH|=G(k8Dsa!7&@gNoBHHxb^dxULkjIo6(qN0$#QcVMwG2s<$tE`O0qb!r|+llcd zUk0QGl?FJ1yJp(Jy#|On;Oz(DtP=M+`>i1(TVtL*6f#UKgF2)mm+hqhoPW7xQr7JW zC&@R-fg)lhIzE@Z#OnoIi113RN?W3waGeIBfS^D#H?yp7y7saJnL+k3G?P>Bh&{~EF+mMO`XPaq zdecAHa2Zrj<04RaIt!y-P5MC8X^6MBPD!r#ds1 zI``GKevN-IwH{uv&J@>K9mleE?CDM_;xsMxD%&5=Z4nEsVit!QBM;s$BAv?RaV3Xx zSFJ~xE7D!7%a*xS^U%r_seD1Kf=I-QLqO$<<#HjHF0^%~!A4syo6qA5U36x1neRcl z3^AL|LD38^<=*zyBENFADVSNw>b}~!te=Jyrmu4}4EY#)EZu{zjFt3U#oa6wcQ!(X zR}D`RLQUHOL>y#UP;=r9f)F@uUVYuVu$I4IYai1noVCbesMm1bu;gXQvT6Uo0ZHbN z@{{1MkrA3ELrHqBD4&dkVl0lP7c~MYhguZp8O%tt&=BCvgb3631l0bp{Xvk z}`FYl00j&8LfGTd-=zt1)#8jC_hat&P&1HO%Z~iYKa=E z6bvr=2B2`VP)keeQuU&m6SjD} zJJ)XPCxWQM)F`lQSySR3Sg^zMi87*UnZZEFv#_fa`#coI6m4g}T4Gs}!%G^2?4zP9 zLpvx|pvn;jh4k$LRY_>Dw<_*|3-lU~62h<^A~jxMeQGc$z$>se*wv1E$UrET^4VTA z)cW}iJV>I=uO(>2P{A;fzOn~)fq7VN(e1J0Y_)NNMIA71+2y($bi3tR%ZC zqHs+|_ZxR}_v5ZxQz95DiS#aCuJ&brU^8*F0il)cPDvVN|JLe|u zuIXA0Qz-$gtU`Qq(X!}Pz3`GCW)w>JLc$^W#y2dKJf;+rgi`$RwS}OzdK)yYs6jLr z>D3e^4Xpf|9?&S15EHAl$0c{$#+nTbB~%D_zol22wB~{S=1&Lu17>mc1O12ghS42( zp$H&9tFe|*QhD_h?vNe~Rha6DCr#9VP()Zz((rVsG?=pKz;-9@_UU>Th|^&2PT0Ol zhIUO#`Xao2Q|mD$KL6pgMxh26yj0gJIte&1+aBbIu|{)|>ViWkoY?QMgs6M?M>r<1 z9+HxtR3SrqCzaTX?UN(7e%h82fmJ}Mf(A|2u{T*$V3;5j22|L~!vW@Yn$#Gj0-R!T zgtq?9_9%}vwopYBm=S8kMqBHwqPB)2BI$MoQUcfzyT9J@Zyq%iptq)c*BS4tD*5!; zu5%>I>Cqd%pCBM=fIw0ETwY{1oT4$vpKY;O=`S0MF_6$Hy_$%xna{rarRNaYm)(?p zdD+q=al)~;H+_5yR=<5oKjDzOy_0nK?r`i94O9~#%I^**7;`uQW{<08fbU8l-Fdbk z$CM0o;ENN_{YD=gl_VSri5kCL#;EN#de#n;!MCps21qjWCY*0yeH6*?wy%^4Y*)Jb zt7c19A-n+>?gB?sWZ=!c(|oX z`QQmtAJI0Dw4E@D^LU$(JLQ(!Jcl?7!FfoT?-CmmbtR&N!%w?Jhwn(IN)18QFM){> zjPZ2iY)C|*euEB zq=sO?d$Ytxc~((F>`~^sr&;|P1AII8du;i(X6J9GO2Zb0PHk)E+4~Xqv((KN;fQcj zD{u;?t+yBF;6uMJk{zTR6^pyGeSm5`X(JEU>Ny&Ryd&Z}z3Rx4Z*8iLF&t=p#-PT! z`c%uS`&MsLiM(QHMAhaRx&7r93nc@S7WET~vU6DuPSg-~#|JhnY+xt>HVvt6I$GLo zS0OTPV5k6Ul{RQKfw9%>v_Jk6?UN+r{R6xiOD(VN?lq4AH@&Zno1+07yx)J!2G1Qh z?NtwpLWUbGZPWNMlMMXJsv;T2h{rs>kGH|hbJP)|qX-nTc9q3Xh7AmbREaJ)#Lc&* zdd!ZVS#=+-{kb5oI3q(^jo$P_4~9y(Rp%5od*be<1X;vssen1|>Po93;?JxUYJ5xM zXkL2KDRqojeY`EFsa>N`PE5_7mZFY=EU^30e?M0w*C?jSKc~TA;v2gC28x3G4c%(C zXZqTpD8$s+4jSXdX{0+X$uCOL#p!V;THIvC2O5JyV7s2Z0#4BI%vt5hRJ*WuUyhGX z`a+;MRmV)|cz%U#H1M-V4~HruBHOF-go%(L0~#daPL=^x21fId=`N^-S7VdyadeH{ z%BqAmWnU`SZng=m@8$mdU|T+w|mzUP!x=e|)S5NasfuA6kvhms(uK)36-%Er+L zlZ2u4FjK|mT{h-545z`&yPAV3uBFu77X7%N3}P@<>hnFEgr>z7PzJ@OA^~X<;KUC> ztXV884TvT|_1zK1vv;_Gp^8Cy_SP`lZEcNxULS^oMlh#zuz_*kXp#w{M;UGghLyVR z!S(&Zs8Guf8#wnX^Be`3fpg|8vq_Cy^dT6*8Atw=Vp+NyyCEo))%>zy%9S1v=fSJ_ z6?!=$vhK64q&LXQ%>}I$;+Q2A?+Pb`b z_eCsKM0k=Qj}%0h%6jmbUw7B-x-&?FU*50#=^yS&_!44pcI^56@Jw4pZ1=Z3zo;T2 zI#b_64OgtNK@|ps8YMdPi9d<1drvB;J#){ryLi0B|#um5rqyyuR5(_nIez6PNP7E z%9KYtY%7x&tP*4e6+}fxE~R`$zHg%%H)|BdP)AO^BvlHq0@b~D?QN3m0d-ZNaz628 zBa`$TSRbCbvVLxbpyxR3e&+H5b>eR{3R&xI0N-0TUbMDEmXNczw?@XSFSTj!z{iqc zYOHl%iWm6D6oVcJS!03E_Xu+Wx)@f$LiPPL7Ef~%8iJHVVmi*~ZhQYE=M^;s8HdCr z02!hmb$ADA*G%Vs-+_X^n4o%Aw~5w5J+!K}O>}Zm=c)^xkx}efEn4Bc6K$~IIoRx8jg1^sKTCBb<}xygtL97vtYK@*4i_ut#!mMmTI)N{Qvc^s*Y!+n!(-qbfKl{cTE{-#r*n zZ^DW(xX;z_$cdf>j49pca^9F_C+0W-N9BVcuCPaSc;0T>dHNTU_~^Xl z+r8<5prGf-7q)wYT4l{rKUWz*8&BQ}n^n{h;M(~GnBX*LC-rliCf`wRnfY9Qz!Yg^ z8)c!!CnQ0DDP)mql<_U#sM_x)38IG9>uwF|l=&bW+L$n%GTw9>vtHYxWkxmK)By8F z*>9_>i{um68NX3>SpRF-oR0wKMFy;Zy8p#*RaHNvQ78dl!rrYa-gm1Vd!^?UMc92e zhxgNJ|5*c32E1UsYgI}KZ5YIZ>so-fAoOH&|G-8R`ET+`)MPLF;rmsX-Wd`3%nV8ut;uEg4GuoeCF>SC~lX8 zL>WV$`N!nSe!@{)nKOT05&JOIqEEqjUto0qiC_r>8D&)BYC3dZV0`x85`ooZn6v!B zeSr}PLZOQCeSxV`V^BfGkbQyXfX1POq51+diyEb(Y3RPd{qm`#FF=n7?+c8`hlNj{ z^#xea$=Z{yEfNql_(8`joVx$iD4K+F-Gea(LzBR-=42gCZGrcYt3-d+X+7`^hoNunztC z)wU1MsxJ}@Rj?)kzq1yd@{b7uqL3qZ2ZkFgX&JIRFlu3mY@(XK)*X2J?@#m4D!9sY z2i&~|0&2=^4yty%;F;iQ!O#DP8x+#F&;JMAgsaJR^asoWrF*$5rdX3;)2`D$2OO?s zKz+sDU?2c%Hb1_fRtzQy0@PQ0N)WzDI5wam$lhxtuv=DgRzgF7x~dxaKE!_xBaZtJ zMV9TrHa$J?wKQbQwFvqHK2v?E=~BBjC0nG1D5fnAzN2^-!H3`=dtHRR?96Ajv4Y_d zYFVsb!@Q*1_2eO+FAPhWCc&W{0~X~WngoG;*LZd)IJ_o-a}_&qh!_;|jtj@TEC<{U zYbV4E76s^bf*Nw(DR_V7?>&AP+J2{ipY4Y(-fyY!&-U3Fw+y|aQ7FSd<8Hm=HGZNH zHs(-C1uUE$9qR01Q0F0L45A9Z^PqI_cEQeDS`z^s=P)?>CLs>a;C8{+XZ=DD7FV9b z+6DZiFnsrmMFOL!K_`V=zYj%OzWzZCh^VoMnV_59%#?fI!U#C^($+S`$yKK7JQyk& zL~c;K;4ce3DC)4=1-uc^I8Nh81h#S1;mi-%{lcc(H4LW);wF>Ez822RVeNulm;EN` ziy71|;D@MRJ@98qP!trSOP*kNP$M9;q|392Vsy>Q3z6t{jX@!PA#ydl5M5;~p^i$p zJ_VQO{d0ESE_ki>k3de(irp^YDn9$B?okg?!Mh6Ke?3W(IaFb~3cOuVvu$EOfk=jK z7x44oQR_8|WccU7Dpz>B;Q9PU&m`u2`F4T(`2&4L@{1bE9|*7fiD!pqfxhlxLm?;b z)ffsVb1Kd!Ysv+G^h{AWd6%{odEPE~={|ewGjzMaJ@oFS!WUppB6+(Y-rM4VP=fVs z8PBs328JpIC_T@f+QQGg%i{FOJ- zppd`!%H!K2A40scynOS zI{;%J)G&&EuLCfsoiNUzkh_1Zujn(s?z5xf9yc0xARQsBEWTa&e9%5ULoe zQ_v_i3S_7x;hlmh<~c=C`A)$x#uu{3euCdAVCm$F2Oqbnm%T026sRc~;))V0b>&p^ z94X@sf4(42mHO|Avi^#NZ=;)w8ikatMgcGJH$6M>wIDVaFJ6g1(e6FT8AV+<<46aO z$CobKO30RgMgV2KmEfCzIUfcR*HHlLndl~io95Y z?lzumaguQZLGl&iymYoleUpm(Rl~$WyDLw7)AOuy5t|Qa@>?uI8rF972h@_SB4jxt zHnA?yQ3inm!e9n5@*~TYoPO;io)s0?%eyjEUY^N1zs*67Qe(fo;5UA^*_A8t7A5BjUb!XDsPUh+xA96>yjP3rKs5v2H>xW7ue#8tzvlJf39^VY5Tjld3SdR2tp7ou z=a{AdI#~!}#f5CB&>&Rc6&F+Pu-r*uz7RnHacGLQWmkFM?ehEXO$K?G0zS4c-8;*p zh6>7$ZLl|L(bs3B@7NP5!v;g+PnYSQVJ8bC3j|pGt zAyG#RKj)w;^PJ(BGc*jPK%SnMynKM= z%dVOI zfh73@@~Z5bWP6ct9`8Lz*1;2S#`x*%oQHHd1dRWWh1tx2wxO08D$ID*HK#|HR4Op2 z>f#&Tx#wvJa8(n<>L@uksh?9w;F+rm2?ld7zcud`Sd)RF5-1_FDYcS5 zV^_e^N2#E~i$8iY8tg0H?lTL2S0t|}0`5uR?I0ZGxp&!KW06af1f#itOB@Q-VOP|$ z+MrNESgcYnNLQlk%Ir?MABy+NUx zFx=bLLfD(lUH|-JKo3Izg?KgDc2<6jUjMxGCBWz3cySAPRWpxuwHk&3LP7=7br;)1 zJlS!rN_eS}zHBB;>h`TzsQ`x{gx*sIqu)yN3JTmU?C!(%AcSg(|Ffd94Lu>*2 zgdu!WRgxUi(BtJb9Eej{*;!>!NFNp^xG=8E>o6hERw;`M-1-U-@U7yOAmnE?B$>4K=fGjAu)IIa1LFXSA-4w9`}0ougmUNIwcI zzE7SH?)L9V;I3oG3JmrpbwAmms}SI?-lTlIhV8q^GlSHv@q(L>cJ$Qr-S6S}=(DMm z4F-ku{TNe&&kVY6JnSTV{tB^fDQfUf=lNO{yYOX6D3oA- z>*4$ZE1HKJAPT^% List[Dict[str, Any]]: + return [{ + "role": + "user", + "content": [{ + "type": "text", + "text": PROMPT, + }] + [{ + "type": "image_url", + "image_url": { + "url": url + } + } for url in urls], + }] + + +def _create_engine_inputs(urls: List[str]) -> TokensPrompt: + msg = _create_msg_format(urls) + + tokenizer = MistralTokenizer.from_model("pixtral") + + request = ChatCompletionRequest(messages=msg) # type: ignore[type-var] + tokenized = tokenizer.encode_chat_completion(request) + + engine_inputs = TokensPrompt(prompt_token_ids=tokenized.tokens) + + images = [] + for chunk in request.messages[0].content: + if isinstance(chunk, ImageURLChunk): + images.append(image_from_chunk(chunk)) + + mm_data = MultiModalDataBuiltins(image=images) + engine_inputs["multi_modal_data"] = mm_data + + return engine_inputs + + +MSGS = [ + _create_msg_format(IMG_URLS[:1]), + _create_msg_format(IMG_URLS[:2]), + _create_msg_format(IMG_URLS), +] +ENGINE_INPUTS = [ + _create_engine_inputs(IMG_URLS[:1]), + _create_engine_inputs(IMG_URLS[:2]), + _create_engine_inputs(IMG_URLS), +] + +SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5) +LIMIT_MM_PER_PROMPT = dict(image=4) + +MAX_MODEL_LEN = [8192, 65536] +FIXTURE_LOGPROBS_CHAT = "tests/models/fixtures/pixtral_chat.pickle" +FIXTURE_LOGPROBS_ENGINE = "tests/models/fixtures/pixtral_chat_engine.pickle" + + +def load_logprobs(filename: str) -> Any: + with open(filename, 'rb') as f: + return pickle.load(f) @pytest.mark.skip( @@ -16,49 +95,74 @@ "Model is too big, test passed on A100 locally but will OOM on CI machine." ) @pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("max_model_len", MAX_MODEL_LEN) @pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [64]) -@pytest.mark.parametrize("num_logprobs", [5]) -def test_models( +def test_chat( vllm_runner, - example_prompts, + max_model_len: int, model: str, dtype: str, - max_tokens: int, - num_logprobs: int, ) -> None: - image_urls = [ - "https://picsum.photos/id/237/200/300", - "https://picsum.photos/seed/picsum/200/300" - ] - expected = [ - "The image depicts a black dog lying on a wooden surface, looking directly at the camera with a calm expression.", # noqa - "The image depicts a serene landscape with a snow-covered mountain under a pastel-colored sky during sunset." # noqa - ] - prompt = "Describe the image in one short sentence." - - sampling_params = SamplingParams(max_tokens=512, temperature=0.0) - - with vllm_runner(model, dtype=dtype, - tokenizer_mode="mistral") as vllm_model: - - for i, image_url in enumerate(image_urls): - messages = [ - { - "role": - "user", - "content": [{ - "type": "text", - "text": prompt - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }] - }, - ] - - outputs = vllm_model.model.chat(messages, - sampling_params=sampling_params) - assert outputs[0].outputs[0].text == expected[i] + EXPECTED_CHAT_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_CHAT) + with vllm_runner( + model, + dtype=dtype, + tokenizer_mode="mistral", + enable_chunked_prefill=False, + max_model_len=max_model_len, + limit_mm_per_prompt=LIMIT_MM_PER_PROMPT, + ) as vllm_model: + outputs = [] + for msg in MSGS: + output = vllm_model.model.chat(msg, + sampling_params=SAMPLING_PARAMS) + + outputs.extend(output) + + logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs) + check_logprobs_close(outputs_0_lst=logprobs, + outputs_1_lst=EXPECTED_CHAT_LOGPROBS, + name_0="output", + name_1="h100_ref") + + +@pytest.mark.skip( + reason= + "Model is too big, test passed on A100 locally but will OOM on CI machine." +) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +def test_model_engine(vllm_runner, model: str, dtype: str) -> None: + EXPECTED_ENGINE_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_ENGINE) + args = EngineArgs( + model=model, + tokenizer_mode="mistral", + enable_chunked_prefill=False, + limit_mm_per_prompt=LIMIT_MM_PER_PROMPT, + dtype=dtype, + ) + engine = LLMEngine.from_engine_args(args) + + engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[0], SAMPLING_PARAMS) + engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[1], SAMPLING_PARAMS) + + outputs = [] + count = 0 + while True: + out = engine.step() + count += 1 + for request_output in out: + if request_output.finished: + outputs.append(request_output) + + if count == 2: + engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[2], + SAMPLING_PARAMS) + if not engine.has_unfinished_requests(): + break + + logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs) + check_logprobs_close(outputs_0_lst=logprobs, + outputs_1_lst=EXPECTED_ENGINE_LOGPROBS, + name_0="output", + name_1="h100_ref") diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 010cf85f45e07..b26fd558fa1ea 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -1,4 +1,3 @@ -import math from array import array from dataclasses import dataclass, fields from itertools import tee @@ -15,11 +14,12 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.utils import merge_multimodal_embeddings from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs @@ -48,23 +48,29 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, tokenizer = cached_get_tokenizer( ctx.model_config.tokenizer, tokenizer_mode=ctx.model_config.tokenizer_mode) - mm_encoder = tokenizer.instruct.mm_encoder - mm_config = ctx.model_config.multimodal_config - max_num_images_per_request = mm_config.limit_per_prompt.get("image", 1) + mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder + patch_size = mm_encoder.mm_config.image_patch_size + image_token_id = mm_encoder.special_ids.img - # approximate image size - size = int(math.sqrt(seq_len) * mm_encoder.mm_config.image_patch_size) + mm_config = ctx.model_config.multimodal_config + num_images = mm_config.limit_per_prompt.get("image", 1) + # dummy size + size = 256 image = Image.new("RGB", (size, size), color=0) - img_chunk = ImageChunk(image=image) - tokens = mm_encoder(img_chunk).tokens - token_ids = max_num_images_per_request * array(VLLM_TOKEN_ID_ARRAY_TYPE, - tokens) + image_feature_size = (size**2) // (patch_size**2) + + num_image_tokens = image_feature_size * num_images + + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + [image_token_id]) * num_image_tokens + token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, + [0]) * (seq_len - num_image_tokens) seq_data = SequenceData(token_ids) - mm_data = {"image": max_num_images_per_request * [image]} + mm_data = {"image": num_images * [image]} return seq_data, mm_data @@ -99,32 +105,31 @@ def input_mapper_for_pixtral(ctx: InputContext, return MultiModalInputs({"images": images}) -def merge_multimodal_embeddings(input_ids: torch.Tensor, - inputs_embeds: torch.Tensor, - image_features: Optional[List[torch.Tensor]], - image_id: int) -> torch.Tensor: - text_locations = input_ids != image_id - image_locations = input_ids == image_id - - seq_len = input_ids.shape[0] +def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is not None and "image" in multi_modal_data: + tokenizer = cached_get_tokenizer( + ctx.model_config.tokenizer, + tokenizer_mode=ctx.model_config.tokenizer_mode) - N_txt = text_locations.sum().item() - _, D_txt = inputs_embeds.shape - N_img, D_img = image_features.shape + mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder + image_token_id = mm_encoder.special_ids.img - assert (D_txt == D_img), (f"Text features dim {D_txt} should be equal " - "to image features dim {D_img}") - assert (seq_len == N_txt + - N_img), (f"seq_len {seq_len} should be equal to N_txt + N_img " - f"{(N_txt, N_img, image_locations.sum().item())}") + if image_token_id not in llm_inputs['prompt_token_ids']: + raise ValueError( + (f"You've passed {llm_inputs=} without {image_token_id=}" + " Make sure to process your input via mistral_common's" + " tokenizer or pass a chat completion request. For more" + " For more info, see: " + "https://github.com/vllm-project/vllm/issues/8411.")) - inputs_embeds[image_locations, :] = image_features - return inputs_embeds + return llm_inputs @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral) +@INPUT_REGISTRY.register_input_processor(input_processor_for_pixtral) class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal): def __init__(self, @@ -201,11 +206,21 @@ def _parse_and_validate_image_input( return None if isinstance(images, torch.Tensor): - # always take last images - images = [images[-1][i] for i in range(images.size(1))] + # if passed as batch take all images + N, B, C, W, H = images.shape + images = images.reshape(N * B, C, W, H) + images = [images[i] for i in range(images.size(0))] elif isinstance(images, list): - # always take last images - images = [images[-1][i] for i in range(len(images[0]))] + # if passed as list flatten lists of tensors + flatten_images = [] + for imgs_per_req in images: + imgs_per_req = [ + imgs_per_req[i] for i in range(imgs_per_req.size(0)) + ] if isinstance(imgs_per_req, torch.Tensor) else imgs_per_req + + flatten_images.extend(imgs_per_req) + + images = flatten_images return images From a480939e8e3b8e5b5571531c30212a1a947ee32e Mon Sep 17 00:00:00 2001 From: Wenxiang <8460860+wenxcs@users.noreply.github.com> Date: Fri, 13 Sep 2024 07:25:00 +0800 Subject: [PATCH 20/98] [Bugfix] Fix weight loading issue by rename variable. (#8293) --- vllm/model_executor/models/phimoe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 25bc0590c745c..5036f55803c20 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -600,7 +600,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader( param, loaded_weight, - weight_name, + name, shard_id=shard_id, expert_id=expert_id, ) From 360ddbd37ec82d5a83fd02ee94d7401684bc3c92 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Thu, 12 Sep 2024 17:31:18 -0700 Subject: [PATCH 21/98] [Misc] Update Pixtral example (#8431) --- examples/offline_inference_pixtral.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/offline_inference_pixtral.py b/examples/offline_inference_pixtral.py index 738d890607e37..c12ff7021cf51 100644 --- a/examples/offline_inference_pixtral.py +++ b/examples/offline_inference_pixtral.py @@ -11,7 +11,7 @@ # - Server: # # ```bash -# vllm serve mistralai/Pixtral-12B-2409 --tokenizer_mode mistral --limit_mm_per_prompt 'image=4' --max_num_batched_tokens 16384 +# vllm serve mistralai/Pixtral-12B-2409 --tokenizer-mode mistral --limit-mm-per-prompt 'image=4' --max-model-len 16384 # ``` # # - Client: @@ -45,6 +45,7 @@ def run_simple_demo(): model_name = "mistralai/Pixtral-12B-2409" sampling_params = SamplingParams(max_tokens=8192) + # Lower max_num_seqs or max_model_len on low-VRAM GPUs. llm = LLM(model=model_name, tokenizer_mode="mistral") prompt = "Describe this image in one sentence." @@ -83,7 +84,7 @@ def run_advanced_demo(): model=model_name, tokenizer_mode="mistral", limit_mm_per_prompt={"image": max_img_per_msg}, - max_num_batched_tokens=max_img_per_msg * max_tokens_per_img, + max_model_len=max_img_per_msg * max_tokens_per_img, ) prompt = "Describe the following image." From 8f44a92d852935c8378eaab85bad47ef3174e02b Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 12 Sep 2024 21:23:42 -0400 Subject: [PATCH 22/98] [BugFix] fix group_topk (#8430) --- vllm/model_executor/layers/fused_moe/fused_moe.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index bd13d8fecbb96..a0cb4337f9dee 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -410,6 +410,7 @@ def fused_topk( if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + return topk_weights, topk_ids @@ -443,7 +444,8 @@ def grouped_topk(hidden_states: torch.Tensor, if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids + + return topk_weights, topk_ids.to(torch.int32) def get_config_dtype_str(dtype: torch.dtype, From 5ec9c0fb3c667c30117eb1fd743e0e7c13ccf997 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 13 Sep 2024 10:56:13 +0800 Subject: [PATCH 23/98] [Core] Factor out input preprocessing to a separate class (#7329) --- tests/engine/test_skip_tokenizer_init.py | 5 +- vllm/engine/async_llm_engine.py | 145 +----- vllm/engine/llm_engine.py | 407 +---------------- vllm/inputs/parse.py | 37 +- vllm/inputs/preprocess.py | 536 +++++++++++++++++++++++ 5 files changed, 590 insertions(+), 540 deletions(-) create mode 100644 vllm/inputs/preprocess.py diff --git a/tests/engine/test_skip_tokenizer_init.py b/tests/engine/test_skip_tokenizer_init.py index 338b208723ba9..b8818af5614cf 100644 --- a/tests/engine/test_skip_tokenizer_init.py +++ b/tests/engine/test_skip_tokenizer_init.py @@ -11,9 +11,10 @@ def test_skip_tokenizer_initialization(model: str): # token ids. llm = LLM(model=model, skip_tokenizer_init=True) sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True) - with pytest.raises(ValueError) as err: + + with pytest.raises(ValueError, match="cannot pass text prompts when"): llm.generate("abc", sampling_params) - assert "prompts must be None if" in str(err.value) + outputs = llm.generate({"prompt_token_ids": [1, 2, 3]}, sampling_params=sampling_params) assert len(outputs) > 0 diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 362b0f3a44b02..01114e9843ce4 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -4,22 +4,17 @@ from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List, Mapping, Optional, Set, Tuple, Type, Union) -from typing_extensions import assert_never - import vllm.envs as envs from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_timeout import asyncio_timeout -from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine, - PromptComponents, SchedulerOutputState) +from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState from vllm.engine.metrics_types import StatLoggerBase from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs, - SingletonPromptInputs) -from vllm.inputs.parse import is_explicit_encoder_decoder_prompt +from vllm.inputs import PromptInputs from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput @@ -403,139 +398,6 @@ async def stop_remote_worker_execution_loop_async(self) -> None: """Stop the remote worker execution loop.""" await self.model_executor.stop_remote_worker_execution_loop_async() - async def _tokenize_prompt_async( - self, - prompt: str, - request_id: str, - lora_request: Optional[LoRARequest], - ) -> List[int]: - """Async version of :meth:`_tokenize_prompt`.""" - tokenizer = self.get_tokenizer_group( - missing_msg="prompts must be None if skip_tokenizer_init is True") - - return await tokenizer.encode_async(request_id=request_id, - prompt=prompt, - lora_request=lora_request) - - async def _extract_prompt_components_async( - self, - inputs: SingletonPromptInputs, - request_id: str, - lora_request: Optional[LoRARequest] = None, - ) -> PromptComponents: - """Async version of :meth:`_extract_prompt_components`.""" - if isinstance(inputs, str): - prompt = inputs - prompt_token_ids = await self._tokenize_prompt_async( - prompt, - request_id=request_id, - lora_request=lora_request, - ) - multi_modal_data = None - elif isinstance(inputs, dict): - if "prompt_token_ids" in inputs: - prompt = None - prompt_token_ids = inputs["prompt_token_ids"] - else: - # NOTE: This extra assignment is required to pass mypy - prompt = parsed_prompt = inputs["prompt"] - prompt_token_ids = await self._tokenize_prompt_async( - parsed_prompt, - request_id=request_id, - lora_request=lora_request, - ) - - multi_modal_data = inputs.get("multi_modal_data") - else: - assert_never(inputs) - - return prompt, prompt_token_ids, multi_modal_data - - async def _process_encoder_decoder_prompt_async( - self, - inputs: PromptInputs, - request_id: str, - ) -> EncoderDecoderLLMInputs: - """Async version of :meth:`_process_encoder_decoder_prompt`.""" - encoder_comps: PromptComponents - decoder_comps: DecoderPromptComponents - - if is_explicit_encoder_decoder_prompt(inputs): - encoder_task = self._extract_prompt_components_async( - inputs["encoder_prompt"], - request_id=request_id, - ) - - if (decoder_input := inputs["decoder_prompt"]) is None: - encoder_comps = await encoder_task - decoder_comps = None, None, None - else: - decoder_task = self._extract_prompt_components_async( - decoder_input, - request_id=request_id, - ) - - encoder_comps, decoder_comps = await asyncio.gather( - encoder_task, decoder_task) - else: - encoder_comps = await self._extract_prompt_components_async( - inputs, - request_id=request_id, - ) - - decoder_comps = None, None, None - - return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps) - - async def _process_decoder_only_prompt_async( - self, - inputs: SingletonPromptInputs, - request_id: str, - lora_request: Optional[LoRARequest] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> LLMInputs: - """Async version of :meth:`_process_decoder_only_prompt`.""" - prompt_comps = await self._extract_prompt_components_async( - inputs, - request_id=request_id, - lora_request=lora_request, - ) - - return self._build_decoder_only_llm_inputs( - prompt_comps, - prompt_adapter_request=prompt_adapter_request, - ) - - async def process_model_inputs_async( - self, - inputs: PromptInputs, - request_id: str, - lora_request: Optional[LoRARequest] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> Union[LLMInputs, EncoderDecoderLLMInputs]: - """Async version of :meth:`process_model_inputs`.""" - if self.is_encoder_decoder_model(): - # Encoder-decoder model requires special mapping of - # input prompts to encoder & decoder - model_inputs = await self._process_encoder_decoder_prompt_async( - inputs, - request_id=request_id, - ) - else: - if is_explicit_encoder_decoder_prompt(inputs): - raise ValueError("Cannot pass encoder-decoder prompt " - "to decoder-only models") - - # Decoder-only operation - model_inputs = await self._process_decoder_only_prompt_async( - inputs, - request_id=request_id, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, - ) - - return self.input_processor(model_inputs) - async def add_request_async( self, request_id: str, @@ -553,12 +415,13 @@ async def add_request_async( if arrival_time is None: arrival_time = time.time() - processed_inputs = await self.process_model_inputs_async( + preprocessed_inputs = await self.input_preprocessor.preprocess_async( inputs, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, ) + processed_inputs = self.input_processor(preprocessed_inputs) self._add_processed_request( request_id=request_id, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index e07893b29ec38..c4d97c8f6d857 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -6,10 +6,10 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, Iterable, List, Mapping, NamedTuple, Optional) from typing import Sequence as GenericSequence -from typing import Set, Tuple, Type, Union +from typing import Set, Type, Union import torch -from typing_extensions import TypeVar, assert_never +from typing_extensions import TypeVar import vllm.envs as envs from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, @@ -28,13 +28,11 @@ from vllm.executor.executor_base import ExecutorBase from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, - InputRegistry, LLMInputs, PromptInputs, - SingletonPromptInputs) -from vllm.inputs.parse import is_explicit_encoder_decoder_prompt + InputRegistry, LLMInputs, PromptInputs) +from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.multimodal import MultiModalDataDict from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, RequestOutputFactory) from vllm.pooling_params import PoolingParams @@ -75,11 +73,6 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: _G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) _O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) -PromptComponents = Tuple[Optional[str], List[int], - Optional[MultiModalDataDict]] -DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]], - Optional[MultiModalDataDict]] - @dataclass class SchedulerOutputState: @@ -313,6 +306,9 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: self.generation_config_fields = _load_generation_config_dict( model_config) + self.input_preprocessor = InputPreprocessor(model_config, + self.tokenizer) + self.input_registry = input_registry self.input_processor = input_registry.create_input_processor( model_config) @@ -571,19 +567,15 @@ def __del__(self): if model_executor := getattr(self, "model_executor", None): model_executor.shutdown() - MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because " - "skip_tokenizer_init is True") - def get_tokenizer_group( self, group_type: Type[_G] = BaseTokenizerGroup, - *, - missing_msg: str = MISSING_TOKENIZER_GROUP_MSG, ) -> _G: tokenizer_group = self.tokenizer if tokenizer_group is None: - raise ValueError(missing_msg) + raise ValueError("Unable to get tokenizer because " + "skip_tokenizer_init is True") if not isinstance(tokenizer_group, group_type): raise TypeError("Invalid type of tokenizer group. " f"Expected type: {group_type}, but " @@ -615,52 +607,6 @@ def _verify_args(self) -> None: self.prompt_adapter_config.verify_with_model_config( self.model_config) - def _get_bos_token_id(self, - lora_request: Optional[LoRARequest] = None - ) -> Optional[int]: - if self.tokenizer is None: - logger.warning("Using None for BOS token id because tokenizer " - "is not initialized") - return None - - return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id - - def _get_eos_token_id(self, - lora_request: Optional[LoRARequest] = None - ) -> Optional[int]: - if self.tokenizer is None: - logger.warning("Using None for EOS token id because tokenizer " - "is not initialized") - return None - - return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id - - def _get_decoder_start_token_id(self) -> Optional[int]: - ''' - Obtain the decoder start token id employed by an encoder/decoder - model. Returns None for non-encoder/decoder models or if the - model config is unavailable. - ''' - - if not self.is_encoder_decoder_model(): - logger.warning("Using None for decoder start token id because " - "this is not an encoder/decoder model.") - return None - - if (self.model_config is None or self.model_config.hf_config is None): - logger.warning("Using None for decoder start token id because " - "model config is not available.") - return None - - dec_start_token_id = getattr(self.model_config.hf_config, - 'decoder_start_token_id', None) - if dec_start_token_id is None: - logger.warning("Falling back on for decoder start token id " - "because decoder start token id is not available.") - dec_start_token_id = self._get_bos_token_id() - - return dec_start_token_id - def _add_processed_request( self, request_id: str, @@ -675,7 +621,7 @@ def _add_processed_request( # Create the sequences. block_size = self.cache_config.block_size seq_id = next(self.seq_counter) - eos_token_id = self._get_eos_token_id(lora_request) + eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, lora_request, prompt_adapter_request) @@ -725,334 +671,6 @@ def _add_processed_request( def stop_remote_worker_execution_loop(self) -> None: self.model_executor.stop_remote_worker_execution_loop() - _LLMInputComponentsType = Tuple[str, List[int]] - - def _prepare_decoder_input_ids_for_generation( - self, - decoder_input_ids: Optional[List[int]], - ) -> List[int]: - """ - Prepares `decoder_input_ids` for generation with encoder-decoder models. - - Based on - - https://github.com/huggingface/transformers/blob/ - 4037a2b5b1278736e566aec12e169100275545ea/ - src/transformers/generation/utils.py - - specifically GenerationMixin._prepare_decoder_input_ids_for_generation() - - Arguments: - - * decoder_input_ids: input token ids to preprocess - - Returns: - - * Processed token list - """ - - decoder_start_token_id = self._get_decoder_start_token_id() - assert decoder_start_token_id is not None - - if decoder_input_ids is None: - # no decoder prompt input -> - # use decoder_start_token_id as decoder_input_ids - decoder_input_ids = self._get_default_enc_dec_decoder_prompt() - - if (len(decoder_input_ids) == 0 - or decoder_input_ids[0] != decoder_start_token_id): - decoder_input_ids = [decoder_start_token_id] + decoder_input_ids - - return decoder_input_ids - - def _tokenize_prompt( - self, - prompt: str, - request_id: str, - lora_request: Optional[LoRARequest], - ) -> List[int]: - ''' - Wrapper around application of the model's tokenizer. - - Arguments: - - * prompt - * request_id - * lora_request - - Returns: - - * prompt token ids - ''' - - tokenizer = self.get_tokenizer_group( - missing_msg="prompts must be None if skip_tokenizer_init is True") - - return tokenizer.encode(request_id=request_id, - prompt=prompt, - lora_request=lora_request) - - def _extract_prompt_components( - self, - inputs: SingletonPromptInputs, - request_id: str, - lora_request: Optional[LoRARequest] = None, - ) -> PromptComponents: - ''' - Extract the components of any single encoder or decoder input prompt. - - Arguments: - - * request_id - * inputs: single encoder or decoder input prompt - * lora_request: this is only valid for decoder prompts - - Returns: - - * prompt - * prompt_token_ids - * multi_modal_data - ''' - - if isinstance(inputs, str): - prompt = inputs - prompt_token_ids = self._tokenize_prompt( - prompt, - request_id=request_id, - lora_request=lora_request, - ) - multi_modal_data = None - elif isinstance(inputs, dict): - if "prompt_token_ids" in inputs: - prompt = None - prompt_token_ids = inputs["prompt_token_ids"] - else: - # NOTE: This extra assignment is required to pass mypy - prompt = parsed_prompt = inputs["prompt"] - prompt_token_ids = self._tokenize_prompt( - parsed_prompt, - request_id=request_id, - lora_request=lora_request, - ) - - multi_modal_data = inputs.get("multi_modal_data") - else: - assert_never(inputs) - - return prompt, prompt_token_ids, multi_modal_data - - def _apply_prompt_adapter( - self, - prompt_token_ids: List[int], - prompt_adapter_request: Optional[PromptAdapterRequest], - ) -> List[int]: - if prompt_adapter_request: - prompt_token_ids = ( - [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens - + prompt_token_ids) - - return prompt_token_ids - - def _get_default_enc_dec_decoder_prompt(self) -> List[int]: - ''' - Specifically for encoder/decoder models: - generate a default decoder prompt for when - the user specifies only the encoder prompt. - - Encoder/decoder models utilize the decoder - prompt in different ways; as new models are - added, it is intended that this function - will be extended to produce differing - default decoder prompts, depending on the - model variety. - - Absent a special case, the default behavior - of this method is to mirror the behavior of - the HuggingFace (HF) GenerationMixin for a None - decoder prompt, which is to employ a logit processor - setting to force the first decoded token to be . - Here, this behavior is approximated by having the - "default" decoder prompt be . - - However, it is possible that in the future - other models may have different or more - complex logic for the default decoder prompt. - This motivates having a special helper method - for default decoder prompts. - - Returns: - - * prompt_token_ids - ''' - - bos_token_id = self._get_bos_token_id() - assert bos_token_id is not None - return [bos_token_id] - - def _build_enc_dec_llm_inputs( - self, - encoder_comps: PromptComponents, - decoder_comps: DecoderPromptComponents, - ) -> EncoderDecoderLLMInputs: - encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps - decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps - - if encoder_mm_data is not None or decoder_mm_data is not None: - raise ValueError("Multi-modal encoder-decoder models are " - "not supported yet") - - decoder_prompt_ids = ( - self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids)) - - return EncoderDecoderLLMInputs( - prompt_token_ids=decoder_prompt_ids, - prompt=decoder_prompt, - encoder_prompt_token_ids=encoder_prompt_ids, - encoder_prompt=encoder_prompt, - ) - - def _process_encoder_decoder_prompt( - self, - inputs: PromptInputs, - request_id: str, - ) -> EncoderDecoderLLMInputs: - ''' - For encoder/decoder models only: - Process an input prompt into an - :class:`EncoderDecoderLLMInputs` instance. - - There are two types of input prompts: - singleton prompts which carry only the - encoder prompt, and explicit encoder/decoder - prompts which carry both the encoder and the - decoder prompts as member variables. - - This function handles the following scenarios: - * Singleton encoder prompt: extract encoder prompt - token ids & infer default decoder prompt token ids - * Explicit encoder/decoder prompt: extract encoder - and decoder prompt token ids - - Note that for Explicit encoder/decoder prompts, - each sub-prompt (encoder or decoder prompt) can - have any possible singleton type; thus this - method relies on helper functions to obtain - token ids for the sub-prompts. - - Arguments: - - * inputs: an input prompt - * request_id - - Returns: - - * :class:`EncoderDecoderLLMInputs` instance - ''' - - encoder_comps: PromptComponents - decoder_comps: DecoderPromptComponents - - if is_explicit_encoder_decoder_prompt(inputs): - encoder_comps = self._extract_prompt_components( - inputs["encoder_prompt"], - request_id=request_id, - ) - - if (decoder_input := inputs["decoder_prompt"]) is None: - decoder_comps = None, None, None - else: - decoder_comps = self._extract_prompt_components( - decoder_input, - request_id=request_id, - ) - else: - encoder_comps = self._extract_prompt_components( - inputs, - request_id=request_id, - ) - - decoder_comps = None, None, None - - return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps) - - def _build_decoder_only_llm_inputs( - self, - prompt_comps: PromptComponents, - prompt_adapter_request: Optional[PromptAdapterRequest], - ) -> LLMInputs: - prompt, prompt_token_ids, multi_modal_data = prompt_comps - - prompt_token_ids = self._apply_prompt_adapter( - prompt_token_ids, prompt_adapter_request=prompt_adapter_request) - - return LLMInputs(prompt_token_ids=prompt_token_ids, - prompt=prompt, - multi_modal_data=multi_modal_data) - - def _process_decoder_only_prompt( - self, - inputs: SingletonPromptInputs, - request_id: str, - lora_request: Optional[LoRARequest] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> LLMInputs: - ''' - For decoder-only models: - Process an input prompt into an :class:`LLMInputs` instance. - - Arguments: - - * inputs: input prompt - * request_id - * lora_request - * prompt_adapter_request - - Returns: - - * :class:`LLMInputs` instance - ''' - - prompt_comps = self._extract_prompt_components( - inputs, - request_id=request_id, - lora_request=lora_request, - ) - - return self._build_decoder_only_llm_inputs( - prompt_comps, - prompt_adapter_request=prompt_adapter_request, - ) - - def process_model_inputs( - self, - inputs: PromptInputs, - request_id: str, - lora_request: Optional[LoRARequest] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> Union[LLMInputs, EncoderDecoderLLMInputs]: - - if self.is_encoder_decoder_model(): - # Encoder-decoder model requires special mapping of - # input prompts to encoder & decoder - model_inputs = self._process_encoder_decoder_prompt( - inputs, - request_id=request_id, - ) - else: - if is_explicit_encoder_decoder_prompt(inputs): - raise ValueError("Cannot pass encoder-decoder prompt " - "to decoder-only models") - - # Decoder-only operation - model_inputs = self._process_decoder_only_prompt( - inputs, - request_id=request_id, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, - ) - - return self.input_processor(model_inputs) - def add_request( self, request_id: str, @@ -1111,12 +729,13 @@ def add_request( if arrival_time is None: arrival_time = time.time() - processed_inputs = self.process_model_inputs( + preprocessed_inputs = self.input_preprocessor.preprocess( inputs, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, ) + processed_inputs = self.input_processor(preprocessed_inputs) self._add_processed_request( request_id=request_id, @@ -2043,7 +1662,7 @@ def create_trace_span(self, seq_group: SequenceGroup) -> None: metrics.model_execute_time) def is_encoder_decoder_model(self): - return self.model_config.is_encoder_decoder_model + return self.input_preprocessor.is_encoder_decoder_model() def is_embedding_model(self): return self.model_config.is_embedding_model diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index b5e8ef7860598..ac9d355c64c80 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -5,7 +5,8 @@ from vllm.utils import is_list_of from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, - LLMInputs, PromptInputs) + LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, + TokensPrompt) class ParsedText(TypedDict): @@ -60,8 +61,38 @@ def parse_and_batch_prompt( for elem in prompt ] - raise ValueError("prompt must be a string, array of strings, " - "array of tokens, or array of token arrays") + raise TypeError("prompt must be a string, array of strings, " + "array of tokens, or array of token arrays") + + +class ParsedStrPrompt(TypedDict): + type: Literal["str"] + content: str + + +class ParsedTextPrompt(TypedDict): + type: Literal["text"] + content: TextPrompt + + +class ParsedTokensPrompt(TypedDict): + type: Literal["tokens"] + content: TokensPrompt + + +def parse_singleton_prompt( + inputs: SingletonPromptInputs, +) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]: + if isinstance(inputs, str): + return ParsedStrPrompt(type="str", content=inputs) + elif isinstance(inputs, dict): + if "prompt_token_ids" in inputs: + return ParsedTokensPrompt(type="tokens", + content=inputs) # type: ignore + elif "prompt" in inputs: + return ParsedTextPrompt(type="text", content=inputs) + + raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt") def is_explicit_encoder_decoder_prompt( diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py new file mode 100644 index 0000000000000..be2aa5f8cb7d0 --- /dev/null +++ b/vllm/inputs/preprocess.py @@ -0,0 +1,536 @@ +import asyncio +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +from typing_extensions import assert_never + +from vllm.config import ModelConfig +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup + +from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs, + SingletonPromptInputs) +from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt + +if TYPE_CHECKING: + from vllm.multimodal import MultiModalDataDict + +logger = init_logger(__name__) + +PromptComponents = Tuple[Optional[str], List[int], + Optional["MultiModalDataDict"]] +DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]], + Optional["MultiModalDataDict"]] + + +class InputPreprocessor: + + def __init__( + self, + model_config: ModelConfig, + tokenizer: Optional[BaseTokenizerGroup], + ) -> None: + super().__init__() + + self.model_config = model_config + self.tokenizer = tokenizer + + def get_tokenizer_group(self) -> BaseTokenizerGroup: + if self.tokenizer is None: + raise ValueError("You cannot pass text prompts when " + "`skip_tokenizer_init` is True") + + return self.tokenizer + + def get_bos_token_id(self, + lora_request: Optional[LoRARequest] = None + ) -> Optional[int]: + if self.tokenizer is None: + logger.warning("Using None for BOS token id because tokenizer " + "is not initialized") + return None + + return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id + + def get_eos_token_id(self, + lora_request: Optional[LoRARequest] = None + ) -> Optional[int]: + if self.tokenizer is None: + logger.warning("Using None for EOS token id because tokenizer " + "is not initialized") + return None + + return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id + + def get_decoder_start_token_id(self) -> Optional[int]: + ''' + Obtain the decoder start token id employed by an encoder/decoder + model. Returns None for non-encoder/decoder models or if the + model config is unavailable. + ''' + + if not self.is_encoder_decoder_model(): + logger.warning("Using None for decoder start token id because " + "this is not an encoder/decoder model.") + return None + + if (self.model_config is None or self.model_config.hf_config is None): + logger.warning("Using None for decoder start token id because " + "model config is not available.") + return None + + dec_start_token_id = getattr(self.model_config.hf_config, + 'decoder_start_token_id', None) + if dec_start_token_id is None: + logger.warning("Falling back on for decoder start token id " + "because decoder start token id is not available.") + dec_start_token_id = self.get_bos_token_id() + + return dec_start_token_id + + def _get_default_enc_dec_decoder_prompt(self) -> List[int]: + ''' + Specifically for encoder/decoder models: + generate a default decoder prompt for when + the user specifies only the encoder prompt. + + Encoder/decoder models utilize the decoder + prompt in different ways; as new models are + added, it is intended that this function + will be extended to produce differing + default decoder prompts, depending on the + model variety. + + Absent a special case, the default behavior + of this method is to mirror the behavior of + the HuggingFace (HF) GenerationMixin for a None + decoder prompt, which is to employ a logit processor + setting to force the first decoded token to be . + Here, this behavior is approximated by having the + "default" decoder prompt be . + + However, it is possible that in the future + other models may have different or more + complex logic for the default decoder prompt. + This motivates having a special helper method + for default decoder prompts. + + Returns: + + * prompt_token_ids + ''' + + bos_token_id = self.get_bos_token_id() + assert bos_token_id is not None + return [bos_token_id] + + def _prepare_decoder_input_ids_for_generation( + self, + decoder_input_ids: Optional[List[int]], + ) -> List[int]: + """ + Prepares `decoder_input_ids` for generation with encoder-decoder models. + + Based on + + https://github.com/huggingface/transformers/blob/ + 4037a2b5b1278736e566aec12e169100275545ea/ + src/transformers/generation/utils.py + + specifically GenerationMixin._prepare_decoder_input_ids_for_generation() + + Arguments: + + * decoder_input_ids: input token ids to preprocess + + Returns: + + * Processed token list + """ + + decoder_start_token_id = self.get_decoder_start_token_id() + assert decoder_start_token_id is not None + + if decoder_input_ids is None: + # no decoder prompt input -> + # use decoder_start_token_id as decoder_input_ids + decoder_input_ids = self._get_default_enc_dec_decoder_prompt() + + if (len(decoder_input_ids) == 0 + or decoder_input_ids[0] != decoder_start_token_id): + decoder_input_ids = [decoder_start_token_id] + decoder_input_ids + + return decoder_input_ids + + def _apply_prompt_adapter( + self, + prompt_token_ids: List[int], + prompt_adapter_request: Optional[PromptAdapterRequest], + ) -> List[int]: + if prompt_adapter_request: + prompt_token_ids = ( + [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + + prompt_token_ids) + + return prompt_token_ids + + def _tokenize_prompt( + self, + prompt: str, + request_id: str, + lora_request: Optional[LoRARequest], + ) -> List[int]: + """ + Apply the model's tokenizer to a text prompt, returning the + corresponding token IDs. + """ + tokenizer = self.get_tokenizer_group() + + return tokenizer.encode(request_id=request_id, + prompt=prompt, + lora_request=lora_request) + + async def _tokenize_prompt_async( + self, + prompt: str, + request_id: str, + lora_request: Optional[LoRARequest], + ) -> List[int]: + """Async version of :meth:`_tokenize_prompt`.""" + tokenizer = self.get_tokenizer_group() + + return await tokenizer.encode_async(request_id=request_id, + prompt=prompt, + lora_request=lora_request) + + def _extract_prompt_components( + self, + inputs: SingletonPromptInputs, + request_id: str, + lora_request: Optional[LoRARequest] = None, + ) -> PromptComponents: + ''' + Extract the components of any single encoder or decoder input prompt. + + Arguments: + + * request_id + * inputs: single encoder or decoder input prompt + * lora_request: this is only valid for decoder prompts + + Returns: + + * prompt + * prompt_token_ids + * multi_modal_data + ''' + + parsed = parse_singleton_prompt(inputs) + + if parsed["type"] == "str": + prompt = parsed["content"] + prompt_token_ids = self._tokenize_prompt( + prompt, + request_id=request_id, + lora_request=lora_request, + ) + multi_modal_data = None + elif parsed["type"] == "tokens": + prompt = None + prompt_token_ids = parsed["content"]["prompt_token_ids"] + multi_modal_data = parsed["content"].get("multi_modal_data") + elif parsed["type"] == "text": + prompt = parsed["content"]["prompt"] + prompt_token_ids = self._tokenize_prompt( + prompt, + request_id=request_id, + lora_request=lora_request, + ) + multi_modal_data = parsed["content"].get("multi_modal_data") + else: + assert_never(parsed) + + return prompt, prompt_token_ids, multi_modal_data + + async def _extract_prompt_components_async( + self, + inputs: SingletonPromptInputs, + request_id: str, + lora_request: Optional[LoRARequest] = None, + ) -> PromptComponents: + """Async version of :meth:`_extract_prompt_components`.""" + parsed = parse_singleton_prompt(inputs) + + if parsed["type"] == "str": + prompt = parsed["content"] + prompt_token_ids = await self._tokenize_prompt_async( + prompt, + request_id=request_id, + lora_request=lora_request, + ) + multi_modal_data = None + elif parsed["type"] == "tokens": + prompt = None + prompt_token_ids = parsed["content"]["prompt_token_ids"] + multi_modal_data = parsed["content"].get("multi_modal_data") + elif parsed["type"] == "text": + prompt = parsed["content"]["prompt"] + prompt_token_ids = await self._tokenize_prompt_async( + prompt, + request_id=request_id, + lora_request=lora_request, + ) + multi_modal_data = parsed["content"].get("multi_modal_data") + else: + assert_never(parsed) + + return prompt, prompt_token_ids, multi_modal_data + + def _build_enc_dec_llm_inputs( + self, + encoder_comps: PromptComponents, + decoder_comps: DecoderPromptComponents, + ) -> EncoderDecoderLLMInputs: + encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps + decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps + + if encoder_mm_data is not None or decoder_mm_data is not None: + raise ValueError("Multi-modal encoder-decoder models are " + "not supported yet") + + decoder_prompt_ids = ( + self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids)) + + return EncoderDecoderLLMInputs( + prompt_token_ids=decoder_prompt_ids, + prompt=decoder_prompt, + encoder_prompt_token_ids=encoder_prompt_ids, + encoder_prompt=encoder_prompt, + ) + + def _process_encoder_decoder_prompt( + self, + inputs: PromptInputs, + request_id: str, + ) -> EncoderDecoderLLMInputs: + ''' + For encoder/decoder models only: + Process an input prompt into an + :class:`EncoderDecoderLLMInputs` instance. + + There are two types of input prompts: + singleton prompts which carry only the + encoder prompt, and explicit encoder/decoder + prompts which carry both the encoder and the + decoder prompts as member variables. + + This function handles the following scenarios: + * Singleton encoder prompt: extract encoder prompt + token ids & infer default decoder prompt token ids + * Explicit encoder/decoder prompt: extract encoder + and decoder prompt token ids + + Note that for Explicit encoder/decoder prompts, + each sub-prompt (encoder or decoder prompt) can + have any possible singleton type; thus this + method relies on helper functions to obtain + token ids for the sub-prompts. + + Arguments: + + * inputs: an input prompt + * request_id + + Returns: + + * :class:`EncoderDecoderLLMInputs` instance + ''' + + encoder_comps: PromptComponents + decoder_comps: DecoderPromptComponents + + if is_explicit_encoder_decoder_prompt(inputs): + encoder_comps = self._extract_prompt_components( + inputs["encoder_prompt"], + request_id=request_id, + ) + + if (decoder_input := inputs["decoder_prompt"]) is None: + decoder_comps = None, None, None + else: + decoder_comps = self._extract_prompt_components( + decoder_input, + request_id=request_id, + ) + else: + encoder_comps = self._extract_prompt_components( + inputs, + request_id=request_id, + ) + + decoder_comps = None, None, None + + return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps) + + async def _process_encoder_decoder_prompt_async( + self, + inputs: PromptInputs, + request_id: str, + ) -> EncoderDecoderLLMInputs: + """Async version of :meth:`_process_encoder_decoder_prompt`.""" + encoder_comps: PromptComponents + decoder_comps: DecoderPromptComponents + + if is_explicit_encoder_decoder_prompt(inputs): + encoder_task = self._extract_prompt_components_async( + inputs["encoder_prompt"], + request_id=request_id, + ) + + if (decoder_input := inputs["decoder_prompt"]) is None: + encoder_comps = await encoder_task + decoder_comps = None, None, None + else: + decoder_task = self._extract_prompt_components_async( + decoder_input, + request_id=request_id, + ) + + encoder_comps, decoder_comps = await asyncio.gather( + encoder_task, decoder_task) + else: + encoder_comps = await self._extract_prompt_components_async( + inputs, + request_id=request_id, + ) + + decoder_comps = None, None, None + + return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps) + + def _build_decoder_only_llm_inputs( + self, + prompt_comps: PromptComponents, + prompt_adapter_request: Optional[PromptAdapterRequest], + ) -> LLMInputs: + prompt, prompt_token_ids, multi_modal_data = prompt_comps + + prompt_token_ids = self._apply_prompt_adapter( + prompt_token_ids, prompt_adapter_request=prompt_adapter_request) + + return LLMInputs(prompt_token_ids=prompt_token_ids, + prompt=prompt, + multi_modal_data=multi_modal_data) + + def _process_decoder_only_prompt( + self, + inputs: SingletonPromptInputs, + request_id: str, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> LLMInputs: + ''' + For decoder-only models: + Process an input prompt into an :class:`LLMInputs` instance. + + Arguments: + + * inputs: input prompt + * request_id + * lora_request + * prompt_adapter_request + + Returns: + + * :class:`LLMInputs` instance + ''' + + prompt_comps = self._extract_prompt_components( + inputs, + request_id=request_id, + lora_request=lora_request, + ) + + return self._build_decoder_only_llm_inputs( + prompt_comps, + prompt_adapter_request=prompt_adapter_request, + ) + + async def _process_decoder_only_prompt_async( + self, + inputs: SingletonPromptInputs, + request_id: str, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> LLMInputs: + """Async version of :meth:`_process_decoder_only_prompt`.""" + prompt_comps = await self._extract_prompt_components_async( + inputs, + request_id=request_id, + lora_request=lora_request, + ) + + return self._build_decoder_only_llm_inputs( + prompt_comps, + prompt_adapter_request=prompt_adapter_request, + ) + + def preprocess( + self, + inputs: PromptInputs, + request_id: str, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> Union[LLMInputs, EncoderDecoderLLMInputs]: + """Preprocess the input prompt.""" + if self.is_encoder_decoder_model(): + # Encoder-decoder model requires special mapping of + # input prompts to encoder & decoder + return self._process_encoder_decoder_prompt( + inputs, + request_id=request_id, + ) + + if is_explicit_encoder_decoder_prompt(inputs): + raise ValueError("Cannot pass encoder-decoder prompt " + "to decoder-only models") + + # Decoder-only operation + return self._process_decoder_only_prompt( + inputs, + request_id=request_id, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + + async def preprocess_async( + self, + inputs: PromptInputs, + request_id: str, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> Union[LLMInputs, EncoderDecoderLLMInputs]: + """Async version of :meth:`preprocess`.""" + if self.is_encoder_decoder_model(): + # Encoder-decoder model requires special mapping of + # input prompts to encoder & decoder + return await self._process_encoder_decoder_prompt_async( + inputs, + request_id=request_id, + ) + + if is_explicit_encoder_decoder_prompt(inputs): + raise ValueError("Cannot pass encoder-decoder prompt " + "to decoder-only models") + + # Decoder-only operation + return await self._process_decoder_only_prompt_async( + inputs, + request_id=request_id, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + + def is_encoder_decoder_model(self): + return self.model_config.is_encoder_decoder_model From 40c396533d00b9b6efe08241525630dcf8d88c72 Mon Sep 17 00:00:00 2001 From: shangmingc Date: Fri, 13 Sep 2024 11:06:28 +0800 Subject: [PATCH 24/98] [Bugfix] Mapping physical device indices for e2e test utils (#8290) --- tests/utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/utils.py b/tests/utils.py index 6e5bc05b3901a..3c519fb6e50e0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -356,12 +356,23 @@ def error_on_warning(): yield +def get_physical_device_indices(devices): + visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES") + if visible_devices is None: + return devices + + visible_indices = [int(x) for x in visible_devices.split(",")] + index_mapping = {i: physical for i, physical in enumerate(visible_indices)} + return [index_mapping[i] for i in devices if i in index_mapping] + + @_nvml() def wait_for_gpu_memory_to_clear(devices: List[int], threshold_bytes: int, timeout_s: float = 120) -> None: # Use nvml instead of pytorch to reduce measurement error from torch cuda # context. + devices = get_physical_device_indices(devices) start_time = time.time() while True: output: Dict[int, str] = {} From 3f79bc3d1a65b7ed266702bb745c66b10283361f Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 13 Sep 2024 11:21:42 +0800 Subject: [PATCH 25/98] [Bugfix] Bump fastapi and pydantic version (#8435) --- requirements-common.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements-common.txt b/requirements-common.txt index 3a9ae4aa77421..8432be61ed77d 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -7,11 +7,11 @@ py-cpuinfo transformers >= 4.43.2 # Required for Chameleon and Llama 3.1 hotfox. tokenizers >= 0.19.1 # Required for Llama 3. protobuf # Required by LlamaTokenizer. -fastapi +fastapi >= 0.114.1 aiohttp openai >= 1.40.0 # Ensure modern openai package (ensure types module present) uvicorn[standard] -pydantic >= 2.8 # Required for OpenAI server. +pydantic >= 2.9 # Required for fastapi >= 0.113.0 pillow # Required for image processing prometheus_client >= 0.18.0 prometheus-fastapi-instrumentator >= 7.0.0 From 84275504885ae5d4b3c63209f711706c8b758882 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 13 Sep 2024 11:47:52 +0800 Subject: [PATCH 26/98] [CI/Build] Update pixtral tests to use JSON (#8436) --- pyproject.toml | 2 +- tests/models/fixtures/pixtral_chat.json | 1 + tests/models/fixtures/pixtral_chat.pickle | Bin 20865 -> 0 bytes .../models/fixtures/pixtral_chat_engine.json | 1 + .../fixtures/pixtral_chat_engine.pickle | Bin 20858 -> 0 bytes tests/models/test_pixtral.py | 56 ++++++++++++------ 6 files changed, 42 insertions(+), 18 deletions(-) create mode 100644 tests/models/fixtures/pixtral_chat.json delete mode 100644 tests/models/fixtures/pixtral_chat.pickle create mode 100644 tests/models/fixtures/pixtral_chat_engine.json delete mode 100644 tests/models/fixtures/pixtral_chat_engine.pickle diff --git a/pyproject.toml b/pyproject.toml index 22a25d9cf32e6..d9e3278db4d19 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,7 +76,7 @@ exclude = [ [tool.codespell] ignore-words-list = "dout, te, indicies, subtile" -skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build" +skip = "./tests/models/fixtures,./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build" [tool.isort] use_parentheses = true diff --git a/tests/models/fixtures/pixtral_chat.json b/tests/models/fixtures/pixtral_chat.json new file mode 100644 index 0000000000000..643afb83d29b8 --- /dev/null +++ b/tests/models/fixtures/pixtral_chat.json @@ -0,0 +1 @@ +[[[1784, 3937, 6122, 1261, 7244, 10575, 18970, 1408, 1261, 32656, 4691, 1046, 2], "The image shows a black dog sitting on a wooden surface.", [{"1784": {"logprob": -0.11687260121107101, "rank": 1, "decoded_token": "The"}, "4380": {"logprob": -2.366872549057007, "rank": 2, "decoded_token": "This"}, "1049": {"logprob": -4.741872787475586, "rank": 3, "decoded_token": "1"}, "117991": {"logprob": -5.991872787475586, "rank": 4, "decoded_token": "Certain"}, "1785": {"logprob": -5.991872787475586, "rank": 5, "decoded_token": "In"}}, {"3937": {"logprob": -0.28887900710105896, "rank": 1, "decoded_token": " image"}, "2158": {"logprob": -1.4138790369033813, "rank": 2, "decoded_token": " first"}, "3977": {"logprob": -5.788878917694092, "rank": 3, "decoded_token": " top"}, "7244": {"logprob": -6.163878917694092, "rank": 4, "decoded_token": " black"}, "8061": {"logprob": -6.788878917694092, "rank": 5, "decoded_token": " images"}}, {"6122": {"logprob": -0.9653709530830383, "rank": 1, "decoded_token": " shows"}, "51948": {"logprob": -1.4653708934783936, "rank": 2, "decoded_token": " depicts"}, "6971": {"logprob": -1.4653708934783936, "rank": 3, "decoded_token": " features"}, "25981": {"logprob": -2.8403708934783936, "rank": 4, "decoded_token": " displays"}, "8688": {"logprob": -2.8403708934783936, "rank": 5, "decoded_token": " contains"}}, {"1261": {"logprob": -0.003059827256947756, "rank": 1, "decoded_token": " a"}, "1420": {"logprob": -6.2530598640441895, "rank": 2, "decoded_token": " an"}, "2295": {"logprob": -7.8780598640441895, "rank": 3, "decoded_token": " two"}, "2342": {"logprob": -7.8780598640441895, "rank": 4, "decoded_token": " only"}, "1278": {"logprob": -8.628059387207031, "rank": 5, "decoded_token": " the"}}, {"7244": {"logprob": -0.17616479098796844, "rank": 1, "decoded_token": " black"}, "6231": {"logprob": -2.3011648654937744, "rank": 2, "decoded_token": " close"}, "4249": {"logprob": -3.4261648654937744, "rank": 3, "decoded_token": " single"}, "4329": {"logprob": -5.113664627075195, "rank": 4, "decoded_token": " large"}, "10575": {"logprob": -5.176164627075195, "rank": 5, "decoded_token": " dog"}}, {"10575": {"logprob": -0.10940006375312805, "rank": 1, "decoded_token": " dog"}, "116572": {"logprob": -2.4844000339508057, "rank": 2, "decoded_token": " puppy"}, "119075": {"logprob": -4.109400272369385, "rank": 3, "decoded_token": " Labrador"}, "15812": {"logprob": -7.296900272369385, "rank": 4, "decoded_token": " Lab"}, "7990": {"logprob": -7.421900272369385, "rank": 5, "decoded_token": " cat"}}, {"18970": {"logprob": -0.8322296738624573, "rank": 1, "decoded_token": " sitting"}, "1454": {"logprob": -1.5822296142578125, "rank": 2, "decoded_token": " with"}, "28528": {"logprob": -1.9572296142578125, "rank": 3, "decoded_token": " lying"}, "7283": {"logprob": -2.2072296142578125, "rank": 4, "decoded_token": " looking"}, "15866": {"logprob": -3.0197296142578125, "rank": 5, "decoded_token": " standing"}}, {"1408": {"logprob": -0.08769982308149338, "rank": 1, "decoded_token": " on"}, "1321": {"logprob": -3.7126998901367188, "rank": 2, "decoded_token": " and"}, "3675": {"logprob": -3.9626998901367188, "rank": 3, "decoded_token": " against"}, "41132": {"logprob": -4.587699890136719, "rank": 4, "decoded_token": " attent"}, "1454": {"logprob": -5.087699890136719, "rank": 5, "decoded_token": " with"}}, {"1261": {"logprob": -0.5400654673576355, "rank": 1, "decoded_token": " a"}, "32656": {"logprob": -0.9150654673576355, "rank": 2, "decoded_token": " wooden"}, "3977": {"logprob": -5.415065288543701, "rank": 3, "decoded_token": " top"}, "12603": {"logprob": -5.540065288543701, "rank": 4, "decoded_token": " wood"}, "44130": {"logprob": -6.290065288543701, "rank": 5, "decoded_token": " rust"}}, {"32656": {"logprob": -0.02516966126859188, "rank": 1, "decoded_token": " wooden"}, "44130": {"logprob": -4.400169849395752, "rank": 2, "decoded_token": " rust"}, "12603": {"logprob": -5.275169849395752, "rank": 3, "decoded_token": " wood"}, "3403": {"logprob": -5.525169849395752, "rank": 4, "decoded_token": " text"}, "17253": {"logprob": -6.962669849395752, "rank": 5, "decoded_token": " weather"}}, {"4691": {"logprob": -0.7264319658279419, "rank": 1, "decoded_token": " surface"}, "11237": {"logprob": -0.8514319658279419, "rank": 2, "decoded_token": " floor"}, "7042": {"logprob": -2.6014318466186523, "rank": 3, "decoded_token": " background"}, "28984": {"logprob": -5.226431846618652, "rank": 4, "decoded_token": " deck"}, "1615": {"logprob": -5.726431846618652, "rank": 5, "decoded_token": " pl"}}, {"1046": {"logprob": -0.4668232202529907, "rank": 1, "decoded_token": "."}, "1044": {"logprob": -1.9668232202529907, "rank": 2, "decoded_token": ","}, "1321": {"logprob": -2.466823101043701, "rank": 3, "decoded_token": " and"}, "7283": {"logprob": -2.716823101043701, "rank": 4, "decoded_token": " looking"}, "1454": {"logprob": -2.716823101043701, "rank": 5, "decoded_token": " with"}}, {"2": {"logprob": -0.002247072057798505, "rank": 1, "decoded_token": ""}, "1531": {"logprob": -6.627246856689453, "rank": 2, "decoded_token": " The"}, "1032": {"logprob": -7.127246856689453, "rank": 3, "decoded_token": " "}, "3730": {"logprob": -9.877246856689453, "rank": 4, "decoded_token": " There"}, "1256": {"logprob": -11.127246856689453, "rank": 5, "decoded_token": " "}}]], [[1049, 1046, 1349, 7244, 10575, 1454, 2327, 94766, 32961, 53048, 41132, 3923, 1408, 1261, 32656, 4691, 1626, 1050, 1046, 1349, 15375, 24361, 4521, 1454, 122203, 27469, 94973, 2425, 1261, 16152, 1121, 21283, 1046, 2], "1. A black dog with floppy ears sits attentively on a wooden surface.\n2. A vast mountain range with rugged peaks stretches under a cloudy sky.", [{"1049": {"logprob": -0.42824622988700867, "rank": 1, "decoded_token": "1"}, "1045": {"logprob": -1.553246259689331, "rank": 2, "decoded_token": "-"}, "1065": {"logprob": -2.428246259689331, "rank": 3, "decoded_token": "A"}, "1784": {"logprob": -4.053246021270752, "rank": 4, "decoded_token": "The"}, "69957": {"logprob": -4.428246021270752, "rank": 5, "decoded_token": "Sure"}}, {"1046": {"logprob": -1.9788545614574105e-05, "rank": 1, "decoded_token": "."}, "1058": {"logprob": -11.750020027160645, "rank": 2, "decoded_token": ":"}, "3590": {"logprob": -12.125020027160645, "rank": 3, "decoded_token": ".A"}, "1065": {"logprob": -13.062520027160645, "rank": 4, "decoded_token": "A"}, "1041": {"logprob": -13.750020027160645, "rank": 5, "decoded_token": ")"}}, {"1349": {"logprob": -0.14020134508609772, "rank": 1, "decoded_token": " A"}, "1429": {"logprob": -2.3902013301849365, "rank": 2, "decoded_token": " \""}, "1603": {"logprob": -3.7652013301849365, "rank": 3, "decoded_token": " **"}, "11967": {"logprob": -4.890201568603516, "rank": 4, "decoded_token": " Image"}, "1531": {"logprob": -5.015201568603516, "rank": 5, "decoded_token": " The"}}, {"7244": {"logprob": -0.2003599852323532, "rank": 1, "decoded_token": " black"}, "38462": {"logprob": -3.075360059738159, "rank": 2, "decoded_token": " curious"}, "68076": {"logprob": -3.575360059738159, "rank": 3, "decoded_token": " cute"}, "4329": {"logprob": -3.887860059738159, "rank": 4, "decoded_token": " large"}, "6231": {"logprob": -4.32535982131958, "rank": 5, "decoded_token": " close"}}, {"10575": {"logprob": -0.18818901479244232, "rank": 1, "decoded_token": " dog"}, "116572": {"logprob": -2.0631890296936035, "rank": 2, "decoded_token": " puppy"}, "119075": {"logprob": -3.1881890296936035, "rank": 3, "decoded_token": " Labrador"}, "15812": {"logprob": -6.9381890296936035, "rank": 4, "decoded_token": " Lab"}, "8636": {"logprob": -7.3131890296936035, "rank": 5, "decoded_token": " lab"}}, {"1454": {"logprob": -0.5699259042739868, "rank": 1, "decoded_token": " with"}, "53048": {"logprob": -1.2574259042739868, "rank": 2, "decoded_token": " sits"}, "1395": {"logprob": -3.0699257850646973, "rank": 3, "decoded_token": " is"}, "22524": {"logprob": -3.6324257850646973, "rank": 4, "decoded_token": " lies"}, "18970": {"logprob": -3.7574257850646973, "rank": 5, "decoded_token": " sitting"}}, {"2327": {"logprob": -1.2377738952636719, "rank": 1, "decoded_token": " fl"}, "1261": {"logprob": -1.3627738952636719, "rank": 2, "decoded_token": " a"}, "17300": {"logprob": -1.9252738952636719, "rank": 3, "decoded_token": " soul"}, "100089": {"logprob": -2.675273895263672, "rank": 4, "decoded_token": " expressive"}, "6444": {"logprob": -3.237773895263672, "rank": 5, "decoded_token": " soft"}}, {"94766": {"logprob": -0.0025601964443922043, "rank": 1, "decoded_token": "oppy"}, "124603": {"logprob": -6.315060138702393, "rank": 2, "decoded_token": "uffy"}, "1484": {"logprob": -7.877560138702393, "rank": 3, "decoded_token": "op"}, "24897": {"logprob": -8.81506061553955, "rank": 4, "decoded_token": "appy"}, "102477": {"logprob": -9.69006061553955, "rank": 5, "decoded_token": "opping"}}, {"32961": {"logprob": -5.113947918289341e-05, "rank": 1, "decoded_token": " ears"}, "16962": {"logprob": -11.250051498413086, "rank": 2, "decoded_token": " ear"}, "5731": {"logprob": -11.812551498413086, "rank": 3, "decoded_token": " eyes"}, "3351": {"logprob": -12.000051498413086, "rank": 4, "decoded_token": " years"}, "42071": {"logprob": -13.062551498413086, "rank": 5, "decoded_token": " cheeks"}}, {"53048": {"logprob": -0.6179640889167786, "rank": 1, "decoded_token": " sits"}, "10637": {"logprob": -1.9929640293121338, "rank": 2, "decoded_token": " looks"}, "1321": {"logprob": -2.430464029312134, "rank": 3, "decoded_token": " and"}, "1395": {"logprob": -2.617964029312134, "rank": 4, "decoded_token": " is"}, "18970": {"logprob": -3.055464029312134, "rank": 5, "decoded_token": " sitting"}}, {"41132": {"logprob": -0.3746516704559326, "rank": 1, "decoded_token": " attent"}, "1408": {"logprob": -2.3121516704559326, "rank": 2, "decoded_token": " on"}, "106534": {"logprob": -2.3746516704559326, "rank": 3, "decoded_token": " calmly"}, "12276": {"logprob": -2.6246516704559326, "rank": 4, "decoded_token": " alert"}, "6482": {"logprob": -5.124651908874512, "rank": 5, "decoded_token": " patient"}}, {"3923": {"logprob": -8.463501580990851e-05, "rank": 1, "decoded_token": "ively"}, "1556": {"logprob": -9.50008487701416, "rank": 2, "decoded_token": "ive"}, "6655": {"logprob": -11.87508487701416, "rank": 3, "decoded_token": "atively"}, "3929": {"logprob": -14.00008487701416, "rank": 4, "decoded_token": "ently"}, "47885": {"logprob": -14.62508487701416, "rank": 5, "decoded_token": "edly"}}, {"1408": {"logprob": -0.06439964473247528, "rank": 1, "decoded_token": " on"}, "3675": {"logprob": -3.0643997192382812, "rank": 2, "decoded_token": " against"}, "1294": {"logprob": -4.939399719238281, "rank": 3, "decoded_token": " in"}, "7283": {"logprob": -5.689399719238281, "rank": 4, "decoded_token": " looking"}, "1044": {"logprob": -5.814399719238281, "rank": 5, "decoded_token": ","}}, {"1261": {"logprob": -0.2108541578054428, "rank": 1, "decoded_token": " a"}, "32656": {"logprob": -1.710854172706604, "rank": 2, "decoded_token": " wooden"}, "17253": {"logprob": -5.5858540534973145, "rank": 3, "decoded_token": " weather"}, "44130": {"logprob": -6.0858540534973145, "rank": 4, "decoded_token": " rust"}, "12603": {"logprob": -6.9608540534973145, "rank": 5, "decoded_token": " wood"}}, {"32656": {"logprob": -0.08556432276964188, "rank": 1, "decoded_token": " wooden"}, "44130": {"logprob": -2.710564374923706, "rank": 2, "decoded_token": " rust"}, "17253": {"logprob": -4.710564136505127, "rank": 3, "decoded_token": " weather"}, "12603": {"logprob": -5.960564136505127, "rank": 4, "decoded_token": " wood"}, "3403": {"logprob": -5.960564136505127, "rank": 5, "decoded_token": " text"}}, {"4691": {"logprob": -0.7751782536506653, "rank": 1, "decoded_token": " surface"}, "11237": {"logprob": -0.7751782536506653, "rank": 2, "decoded_token": " floor"}, "7042": {"logprob": -2.9001781940460205, "rank": 3, "decoded_token": " background"}, "28984": {"logprob": -4.1501784324646, "rank": 4, "decoded_token": " deck"}, "92504": {"logprob": -6.1501784324646, "rank": 5, "decoded_token": " backdrop"}}, {"1626": {"logprob": -0.12918435037136078, "rank": 1, "decoded_token": ".\n"}, "1044": {"logprob": -2.3791842460632324, "rank": 2, "decoded_token": ","}, "1046": {"logprob": -4.129184246063232, "rank": 3, "decoded_token": "."}, "1338": {"logprob": -5.129184246063232, "rank": 4, "decoded_token": ".\n\n"}, "7283": {"logprob": -5.629184246063232, "rank": 5, "decoded_token": " looking"}}, {"1050": {"logprob": -0.00017474555352237076, "rank": 1, "decoded_token": "2"}, "1256": {"logprob": -9.000174522399902, "rank": 2, "decoded_token": " "}, "1032": {"logprob": -10.875174522399902, "rank": 3, "decoded_token": " "}, "1293": {"logprob": -11.625174522399902, "rank": 4, "decoded_token": " "}, "1051": {"logprob": -12.125174522399902, "rank": 5, "decoded_token": "3"}}, {"1046": {"logprob": -7.629365427419543e-06, "rank": 1, "decoded_token": "."}, "3590": {"logprob": -12.875007629394531, "rank": 2, "decoded_token": ".A"}, "1626": {"logprob": -13.062507629394531, "rank": 3, "decoded_token": ".\n"}, "1338": {"logprob": -14.562507629394531, "rank": 4, "decoded_token": ".\n\n"}, "1058": {"logprob": -14.812507629394531, "rank": 5, "decoded_token": ":"}}, {"1349": {"logprob": -0.558266282081604, "rank": 1, "decoded_token": " A"}, "11826": {"logprob": -1.495766282081604, "rank": 2, "decoded_token": " Maj"}, "37159": {"logprob": -2.2457661628723145, "rank": 3, "decoded_token": " Snow"}, "113465": {"logprob": -3.9957661628723145, "rank": 4, "decoded_token": " Rug"}, "1531": {"logprob": -3.9957661628723145, "rank": 5, "decoded_token": " The"}}, {"15375": {"logprob": -0.6446555852890015, "rank": 1, "decoded_token": " vast"}, "37849": {"logprob": -2.019655704498291, "rank": 2, "decoded_token": " breat"}, "61082": {"logprob": -2.394655704498291, "rank": 3, "decoded_token": " panor"}, "10726": {"logprob": -3.082155704498291, "rank": 4, "decoded_token": " scen"}, "2169": {"logprob": -3.207155704498291, "rank": 5, "decoded_token": " ser"}}, {"24361": {"logprob": -0.7034653425216675, "rank": 1, "decoded_token": " mountain"}, "127945": {"logprob": -1.9534653425216675, "rank": 2, "decoded_token": " mountainous"}, "1044": {"logprob": -2.078465461730957, "rank": 3, "decoded_token": ","}, "4521": {"logprob": -2.328465461730957, "rank": 4, "decoded_token": " range"}, "28035": {"logprob": -2.453465461730957, "rank": 5, "decoded_token": " landscape"}}, {"4521": {"logprob": -0.07058106362819672, "rank": 1, "decoded_token": " range"}, "28035": {"logprob": -2.6955809593200684, "rank": 2, "decoded_token": " landscape"}, "37691": {"logprob": -8.320581436157227, "rank": 3, "decoded_token": " valley"}, "12248": {"logprob": -9.445581436157227, "rank": 4, "decoded_token": " peak"}, "13327": {"logprob": -9.695581436157227, "rank": 5, "decoded_token": " scene"}}, {"1454": {"logprob": -1.1448894739151, "rank": 1, "decoded_token": " with"}, "94973": {"logprob": -1.1448894739151, "rank": 2, "decoded_token": " stretches"}, "2425": {"logprob": -1.8948894739151, "rank": 3, "decoded_token": " under"}, "1395": {"logprob": -2.5198893547058105, "rank": 4, "decoded_token": " is"}, "13875": {"logprob": -3.0198893547058105, "rank": 5, "decoded_token": " covered"}}, {"122203": {"logprob": -1.0288245677947998, "rank": 1, "decoded_token": " rugged"}, "58127": {"logprob": -1.6538245677947998, "rank": 2, "decoded_token": " jag"}, "27469": {"logprob": -2.1538245677948, "rank": 3, "decoded_token": " peaks"}, "23745": {"logprob": -2.6538245677948, "rank": 4, "decoded_token": " snow"}, "95746": {"logprob": -2.8413245677948, "rank": 5, "decoded_token": " rocky"}}, {"27469": {"logprob": -0.20564845204353333, "rank": 1, "decoded_token": " peaks"}, "24765": {"logprob": -2.580648422241211, "rank": 2, "decoded_token": " terrain"}, "130655": {"logprob": -2.955648422241211, "rank": 3, "decoded_token": ""}, "1044": {"logprob": -3.580648422241211, "rank": 4, "decoded_token": ","}, "61263": {"logprob": -4.455648422241211, "rank": 5, "decoded_token": " slopes"}}, {"94973": {"logprob": -1.0839273929595947, "rank": 1, "decoded_token": " stretches"}, "1321": {"logprob": -1.1464273929595947, "rank": 2, "decoded_token": " and"}, "2425": {"logprob": -1.7714273929595947, "rank": 3, "decoded_token": " under"}, "13875": {"logprob": -3.0839273929595947, "rank": 4, "decoded_token": " covered"}, "1395": {"logprob": -3.2714273929595947, "rank": 5, "decoded_token": " is"}}, {"2425": {"logprob": -0.9016233682632446, "rank": 1, "decoded_token": " under"}, "5669": {"logprob": -1.0266233682632446, "rank": 2, "decoded_token": " across"}, "1848": {"logprob": -1.9016233682632446, "rank": 3, "decoded_token": " out"}, "2203": {"logprob": -3.151623249053955, "rank": 4, "decoded_token": " into"}, "8994": {"logprob": -4.026623249053955, "rank": 5, "decoded_token": " towards"}}, {"1261": {"logprob": -0.00555459875613451, "rank": 1, "decoded_token": " a"}, "1420": {"logprob": -5.380554676055908, "rank": 2, "decoded_token": " an"}, "1278": {"logprob": -7.630554676055908, "rank": 3, "decoded_token": " the"}, "2136": {"logprob": -9.31805419921875, "rank": 4, "decoded_token": " over"}, "16152": {"logprob": -9.38055419921875, "rank": 5, "decoded_token": " cloud"}}, {"16152": {"logprob": -0.6862213015556335, "rank": 1, "decoded_token": " cloud"}, "6133": {"logprob": -1.4362213611602783, "rank": 2, "decoded_token": " clear"}, "18416": {"logprob": -2.6862213611602783, "rank": 3, "decoded_token": " haz"}, "27254": {"logprob": -3.0612213611602783, "rank": 4, "decoded_token": " partly"}, "4391": {"logprob": -3.1862213611602783, "rank": 5, "decoded_token": " light"}}, {"1121": {"logprob": -0.10446903109550476, "rank": 1, "decoded_token": "y"}, "4527": {"logprob": -2.854469060897827, "rank": 2, "decoded_token": "less"}, "1286": {"logprob": -3.479469060897827, "rank": 3, "decoded_token": "ed"}, "114525": {"logprob": -5.479468822479248, "rank": 4, "decoded_token": "-covered"}, "77187": {"logprob": -5.479468822479248, "rank": 5, "decoded_token": "-filled"}}, {"21283": {"logprob": -0.003459066851064563, "rank": 1, "decoded_token": " sky"}, "10991": {"logprob": -6.3784589767456055, "rank": 2, "decoded_token": " blue"}, "1044": {"logprob": -6.8784589767456055, "rank": 3, "decoded_token": ","}, "26549": {"logprob": -7.8784589767456055, "rank": 4, "decoded_token": " gray"}, "34052": {"logprob": -8.503458976745605, "rank": 5, "decoded_token": " grey"}}, {"1046": {"logprob": -0.01103890035301447, "rank": 1, "decoded_token": "."}, "1044": {"logprob": -4.636038780212402, "rank": 2, "decoded_token": ","}, "1338": {"logprob": -7.261038780212402, "rank": 3, "decoded_token": ".\n\n"}, "1294": {"logprob": -8.136038780212402, "rank": 4, "decoded_token": " in"}, "1454": {"logprob": -8.761038780212402, "rank": 5, "decoded_token": " with"}}, {"2": {"logprob": -9.059865078597795e-06, "rank": 1, "decoded_token": ""}, "1032": {"logprob": -11.625008583068848, "rank": 2, "decoded_token": " "}, "1256": {"logprob": -16.125009536743164, "rank": 3, "decoded_token": " "}, "1319": {"logprob": -17.375009536743164, "rank": 4, "decoded_token": " ("}, "1766": {"logprob": -18.750009536743164, "rank": 5, "decoded_token": " ["}}]], [[1049, 1046, 1349, 7244, 10575, 53048, 41132, 3923, 1408, 1261, 32656, 11237, 1626, 1050, 1046, 1349, 15375, 24361, 4521, 94973, 5669, 1278, 48932, 2425, 1261, 16152, 1121, 21283, 1626, 1051, 1046, 8342, 71284, 7377, 1394, 22140, 1294, 1278, 27208, 1513, 97558, 1626, 1052, 1046, 1349, 53301, 59396, 3549, 13335, 2645, 1261, 1295, 3506, 11223, 12097, 1046, 2], "1. A black dog sits attentively on a wooden floor.\n2. A vast mountain range stretches across the horizon under a cloudy sky.\n3. Surfers wait for waves in the ocean at sunset.\n4. A winding gravel path leads through a lush green park.", [{"1049": {"logprob": -0.05001257359981537, "rank": 1, "decoded_token": "1"}, "1045": {"logprob": -3.1750125885009766, "rank": 2, "decoded_token": "-"}, "69957": {"logprob": -5.925012588500977, "rank": 3, "decoded_token": "Sure"}, "11745": {"logprob": -6.425012588500977, "rank": 4, "decoded_token": "Here"}, "1065": {"logprob": -6.425012588500977, "rank": 5, "decoded_token": "A"}}, {"1046": {"logprob": -9.536697689327411e-06, "rank": 1, "decoded_token": "."}, "1058": {"logprob": -11.875009536743164, "rank": 2, "decoded_token": ":"}, "3590": {"logprob": -13.375009536743164, "rank": 3, "decoded_token": ".A"}, "1041": {"logprob": -14.750009536743164, "rank": 4, "decoded_token": ")"}, "1065": {"logprob": -15.687509536743164, "rank": 5, "decoded_token": "A"}}, {"1349": {"logprob": -0.12580634653568268, "rank": 1, "decoded_token": " A"}, "1429": {"logprob": -2.3758063316345215, "rank": 2, "decoded_token": " \""}, "1531": {"logprob": -4.6258063316345215, "rank": 3, "decoded_token": " The"}, "11967": {"logprob": -4.6258063316345215, "rank": 4, "decoded_token": " Image"}, "1603": {"logprob": -5.6258063316345215, "rank": 5, "decoded_token": " **"}}, {"7244": {"logprob": -0.15412142872810364, "rank": 1, "decoded_token": " black"}, "68076": {"logprob": -3.3416213989257812, "rank": 2, "decoded_token": " cute"}, "6231": {"logprob": -3.9666213989257812, "rank": 3, "decoded_token": " close"}, "38462": {"logprob": -4.216621398925781, "rank": 4, "decoded_token": " curious"}, "4329": {"logprob": -4.404121398925781, "rank": 5, "decoded_token": " large"}}, {"10575": {"logprob": -0.12086891382932663, "rank": 1, "decoded_token": " dog"}, "116572": {"logprob": -2.3708689212799072, "rank": 2, "decoded_token": " puppy"}, "119075": {"logprob": -3.9958689212799072, "rank": 3, "decoded_token": " Labrador"}, "15812": {"logprob": -7.683368682861328, "rank": 4, "decoded_token": " Lab"}, "8636": {"logprob": -7.808368682861328, "rank": 5, "decoded_token": " lab"}}, {"53048": {"logprob": -0.8729249238967896, "rank": 1, "decoded_token": " sits"}, "1454": {"logprob": -1.1229249238967896, "rank": 2, "decoded_token": " with"}, "1395": {"logprob": -2.4354248046875, "rank": 3, "decoded_token": " is"}, "18970": {"logprob": -2.6854248046875, "rank": 4, "decoded_token": " sitting"}, "22524": {"logprob": -3.6854248046875, "rank": 5, "decoded_token": " lies"}}, {"41132": {"logprob": -0.5888903737068176, "rank": 1, "decoded_token": " attent"}, "106534": {"logprob": -1.2763903141021729, "rank": 2, "decoded_token": " calmly"}, "12276": {"logprob": -2.838890314102173, "rank": 3, "decoded_token": " alert"}, "1408": {"logprob": -2.901390314102173, "rank": 4, "decoded_token": " on"}, "6482": {"logprob": -5.026390552520752, "rank": 5, "decoded_token": " patient"}}, {"3923": {"logprob": -9.16677454370074e-05, "rank": 1, "decoded_token": "ively"}, "1556": {"logprob": -9.625091552734375, "rank": 2, "decoded_token": "ive"}, "6655": {"logprob": -10.875091552734375, "rank": 3, "decoded_token": "atively"}, "3929": {"logprob": -13.125091552734375, "rank": 4, "decoded_token": "ently"}, "47885": {"logprob": -13.750091552734375, "rank": 5, "decoded_token": "edly"}}, {"1408": {"logprob": -0.052677519619464874, "rank": 1, "decoded_token": " on"}, "3675": {"logprob": -3.802677631378174, "rank": 2, "decoded_token": " against"}, "1454": {"logprob": -4.302677631378174, "rank": 3, "decoded_token": " with"}, "1294": {"logprob": -5.177677631378174, "rank": 4, "decoded_token": " in"}, "7283": {"logprob": -5.427677631378174, "rank": 5, "decoded_token": " looking"}}, {"1261": {"logprob": -0.36706605553627014, "rank": 1, "decoded_token": " a"}, "32656": {"logprob": -1.2420660257339478, "rank": 2, "decoded_token": " wooden"}, "17253": {"logprob": -4.617065906524658, "rank": 3, "decoded_token": " weather"}, "44130": {"logprob": -5.742065906524658, "rank": 4, "decoded_token": " rust"}, "12603": {"logprob": -6.617065906524658, "rank": 5, "decoded_token": " wood"}}, {"32656": {"logprob": -0.07824385166168213, "rank": 1, "decoded_token": " wooden"}, "44130": {"logprob": -2.8282437324523926, "rank": 2, "decoded_token": " rust"}, "17253": {"logprob": -4.703243732452393, "rank": 3, "decoded_token": " weather"}, "12603": {"logprob": -5.828243732452393, "rank": 4, "decoded_token": " wood"}, "3403": {"logprob": -5.953243732452393, "rank": 5, "decoded_token": " text"}}, {"11237": {"logprob": -0.5853750705718994, "rank": 1, "decoded_token": " floor"}, "4691": {"logprob": -1.0853750705718994, "rank": 2, "decoded_token": " surface"}, "7042": {"logprob": -2.7103750705718994, "rank": 3, "decoded_token": " background"}, "28984": {"logprob": -3.5853750705718994, "rank": 4, "decoded_token": " deck"}, "92504": {"logprob": -6.08537483215332, "rank": 5, "decoded_token": " backdrop"}}, {"1626": {"logprob": -0.7340722680091858, "rank": 1, "decoded_token": ".\n"}, "1044": {"logprob": -0.8590722680091858, "rank": 2, "decoded_token": ","}, "1454": {"logprob": -3.359072208404541, "rank": 3, "decoded_token": " with"}, "7283": {"logprob": -3.609072208404541, "rank": 4, "decoded_token": " looking"}, "1321": {"logprob": -4.109072208404541, "rank": 5, "decoded_token": " and"}}, {"1050": {"logprob": -1.1324817933200393e-05, "rank": 1, "decoded_token": "2"}, "1051": {"logprob": -11.625011444091797, "rank": 2, "decoded_token": "3"}, "1256": {"logprob": -14.000011444091797, "rank": 3, "decoded_token": " "}, "1049": {"logprob": -14.625011444091797, "rank": 4, "decoded_token": "1"}, "1032": {"logprob": -14.625011444091797, "rank": 5, "decoded_token": " "}}, {"1046": {"logprob": -2.50339189733495e-06, "rank": 1, "decoded_token": "."}, "3590": {"logprob": -13.56250286102295, "rank": 2, "decoded_token": ".A"}, "1626": {"logprob": -15.43750286102295, "rank": 3, "decoded_token": ".\n"}, "4700": {"logprob": -15.50000286102295, "rank": 4, "decoded_token": ".M"}, "3051": {"logprob": -16.000001907348633, "rank": 5, "decoded_token": ".S"}}, {"1349": {"logprob": -0.6769706010818481, "rank": 1, "decoded_token": " A"}, "11826": {"logprob": -1.9269706010818481, "rank": 2, "decoded_token": " Maj"}, "37159": {"logprob": -2.1144704818725586, "rank": 3, "decoded_token": " Snow"}, "27260": {"logprob": -2.6144704818725586, "rank": 4, "decoded_token": " Mountain"}, "113465": {"logprob": -2.8644704818725586, "rank": 5, "decoded_token": " Rug"}}, {"15375": {"logprob": -0.9251430034637451, "rank": 1, "decoded_token": " vast"}, "10726": {"logprob": -2.300143003463745, "rank": 2, "decoded_token": " scen"}, "4521": {"logprob": -2.362643003463745, "rank": 3, "decoded_token": " range"}, "122203": {"logprob": -2.425143003463745, "rank": 4, "decoded_token": " rugged"}, "61082": {"logprob": -2.800143003463745, "rank": 5, "decoded_token": " panor"}}, {"24361": {"logprob": -0.5277582406997681, "rank": 1, "decoded_token": " mountain"}, "127945": {"logprob": -1.902758240699768, "rank": 2, "decoded_token": " mountainous"}, "28035": {"logprob": -2.5277581214904785, "rank": 3, "decoded_token": " landscape"}, "4521": {"logprob": -2.5277581214904785, "rank": 4, "decoded_token": " range"}, "1044": {"logprob": -2.7777581214904785, "rank": 5, "decoded_token": ","}}, {"4521": {"logprob": -0.055658817291259766, "rank": 1, "decoded_token": " range"}, "28035": {"logprob": -2.9306588172912598, "rank": 2, "decoded_token": " landscape"}, "37691": {"logprob": -8.430658340454102, "rank": 3, "decoded_token": " valley"}, "13327": {"logprob": -9.055658340454102, "rank": 4, "decoded_token": " scene"}, "3719": {"logprob": -9.805658340454102, "rank": 5, "decoded_token": " view"}}, {"94973": {"logprob": -0.6880245208740234, "rank": 1, "decoded_token": " stretches"}, "2425": {"logprob": -1.7505245208740234, "rank": 2, "decoded_token": " under"}, "1395": {"logprob": -2.3130245208740234, "rank": 3, "decoded_token": " is"}, "1454": {"logprob": -2.6880245208740234, "rank": 4, "decoded_token": " with"}, "7038": {"logprob": -3.2505245208740234, "rank": 5, "decoded_token": " extends"}}, {"5669": {"logprob": -0.4545598328113556, "rank": 1, "decoded_token": " across"}, "2425": {"logprob": -1.4545598030090332, "rank": 2, "decoded_token": " under"}, "1848": {"logprob": -2.454559803009033, "rank": 3, "decoded_token": " out"}, "2203": {"logprob": -4.204559803009033, "rank": 4, "decoded_token": " into"}, "25136": {"logprob": -4.642059803009033, "rank": 5, "decoded_token": " beneath"}}, {"1278": {"logprob": -0.23015151917934418, "rank": 1, "decoded_token": " the"}, "1261": {"logprob": -1.6051515340805054, "rank": 2, "decoded_token": " a"}, "1420": {"logprob": -5.605151653289795, "rank": 3, "decoded_token": " an"}, "2425": {"logprob": -7.167651653289795, "rank": 4, "decoded_token": " under"}, "1454": {"logprob": -10.167651176452637, "rank": 5, "decoded_token": " with"}}, {"48932": {"logprob": -0.2797861397266388, "rank": 1, "decoded_token": " horizon"}, "21283": {"logprob": -2.0297861099243164, "rank": 2, "decoded_token": " sky"}, "3937": {"logprob": -3.2797861099243164, "rank": 3, "decoded_token": " image"}, "28035": {"logprob": -3.6547861099243164, "rank": 4, "decoded_token": " landscape"}, "3044": {"logprob": -3.7797861099243164, "rank": 5, "decoded_token": " sk"}}, {"2425": {"logprob": -0.28862035274505615, "rank": 1, "decoded_token": " under"}, "1044": {"logprob": -2.4136204719543457, "rank": 2, "decoded_token": ","}, "1454": {"logprob": -2.5386204719543457, "rank": 3, "decoded_token": " with"}, "1626": {"logprob": -3.7886204719543457, "rank": 4, "decoded_token": ".\n"}, "1408": {"logprob": -3.9136204719543457, "rank": 5, "decoded_token": " on"}}, {"1261": {"logprob": -0.04524127021431923, "rank": 1, "decoded_token": " a"}, "16152": {"logprob": -4.045241355895996, "rank": 2, "decoded_token": " cloud"}, "1420": {"logprob": -4.045241355895996, "rank": 3, "decoded_token": " an"}, "2136": {"logprob": -6.107741355895996, "rank": 4, "decoded_token": " over"}, "6133": {"logprob": -6.357741355895996, "rank": 5, "decoded_token": " clear"}}, {"16152": {"logprob": -0.19613930583000183, "rank": 1, "decoded_token": " cloud"}, "6133": {"logprob": -2.883639335632324, "rank": 2, "decoded_token": " clear"}, "27254": {"logprob": -3.508639335632324, "rank": 3, "decoded_token": " partly"}, "18416": {"logprob": -3.883639335632324, "rank": 4, "decoded_token": " haz"}, "4391": {"logprob": -4.321139335632324, "rank": 5, "decoded_token": " light"}}, {"1121": {"logprob": -0.05146069824695587, "rank": 1, "decoded_token": "y"}, "1286": {"logprob": -3.8014607429504395, "rank": 2, "decoded_token": "ed"}, "77187": {"logprob": -4.5514607429504395, "rank": 3, "decoded_token": "-filled"}, "114525": {"logprob": -4.9264607429504395, "rank": 4, "decoded_token": "-covered"}, "4527": {"logprob": -4.9264607429504395, "rank": 5, "decoded_token": "less"}}, {"21283": {"logprob": -0.00033122775494121015, "rank": 1, "decoded_token": " sky"}, "10991": {"logprob": -8.875330924987793, "rank": 2, "decoded_token": " blue"}, "1044": {"logprob": -9.500330924987793, "rank": 3, "decoded_token": ","}, "26549": {"logprob": -10.500330924987793, "rank": 4, "decoded_token": " gray"}, "34052": {"logprob": -11.375330924987793, "rank": 5, "decoded_token": " grey"}}, {"1626": {"logprob": -0.00012683063687290996, "rank": 1, "decoded_token": ".\n"}, "1044": {"logprob": -9.500126838684082, "rank": 2, "decoded_token": ","}, "1046": {"logprob": -10.500126838684082, "rank": 3, "decoded_token": "."}, "1454": {"logprob": -10.875126838684082, "rank": 4, "decoded_token": " with"}, "1294": {"logprob": -13.375126838684082, "rank": 5, "decoded_token": " in"}}, {"1051": {"logprob": -3.2186455882765586e-06, "rank": 1, "decoded_token": "3"}, "1052": {"logprob": -12.75000286102295, "rank": 2, "decoded_token": "4"}, "1050": {"logprob": -15.00000286102295, "rank": 3, "decoded_token": "2"}, "1049": {"logprob": -17.000003814697266, "rank": 4, "decoded_token": "1"}, "1032": {"logprob": -17.937503814697266, "rank": 5, "decoded_token": " "}}, {"1046": {"logprob": -1.9073468138230965e-06, "rank": 1, "decoded_token": "."}, "3590": {"logprob": -14.625001907348633, "rank": 2, "decoded_token": ".A"}, "5226": {"logprob": -15.625001907348633, "rank": 3, "decoded_token": ".D"}, "6847": {"logprob": -15.750001907348633, "rank": 4, "decoded_token": ".T"}, "4700": {"logprob": -16.750001907348633, "rank": 5, "decoded_token": ".M"}}, {"8342": {"logprob": -0.5928499102592468, "rank": 1, "decoded_token": " Sur"}, "1349": {"logprob": -1.6553499698638916, "rank": 2, "decoded_token": " A"}, "22468": {"logprob": -2.5303499698638916, "rank": 3, "decoded_token": " Several"}, "1488": {"logprob": -2.7178499698638916, "rank": 4, "decoded_token": " W"}, "15035": {"logprob": -3.2178499698638916, "rank": 5, "decoded_token": " People"}}, {"71284": {"logprob": -0.003268140833824873, "rank": 1, "decoded_token": "fers"}, "1102": {"logprob": -5.878268241882324, "rank": 2, "decoded_token": "f"}, "1726": {"logprob": -7.753268241882324, "rank": 3, "decoded_token": "fer"}, "61888": {"logprob": -12.315768241882324, "rank": 4, "decoded_token": "fline"}, "2119": {"logprob": -13.065768241882324, "rank": 5, "decoded_token": "fter"}}, {"7377": {"logprob": -1.4883846044540405, "rank": 1, "decoded_token": " wait"}, "1584": {"logprob": -1.7383846044540405, "rank": 2, "decoded_token": " are"}, "88014": {"logprob": -1.9258846044540405, "rank": 3, "decoded_token": " paddle"}, "1294": {"logprob": -1.9258846044540405, "rank": 4, "decoded_token": " in"}, "24434": {"logprob": -2.23838472366333, "rank": 5, "decoded_token": " ride"}}, {"1394": {"logprob": -0.6120346188545227, "rank": 1, "decoded_token": " for"}, "1294": {"logprob": -0.9870346188545227, "rank": 2, "decoded_token": " in"}, "1408": {"logprob": -2.737034559249878, "rank": 3, "decoded_token": " on"}, "6482": {"logprob": -4.487034797668457, "rank": 4, "decoded_token": " patient"}, "1321": {"logprob": -5.612034797668457, "rank": 5, "decoded_token": " and"}}, {"22140": {"logprob": -0.008224429562687874, "rank": 1, "decoded_token": " waves"}, "1278": {"logprob": -5.5082244873046875, "rank": 2, "decoded_token": " the"}, "1261": {"logprob": -5.6332244873046875, "rank": 3, "decoded_token": " a"}, "39460": {"logprob": -8.133224487304688, "rank": 4, "decoded_token": " incoming"}, "1321": {"logprob": -9.758224487304688, "rank": 5, "decoded_token": " and"}}, {"1294": {"logprob": -0.3204176723957062, "rank": 1, "decoded_token": " in"}, "1408": {"logprob": -2.195417642593384, "rank": 2, "decoded_token": " on"}, "1513": {"logprob": -2.320417642593384, "rank": 3, "decoded_token": " at"}, "3016": {"logprob": -3.695417642593384, "rank": 4, "decoded_token": " while"}, "1435": {"logprob": -3.820417642593384, "rank": 5, "decoded_token": " as"}}, {"1278": {"logprob": -0.004615250043570995, "rank": 1, "decoded_token": " the"}, "1261": {"logprob": -6.192115306854248, "rank": 2, "decoded_token": " a"}, "1420": {"logprob": -6.942115306854248, "rank": 3, "decoded_token": " an"}, "40466": {"logprob": -7.317115306854248, "rank": 4, "decoded_token": " shallow"}, "26517": {"logprob": -7.879615306854248, "rank": 5, "decoded_token": " calm"}}, {"27208": {"logprob": -0.06491076946258545, "rank": 1, "decoded_token": " ocean"}, "7786": {"logprob": -3.439910888671875, "rank": 2, "decoded_token": " distance"}, "5124": {"logprob": -5.314910888671875, "rank": 3, "decoded_token": " early"}, "26517": {"logprob": -5.377410888671875, "rank": 4, "decoded_token": " calm"}, "11196": {"logprob": -5.377410888671875, "rank": 5, "decoded_token": " sea"}}, {"1513": {"logprob": -1.144903540611267, "rank": 1, "decoded_token": " at"}, "1435": {"logprob": -1.269903540611267, "rank": 2, "decoded_token": " as"}, "3184": {"logprob": -1.394903540611267, "rank": 3, "decoded_token": " during"}, "3016": {"logprob": -3.0199036598205566, "rank": 4, "decoded_token": " while"}, "6117": {"logprob": -3.1449036598205566, "rank": 5, "decoded_token": " near"}}, {"97558": {"logprob": -0.12556149065494537, "rank": 1, "decoded_token": " sunset"}, "11729": {"logprob": -2.875561475753784, "rank": 2, "decoded_token": " sun"}, "1266": {"logprob": -3.375561475753784, "rank": 3, "decoded_token": " d"}, "54507": {"logprob": -4.000561714172363, "rank": 4, "decoded_token": " dawn"}, "1261": {"logprob": -5.125561714172363, "rank": 5, "decoded_token": " a"}}, {"1626": {"logprob": -0.26737067103385925, "rank": 1, "decoded_token": ".\n"}, "1044": {"logprob": -2.2673707008361816, "rank": 2, "decoded_token": ","}, "3016": {"logprob": -2.7673707008361816, "rank": 3, "decoded_token": " while"}, "1454": {"logprob": -3.5173707008361816, "rank": 4, "decoded_token": " with"}, "6117": {"logprob": -4.142370700836182, "rank": 5, "decoded_token": " near"}}, {"1052": {"logprob": -2.9802276912960224e-06, "rank": 1, "decoded_token": "4"}, "1051": {"logprob": -13.37500286102295, "rank": 2, "decoded_token": "3"}, "1049": {"logprob": -14.00000286102295, "rank": 3, "decoded_token": "1"}, "1053": {"logprob": -14.56250286102295, "rank": 4, "decoded_token": "5"}, "1032": {"logprob": -16.750003814697266, "rank": 5, "decoded_token": " "}}, {"1046": {"logprob": -1.6689286894688848e-06, "rank": 1, "decoded_token": "."}, "3590": {"logprob": -13.500001907348633, "rank": 2, "decoded_token": ".A"}, "6847": {"logprob": -16.562501907348633, "rank": 3, "decoded_token": ".T"}, "1044": {"logprob": -17.312501907348633, "rank": 4, "decoded_token": ","}, "1349": {"logprob": -17.500001907348633, "rank": 5, "decoded_token": " A"}}, {"1349": {"logprob": -0.004883386194705963, "rank": 1, "decoded_token": " A"}, "2048": {"logprob": -5.504883289337158, "rank": 2, "decoded_token": " An"}, "10638": {"logprob": -7.754883289337158, "rank": 3, "decoded_token": " Two"}, "111463": {"logprob": -9.754883766174316, "rank": 4, "decoded_token": " Trees"}, "1531": {"logprob": -10.692383766174316, "rank": 5, "decoded_token": " The"}}, {"53301": {"logprob": -1.5612412691116333, "rank": 1, "decoded_token": " winding"}, "15192": {"logprob": -1.7487412691116333, "rank": 2, "decoded_token": " narrow"}, "47945": {"logprob": -2.1237411499023438, "rank": 3, "decoded_token": " dirt"}, "2169": {"logprob": -2.5612411499023438, "rank": 4, "decoded_token": " ser"}, "59396": {"logprob": -2.6862411499023438, "rank": 5, "decoded_token": " gravel"}}, {"59396": {"logprob": -0.9024254083633423, "rank": 1, "decoded_token": " gravel"}, "3549": {"logprob": -1.1524254083633423, "rank": 2, "decoded_token": " path"}, "47945": {"logprob": -1.6524254083633423, "rank": 3, "decoded_token": " dirt"}, "14801": {"logprob": -3.1524252891540527, "rank": 4, "decoded_token": " pathway"}, "15551": {"logprob": -4.277425289154053, "rank": 5, "decoded_token": " stone"}}, {"3549": {"logprob": -0.021290099248290062, "rank": 1, "decoded_token": " path"}, "14801": {"logprob": -3.8962900638580322, "rank": 2, "decoded_token": " pathway"}, "33659": {"logprob": -7.896290302276611, "rank": 3, "decoded_token": " trail"}, "9480": {"logprob": -9.521289825439453, "rank": 4, "decoded_token": " road"}, "7368": {"logprob": -9.646289825439453, "rank": 5, "decoded_token": "path"}}, {"13335": {"logprob": -0.16593234241008759, "rank": 1, "decoded_token": " leads"}, "39985": {"logprob": -2.8534324169158936, "rank": 2, "decoded_token": " cuts"}, "1639": {"logprob": -3.9784324169158936, "rank": 3, "decoded_token": " me"}, "11500": {"logprob": -4.1034321784973145, "rank": 4, "decoded_token": " runs"}, "2645": {"logprob": -4.2909321784973145, "rank": 5, "decoded_token": " through"}}, {"2645": {"logprob": -0.05767015367746353, "rank": 1, "decoded_token": " through"}, "8994": {"logprob": -4.0576701164245605, "rank": 2, "decoded_token": " towards"}, "2396": {"logprob": -4.1826701164245605, "rank": 3, "decoded_token": " between"}, "2203": {"logprob": -4.5576701164245605, "rank": 4, "decoded_token": " into"}, "1317": {"logprob": -5.5576701164245605, "rank": 5, "decoded_token": " to"}}, {"1261": {"logprob": -0.017209367826581, "rank": 1, "decoded_token": " a"}, "11223": {"logprob": -4.892209529876709, "rank": 2, "decoded_token": " green"}, "1295": {"logprob": -5.017209529876709, "rank": 3, "decoded_token": " l"}, "23170": {"logprob": -6.767209529876709, "rank": 4, "decoded_token": " grass"}, "1420": {"logprob": -7.267209529876709, "rank": 5, "decoded_token": " an"}}, {"1295": {"logprob": -0.9430665969848633, "rank": 1, "decoded_token": " l"}, "11223": {"logprob": -1.3180665969848633, "rank": 2, "decoded_token": " green"}, "23170": {"logprob": -1.9430665969848633, "rank": 3, "decoded_token": " grass"}, "12097": {"logprob": -2.4430665969848633, "rank": 4, "decoded_token": " park"}, "26428": {"logprob": -3.3180665969848633, "rank": 5, "decoded_token": " garden"}}, {"3506": {"logprob": -6.556489552167477e-06, "rank": 1, "decoded_token": "ush"}, "1374": {"logprob": -12.000006675720215, "rank": 2, "decoded_token": "us"}, "90716": {"logprob": -15.625006675720215, "rank": 3, "decoded_token": "USH"}, "16938": {"logprob": -15.875006675720215, "rank": 4, "decoded_token": "usher"}, "13326": {"logprob": -17.1875057220459, "rank": 5, "decoded_token": "inden"}}, {"11223": {"logprob": -0.36697858572006226, "rank": 1, "decoded_token": " green"}, "1044": {"logprob": -1.366978645324707, "rank": 2, "decoded_token": ","}, "26428": {"logprob": -3.491978645324707, "rank": 3, "decoded_token": " garden"}, "12097": {"logprob": -4.116978645324707, "rank": 4, "decoded_token": " park"}, "23170": {"logprob": -5.866978645324707, "rank": 5, "decoded_token": " grass"}}, {"12097": {"logprob": -0.5570574402809143, "rank": 1, "decoded_token": " park"}, "3727": {"logprob": -1.9320573806762695, "rank": 2, "decoded_token": " field"}, "28035": {"logprob": -2.1820573806762695, "rank": 3, "decoded_token": " landscape"}, "26428": {"logprob": -2.4320573806762695, "rank": 4, "decoded_token": " garden"}, "4457": {"logprob": -2.8070573806762695, "rank": 5, "decoded_token": " area"}}, {"1046": {"logprob": -0.7940837144851685, "rank": 1, "decoded_token": "."}, "1454": {"logprob": -1.2940837144851685, "rank": 2, "decoded_token": " with"}, "8994": {"logprob": -2.794083595275879, "rank": 3, "decoded_token": " towards"}, "54410": {"logprob": -3.544083595275879, "rank": 4, "decoded_token": " lined"}, "2425": {"logprob": -3.544083595275879, "rank": 5, "decoded_token": " under"}}, {"2": {"logprob": -2.145764938177308e-06, "rank": 1, "decoded_token": ""}, "1032": {"logprob": -13.125001907348633, "rank": 2, "decoded_token": " "}, "1256": {"logprob": -16.000001907348633, "rank": 3, "decoded_token": " "}, "1293": {"logprob": -18.750001907348633, "rank": 4, "decoded_token": " "}, "1319": {"logprob": -19.687501907348633, "rank": 5, "decoded_token": " ("}}]]] \ No newline at end of file diff --git a/tests/models/fixtures/pixtral_chat.pickle b/tests/models/fixtures/pixtral_chat.pickle deleted file mode 100644 index 43d4c883c3a49c314e1b5f529f6505f36b511181..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 20865 zcmai6X>?r0mF``O7fD8zyn)MBykpsBaY!K0W;qza8z2Tu2pQCtT2fnWwP>|0oET`# zl1bu8>|innNP@^A3ke~FZ4NQM{E+(R8(ZeW|GWbUFID5c)#$w&^Uj&(ojcE)(&824_$L04L2)}%VqqOU7A&?i!2RW{YuE!uOPqAyb@WO_P9t_S|}KrYvw z?h$?c`Hoaux_Ju>a-|w){_Ip%^U?Eh_%Kv4jNmjI$=&tJkJ7t?4lF~C%umd+Pa znI5$DX`(zX4!Fo!24%IPj_sB9Es=Srf7Dk?>Lx;N8SPz!P5)bc>oQ6)OksX~7~t&mry9i&V= zRNqP_QB0L+&-C?XQ|lD+gBj(KsHaM_<$7@CtfF^aSMW2n?@t!wmTiX7(1VdbtgUi$ zwHR!D{l_+LZY3ZshKfhDw{ns@YNYE4K`?Tuwrs9X17LH4!oA6G z^@a8#nIo^trt+N%cJ4gaT>bq9Mk-*N0BTAF1Nsk#@6h&8MGW^70JdEEKRg_xL*qoq z#&tsU_V@PUzT#3maDJ2GUA1)+>ek?O=yR6E*TgD!Bw*Z%4VsSv3`A>E^J zZ1!Zm#i0&6drAY3?3b5+u{F@3BYV@;&Y=Fg7J?Eg^s5V1@a=Zuorxhbg(~cJd*Twu zm~5D5Kos5h9z;{C?s8F+EsFT(`_*hH{+!&+j~(Igcgb;!bvHjIh#ZQD3vJQk>?Te& zD3k#6rrDi6&SlqNQKX@ezCx`*`ZO0c^+jU~HB^cL*tNUT7-34riU|(b-*fVNNusCY zg7N~Knb#M%z_%@s_Nfl=Wfip>+*gK-J-xL^_E5+eu?muq&U~()yzhj|9h0(t&n`(= zk^>n;Hx7I*dxJMv6v~N;-mFp5;BL(RuAv@1GAao-qosO=7`bBVAdQC-!bc;52Tk75bR3F+N&ud%A7IUf#r+l$@_PkVO_dMj#Jl6pzPn__-cbCNgjGVjhh*mqmLz6*Zj zb+>$uKVDabS?#1#yj09r$?=NW<%bZ*mw_^dm2d@{APUZV>BZq8PZAY24>Fm0) zS#I?)Xk~4xuOQZdFJi(WAaR9kS=r0?cXp=RMQ=LQ4F(JObfK*)-6z0l(|M5+@JjA) zUnlyy*V%NLC9M8Gx>gO*oWlBb9RYJb(Va~Hz~3iHdafLAp>jMIA){-?wh}_E+cHEP zU|~?Z^aeo)TsNIvx13SOU$1o!8Wb*DYz1U%xNw;8vS``%k-x7<=8*CqVYr!$(41mq zke(~bUq(U+=16ms27#0#&5H9pW@NrF5a8Uk8%)%(=L_s0$TmFBP(8zSe1H4CqlBdj zuBZmO$%V-1^ch0|+lZKnbWdaRH{L>!ZdA;lj~Kt%i9>nB4@_=Hv2hj_R0!jdSYKmc zYLay%s!8T4;1fSjw+76e6#wKOxx-f8a*>ANTvv*={(L6aj~Q{xJ5_vr$66?=hy%2s z2=XMbW{NRLI&2ciQ_nhYfgpOQBY}RdQctCXr-6?*+&Dx~wL^`T28OQ5BngH}4wD8} zSs?1D!sKq0G;pfCe~?U~7Mlh>%!-O5lLka9l2y5)(!sW5|CN*erU{}Bo1eh07yKYZ#Pm=AmppoA!FhDePkS07sp3ed^b7B@@qhb@FWsV~=$hFU(ofqPuM;}bm` zDi|ZuSM)*!*av$p4tw-uS8dg9gGC)MZ@CVWj;r^BgZ0}#o$knp;f9XGZ{5v@pM2_V zg4CjtSbs-{s>%MsR$}@F141nks&JIakxdN^+rAFe57U@9blm6L>T zFFF?e>ZhO6%*d7SZG=;D>ZuM&&Mqco$a!l63PEiRE@;l60nu8d*O8Yrvg~he$e>U{ zLaZ|*uDruFHeVPRN~jREUP%jvzR_&`#uds z4d5uM#^srKgT)|!R$jBrUp6i?#y~=s)EN=qGoSqY^dVw~(_gkA{p_+|4H5+B-mUp~ z1=g_gEkQWs?p9zW<>n*{R0}XlaX59SF^3Cai$k>nyeNHO&-e7esKF)3i_-Y-PHH8B zC?qIKxr_-rf2yq=B7;}2jTT5U%2%%eie%{Ol_G%^rEmQ8)Ip+$GT?nUT2UGRQ3ssI zmM=;xj@_i~p^(EDrL{Ld-b!PW7G0FKT%6%M#UzVEHC3w6ZO^M$k5Y7r8@0at&a@%2 zs+_C2S_L*shNG36B|$(7`Y6xXW$ks#wFwZRGWIXCfT_37k$AMZT8&W>FcyQP-GEV? zr?`at*?ZRO0g-d~7^KXL#F?Q6jlk)rBGKao=`3j=$oe@jF^;jGlQhp2YoQ8FZ>mc~ z&b1h19X=L_D7w6&d1>STOrL%csQKab$~+m?&lOp!SJ*1a`NBZZPa?ERe3lnlh&#)? zc3KE86=DK>yY^|eyr_A3*ZE1pD2Je;<`}mc@%O*y?lv*urB>q-%-QU2&f)zRXgf$b zE>`s91_0IarLEjuYb*|V$HZ6r)o~>+YyuEYv^io>b6s9*>&T{JiSo4q6kx=i0h)fmOh>k^LvXdQXx@CihZD zWjOL5TiuakKVJZ-AiFE1^Lc8^wp=#M2d}VD)Dl8lJZBqg0iZ&~(Yf!tW7R!-Q<26( z_RtZ7LqSzGsdc~;MF;KgUHM3o&?ulfIOW5ESsml=vrt+JeF(gcdDK5GsKld6yyZ8q z2pSfb!u6IC*Q;J&fT%{ho>~Y#?mK^2Xk+MsXcd+O2~^qAyNnVy1bpU!M@Ov&1*rwJO80;7-{NafpcOeKRS z+iPXFHi^qowy!(!`g^oJlwkV0ycv6x#SndJ#(Ebs`ZIs_0`NkoB>EVzIAfDeap}+e zlR;51s6$^voT~9huj)+8p+Bsj@gyk}-p3oz~_vbqxIr-Y}N&=yZIP?GnA}QD;7-#)@6yBis^9@Vy2Mh>=;F*VnwN5!-K zImxgf5VFQ>pBD^^MQD(oD`%l(e-87fC9MX6q$6T3uIO&Jn3Bgpjzw0Gb3}X{AfxoA z4(&$0zQ8NjjiMIRtOmI4;eEj`?M1;armFtck3>tM0eV$Crg{0exBen;d>ntVe6{yx zyrtnvcmmLArqy-)GhXZsZ`U}l@i_0oYV2;+#JzjRc{>+)%NFSd*Bd1buEX}TRDIh# zdN<3aKzhT++bNezI)sasn@iB4BYri=%oTTR%0Z{vubj{}5=PClI3gum~g+*fm zxON=`6P)IvR`cAY$;(SQB`YjV2$6CV9H?C<8A3==)Vl3GbP0+k4s&5mYq$@ZLTQ^yh-rN3yTM7Sd`#2 z2}&CP{zutKNkXEG(Hj7X2geG+kzd&c06P+w6-x~awU{IEk|x2_gf@jTDseX*y-6^s z;=hMzT(ubXV85hEkQ@sj6C2U8O@dj{U{FEDs7-?T))q<_tx2$;Xsju38of#I>jN8; zfdDfjyh$*=`xTABMPN+==5(^|_`QORp@zWecpg;$fI-nBlxrnSv=~~1;HI8(Ctkhm zyvp2bz(pwAD){iF;vg|2KcZ{b=))yv>;prNg2S{58ulBYT82?t1yig!3JHA@uQc#h z!Tl9a0zY&S3WX|Z`_+C78__k!+%x zf7L43bLI7#fpV@gtpfjT3jsB%IN)JX{ZcAP1M;xo&495Mh4kHK06!3swYhr+$sF=G zja}tatYffkkH23K5U8(yI~)puzw@(vs1S+-0qUziAqX!Ljt&_Jve#V&cHv4cv{rz+ z${Kkc50ZBcTasYPFUoZchM#{)TS6}%rDG6!D_0~t zD-c`cW*XDrRA-#o!#gczeZ$V8tZyUUO%cFW|ln}HH_(6Z9%}^UX~8cU^rRFf^|M#}p@5 zarx((9Ii_Vxsu+&|6T3qaC-;5D=>4K!I22;>8Qh*Z?mh0O+Pj;TpIA3Od6Y9xH5-y z1$M{Iam17y^zyURe@13&C<=1UQh8S(a?KGMhKpiyR=(rLEe3@IiO5y#a&&drSVA3@ zaF+^h(FaGQ{0Mhf;|2la+81e0xZoHH*t`*D=|JKl&1;)UcZ%~1*{}WVHpM1m`0&9%th$3{{lG*=y|8G2TK^NfZu1%_6Xg@t?h7yQYUa z98c$mj?r_S;-^er5wp)xK;Y&4tvKGIkiYJt^RlSMLWsBefJJSzJkkE;yj_B(6?tpF zV2AOtwt0ml2)VyKjOSm81u)lrc%xuak1<8wZyE*cSYJ-IDC8a->nr+k&44K3V!{y@tPi} zSp8r`Wp%s>Fh0|s3>4@lK51xMU(kf@;2X%~3BLlFQ`MND@n z)JpJ9!SN5&x6;_S5;i`0r(lL$ph4(Lqjd_JLIwpgRFLpa!7Tfnyr_JqV746#ne$Eo z3nyRP_J^T>x(+8Ew@sniL4p8rMTw=la+ZCLl<~4Z-w^LwJKH$*txb~Zyf-=dJR zZYQX@W-0!*Cr3UPmas|jr@6f+Ij?B4Vw~ylak!zfQbIP{ycsB)Qi2x&OG4&c3RssZ z0+xJWo&(=eiU5CP7^v!(Y?&fpVba(@qWU9;)%EVQEV2-E>qztID8`?{AVp&Z=Qq-P z0YZ44uvqdU0P6%bw?7h&CI6KJgRPb?_%R&bqb*)399(~1l0b?yRSJ)=d+D*0=L&)# z=}NI84NcLMx=H>FTlDxg160(-N*r$MmmN;B%0kdDD#Ur|T(9~x75S@niG%i6KkDAQ ziDByjP5vA6kH%{+mVttB2O_qy9?kK95-hNjrSy67?k(NZJOW z-Gd7FkndX%DlkLx7kK>jd@C7J)33ik<>#7p`yNUX2C2K}nkDB84S54pE1*ABRV%Rm zB4nm;tRJp!%1=aQf&tX zR6$FuEw@I0zRRwR$foNI@(=}lZlAyAAwkqoLHW52Z~CIf(EQWay>Ge;DI*p`^H2Bc z9%sqnIEz90Fk4+2@~~T(sqq$u+~J;uy|Te()z9zgl>zngiM*Wj$U{Q}LhfY~`KcD7 zsjFLj6MpK>{E?>{GvqpbU2yqKU2ODj;%)is43JzO@9UfBUG5%e`1kx|7>Jp2DnX4+ z^qtj5xg(Fc?gI@&0qPGKz^B;o=P&>nWBdNZQ?JlS)RDli5$X0kXE?=e&d~-%$ zuyZG>ocre_jaTACue*0$WR-={IzT~ac2U(9cZZe}Iz%7=E^evLGEVRg*0Ze$&u%3g z%7C1$UwO(t_NEwPL8v1l)YZDTpLfQSwV@l6WQSIuJvE>V@%^B_Xi%ttJXdW&JR93F z?dBmeh5=kQ8;djz_BlC@jmx`PT#mIcqHU5gKJm4i&5nM^#KS!VXBxIW zKJt0dGGF~btjyu;eCecRCUfIS*U@fXrjo8B-BIL1VN4+F@bS2;{A_lyG|3P!9wjZ? zIx-w^Et8fl2pJpb8agfGIoIM%CZSS}K{*#so0pttAiz~E6w9OJlA?J|E`d#S%DN0a3~|L?yjI%Cc-~xQC!51FFCek)d}+@*+LO2!|58JdtttN zUtZQto7hS~6au%O3ptgFDsSTaZhC5hhN6yuA0N49!7YLeq6#+hx2|<5PQU3X4MZUc zQGolnD$TE)<-(e(cNfVXDuEO-mr^_F<8B5ln_3J?ITc>}F^kdY-uCt%{mL1wWDZ5Z zEeX6EgtNTl$L^?!=bqOvS_}A^S)n@Jj9S)M6iSGSHR@&j&e5)=+RR&m2GsoFor2?C zSnZkJG9M1*pr4VSo6Jk>&wI$BE6txKW;0(WFZEPNrQuuL(lO1kM7?r=sE4BCSBxvRpOH6MM9wj*a7V&eeb(0X&!5VC;)bF zI6F~O!Ws)QfYDNzjtedYzdyWvWLVP3xDxhn-0<9RzCB9=5#{a=+g3T!*H;7h2ZeAs dSkO2~JgPq^{|7y`q4EF# diff --git a/tests/models/fixtures/pixtral_chat_engine.json b/tests/models/fixtures/pixtral_chat_engine.json new file mode 100644 index 0000000000000..60e4ae6cebf59 --- /dev/null +++ b/tests/models/fixtures/pixtral_chat_engine.json @@ -0,0 +1 @@ +[[[1784, 3937, 6122, 1261, 7244, 10575, 18970, 1408, 1261, 32656, 4691, 1046, 2], "The image shows a black dog sitting on a wooden surface.", [{"1784": {"logprob": -0.11685245484113693, "rank": 1, "decoded_token": "The"}, "4380": {"logprob": -2.3668525218963623, "rank": 2, "decoded_token": "This"}, "1049": {"logprob": -4.741852283477783, "rank": 3, "decoded_token": "1"}, "117991": {"logprob": -5.991852283477783, "rank": 4, "decoded_token": "Certain"}, "1785": {"logprob": -5.991852283477783, "rank": 5, "decoded_token": "In"}}, {"3937": {"logprob": -0.2591013014316559, "rank": 1, "decoded_token": " image"}, "2158": {"logprob": -1.5091012716293335, "rank": 2, "decoded_token": " first"}, "3977": {"logprob": -5.884101390838623, "rank": 3, "decoded_token": " top"}, "7244": {"logprob": -6.259101390838623, "rank": 4, "decoded_token": " black"}, "8061": {"logprob": -6.759101390838623, "rank": 5, "decoded_token": " images"}}, {"6122": {"logprob": -0.9660423994064331, "rank": 1, "decoded_token": " shows"}, "51948": {"logprob": -1.466042399406433, "rank": 2, "decoded_token": " depicts"}, "6971": {"logprob": -1.466042399406433, "rank": 3, "decoded_token": " features"}, "25981": {"logprob": -2.8410425186157227, "rank": 4, "decoded_token": " displays"}, "8688": {"logprob": -2.8410425186157227, "rank": 5, "decoded_token": " contains"}}, {"1261": {"logprob": -0.0030613720882683992, "rank": 1, "decoded_token": " a"}, "1420": {"logprob": -6.253061294555664, "rank": 2, "decoded_token": " an"}, "2295": {"logprob": -7.878061294555664, "rank": 3, "decoded_token": " two"}, "2342": {"logprob": -7.878061294555664, "rank": 4, "decoded_token": " only"}, "1278": {"logprob": -8.628061294555664, "rank": 5, "decoded_token": " the"}}, {"7244": {"logprob": -0.17649099230766296, "rank": 1, "decoded_token": " black"}, "6231": {"logprob": -2.3014910221099854, "rank": 2, "decoded_token": " close"}, "4249": {"logprob": -3.4264910221099854, "rank": 3, "decoded_token": " single"}, "4329": {"logprob": -5.113990783691406, "rank": 4, "decoded_token": " large"}, "10575": {"logprob": -5.176490783691406, "rank": 5, "decoded_token": " dog"}}, {"10575": {"logprob": -0.10929587483406067, "rank": 1, "decoded_token": " dog"}, "116572": {"logprob": -2.4842958450317383, "rank": 2, "decoded_token": " puppy"}, "119075": {"logprob": -4.109295845031738, "rank": 3, "decoded_token": " Labrador"}, "15812": {"logprob": -7.296795845031738, "rank": 4, "decoded_token": " Lab"}, "7990": {"logprob": -7.484295845031738, "rank": 5, "decoded_token": " cat"}}, {"18970": {"logprob": -0.830376148223877, "rank": 1, "decoded_token": " sitting"}, "1454": {"logprob": -1.580376148223877, "rank": 2, "decoded_token": " with"}, "28528": {"logprob": -1.955376148223877, "rank": 3, "decoded_token": " lying"}, "7283": {"logprob": -2.205376148223877, "rank": 4, "decoded_token": " looking"}, "15866": {"logprob": -3.017876148223877, "rank": 5, "decoded_token": " standing"}}, {"1408": {"logprob": -0.08554735779762268, "rank": 1, "decoded_token": " on"}, "1321": {"logprob": -3.71054744720459, "rank": 2, "decoded_token": " and"}, "3675": {"logprob": -3.96054744720459, "rank": 3, "decoded_token": " against"}, "41132": {"logprob": -4.71054744720459, "rank": 4, "decoded_token": " attent"}, "1454": {"logprob": -5.08554744720459, "rank": 5, "decoded_token": " with"}}, {"1261": {"logprob": -0.540847897529602, "rank": 1, "decoded_token": " a"}, "32656": {"logprob": -0.915847897529602, "rank": 2, "decoded_token": " wooden"}, "12603": {"logprob": -5.4158477783203125, "rank": 3, "decoded_token": " wood"}, "3977": {"logprob": -5.4158477783203125, "rank": 4, "decoded_token": " top"}, "17253": {"logprob": -6.2908477783203125, "rank": 5, "decoded_token": " weather"}}, {"32656": {"logprob": -0.025753861293196678, "rank": 1, "decoded_token": " wooden"}, "44130": {"logprob": -4.400753974914551, "rank": 2, "decoded_token": " rust"}, "12603": {"logprob": -5.275753974914551, "rank": 3, "decoded_token": " wood"}, "3403": {"logprob": -5.400753974914551, "rank": 4, "decoded_token": " text"}, "17253": {"logprob": -6.963253974914551, "rank": 5, "decoded_token": " weather"}}, {"4691": {"logprob": -0.7265751957893372, "rank": 1, "decoded_token": " surface"}, "11237": {"logprob": -0.8515751957893372, "rank": 2, "decoded_token": " floor"}, "7042": {"logprob": -2.6015751361846924, "rank": 3, "decoded_token": " background"}, "28984": {"logprob": -5.2265753746032715, "rank": 4, "decoded_token": " deck"}, "1615": {"logprob": -5.7265753746032715, "rank": 5, "decoded_token": " pl"}}, {"1046": {"logprob": -0.4868825674057007, "rank": 1, "decoded_token": "."}, "1044": {"logprob": -1.9868825674057007, "rank": 2, "decoded_token": ","}, "1321": {"logprob": -2.3618826866149902, "rank": 3, "decoded_token": " and"}, "1454": {"logprob": -2.6118826866149902, "rank": 4, "decoded_token": " with"}, "7283": {"logprob": -2.7368826866149902, "rank": 5, "decoded_token": " looking"}}, {"2": {"logprob": -0.0026643513701856136, "rank": 1, "decoded_token": ""}, "1531": {"logprob": -6.502664566040039, "rank": 2, "decoded_token": " The"}, "1032": {"logprob": -6.877664566040039, "rank": 3, "decoded_token": " "}, "3730": {"logprob": -9.752664566040039, "rank": 4, "decoded_token": " There"}, "1256": {"logprob": -11.002664566040039, "rank": 5, "decoded_token": " "}}]], [[1049, 1046, 1349, 7244, 10575, 1454, 2327, 94766, 32961, 53048, 41132, 3923, 1408, 1261, 32656, 4691, 1626, 1050, 1046, 1349, 15375, 24361, 4521, 94973, 5669, 1278, 48932, 2425, 1261, 16152, 1121, 21283, 1046, 2], "1. A black dog with floppy ears sits attentively on a wooden surface.\n2. A vast mountain range stretches across the horizon under a cloudy sky.", [{"1049": {"logprob": -0.42824622988700867, "rank": 1, "decoded_token": "1"}, "1045": {"logprob": -1.553246259689331, "rank": 2, "decoded_token": "-"}, "1065": {"logprob": -2.428246259689331, "rank": 3, "decoded_token": "A"}, "1784": {"logprob": -4.053246021270752, "rank": 4, "decoded_token": "The"}, "69957": {"logprob": -4.428246021270752, "rank": 5, "decoded_token": "Sure"}}, {"1046": {"logprob": -1.811964830267243e-05, "rank": 1, "decoded_token": "."}, "1058": {"logprob": -11.875018119812012, "rank": 2, "decoded_token": ":"}, "3590": {"logprob": -12.250018119812012, "rank": 3, "decoded_token": ".A"}, "1065": {"logprob": -13.062518119812012, "rank": 4, "decoded_token": "A"}, "1041": {"logprob": -13.750018119812012, "rank": 5, "decoded_token": ")"}}, {"1349": {"logprob": -0.13647246360778809, "rank": 1, "decoded_token": " A"}, "1429": {"logprob": -2.386472463607788, "rank": 2, "decoded_token": " \""}, "1603": {"logprob": -3.886472463607788, "rank": 3, "decoded_token": " **"}, "11967": {"logprob": -5.011472702026367, "rank": 4, "decoded_token": " Image"}, "1531": {"logprob": -5.011472702026367, "rank": 5, "decoded_token": " The"}}, {"7244": {"logprob": -0.18561004102230072, "rank": 1, "decoded_token": " black"}, "38462": {"logprob": -3.185610055923462, "rank": 2, "decoded_token": " curious"}, "68076": {"logprob": -3.623110055923462, "rank": 3, "decoded_token": " cute"}, "4329": {"logprob": -3.935610055923462, "rank": 4, "decoded_token": " large"}, "74168": {"logprob": -4.373109817504883, "rank": 5, "decoded_token": " gloss"}}, {"10575": {"logprob": -0.17297746241092682, "rank": 1, "decoded_token": " dog"}, "116572": {"logprob": -2.1729774475097656, "rank": 2, "decoded_token": " puppy"}, "119075": {"logprob": -3.1729774475097656, "rank": 3, "decoded_token": " Labrador"}, "15812": {"logprob": -6.985477447509766, "rank": 4, "decoded_token": " Lab"}, "8636": {"logprob": -7.360477447509766, "rank": 5, "decoded_token": " lab"}}, {"1454": {"logprob": -0.5785807967185974, "rank": 1, "decoded_token": " with"}, "53048": {"logprob": -1.2660808563232422, "rank": 2, "decoded_token": " sits"}, "1395": {"logprob": -3.016080856323242, "rank": 3, "decoded_token": " is"}, "22524": {"logprob": -3.578580856323242, "rank": 4, "decoded_token": " lies"}, "18970": {"logprob": -3.703580856323242, "rank": 5, "decoded_token": " sitting"}}, {"2327": {"logprob": -1.2709298133850098, "rank": 1, "decoded_token": " fl"}, "1261": {"logprob": -1.3959298133850098, "rank": 2, "decoded_token": " a"}, "17300": {"logprob": -1.8959298133850098, "rank": 3, "decoded_token": " soul"}, "100089": {"logprob": -2.6459298133850098, "rank": 4, "decoded_token": " expressive"}, "6444": {"logprob": -3.1459298133850098, "rank": 5, "decoded_token": " soft"}}, {"94766": {"logprob": -0.002432247158139944, "rank": 1, "decoded_token": "oppy"}, "124603": {"logprob": -6.377432346343994, "rank": 2, "decoded_token": "uffy"}, "1484": {"logprob": -7.877432346343994, "rank": 3, "decoded_token": "op"}, "24897": {"logprob": -8.877431869506836, "rank": 4, "decoded_token": "appy"}, "102477": {"logprob": -9.752431869506836, "rank": 5, "decoded_token": "opping"}}, {"32961": {"logprob": -5.113947918289341e-05, "rank": 1, "decoded_token": " ears"}, "16962": {"logprob": -11.312551498413086, "rank": 2, "decoded_token": " ear"}, "5731": {"logprob": -11.750051498413086, "rank": 3, "decoded_token": " eyes"}, "3351": {"logprob": -12.000051498413086, "rank": 4, "decoded_token": " years"}, "42071": {"logprob": -13.000051498413086, "rank": 5, "decoded_token": " cheeks"}}, {"53048": {"logprob": -0.6131591200828552, "rank": 1, "decoded_token": " sits"}, "10637": {"logprob": -1.9881591796875, "rank": 2, "decoded_token": " looks"}, "1321": {"logprob": -2.4256591796875, "rank": 3, "decoded_token": " and"}, "1395": {"logprob": -2.6756591796875, "rank": 4, "decoded_token": " is"}, "18970": {"logprob": -3.0506591796875, "rank": 5, "decoded_token": " sitting"}}, {"41132": {"logprob": -0.36187249422073364, "rank": 1, "decoded_token": " attent"}, "1408": {"logprob": -2.361872434616089, "rank": 2, "decoded_token": " on"}, "106534": {"logprob": -2.424372434616089, "rank": 3, "decoded_token": " calmly"}, "12276": {"logprob": -2.611872434616089, "rank": 4, "decoded_token": " alert"}, "6482": {"logprob": -5.174372673034668, "rank": 5, "decoded_token": " patient"}}, {"3923": {"logprob": -8.451581379631534e-05, "rank": 1, "decoded_token": "ively"}, "1556": {"logprob": -9.50008487701416, "rank": 2, "decoded_token": "ive"}, "6655": {"logprob": -11.87508487701416, "rank": 3, "decoded_token": "atively"}, "3929": {"logprob": -14.00008487701416, "rank": 4, "decoded_token": "ently"}, "47885": {"logprob": -14.75008487701416, "rank": 5, "decoded_token": "edly"}}, {"1408": {"logprob": -0.058125678449869156, "rank": 1, "decoded_token": " on"}, "3675": {"logprob": -3.1831257343292236, "rank": 2, "decoded_token": " against"}, "1294": {"logprob": -4.9331254959106445, "rank": 3, "decoded_token": " in"}, "7283": {"logprob": -5.8081254959106445, "rank": 4, "decoded_token": " looking"}, "1044": {"logprob": -5.9331254959106445, "rank": 5, "decoded_token": ","}}, {"1261": {"logprob": -0.21029606461524963, "rank": 1, "decoded_token": " a"}, "32656": {"logprob": -1.7102960348129272, "rank": 2, "decoded_token": " wooden"}, "17253": {"logprob": -5.710296154022217, "rank": 3, "decoded_token": " weather"}, "44130": {"logprob": -6.085296154022217, "rank": 4, "decoded_token": " rust"}, "12603": {"logprob": -6.960296154022217, "rank": 5, "decoded_token": " wood"}}, {"32656": {"logprob": -0.08548421412706375, "rank": 1, "decoded_token": " wooden"}, "44130": {"logprob": -2.710484266281128, "rank": 2, "decoded_token": " rust"}, "17253": {"logprob": -4.710484027862549, "rank": 3, "decoded_token": " weather"}, "12603": {"logprob": -5.960484027862549, "rank": 4, "decoded_token": " wood"}, "3403": {"logprob": -5.960484027862549, "rank": 5, "decoded_token": " text"}}, {"4691": {"logprob": -0.7172377109527588, "rank": 1, "decoded_token": " surface"}, "11237": {"logprob": -0.8422377109527588, "rank": 2, "decoded_token": " floor"}, "7042": {"logprob": -2.842237710952759, "rank": 3, "decoded_token": " background"}, "28984": {"logprob": -4.21723747253418, "rank": 4, "decoded_token": " deck"}, "92504": {"logprob": -6.21723747253418, "rank": 5, "decoded_token": " backdrop"}}, {"1626": {"logprob": -0.12971943616867065, "rank": 1, "decoded_token": ".\n"}, "1044": {"logprob": -2.3797194957733154, "rank": 2, "decoded_token": ","}, "1046": {"logprob": -4.129719257354736, "rank": 3, "decoded_token": "."}, "1338": {"logprob": -5.129719257354736, "rank": 4, "decoded_token": ".\n\n"}, "7283": {"logprob": -5.504719257354736, "rank": 5, "decoded_token": " looking"}}, {"1050": {"logprob": -0.00015698630886618048, "rank": 1, "decoded_token": "2"}, "1256": {"logprob": -9.125157356262207, "rank": 2, "decoded_token": " "}, "1032": {"logprob": -10.875157356262207, "rank": 3, "decoded_token": " "}, "1293": {"logprob": -11.750157356262207, "rank": 4, "decoded_token": " "}, "1051": {"logprob": -12.125157356262207, "rank": 5, "decoded_token": "3"}}, {"1046": {"logprob": -6.6756979322235566e-06, "rank": 1, "decoded_token": "."}, "3590": {"logprob": -13.062506675720215, "rank": 2, "decoded_token": ".A"}, "1626": {"logprob": -13.187506675720215, "rank": 3, "decoded_token": ".\n"}, "1338": {"logprob": -14.750006675720215, "rank": 4, "decoded_token": ".\n\n"}, "1058": {"logprob": -14.937506675720215, "rank": 5, "decoded_token": ":"}}, {"1349": {"logprob": -0.5863217115402222, "rank": 1, "decoded_token": " A"}, "11826": {"logprob": -1.4613217115402222, "rank": 2, "decoded_token": " Maj"}, "37159": {"logprob": -2.2113218307495117, "rank": 3, "decoded_token": " Snow"}, "113465": {"logprob": -3.8988218307495117, "rank": 4, "decoded_token": " Rug"}, "1531": {"logprob": -3.9613218307495117, "rank": 5, "decoded_token": " The"}}, {"15375": {"logprob": -0.639299213886261, "rank": 1, "decoded_token": " vast"}, "37849": {"logprob": -2.014299154281616, "rank": 2, "decoded_token": " breat"}, "61082": {"logprob": -2.389299154281616, "rank": 3, "decoded_token": " panor"}, "10726": {"logprob": -3.139299154281616, "rank": 4, "decoded_token": " scen"}, "2169": {"logprob": -3.201799154281616, "rank": 5, "decoded_token": " ser"}}, {"24361": {"logprob": -0.702845573425293, "rank": 1, "decoded_token": " mountain"}, "127945": {"logprob": -1.952845573425293, "rank": 2, "decoded_token": " mountainous"}, "1044": {"logprob": -2.077845573425293, "rank": 3, "decoded_token": ","}, "4521": {"logprob": -2.327845573425293, "rank": 4, "decoded_token": " range"}, "28035": {"logprob": -2.452845573425293, "rank": 5, "decoded_token": " landscape"}}, {"4521": {"logprob": -0.07058162242174149, "rank": 1, "decoded_token": " range"}, "28035": {"logprob": -2.6955816745758057, "rank": 2, "decoded_token": " landscape"}, "37691": {"logprob": -8.320581436157227, "rank": 3, "decoded_token": " valley"}, "12248": {"logprob": -9.445581436157227, "rank": 4, "decoded_token": " peak"}, "13327": {"logprob": -9.695581436157227, "rank": 5, "decoded_token": " scene"}}, {"94973": {"logprob": -1.1164050102233887, "rank": 1, "decoded_token": " stretches"}, "1454": {"logprob": -1.1789050102233887, "rank": 2, "decoded_token": " with"}, "2425": {"logprob": -1.8664050102233887, "rank": 3, "decoded_token": " under"}, "1395": {"logprob": -2.5539050102233887, "rank": 4, "decoded_token": " is"}, "13875": {"logprob": -2.9914050102233887, "rank": 5, "decoded_token": " covered"}}, {"5669": {"logprob": -0.3286789357662201, "rank": 1, "decoded_token": " across"}, "1848": {"logprob": -2.078678846359253, "rank": 2, "decoded_token": " out"}, "2425": {"logprob": -2.328678846359253, "rank": 3, "decoded_token": " under"}, "2203": {"logprob": -3.328678846359253, "rank": 4, "decoded_token": " into"}, "8994": {"logprob": -4.766179084777832, "rank": 5, "decoded_token": " towards"}}, {"1278": {"logprob": -0.039004355669021606, "rank": 1, "decoded_token": " the"}, "1261": {"logprob": -3.289004325866699, "rank": 2, "decoded_token": " a"}, "1420": {"logprob": -7.414004325866699, "rank": 3, "decoded_token": " an"}, "2425": {"logprob": -9.0390043258667, "rank": 4, "decoded_token": " under"}, "1454": {"logprob": -9.2265043258667, "rank": 5, "decoded_token": " with"}}, {"48932": {"logprob": -0.2659883201122284, "rank": 1, "decoded_token": " horizon"}, "21283": {"logprob": -2.140988349914551, "rank": 2, "decoded_token": " sky"}, "3937": {"logprob": -3.015988349914551, "rank": 3, "decoded_token": " image"}, "28035": {"logprob": -3.515988349914551, "rank": 4, "decoded_token": " landscape"}, "3044": {"logprob": -4.265988349914551, "rank": 5, "decoded_token": " sk"}}, {"2425": {"logprob": -0.5356141328811646, "rank": 1, "decoded_token": " under"}, "1044": {"logprob": -1.5356141328811646, "rank": 2, "decoded_token": ","}, "1454": {"logprob": -1.7856141328811646, "rank": 3, "decoded_token": " with"}, "25136": {"logprob": -3.785614013671875, "rank": 4, "decoded_token": " beneath"}, "1408": {"logprob": -5.785614013671875, "rank": 5, "decoded_token": " on"}}, {"1261": {"logprob": -0.006081883795559406, "rank": 1, "decoded_token": " a"}, "1420": {"logprob": -5.506082057952881, "rank": 2, "decoded_token": " an"}, "16152": {"logprob": -7.631082057952881, "rank": 3, "decoded_token": " cloud"}, "6133": {"logprob": -7.881082057952881, "rank": 4, "decoded_token": " clear"}, "2136": {"logprob": -8.006081581115723, "rank": 5, "decoded_token": " over"}}, {"16152": {"logprob": -0.6749536991119385, "rank": 1, "decoded_token": " cloud"}, "6133": {"logprob": -1.4249536991119385, "rank": 2, "decoded_token": " clear"}, "18416": {"logprob": -2.8624536991119385, "rank": 3, "decoded_token": " haz"}, "27254": {"logprob": -2.9874536991119385, "rank": 4, "decoded_token": " partly"}, "4391": {"logprob": -3.2374536991119385, "rank": 5, "decoded_token": " light"}}, {"1121": {"logprob": -0.10860869288444519, "rank": 1, "decoded_token": "y"}, "4527": {"logprob": -2.9836087226867676, "rank": 2, "decoded_token": "less"}, "1286": {"logprob": -3.4836087226867676, "rank": 3, "decoded_token": "ed"}, "77187": {"logprob": -4.608608722686768, "rank": 4, "decoded_token": "-filled"}, "114525": {"logprob": -4.858608722686768, "rank": 5, "decoded_token": "-covered"}}, {"21283": {"logprob": -0.002785732736811042, "rank": 1, "decoded_token": " sky"}, "10991": {"logprob": -6.252785682678223, "rank": 2, "decoded_token": " blue"}, "1044": {"logprob": -7.627785682678223, "rank": 3, "decoded_token": ","}, "26549": {"logprob": -8.627785682678223, "rank": 4, "decoded_token": " gray"}, "34052": {"logprob": -9.377785682678223, "rank": 5, "decoded_token": " grey"}}, {"1046": {"logprob": -0.047878943383693695, "rank": 1, "decoded_token": "."}, "1044": {"logprob": -3.1728789806365967, "rank": 2, "decoded_token": ","}, "1454": {"logprob": -5.547878742218018, "rank": 3, "decoded_token": " with"}, "1338": {"logprob": -7.172878742218018, "rank": 4, "decoded_token": ".\n\n"}, "1294": {"logprob": -9.172879219055176, "rank": 5, "decoded_token": " in"}}, {"2": {"logprob": -1.3351351299206726e-05, "rank": 1, "decoded_token": ""}, "1032": {"logprob": -11.25001335144043, "rank": 2, "decoded_token": " "}, "1256": {"logprob": -16.00001335144043, "rank": 3, "decoded_token": " "}, "1319": {"logprob": -17.25001335144043, "rank": 4, "decoded_token": " ("}, "1766": {"logprob": -18.50001335144043, "rank": 5, "decoded_token": " ["}}]], [[1049, 1046, 1349, 7244, 10575, 53048, 41132, 3923, 1408, 1261, 32656, 11237, 1626, 1050, 1046, 1349, 15375, 24361, 4521, 94973, 5669, 1278, 48932, 2425, 1261, 16152, 1121, 21283, 1626, 1051, 1046, 8342, 71284, 7377, 1394, 22140, 1294, 1278, 27208, 1513, 97558, 1626, 1052, 1046, 1349, 53301, 59396, 3549, 13335, 2645, 1261, 1295, 3506, 11223, 12097, 1046, 2], "1. A black dog sits attentively on a wooden floor.\n2. A vast mountain range stretches across the horizon under a cloudy sky.\n3. Surfers wait for waves in the ocean at sunset.\n4. A winding gravel path leads through a lush green park.", [{"1049": {"logprob": -0.05001257359981537, "rank": 1, "decoded_token": "1"}, "1045": {"logprob": -3.1750125885009766, "rank": 2, "decoded_token": "-"}, "69957": {"logprob": -5.925012588500977, "rank": 3, "decoded_token": "Sure"}, "11745": {"logprob": -6.425012588500977, "rank": 4, "decoded_token": "Here"}, "1065": {"logprob": -6.425012588500977, "rank": 5, "decoded_token": "A"}}, {"1046": {"logprob": -8.702239938429557e-06, "rank": 1, "decoded_token": "."}, "1058": {"logprob": -12.000008583068848, "rank": 2, "decoded_token": ":"}, "3590": {"logprob": -13.375008583068848, "rank": 3, "decoded_token": ".A"}, "1041": {"logprob": -14.750008583068848, "rank": 4, "decoded_token": ")"}, "1065": {"logprob": -15.687508583068848, "rank": 5, "decoded_token": "A"}}, {"1349": {"logprob": -0.14196155965328217, "rank": 1, "decoded_token": " A"}, "1429": {"logprob": -2.2669615745544434, "rank": 2, "decoded_token": " \""}, "1531": {"logprob": -4.516961574554443, "rank": 3, "decoded_token": " The"}, "11967": {"logprob": -4.516961574554443, "rank": 4, "decoded_token": " Image"}, "1603": {"logprob": -5.391961574554443, "rank": 5, "decoded_token": " **"}}, {"7244": {"logprob": -0.14889711141586304, "rank": 1, "decoded_token": " black"}, "68076": {"logprob": -3.398897171020508, "rank": 2, "decoded_token": " cute"}, "6231": {"logprob": -3.961397171020508, "rank": 3, "decoded_token": " close"}, "38462": {"logprob": -4.273897171020508, "rank": 4, "decoded_token": " curious"}, "4329": {"logprob": -4.398897171020508, "rank": 5, "decoded_token": " large"}}, {"10575": {"logprob": -0.12091328203678131, "rank": 1, "decoded_token": " dog"}, "116572": {"logprob": -2.37091326713562, "rank": 2, "decoded_token": " puppy"}, "119075": {"logprob": -3.99591326713562, "rank": 3, "decoded_token": " Labrador"}, "15812": {"logprob": -7.683413505554199, "rank": 4, "decoded_token": " Lab"}, "8636": {"logprob": -7.808413505554199, "rank": 5, "decoded_token": " lab"}}, {"53048": {"logprob": -0.8691943287849426, "rank": 1, "decoded_token": " sits"}, "1454": {"logprob": -1.1191942691802979, "rank": 2, "decoded_token": " with"}, "1395": {"logprob": -2.431694269180298, "rank": 3, "decoded_token": " is"}, "18970": {"logprob": -2.744194269180298, "rank": 4, "decoded_token": " sitting"}, "22524": {"logprob": -3.681694269180298, "rank": 5, "decoded_token": " lies"}}, {"41132": {"logprob": -0.5939557552337646, "rank": 1, "decoded_token": " attent"}, "106534": {"logprob": -1.2814557552337646, "rank": 2, "decoded_token": " calmly"}, "12276": {"logprob": -2.8439557552337646, "rank": 3, "decoded_token": " alert"}, "1408": {"logprob": -2.8439557552337646, "rank": 4, "decoded_token": " on"}, "6482": {"logprob": -4.968955993652344, "rank": 5, "decoded_token": " patient"}}, {"3923": {"logprob": -0.00010084597306558862, "rank": 1, "decoded_token": "ively"}, "1556": {"logprob": -9.500101089477539, "rank": 2, "decoded_token": "ive"}, "6655": {"logprob": -10.875101089477539, "rank": 3, "decoded_token": "atively"}, "3929": {"logprob": -13.000101089477539, "rank": 4, "decoded_token": "ently"}, "47885": {"logprob": -13.750101089477539, "rank": 5, "decoded_token": "edly"}}, {"1408": {"logprob": -0.056158196181058884, "rank": 1, "decoded_token": " on"}, "3675": {"logprob": -3.6811583042144775, "rank": 2, "decoded_token": " against"}, "1454": {"logprob": -4.306158065795898, "rank": 3, "decoded_token": " with"}, "1294": {"logprob": -5.181158065795898, "rank": 4, "decoded_token": " in"}, "7283": {"logprob": -5.431158065795898, "rank": 5, "decoded_token": " looking"}}, {"1261": {"logprob": -0.33056098222732544, "rank": 1, "decoded_token": " a"}, "32656": {"logprob": -1.3305609226226807, "rank": 2, "decoded_token": " wooden"}, "17253": {"logprob": -4.70556116104126, "rank": 3, "decoded_token": " weather"}, "44130": {"logprob": -5.83056116104126, "rank": 4, "decoded_token": " rust"}, "12603": {"logprob": -6.58056116104126, "rank": 5, "decoded_token": " wood"}}, {"32656": {"logprob": -0.07081110030412674, "rank": 1, "decoded_token": " wooden"}, "44130": {"logprob": -2.9458110332489014, "rank": 2, "decoded_token": " rust"}, "17253": {"logprob": -4.6958112716674805, "rank": 3, "decoded_token": " weather"}, "12603": {"logprob": -5.8208112716674805, "rank": 4, "decoded_token": " wood"}, "3403": {"logprob": -6.0708112716674805, "rank": 5, "decoded_token": " text"}}, {"11237": {"logprob": -0.6428436636924744, "rank": 1, "decoded_token": " floor"}, "4691": {"logprob": -1.0178437232971191, "rank": 2, "decoded_token": " surface"}, "7042": {"logprob": -2.642843723297119, "rank": 3, "decoded_token": " background"}, "28984": {"logprob": -3.517843723297119, "rank": 4, "decoded_token": " deck"}, "92504": {"logprob": -6.017843723297119, "rank": 5, "decoded_token": " backdrop"}}, {"1626": {"logprob": -0.7337945103645325, "rank": 1, "decoded_token": ".\n"}, "1044": {"logprob": -0.8587945103645325, "rank": 2, "decoded_token": ","}, "1454": {"logprob": -3.3587944507598877, "rank": 3, "decoded_token": " with"}, "7283": {"logprob": -3.6087944507598877, "rank": 4, "decoded_token": " looking"}, "1321": {"logprob": -4.108794689178467, "rank": 5, "decoded_token": " and"}}, {"1050": {"logprob": -1.0132738680113107e-05, "rank": 1, "decoded_token": "2"}, "1051": {"logprob": -11.75001049041748, "rank": 2, "decoded_token": "3"}, "1256": {"logprob": -14.00001049041748, "rank": 3, "decoded_token": " "}, "1049": {"logprob": -14.62501049041748, "rank": 4, "decoded_token": "1"}, "1032": {"logprob": -14.62501049041748, "rank": 5, "decoded_token": " "}}, {"1046": {"logprob": -2.861018856492592e-06, "rank": 1, "decoded_token": "."}, "3590": {"logprob": -13.43750286102295, "rank": 2, "decoded_token": ".A"}, "4700": {"logprob": -15.37500286102295, "rank": 3, "decoded_token": ".M"}, "1626": {"logprob": -15.37500286102295, "rank": 4, "decoded_token": ".\n"}, "3051": {"logprob": -15.87500286102295, "rank": 5, "decoded_token": ".S"}}, {"1349": {"logprob": -0.6794427633285522, "rank": 1, "decoded_token": " A"}, "11826": {"logprob": -1.9294427633285522, "rank": 2, "decoded_token": " Maj"}, "37159": {"logprob": -2.116942882537842, "rank": 3, "decoded_token": " Snow"}, "27260": {"logprob": -2.616942882537842, "rank": 4, "decoded_token": " Mountain"}, "113465": {"logprob": -2.866942882537842, "rank": 5, "decoded_token": " Rug"}}, {"15375": {"logprob": -0.9194075465202332, "rank": 1, "decoded_token": " vast"}, "10726": {"logprob": -2.294407606124878, "rank": 2, "decoded_token": " scen"}, "4521": {"logprob": -2.356907606124878, "rank": 3, "decoded_token": " range"}, "122203": {"logprob": -2.419407606124878, "rank": 4, "decoded_token": " rugged"}, "61082": {"logprob": -2.856907606124878, "rank": 5, "decoded_token": " panor"}}, {"24361": {"logprob": -0.5804797410964966, "rank": 1, "decoded_token": " mountain"}, "127945": {"logprob": -1.8304797410964966, "rank": 2, "decoded_token": " mountainous"}, "28035": {"logprob": -2.455479621887207, "rank": 3, "decoded_token": " landscape"}, "4521": {"logprob": -2.455479621887207, "rank": 4, "decoded_token": " range"}, "1044": {"logprob": -2.705479621887207, "rank": 5, "decoded_token": ","}}, {"4521": {"logprob": -0.0493546724319458, "rank": 1, "decoded_token": " range"}, "28035": {"logprob": -3.0493545532226562, "rank": 2, "decoded_token": " landscape"}, "37691": {"logprob": -8.424354553222656, "rank": 3, "decoded_token": " valley"}, "13327": {"logprob": -9.049354553222656, "rank": 4, "decoded_token": " scene"}, "3719": {"logprob": -9.799354553222656, "rank": 5, "decoded_token": " view"}}, {"94973": {"logprob": -0.6676871180534363, "rank": 1, "decoded_token": " stretches"}, "2425": {"logprob": -1.792687177658081, "rank": 2, "decoded_token": " under"}, "1395": {"logprob": -2.292687177658081, "rank": 3, "decoded_token": " is"}, "1454": {"logprob": -2.730187177658081, "rank": 4, "decoded_token": " with"}, "7038": {"logprob": -3.292687177658081, "rank": 5, "decoded_token": " extends"}}, {"5669": {"logprob": -0.4542117118835449, "rank": 1, "decoded_token": " across"}, "2425": {"logprob": -1.454211711883545, "rank": 2, "decoded_token": " under"}, "1848": {"logprob": -2.454211711883545, "rank": 3, "decoded_token": " out"}, "2203": {"logprob": -4.204211711883545, "rank": 4, "decoded_token": " into"}, "25136": {"logprob": -4.641711711883545, "rank": 5, "decoded_token": " beneath"}}, {"1278": {"logprob": -0.23009441792964935, "rank": 1, "decoded_token": " the"}, "1261": {"logprob": -1.6050944328308105, "rank": 2, "decoded_token": " a"}, "1420": {"logprob": -5.6050944328308105, "rank": 3, "decoded_token": " an"}, "2425": {"logprob": -7.2300944328308105, "rank": 4, "decoded_token": " under"}, "1454": {"logprob": -10.167593955993652, "rank": 5, "decoded_token": " with"}}, {"48932": {"logprob": -0.3072167932987213, "rank": 1, "decoded_token": " horizon"}, "21283": {"logprob": -1.932216763496399, "rank": 2, "decoded_token": " sky"}, "3937": {"logprob": -3.1822168827056885, "rank": 3, "decoded_token": " image"}, "28035": {"logprob": -3.6822168827056885, "rank": 4, "decoded_token": " landscape"}, "3044": {"logprob": -3.6822168827056885, "rank": 5, "decoded_token": " sk"}}, {"2425": {"logprob": -0.2914469838142395, "rank": 1, "decoded_token": " under"}, "1044": {"logprob": -2.4164469242095947, "rank": 2, "decoded_token": ","}, "1454": {"logprob": -2.5414469242095947, "rank": 3, "decoded_token": " with"}, "1626": {"logprob": -3.7914469242095947, "rank": 4, "decoded_token": ".\n"}, "1408": {"logprob": -3.7914469242095947, "rank": 5, "decoded_token": " on"}}, {"1261": {"logprob": -0.0460360012948513, "rank": 1, "decoded_token": " a"}, "1420": {"logprob": -3.9210360050201416, "rank": 2, "decoded_token": " an"}, "16152": {"logprob": -4.1085357666015625, "rank": 3, "decoded_token": " cloud"}, "2136": {"logprob": -6.1710357666015625, "rank": 4, "decoded_token": " over"}, "6133": {"logprob": -6.4210357666015625, "rank": 5, "decoded_token": " clear"}}, {"16152": {"logprob": -0.20367540419101715, "rank": 1, "decoded_token": " cloud"}, "6133": {"logprob": -2.8286755084991455, "rank": 2, "decoded_token": " clear"}, "27254": {"logprob": -3.5161755084991455, "rank": 3, "decoded_token": " partly"}, "18416": {"logprob": -3.8286755084991455, "rank": 4, "decoded_token": " haz"}, "4391": {"logprob": -4.328675270080566, "rank": 5, "decoded_token": " light"}}, {"1121": {"logprob": -0.05241352692246437, "rank": 1, "decoded_token": "y"}, "1286": {"logprob": -3.8024134635925293, "rank": 2, "decoded_token": "ed"}, "77187": {"logprob": -4.552413463592529, "rank": 3, "decoded_token": "-filled"}, "4527": {"logprob": -4.802413463592529, "rank": 4, "decoded_token": "less"}, "114525": {"logprob": -4.927413463592529, "rank": 5, "decoded_token": "-covered"}}, {"21283": {"logprob": -0.0003716255014296621, "rank": 1, "decoded_token": " sky"}, "10991": {"logprob": -8.750371932983398, "rank": 2, "decoded_token": " blue"}, "1044": {"logprob": -9.375371932983398, "rank": 3, "decoded_token": ","}, "26549": {"logprob": -10.375371932983398, "rank": 4, "decoded_token": " gray"}, "34052": {"logprob": -11.250371932983398, "rank": 5, "decoded_token": " grey"}}, {"1626": {"logprob": -0.00012730741582345217, "rank": 1, "decoded_token": ".\n"}, "1044": {"logprob": -9.500126838684082, "rank": 2, "decoded_token": ","}, "1046": {"logprob": -10.500126838684082, "rank": 3, "decoded_token": "."}, "1454": {"logprob": -10.875126838684082, "rank": 4, "decoded_token": " with"}, "1294": {"logprob": -13.250126838684082, "rank": 5, "decoded_token": " in"}}, {"1051": {"logprob": -3.2186455882765586e-06, "rank": 1, "decoded_token": "3"}, "1052": {"logprob": -12.75000286102295, "rank": 2, "decoded_token": "4"}, "1050": {"logprob": -15.00000286102295, "rank": 3, "decoded_token": "2"}, "1049": {"logprob": -16.937503814697266, "rank": 4, "decoded_token": "1"}, "1032": {"logprob": -17.875003814697266, "rank": 5, "decoded_token": " "}}, {"1046": {"logprob": -1.6689286894688848e-06, "rank": 1, "decoded_token": "."}, "3590": {"logprob": -14.687501907348633, "rank": 2, "decoded_token": ".A"}, "5226": {"logprob": -15.687501907348633, "rank": 3, "decoded_token": ".D"}, "6847": {"logprob": -15.812501907348633, "rank": 4, "decoded_token": ".T"}, "48426": {"logprob": -16.812501907348633, "rank": 5, "decoded_token": ".The"}}, {"8342": {"logprob": -0.5730464458465576, "rank": 1, "decoded_token": " Sur"}, "1349": {"logprob": -1.6980464458465576, "rank": 2, "decoded_token": " A"}, "22468": {"logprob": -2.5730464458465576, "rank": 3, "decoded_token": " Several"}, "1488": {"logprob": -2.6980464458465576, "rank": 4, "decoded_token": " W"}, "15035": {"logprob": -3.1980464458465576, "rank": 5, "decoded_token": " People"}}, {"71284": {"logprob": -0.0033258858602494, "rank": 1, "decoded_token": "fers"}, "1102": {"logprob": -5.878325939178467, "rank": 2, "decoded_token": "f"}, "1726": {"logprob": -7.628325939178467, "rank": 3, "decoded_token": "fer"}, "61888": {"logprob": -12.253325462341309, "rank": 4, "decoded_token": "fline"}, "2119": {"logprob": -13.003325462341309, "rank": 5, "decoded_token": "fter"}}, {"7377": {"logprob": -1.4996429681777954, "rank": 1, "decoded_token": " wait"}, "1584": {"logprob": -1.7496429681777954, "rank": 2, "decoded_token": " are"}, "88014": {"logprob": -1.9371429681777954, "rank": 3, "decoded_token": " paddle"}, "1294": {"logprob": -1.9371429681777954, "rank": 4, "decoded_token": " in"}, "24434": {"logprob": -2.187142848968506, "rank": 5, "decoded_token": " ride"}}, {"1394": {"logprob": -0.6126739382743835, "rank": 1, "decoded_token": " for"}, "1294": {"logprob": -0.9876739382743835, "rank": 2, "decoded_token": " in"}, "1408": {"logprob": -2.7376739978790283, "rank": 3, "decoded_token": " on"}, "6482": {"logprob": -4.425173759460449, "rank": 4, "decoded_token": " patient"}, "1321": {"logprob": -5.612673759460449, "rank": 5, "decoded_token": " and"}}, {"22140": {"logprob": -0.00729279313236475, "rank": 1, "decoded_token": " waves"}, "1278": {"logprob": -5.632292747497559, "rank": 2, "decoded_token": " the"}, "1261": {"logprob": -5.757292747497559, "rank": 3, "decoded_token": " a"}, "39460": {"logprob": -8.257292747497559, "rank": 4, "decoded_token": " incoming"}, "1321": {"logprob": -9.757292747497559, "rank": 5, "decoded_token": " and"}}, {"1294": {"logprob": -0.3071398138999939, "rank": 1, "decoded_token": " in"}, "1408": {"logprob": -2.1821398735046387, "rank": 2, "decoded_token": " on"}, "1513": {"logprob": -2.4321398735046387, "rank": 3, "decoded_token": " at"}, "3016": {"logprob": -3.6821398735046387, "rank": 4, "decoded_token": " while"}, "1435": {"logprob": -3.8071398735046387, "rank": 5, "decoded_token": " as"}}, {"1278": {"logprob": -0.004646694287657738, "rank": 1, "decoded_token": " the"}, "1261": {"logprob": -6.1921467781066895, "rank": 2, "decoded_token": " a"}, "1420": {"logprob": -6.9421467781066895, "rank": 3, "decoded_token": " an"}, "40466": {"logprob": -7.2546467781066895, "rank": 4, "decoded_token": " shallow"}, "26517": {"logprob": -7.8796467781066895, "rank": 5, "decoded_token": " calm"}}, {"27208": {"logprob": -0.0658877044916153, "rank": 1, "decoded_token": " ocean"}, "7786": {"logprob": -3.440887689590454, "rank": 2, "decoded_token": " distance"}, "5124": {"logprob": -5.253387928009033, "rank": 3, "decoded_token": " early"}, "26517": {"logprob": -5.315887928009033, "rank": 4, "decoded_token": " calm"}, "11196": {"logprob": -5.378387928009033, "rank": 5, "decoded_token": " sea"}}, {"1513": {"logprob": -1.1504861116409302, "rank": 1, "decoded_token": " at"}, "1435": {"logprob": -1.2754861116409302, "rank": 2, "decoded_token": " as"}, "3184": {"logprob": -1.4004861116409302, "rank": 3, "decoded_token": " during"}, "3016": {"logprob": -2.9004859924316406, "rank": 4, "decoded_token": " while"}, "6117": {"logprob": -3.1504859924316406, "rank": 5, "decoded_token": " near"}}, {"97558": {"logprob": -0.12151996046304703, "rank": 1, "decoded_token": " sunset"}, "11729": {"logprob": -2.8715200424194336, "rank": 2, "decoded_token": " sun"}, "1266": {"logprob": -3.4965200424194336, "rank": 3, "decoded_token": " d"}, "54507": {"logprob": -3.9965200424194336, "rank": 4, "decoded_token": " dawn"}, "1261": {"logprob": -5.121520042419434, "rank": 5, "decoded_token": " a"}}, {"1626": {"logprob": -0.3073118329048157, "rank": 1, "decoded_token": ".\n"}, "1044": {"logprob": -2.182311773300171, "rank": 2, "decoded_token": ","}, "3016": {"logprob": -2.557311773300171, "rank": 3, "decoded_token": " while"}, "1454": {"logprob": -3.432311773300171, "rank": 4, "decoded_token": " with"}, "6117": {"logprob": -4.05731201171875, "rank": 5, "decoded_token": " near"}}, {"1052": {"logprob": -3.3378546504536644e-06, "rank": 1, "decoded_token": "4"}, "1051": {"logprob": -13.25000286102295, "rank": 2, "decoded_token": "3"}, "1049": {"logprob": -13.93750286102295, "rank": 3, "decoded_token": "1"}, "1053": {"logprob": -14.43750286102295, "rank": 4, "decoded_token": "5"}, "1032": {"logprob": -16.687503814697266, "rank": 5, "decoded_token": " "}}, {"1046": {"logprob": -1.6689286894688848e-06, "rank": 1, "decoded_token": "."}, "3590": {"logprob": -13.500001907348633, "rank": 2, "decoded_token": ".A"}, "6847": {"logprob": -16.437501907348633, "rank": 3, "decoded_token": ".T"}, "1044": {"logprob": -17.312501907348633, "rank": 4, "decoded_token": ","}, "1349": {"logprob": -17.375001907348633, "rank": 5, "decoded_token": " A"}}, {"1349": {"logprob": -0.004292916506528854, "rank": 1, "decoded_token": " A"}, "2048": {"logprob": -5.629292964935303, "rank": 2, "decoded_token": " An"}, "10638": {"logprob": -7.879292964935303, "rank": 3, "decoded_token": " Two"}, "111463": {"logprob": -10.004292488098145, "rank": 4, "decoded_token": " Trees"}, "1531": {"logprob": -10.879292488098145, "rank": 5, "decoded_token": " The"}}, {"53301": {"logprob": -1.5473321676254272, "rank": 1, "decoded_token": " winding"}, "15192": {"logprob": -1.7348321676254272, "rank": 2, "decoded_token": " narrow"}, "47945": {"logprob": -2.109832286834717, "rank": 3, "decoded_token": " dirt"}, "2169": {"logprob": -2.609832286834717, "rank": 4, "decoded_token": " ser"}, "59396": {"logprob": -2.672332286834717, "rank": 5, "decoded_token": " gravel"}}, {"59396": {"logprob": -0.8954829573631287, "rank": 1, "decoded_token": " gravel"}, "3549": {"logprob": -1.1454830169677734, "rank": 2, "decoded_token": " path"}, "47945": {"logprob": -1.6454830169677734, "rank": 3, "decoded_token": " dirt"}, "14801": {"logprob": -3.2704830169677734, "rank": 4, "decoded_token": " pathway"}, "15551": {"logprob": -4.270483016967773, "rank": 5, "decoded_token": " stone"}}, {"3549": {"logprob": -0.02117946185171604, "rank": 1, "decoded_token": " path"}, "14801": {"logprob": -3.896179437637329, "rank": 2, "decoded_token": " pathway"}, "33659": {"logprob": -8.14617919921875, "rank": 3, "decoded_token": " trail"}, "9480": {"logprob": -9.64617919921875, "rank": 4, "decoded_token": " road"}, "7368": {"logprob": -9.64617919921875, "rank": 5, "decoded_token": "path"}}, {"13335": {"logprob": -0.18962937593460083, "rank": 1, "decoded_token": " leads"}, "39985": {"logprob": -2.752129316329956, "rank": 2, "decoded_token": " cuts"}, "1639": {"logprob": -3.877129316329956, "rank": 3, "decoded_token": " me"}, "11500": {"logprob": -3.939629316329956, "rank": 4, "decoded_token": " runs"}, "2645": {"logprob": -4.189629554748535, "rank": 5, "decoded_token": " through"}}, {"2645": {"logprob": -0.05349981039762497, "rank": 1, "decoded_token": " through"}, "8994": {"logprob": -4.053499698638916, "rank": 2, "decoded_token": " towards"}, "2396": {"logprob": -4.303499698638916, "rank": 3, "decoded_token": " between"}, "2203": {"logprob": -4.678499698638916, "rank": 4, "decoded_token": " into"}, "1317": {"logprob": -5.678499698638916, "rank": 5, "decoded_token": " to"}}, {"1261": {"logprob": -0.017386287450790405, "rank": 1, "decoded_token": " a"}, "11223": {"logprob": -4.892386436462402, "rank": 2, "decoded_token": " green"}, "1295": {"logprob": -5.017386436462402, "rank": 3, "decoded_token": " l"}, "23170": {"logprob": -6.642386436462402, "rank": 4, "decoded_token": " grass"}, "1420": {"logprob": -7.267386436462402, "rank": 5, "decoded_token": " an"}}, {"1295": {"logprob": -0.9453322887420654, "rank": 1, "decoded_token": " l"}, "11223": {"logprob": -1.3203322887420654, "rank": 2, "decoded_token": " green"}, "23170": {"logprob": -1.9453322887420654, "rank": 3, "decoded_token": " grass"}, "12097": {"logprob": -2.4453322887420654, "rank": 4, "decoded_token": " park"}, "26428": {"logprob": -3.3203322887420654, "rank": 5, "decoded_token": " garden"}}, {"3506": {"logprob": -6.556489552167477e-06, "rank": 1, "decoded_token": "ush"}, "1374": {"logprob": -12.000006675720215, "rank": 2, "decoded_token": "us"}, "90716": {"logprob": -15.625006675720215, "rank": 3, "decoded_token": "USH"}, "16938": {"logprob": -15.875006675720215, "rank": 4, "decoded_token": "usher"}, "13326": {"logprob": -17.1875057220459, "rank": 5, "decoded_token": "inden"}}, {"11223": {"logprob": -0.3668670654296875, "rank": 1, "decoded_token": " green"}, "1044": {"logprob": -1.3668670654296875, "rank": 2, "decoded_token": ","}, "26428": {"logprob": -3.4918670654296875, "rank": 3, "decoded_token": " garden"}, "12097": {"logprob": -4.1168670654296875, "rank": 4, "decoded_token": " park"}, "23170": {"logprob": -5.8668670654296875, "rank": 5, "decoded_token": " grass"}}, {"12097": {"logprob": -0.5530153512954712, "rank": 1, "decoded_token": " park"}, "3727": {"logprob": -2.0530152320861816, "rank": 2, "decoded_token": " field"}, "28035": {"logprob": -2.1780152320861816, "rank": 3, "decoded_token": " landscape"}, "26428": {"logprob": -2.3030152320861816, "rank": 4, "decoded_token": " garden"}, "4457": {"logprob": -2.8030152320861816, "rank": 5, "decoded_token": " area"}}, {"1046": {"logprob": -0.7924000024795532, "rank": 1, "decoded_token": "."}, "1454": {"logprob": -1.2924000024795532, "rank": 2, "decoded_token": " with"}, "8994": {"logprob": -2.7923998832702637, "rank": 3, "decoded_token": " towards"}, "54410": {"logprob": -3.5423998832702637, "rank": 4, "decoded_token": " lined"}, "2425": {"logprob": -3.5423998832702637, "rank": 5, "decoded_token": " under"}}, {"2": {"logprob": -1.9073468138230965e-06, "rank": 1, "decoded_token": ""}, "1032": {"logprob": -13.250001907348633, "rank": 2, "decoded_token": " "}, "1256": {"logprob": -16.250001907348633, "rank": 3, "decoded_token": " "}, "1293": {"logprob": -19.000001907348633, "rank": 4, "decoded_token": " "}, "1319": {"logprob": -20.000001907348633, "rank": 5, "decoded_token": " ("}}]]] \ No newline at end of file diff --git a/tests/models/fixtures/pixtral_chat_engine.pickle b/tests/models/fixtures/pixtral_chat_engine.pickle deleted file mode 100644 index 19dbeaecc8dfffcddda1d66f83f24a1ed1ec16fc..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 20858 zcmb_kdw5mVmA~)ghDQk%jC4e= zV(sS)ZAH-;AMM;u?WnC)tE1nHC}%_cs?}OS>(J_m5d~3H6jAX-X05%>-e<3HL%*4C z{Q8IW{Ib^XxA$86ti9Jhw-~uK`|W`G=dumc;eXD-n3Jk@-feK+k2vQ{bk3dOjBj;{ zQT)F*UE?es<D_T@!NEbB_OtrYFq4w26k3YqQ>k?n^6 z+?UO^r@KYIH@7_1mTuVqPSz+t>qcDB)wQZ6pT4{|-QAYnu&&~qY)4NnyKKWP8)jU) z0YK|2y1f5#ru*c(RsFzoT_l(4Ub*4y;JOj*={B(3zO;~CneN`OE)0w|^e%HI)^=9V zpLbTBDMS0tOTwRZ4m&&4IecB@{LV}sLE|DeXw+x-3Ic@z6r2l)*0=u<4A__kgNXp9 z;*4~zkjiwUtL)*NlaCKwnBs32xIXiCMq z4Hnk@;GJ!f{E`Y43TX1ZXKmPLg*_fUR8c9~(>0RUqv(A*6x>YRj1t*JMIn(wHS40b z!Exb6&o!zEi@F!eVV6Z z`1SwtV5Fj_x2FdOueE6#2i9cwCdeLz)H$hTxm0^LhXeD}-`W%>&DKa%5)ptz(1|QN zChgN8($SW}bW}RVw_4WRcW=HN+~T;Pj!L~IqQwX7aB@FCYkrY{s3a=-GKEe=^<-^x z0|tdMVxns`%m>2qEXRX)-<2R^sG(AHWwR?i;1^76WDyEv?J(}k(LZuZ9piaQlvUC71fG%f5`@p+J3wAg;a=C zp^)xYI5v8+Zi&XB4l{a61JCR?zEw()MbzMB_D$DVPTltn1SM2>uP#(^%3KR4$K3ZP zu!M7qD!?hc;t<6bWgkWH|=G(k8Dsa!7&@gNoBHHxb^dxULkjIo6(qN0$#QcVMwG2s<$tE`O0qb!r|+llcd zUk0QGl?FJ1yJp(Jy#|On;Oz(DtP=M+`>i1(TVtL*6f#UKgF2)mm+hqhoPW7xQr7JW zC&@R-fg)lhIzE@Z#OnoIi113RN?W3waGeIBfS^D#H?yp7y7saJnL+k3G?P>Bh&{~EF+mMO`XPaq zdecAHa2Zrj<04RaIt!y-P5MC8X^6MBPD!r#ds1 zI``GKevN-IwH{uv&J@>K9mleE?CDM_;xsMxD%&5=Z4nEsVit!QBM;s$BAv?RaV3Xx zSFJ~xE7D!7%a*xS^U%r_seD1Kf=I-QLqO$<<#HjHF0^%~!A4syo6qA5U36x1neRcl z3^AL|LD38^<=*zyBENFADVSNw>b}~!te=Jyrmu4}4EY#)EZu{zjFt3U#oa6wcQ!(X zR}D`RLQUHOL>y#UP;=r9f)F@uUVYuVu$I4IYai1noVCbesMm1bu;gXQvT6Uo0ZHbN z@{{1MkrA3ELrHqBD4&dkVl0lP7c~MYhguZp8O%tt&=BCvgb3631l0bp{Xvk z}`FYl00j&8LfGTd-=zt1)#8jC_hat&P&1HO%Z~iYKa=E z6bvr=2B2`VP)keeQuU&m6SjD} zJJ)XPCxWQM)F`lQSySR3Sg^zMi87*UnZZEFv#_fa`#coI6m4g}T4Gs}!%G^2?4zP9 zLpvx|pvn;jh4k$LRY_>Dw<_*|3-lU~62h<^A~jxMeQGc$z$>se*wv1E$UrET^4VTA z)cW}iJV>I=uO(>2P{A;fzOn~)fq7VN(e1J0Y_)NNMIA71+2y($bi3tR%ZC zqHs+|_ZxR}_v5ZxQz95DiS#aCuJ&brU^8*F0il)cPDvVN|JLe|u zuIXA0Qz-$gtU`Qq(X!}Pz3`GCW)w>JLc$^W#y2dKJf;+rgi`$RwS}OzdK)yYs6jLr z>D3e^4Xpf|9?&S15EHAl$0c{$#+nTbB~%D_zol22wB~{S=1&Lu17>mc1O12ghS42( zp$H&9tFe|*QhD_h?vNe~Rha6DCr#9VP()Zz((rVsG?=pKz;-9@_UU>Th|^&2PT0Ol zhIUO#`Xao2Q|mD$KL6pgMxh26yj0gJIte&1+aBbIu|{)|>ViWkoY?QMgs6M?M>r<1 z9+HxtR3SrqCzaTX?UN(7e%h82fmJ}Mf(A|2u{T*$V3;5j22|L~!vW@Yn$#Gj0-R!T zgtq?9_9%}vwopYBm=S8kMqBHwqPB)2BI$MoQUcfzyT9J@Zyq%iptq)c*BS4tD*5!; zu5%>I>Cqd%pCBM=fIw0ETwY{1oT4$vpKY;O=`S0MF_6$Hy_$%xna{rarRNaYm)(?p zdD+q=al)~;H+_5yR=<5oKjDzOy_0nK?r`i94O9~#%I^**7;`uQW{<08fbU8l-Fdbk z$CM0o;ENN_{YD=gl_VSri5kCL#;EN#de#n;!MCps21qjWCY*0yeH6*?wy%^4Y*)Jb zt7c19A-n+>?gB?sWZ=!c(|oX z`QQmtAJI0Dw4E@D^LU$(JLQ(!Jcl?7!FfoT?-CmmbtR&N!%w?Jhwn(IN)18QFM){> zjPZ2iY)C|*euEB zq=sO?d$Ytxc~((F>`~^sr&;|P1AII8du;i(X6J9GO2Zb0PHk)E+4~Xqv((KN;fQcj zD{u;?t+yBF;6uMJk{zTR6^pyGeSm5`X(JEU>Ny&Ryd&Z}z3Rx4Z*8iLF&t=p#-PT! z`c%uS`&MsLiM(QHMAhaRx&7r93nc@S7WET~vU6DuPSg-~#|JhnY+xt>HVvt6I$GLo zS0OTPV5k6Ul{RQKfw9%>v_Jk6?UN+r{R6xiOD(VN?lq4AH@&Zno1+07yx)J!2G1Qh z?NtwpLWUbGZPWNMlMMXJsv;T2h{rs>kGH|hbJP)|qX-nTc9q3Xh7AmbREaJ)#Lc&* zdd!ZVS#=+-{kb5oI3q(^jo$P_4~9y(Rp%5od*be<1X;vssen1|>Po93;?JxUYJ5xM zXkL2KDRqojeY`EFsa>N`PE5_7mZFY=EU^30e?M0w*C?jSKc~TA;v2gC28x3G4c%(C zXZqTpD8$s+4jSXdX{0+X$uCOL#p!V;THIvC2O5JyV7s2Z0#4BI%vt5hRJ*WuUyhGX z`a+;MRmV)|cz%U#H1M-V4~HruBHOF-go%(L0~#daPL=^x21fId=`N^-S7VdyadeH{ z%BqAmWnU`SZng=m@8$mdU|T+w|mzUP!x=e|)S5NasfuA6kvhms(uK)36-%Er+L zlZ2u4FjK|mT{h-545z`&yPAV3uBFu77X7%N3}P@<>hnFEgr>z7PzJ@OA^~X<;KUC> ztXV884TvT|_1zK1vv;_Gp^8Cy_SP`lZEcNxULS^oMlh#zuz_*kXp#w{M;UGghLyVR z!S(&Zs8Guf8#wnX^Be`3fpg|8vq_Cy^dT6*8Atw=Vp+NyyCEo))%>zy%9S1v=fSJ_ z6?!=$vhK64q&LXQ%>}I$;+Q2A?+Pb`b z_eCsKM0k=Qj}%0h%6jmbUw7B-x-&?FU*50#=^yS&_!44pcI^56@Jw4pZ1=Z3zo;T2 zI#b_64OgtNK@|ps8YMdPi9d<1drvB;J#){ryLi0B|#um5rqyyuR5(_nIez6PNP7E z%9KYtY%7x&tP*4e6+}fxE~R`$zHg%%H)|BdP)AO^BvlHq0@b~D?QN3m0d-ZNaz628 zBa`$TSRbCbvVLxbpyxR3e&+H5b>eR{3R&xI0N-0TUbMDEmXNczw?@XSFSTj!z{iqc zYOHl%iWm6D6oVcJS!03E_Xu+Wx)@f$LiPPL7Ef~%8iJHVVmi*~ZhQYE=M^;s8HdCr z02!hmb$ADA*G%Vs-+_X^n4o%Aw~5w5J+!K}O>}Zm=c)^xkx}efEn4Bc6K$~IIoRx8jg1^sKTCBb<}xygtL97vtYK@*4i_ut#!mMmTI)N{Qvc^s*Y!+n!(-qbfKl{cTE{-#r*n zZ^DW(xX;z_$cdf>j49pca^9F_C+0W-N9BVcuCPaSc;0T>dHNTU_~^Xl z+r8<5prGf-7q)wYT4l{rKUWz*8&BQ}n^n{h;M(~GnBX*LC-rliCf`wRnfY9Qz!Yg^ z8)c!!CnQ0DDP)mql<_U#sM_x)38IG9>uwF|l=&bW+L$n%GTw9>vtHYxWkxmK)By8F z*>9_>i{um68NX3>SpRF-oR0wKMFy;Zy8p#*RaHNvQ78dl!rrYa-gm1Vd!^?UMc92e zhxgNJ|5*c32E1UsYgI}KZ5YIZ>so-fAoOH&|G-8R`ET+`)MPLF;rmsX-Wd`3%nV8ut;uEg4GuoeCF>SC~lX8 zL>WV$`N!nSe!@{)nKOT05&JOIqEEqjUto0qiC_r>8D&)BYC3dZV0`x85`ooZn6v!B zeSr}PLZOQCeSxV`V^BfGkbQyXfX1POq51+diyEb(Y3RPd{qm`#FF=n7?+c8`hlNj{ z^#xea$=Z{yEfNql_(8`joVx$iD4K+F-Gea(LzBR-=42gCZGrcYt3-d+X+7`^hoNunztC z)wU1MsxJ}@Rj?)kzq1yd@{b7uqL3qZ2ZkFgX&JIRFlu3mY@(XK)*X2J?@#m4D!9sY z2i&~|0&2=^4yty%;F;iQ!O#DP8x+#F&;JMAgsaJR^asoWrF*$5rdX3;)2`D$2OO?s zKz+sDU?2c%Hb1_fRtzQy0@PQ0N)WzDI5wam$lhxtuv=DgRzgF7x~dxaKE!_xBaZtJ zMV9TrHa$J?wKQbQwFvqHK2v?E=~BBjC0nG1D5fnAzN2^-!H3`=dtHRR?96Ajv4Y_d zYFVsb!@Q*1_2eO+FAPhWCc&W{0~X~WngoG;*LZd)IJ_o-a}_&qh!_;|jtj@TEC<{U zYbV4E76s^bf*Nw(DR_V7?>&AP+J2{ipY4Y(-fyY!&-U3Fw+y|aQ7FSd<8Hm=HGZNH zHs(-C1uUE$9qR01Q0F0L45A9Z^PqI_cEQeDS`z^s=P)?>CLs>a;C8{+XZ=DD7FV9b z+6DZiFnsrmMFOL!K_`V=zYj%OzWzZCh^VoMnV_59%#?fI!U#C^($+S`$yKK7JQyk& zL~c;K;4ce3DC)4=1-uc^I8Nh81h#S1;mi-%{lcc(H4LW);wF>Ez822RVeNulm;EN` ziy71|;D@MRJ@98qP!trSOP*kNP$M9;q|392Vsy>Q3z6t{jX@!PA#ydl5M5;~p^i$p zJ_VQO{d0ESE_ki>k3de(irp^YDn9$B?okg?!Mh6Ke?3W(IaFb~3cOuVvu$EOfk=jK z7x44oQR_8|WccU7Dpz>B;Q9PU&m`u2`F4T(`2&4L@{1bE9|*7fiD!pqfxhlxLm?;b z)ffsVb1Kd!Ysv+G^h{AWd6%{odEPE~={|ewGjzMaJ@oFS!WUppB6+(Y-rM4VP=fVs z8PBs328JpIC_T@f+QQGg%i{FOJ- zppd`!%H!K2A40scynOS zI{;%J)G&&EuLCfsoiNUzkh_1Zujn(s?z5xf9yc0xARQsBEWTa&e9%5ULoe zQ_v_i3S_7x;hlmh<~c=C`A)$x#uu{3euCdAVCm$F2Oqbnm%T026sRc~;))V0b>&p^ z94X@sf4(42mHO|Avi^#NZ=;)w8ikatMgcGJH$6M>wIDVaFJ6g1(e6FT8AV+<<46aO z$CobKO30RgMgV2KmEfCzIUfcR*HHlLndl~io95Y z?lzumaguQZLGl&iymYoleUpm(Rl~$WyDLw7)AOuy5t|Qa@>?uI8rF972h@_SB4jxt zHnA?yQ3inm!e9n5@*~TYoPO;io)s0?%eyjEUY^N1zs*67Qe(fo;5UA^*_A8t7A5BjUb!XDsPUh+xA96>yjP3rKs5v2H>xW7ue#8tzvlJf39^VY5Tjld3SdR2tp7ou z=a{AdI#~!}#f5CB&>&Rc6&F+Pu-r*uz7RnHacGLQWmkFM?ehEXO$K?G0zS4c-8;*p zh6>7$ZLl|L(bs3B@7NP5!v;g+PnYSQVJ8bC3j|pGt zAyG#RKj)w;^PJ(BGc*jPK%SnMynKM= z%dVOI zfh73@@~Z5bWP6ct9`8Lz*1;2S#`x*%oQHHd1dRWWh1tx2wxO08D$ID*HK#|HR4Op2 z>f#&Tx#wvJa8(n<>L@uksh?9w;F+rm2?ld7zcud`Sd)RF5-1_FDYcS5 zV^_e^N2#E~i$8iY8tg0H?lTL2S0t|}0`5uR?I0ZGxp&!KW06af1f#itOB@Q-VOP|$ z+MrNESgcYnNLQlk%Ir?MABy+NUx zFx=bLLfD(lUH|-JKo3Izg?KgDc2<6jUjMxGCBWz3cySAPRWpxuwHk&3LP7=7br;)1 zJlS!rN_eS}zHBB;>h`TzsQ`x{gx*sIqu)yN3JTmU?C!(%AcSg(|Ffd94Lu>*2 zgdu!WRgxUi(BtJb9Eej{*;!>!NFNp^xG=8E>o6hERw;`M-1-U-@U7yOAmnE?B$>4K=fGjAu)IIa1LFXSA-4w9`}0ougmUNIwcI zzE7SH?)L9V;I3oG3JmrpbwAmms}SI?-lTlIhV8q^GlSHv@q(L>cJ$Qr-S6S}=(DMm z4F-ku{TNe&&kVY6JnSTV{tB^fDQfUf=lNO{yYOX6D3oA- z>*4$ZE1HKJAPT^% TokensPrompt: LIMIT_MM_PER_PROMPT = dict(image=4) MAX_MODEL_LEN = [8192, 65536] -FIXTURE_LOGPROBS_CHAT = "tests/models/fixtures/pixtral_chat.pickle" -FIXTURE_LOGPROBS_ENGINE = "tests/models/fixtures/pixtral_chat_engine.pickle" +FIXTURE_LOGPROBS_CHAT = "tests/models/fixtures/pixtral_chat.json" +FIXTURE_LOGPROBS_ENGINE = "tests/models/fixtures/pixtral_chat_engine.json" +OutputsLogprobs = List[Tuple[List[int], str, Optional[SampleLogprobs]]] -def load_logprobs(filename: str) -> Any: - with open(filename, 'rb') as f: - return pickle.load(f) + +# For the test author to store golden output in JSON +def _dump_outputs_w_logprobs(outputs: OutputsLogprobs, filename: str) -> None: + json_data = [(tokens, text, + [{k: asdict(v) + for k, v in token_logprobs.items()} + for token_logprobs in (logprobs or [])]) + for tokens, text, logprobs in outputs] + + with open(filename, "w") as f: + json.dump(json_data, f) + + +def load_outputs_w_logprobs(filename: str) -> OutputsLogprobs: + with open(filename, "rb") as f: + json_data = json.load(f) + + return [(tokens, text, + [{int(k): Logprob(**v) + for k, v in token_logprobs.items()} + for token_logprobs in logprobs]) + for tokens, text, logprobs in json_data] @pytest.mark.skip( @@ -103,7 +125,7 @@ def test_chat( model: str, dtype: str, ) -> None: - EXPECTED_CHAT_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_CHAT) + EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_CHAT) with vllm_runner( model, dtype=dtype, @@ -120,10 +142,10 @@ def test_chat( outputs.extend(output) logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs) - check_logprobs_close(outputs_0_lst=logprobs, - outputs_1_lst=EXPECTED_CHAT_LOGPROBS, - name_0="output", - name_1="h100_ref") + check_logprobs_close(outputs_0_lst=EXPECTED_CHAT_LOGPROBS, + outputs_1_lst=logprobs, + name_0="h100_ref", + name_1="output") @pytest.mark.skip( @@ -133,7 +155,7 @@ def test_chat( @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) def test_model_engine(vllm_runner, model: str, dtype: str) -> None: - EXPECTED_ENGINE_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_ENGINE) + EXPECTED_ENGINE_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_ENGINE) args = EngineArgs( model=model, tokenizer_mode="mistral", @@ -162,7 +184,7 @@ def test_model_engine(vllm_runner, model: str, dtype: str) -> None: break logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs) - check_logprobs_close(outputs_0_lst=logprobs, - outputs_1_lst=EXPECTED_ENGINE_LOGPROBS, - name_0="output", - name_1="h100_ref") + check_logprobs_close(outputs_0_lst=EXPECTED_ENGINE_LOGPROBS, + outputs_1_lst=logprobs, + name_0="h100_ref", + name_1="output") From 68210201099e6ce1c0a1453633c77fc0185af488 Mon Sep 17 00:00:00 2001 From: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com> Date: Thu, 12 Sep 2024 23:48:59 -0400 Subject: [PATCH 27/98] [Bugfix] Fix async log stats (#8417) --- tests/basic_correctness/test_preemption.py | 1 + vllm/engine/llm_engine.py | 20 ++++++++++++++++---- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/tests/basic_correctness/test_preemption.py b/tests/basic_correctness/test_preemption.py index 7e77037da07d3..50d399bef1878 100644 --- a/tests/basic_correctness/test_preemption.py +++ b/tests/basic_correctness/test_preemption.py @@ -64,6 +64,7 @@ def test_chunked_prefill_recompute( enable_chunked_prefill=enable_chunked_prefill, max_num_seqs=max_num_seqs, worker_use_ray=worker_use_ray, + disable_log_stats=False, ) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index c4d97c8f6d857..0573921a40fc3 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1056,7 +1056,8 @@ def _process_model_outputs(self, # LLMEngine/AsyncLLMEngine directly if is_async: # Log stats. - self.do_log_stats(scheduler_outputs, outputs, finished_before) + self.do_log_stats(scheduler_outputs, outputs, finished_before, + skip) # Tracing self.do_tracing(scheduler_outputs) @@ -1363,18 +1364,20 @@ def remove_logger(self, logger_name: str) -> None: def do_log_stats(self, scheduler_outputs: Optional[SchedulerOutputs] = None, model_output: Optional[List[SamplerOutput]] = None, - finished_before: Optional[List[int]] = None) -> None: + finished_before: Optional[List[int]] = None, + skip: Optional[List[int]] = None) -> None: """Forced log when no requests active.""" if self.log_stats: stats = self._get_stats(scheduler_outputs, model_output, - finished_before) + finished_before, skip) for logger in self.stat_loggers.values(): logger.log(stats) def _get_stats(self, scheduler_outputs: Optional[SchedulerOutputs], model_output: Optional[List[SamplerOutput]] = None, - finished_before: Optional[List[int]] = None) -> Stats: + finished_before: Optional[List[int]] = None, + skip: Optional[List[int]] = None) -> Stats: """Get Stats to be Logged to Prometheus. Args: @@ -1382,6 +1385,10 @@ def _get_stats(self, the scheduled batch, model_output: Optional, used to emit speculative decoding metrics which are created by the workers. + finished_before: Optional, indices of sequences that were finished + before. These sequences will be ignored. + skip: Optional, indices of sequences that were preempted. These + sequences will be ignored. """ now = time.time() @@ -1456,6 +1463,11 @@ def _get_stats(self, actual_num_batched_tokens -= 1 continue + # Currently, skip == preempted sequences, so we need to skip + # their log stats + if skip and idx in skip: + continue + group_was_prefill = idx < scheduler_outputs.num_prefill_groups seq_group = scheduled_seq_group.seq_group From ba7752795567e3f2bfcc1dca340d107e003d32ad Mon Sep 17 00:00:00 2001 From: William Lin Date: Thu, 12 Sep 2024 21:30:00 -0700 Subject: [PATCH 28/98] [bugfix] torch profiler bug for single gpu with GPUExecutor (#8354) --- examples/offline_inference_with_profiler.py | 2 +- vllm/engine/async_llm_engine.py | 15 +++++++++++++-- vllm/engine/llm_engine.py | 15 +++++++++++++-- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/examples/offline_inference_with_profiler.py b/examples/offline_inference_with_profiler.py index 906c9502800d8..1f00d26808771 100644 --- a/examples/offline_inference_with_profiler.py +++ b/examples/offline_inference_with_profiler.py @@ -16,7 +16,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="facebook/opt-125m") +llm = LLM(model="facebook/opt-125m", tensor_parallel_size=1) llm.start_profile() diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 01114e9843ce4..8a07ce1c965e1 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -13,6 +13,7 @@ from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState from vllm.engine.metrics_types import StatLoggerBase from vllm.executor.executor_base import ExecutorAsyncBase +from vllm.executor.gpu_executor import GPUExecutorAsync from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import PromptInputs from vllm.logger import init_logger @@ -1019,7 +1020,17 @@ def remove_logger(self, logger_name: str) -> None: self.engine.remove_logger(logger_name=logger_name) async def start_profile(self) -> None: - self.engine.model_executor._run_workers("start_profile") + # using type instead of isinstance to check to avoid capturing + # inherited classes + if type(self.engine.model_executor) == GPUExecutorAsync: + self.engine.model_executor.start_profile() + else: + self.engine.model_executor._run_workers("start_profile") async def stop_profile(self) -> None: - self.engine.model_executor._run_workers("stop_profile") + # using type instead of isinstance to check to avoid capturing + # inherited classes + if type(self.engine.model_executor) == GPUExecutorAsync: + self.engine.model_executor.stop_profile() + else: + self.engine.model_executor._run_workers("stop_profile") diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 0573921a40fc3..dfdbc22ef00e1 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -26,6 +26,7 @@ from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.executor.executor_base import ExecutorBase +from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, InputRegistry, LLMInputs, PromptInputs) @@ -1597,10 +1598,20 @@ def check_health(self) -> None: self.model_executor.check_health() def start_profile(self) -> None: - self.model_executor.start_profile() + # using type instead of isinstance to check to avoid capturing + # inherited classes (MultiprocessingGPUExecutor) + if type(self.model_executor) == GPUExecutor: + self.model_executor.start_profile() + else: + self.model_executor._run_workers("start_profile") def stop_profile(self) -> None: - self.model_executor.stop_profile() + # using type instead of isinstance to check to avoid capturing + # inherited classes (MultiprocessingGPUExecutor) + if type(self.model_executor) == GPUExecutor: + self.model_executor.stop_profile() + else: + self.model_executor._run_workers("stop_profile") def is_tracing_enabled(self) -> bool: return self.tracer is not None From acda0b35d00e733982aa4c1198f2bd381d368cb5 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Thu, 12 Sep 2024 21:39:49 -0700 Subject: [PATCH 29/98] bump version to v0.6.1.post1 (#8440) --- vllm/version.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/version.py b/vllm/version.py index 1f492a24bf078..975e695ac7821 100644 --- a/vllm/version.py +++ b/vllm/version.py @@ -2,6 +2,7 @@ try: import vllm.commit_id + __commit__ = vllm.commit_id.__commit__ except Exception as e: warnings.warn(f"Failed to read commit hash:\n{e}", @@ -9,4 +10,4 @@ stacklevel=2) __commit__ = "COMMIT_HASH_PLACEHOLDER" -__version__ = "0.6.1" +__version__ = "0.6.1.post1" From 9b4a3b235e5bdf0df7901c77a4b01f5358db3638 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 13 Sep 2024 14:35:20 +0800 Subject: [PATCH 30/98] [CI/Build] Enable InternVL2 PP test only on single node (#8437) --- tests/distributed/test_pipeline_parallel.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 9a02f468f0a93..02288dc9dac90 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -32,10 +32,11 @@ (1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), (2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), (2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - # TODO: Enable internVL2 in a separate test if needed - # (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "ray"), - # (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "ray"), - # (1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "ray"), + # NOTE: InternVL2 multi-node tests are flaky, + # use mp backend to skip the multi-node tests + (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "mp"), + (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "mp"), + (1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "mp"), ], ) @fork_new_process_for_each_test From cab69a15e49aa592db7042f0dc675bbe9b684f83 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 12 Sep 2024 23:52:41 -0700 Subject: [PATCH 31/98] [doc] recommend pip instead of conda (#8446) --- docs/source/getting_started/installation.rst | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/docs/source/getting_started/installation.rst b/docs/source/getting_started/installation.rst index f0e54c29fcad7..50a761b49490c 100644 --- a/docs/source/getting_started/installation.rst +++ b/docs/source/getting_started/installation.rst @@ -26,6 +26,10 @@ You can install vLLM using pip: $ # Install vLLM with CUDA 12.1. $ pip install vllm +.. note:: + + Although we recommend using ``conda`` to create and manage Python environments, it is highly recommended to use ``pip`` to install vLLM. This is because ``pip`` can install ``torch`` with separate library packages like ``NCCL``, while ``conda`` installs ``torch`` with statically linked ``NCCL``. This can cause issues when vLLM tries to use ``NCCL``. See `this issue `_ for more details. + .. note:: As of now, vLLM's binaries are compiled with CUDA 12.1 and public PyTorch release versions by default. @@ -34,7 +38,7 @@ You can install vLLM using pip: .. code-block:: console $ # Install vLLM with CUDA 11.8. - $ export VLLM_VERSION=0.4.0 + $ export VLLM_VERSION=0.6.1.post1 $ export PYTHON_VERSION=310 $ pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux1_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118 @@ -48,7 +52,7 @@ You can install vLLM using pip: .. code-block:: console - $ export VLLM_VERSION=0.5.4 # vLLM's main branch version is currently set to latest released tag + $ export VLLM_VERSION=0.6.1.post1 # vLLM's main branch version is currently set to latest released tag $ pip install https://vllm-wheels.s3.us-west-2.amazonaws.com/nightly/vllm-${VLLM_VERSION}-cp38-abi3-manylinux1_x86_64.whl $ # You can also access a specific commit $ # export VLLM_COMMIT=... @@ -80,11 +84,11 @@ You can also build and install vLLM from source: .. tip:: - Building from source requires quite a lot compilation. If you are building from source for multiple times, it is beneficial to cache the compilation results. For example, you can install `ccache `_ via either `conda install ccache` or `apt install ccache` . As long as `which ccache` command can find the `ccache` binary, it will be used automatically by the build system. After the first build, the subsequent builds will be much faster. + Building from source requires quite a lot compilation. If you are building from source for multiple times, it is beneficial to cache the compilation results. For example, you can install `ccache `_ via either ``conda install ccache`` or ``apt install ccache`` . As long as ``which ccache`` command can find the ``ccache`` binary, it will be used automatically by the build system. After the first build, the subsequent builds will be much faster. .. tip:: To avoid your system being overloaded, you can limit the number of compilation jobs - to be run simultaneously, via the environment variable `MAX_JOBS`. For example: + to be run simultaneously, via the environment variable ``MAX_JOBS``. For example: .. code-block:: console @@ -99,7 +103,7 @@ You can also build and install vLLM from source: $ # Use `--ipc=host` to make sure the shared memory is large enough. $ docker run --gpus all -it --rm --ipc=host nvcr.io/nvidia/pytorch:23.10-py3 - If you don't want to use docker, it is recommended to have a full installation of CUDA Toolkit. You can download and install it from `the official website `_. After installation, set the environment variable `CUDA_HOME` to the installation path of CUDA Toolkit, and make sure that the `nvcc` compiler is in your `PATH`, e.g.: + If you don't want to use docker, it is recommended to have a full installation of CUDA Toolkit. You can download and install it from `the official website `_. After installation, set the environment variable ``CUDA_HOME`` to the installation path of CUDA Toolkit, and make sure that the ``nvcc`` compiler is in your ``PATH``, e.g.: .. code-block:: console From 06311e295666916d3456a357cdd91dd2a03c34e2 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 13 Sep 2024 15:58:28 +0800 Subject: [PATCH 32/98] [Misc] Skip loading extra bias for Qwen2-VL GPTQ-Int8 (#8442) --- vllm/model_executor/models/qwen2_vl.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 3f8c590a39b00..179399a12a3d5 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1055,6 +1055,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if weight_name not in name: continue name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -1078,6 +1081,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.reshape(-1) try: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] except KeyError: print(params_dict.keys()) From a2469127db6144eedb38d0b505287c0044e4ce06 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 13 Sep 2024 02:20:14 -0700 Subject: [PATCH 33/98] [misc][ci] fix quant test (#8449) --- tests/quantization/test_bitsandbytes.py | 32 +++++++++++++++---------- tests/quantization/utils.py | 4 +--- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/tests/quantization/test_bitsandbytes.py b/tests/quantization/test_bitsandbytes.py index 3f0c6cbc051a7..87200b1dcc534 100644 --- a/tests/quantization/test_bitsandbytes.py +++ b/tests/quantization/test_bitsandbytes.py @@ -10,6 +10,8 @@ from tests.quantization.utils import is_quant_method_supported +from ..utils import fork_new_process_for_each_test + models_4bit_to_test = [ ('huggyllama/llama-7b', 'quantize model inflight'), ] @@ -29,6 +31,7 @@ @pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), reason='bitsandbytes is not supported on this GPU type.') @pytest.mark.parametrize("model_name, description", models_4bit_to_test) +@fork_new_process_for_each_test def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, model_name, description) -> None: @@ -41,6 +44,7 @@ def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, reason='bitsandbytes is not supported on this GPU type.') @pytest.mark.parametrize("model_name, description", models_pre_qaunt_4bit_to_test) +@fork_new_process_for_each_test def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, model_name, description) -> None: @@ -52,6 +56,7 @@ def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, reason='bitsandbytes is not supported on this GPU type.') @pytest.mark.parametrize("model_name, description", models_pre_quant_8bit_to_test) +@fork_new_process_for_each_test def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts, model_name, description) -> None: @@ -77,18 +82,8 @@ def validate_generated_texts(hf_runner, model_name, hf_model_kwargs=None): - if hf_model_kwargs is None: - hf_model_kwargs = {} - - # Run with HF runner - with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm: - hf_outputs = llm.generate_greedy(prompts, 8) - hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner") - - # Clean up the GPU memory for the next test - torch.cuda.synchronize() - gc.collect() - torch.cuda.empty_cache() + # NOTE: run vLLM first, as it requires a clean process + # when using distributed inference #Run with vLLM runner with vllm_runner(model_name, @@ -104,6 +99,19 @@ def validate_generated_texts(hf_runner, gc.collect() torch.cuda.empty_cache() + if hf_model_kwargs is None: + hf_model_kwargs = {} + + # Run with HF runner + with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm: + hf_outputs = llm.generate_greedy(prompts, 8) + hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner") + + # Clean up the GPU memory for the next test + torch.cuda.synchronize() + gc.collect() + torch.cuda.empty_cache() + # Compare the generated strings for hf_log, vllm_log in zip(hf_logs, vllm_logs): hf_str = hf_log["generated_text"] diff --git a/tests/quantization/utils.py b/tests/quantization/utils.py index 65bb80ed70c6a..5fad06878f4a3 100644 --- a/tests/quantization/utils.py +++ b/tests/quantization/utils.py @@ -1,12 +1,10 @@ -import torch - from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.platforms import current_platform def is_quant_method_supported(quant_method: str) -> bool: # Currently, all quantization methods require Nvidia or AMD GPUs - if not torch.cuda.is_available(): + if not (current_platform.is_cuda() or current_platform.is_rocm()): return False capability = current_platform.get_device_capability() From ecd7a1d5b69589257d36626195ece6658b61b93c Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 14 Sep 2024 00:02:26 +0800 Subject: [PATCH 34/98] [Installation] Gate FastAPI version for Python 3.8 (#8456) --- requirements-common.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements-common.txt b/requirements-common.txt index 8432be61ed77d..c5f003c3c7ddc 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -7,7 +7,8 @@ py-cpuinfo transformers >= 4.43.2 # Required for Chameleon and Llama 3.1 hotfox. tokenizers >= 0.19.1 # Required for Llama 3. protobuf # Required by LlamaTokenizer. -fastapi >= 0.114.1 +fastapi < 0.113.0; python_version < '3.9' +fastapi >= 0.114.1; python_version >= '3.9' aiohttp openai >= 1.40.0 # Ensure modern openai package (ensure types module present) uvicorn[standard] From 0a4806f0a99880df1f74b10a6dceaf638cd3981c Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 13 Sep 2024 09:32:42 -0700 Subject: [PATCH 35/98] [plugin][torch.compile] allow to add custom compile backend (#8445) --- vllm/plugins/__init__.py | 13 +++++++++++++ vllm/worker/model_runner.py | 4 +++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 765f74fe7356f..7939688ef0da3 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -1,4 +1,5 @@ import logging +from typing import Callable, Optional, Union import vllm.envs as envs @@ -29,3 +30,15 @@ def load_general_plugins(): except Exception: logger.exception("Failed to load general plugin: %s", plugin.name) + + +_torch_compile_backend: Optional[Union[Callable, str]] = None + + +def set_torch_compile_backend(backend: Union[Callable, str]): + global _torch_compile_backend + _torch_compile_backend = backend + + +def get_torch_compile_backend() -> Optional[Union[Callable, str]]: + return _torch_compile_backend diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index acb7bafefc204..bff789c429710 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1064,10 +1064,12 @@ def load_model(self) -> None: "This may lead to less accurate results!") if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo(): + from vllm.plugins import get_torch_compile_backend + backend = get_torch_compile_backend() or "eager" self.model = torch.compile( self.model, fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, - backend="eager") + backend=backend) def save_sharded_state( self, From a84e598e2125960d3b4f716b78863f24ac562947 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 14 Sep 2024 01:20:06 +0800 Subject: [PATCH 36/98] [CI/Build] Reorganize models tests (#7820) --- .buildkite/run-cpu-test.sh | 10 +- .buildkite/test-pipeline.yaml | 70 +++++++---- docs/source/models/supported_models.rst | 2 +- pyproject.toml | 3 +- .../test_basic_correctness.py | 62 ++++++++++ .../basic_correctness/test_chunked_prefill.py | 55 +++++++++ tests/basic_correctness/test_preemption.py | 11 +- tests/conftest.py | 29 ++--- .../test_basic_distributed_correctness.py | 80 ------------ ...t_basic_distributed_correctness_enc_dec.py | 102 ---------------- .../test_chunked_prefill_distributed.py | 75 ------------ .../distributed/test_multimodal_broadcast.py | 58 --------- tests/distributed/test_same_node.py | 14 +-- tests/kernels/utils.py | 4 +- tests/models/decoder_only/__init__.py | 0 .../decoder_only/audio_language/__init__.py | 0 .../audio_language}/test_ultravox.py | 6 +- .../models/decoder_only/language/__init__.py | 0 .../{ => decoder_only/language}/test_aqlm.py | 0 .../language}/test_big_models.py | 2 +- .../language}/test_danube3_4b.py | 2 +- .../{ => decoder_only/language}/test_fp8.py | 2 +- .../{ => decoder_only/language}/test_gguf.py | 2 +- .../language}/test_gptq_marlin.py | 2 +- .../language}/test_gptq_marlin_24.py | 3 +- .../language}/test_granite.py | 2 +- .../{ => decoder_only/language}/test_jamba.py | 3 +- .../language}/test_marlin.py | 2 +- .../language}/test_mistral.py | 2 +- .../language}/test_modelopt.py | 0 .../language}/test_models.py | 2 +- .../language}/test_phimoe.py | 2 +- .../decoder_only/vision_language/__init__.py | 0 .../vision_language}/test_blip2.py | 8 +- .../vision_language/test_broadcast.py | 42 +++++++ .../vision_language}/test_chameleon.py | 8 +- .../vision_language}/test_fuyu.py | 8 +- .../vision_language}/test_intern_vit.py | 4 +- .../vision_language}/test_internvl.py | 10 +- .../vision_language}/test_llava.py | 12 +- .../test_llava_image_embeds.py | 8 +- .../vision_language}/test_llava_next.py | 10 +- .../vision_language}/test_llava_next_video.py | 6 +- .../vision_language}/test_minicpmv.py | 8 +- .../vision_language}/test_paligemma.py | 8 +- .../vision_language}/test_phi3v.py | 8 +- .../vision_language}/test_pixtral.py | 23 ++-- .../vision_language}/test_qwen.py | 8 +- tests/models/embedding/__init__.py | 0 tests/models/embedding/language/__init__.py | 0 .../language}/test_embedding.py | 0 tests/models/encoder_decoder/__init__.py | 0 .../encoder_decoder/language/__init__.py | 0 .../language}/test_bart.py | 115 +++++++++++++----- tests/utils.py | 20 ++- 55 files changed, 415 insertions(+), 498 deletions(-) delete mode 100644 tests/distributed/test_basic_distributed_correctness.py delete mode 100644 tests/distributed/test_basic_distributed_correctness_enc_dec.py delete mode 100644 tests/distributed/test_chunked_prefill_distributed.py delete mode 100644 tests/distributed/test_multimodal_broadcast.py create mode 100644 tests/models/decoder_only/__init__.py create mode 100644 tests/models/decoder_only/audio_language/__init__.py rename tests/models/{ => decoder_only/audio_language}/test_ultravox.py (98%) create mode 100644 tests/models/decoder_only/language/__init__.py rename tests/models/{ => decoder_only/language}/test_aqlm.py (100%) rename tests/models/{ => decoder_only/language}/test_big_models.py (97%) rename tests/models/{ => decoder_only/language}/test_danube3_4b.py (97%) rename tests/models/{ => decoder_only/language}/test_fp8.py (98%) rename tests/models/{ => decoder_only/language}/test_gguf.py (98%) rename tests/models/{ => decoder_only/language}/test_gptq_marlin.py (98%) rename tests/models/{ => decoder_only/language}/test_gptq_marlin_24.py (97%) rename tests/models/{ => decoder_only/language}/test_granite.py (97%) rename tests/models/{ => decoder_only/language}/test_jamba.py (99%) rename tests/models/{ => decoder_only/language}/test_marlin.py (98%) rename tests/models/{ => decoder_only/language}/test_mistral.py (98%) rename tests/models/{ => decoder_only/language}/test_modelopt.py (100%) rename tests/models/{ => decoder_only/language}/test_models.py (97%) rename tests/models/{ => decoder_only/language}/test_phimoe.py (98%) create mode 100644 tests/models/decoder_only/vision_language/__init__.py rename tests/models/{ => decoder_only/vision_language}/test_blip2.py (95%) create mode 100644 tests/models/decoder_only/vision_language/test_broadcast.py rename tests/models/{ => decoder_only/vision_language}/test_chameleon.py (95%) rename tests/models/{ => decoder_only/vision_language}/test_fuyu.py (95%) rename tests/models/{ => decoder_only/vision_language}/test_intern_vit.py (97%) rename tests/models/{ => decoder_only/vision_language}/test_internvl.py (98%) rename tests/models/{ => decoder_only/vision_language}/test_llava.py (96%) rename tests/models/{ => decoder_only/vision_language}/test_llava_image_embeds.py (96%) rename tests/models/{ => decoder_only/vision_language}/test_llava_next.py (97%) rename tests/models/{ => decoder_only/vision_language}/test_llava_next_video.py (98%) rename tests/models/{ => decoder_only/vision_language}/test_minicpmv.py (97%) rename tests/models/{ => decoder_only/vision_language}/test_paligemma.py (96%) rename tests/models/{ => decoder_only/vision_language}/test_phi3v.py (97%) rename tests/models/{ => decoder_only/vision_language}/test_pixtral.py (90%) rename tests/models/{ => decoder_only/vision_language}/test_qwen.py (98%) create mode 100644 tests/models/embedding/__init__.py create mode 100644 tests/models/embedding/language/__init__.py rename tests/models/{ => embedding/language}/test_embedding.py (100%) create mode 100644 tests/models/encoder_decoder/__init__.py create mode 100644 tests/models/encoder_decoder/language/__init__.py rename tests/models/{ => encoder_decoder/language}/test_bart.py (69%) diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index d2ae926daa7c0..f4ead8d277736 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -23,12 +23,10 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py" # Run basic model test docker exec cpu-test bash -c " pip install pytest matplotlib einops transformers_stream_generator - pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py \ - --ignore=tests/models/test_oot_registration.py \ - --ignore=tests/models/test_registry.py \ - --ignore=tests/models/test_fp8.py \ - --ignore=tests/models/test_jamba.py \ - --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported + pytest -v -s tests/models/decoder_only/language \ + --ignore=tests/models/test_fp8.py \ + --ignore=tests/models/decoder_only/language/test_jamba.py \ + --ignore=tests/models/decoder_only/language/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported # Run compressed-tensor test docker exec cpu-test bash -c " diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index d0732ec3fe2fb..9b0cb6663a55b 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -94,7 +94,6 @@ steps: - pytest -v -s entrypoints/test_chat_utils.py - pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests - - label: Distributed Tests (4 GPUs) # 10min working_dir: "/vllm-workspace/tests" num_gpus: 4 @@ -164,15 +163,6 @@ steps: - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - python3 offline_inference_encoder_decoder.py -- label: Models Test # 1hr10min - source_file_dependencies: - - vllm/ - - tests/models - commands: - - pip install -e ./plugins/vllm_add_dummy_model - - pytest -v -s models/test_oot_registration.py # it needs a clean process - - pytest -v -s models -m \"not vlm\" --ignore=models/test_oot_registration.py - - label: torch compile integration test source_file_dependencies: - vllm/ @@ -180,14 +170,6 @@ steps: - pytest -v -s ./compile/test_full_graph.py - pytest -v -s ./compile/test_wrapper.py - -- label: Vision Language Models Test # 42min - #mirror_hardwares: [amd] - source_file_dependencies: - - vllm/ - commands: - - pytest -v -s models -m vlm - - label: Prefix Caching Test # 7min #mirror_hardwares: [amd] source_file_dependencies: @@ -286,6 +268,45 @@ steps: commands: - pytest -v -s tool_use +##### models test ##### + +- label: Basic Models Test # 3min + source_file_dependencies: + - vllm/ + - tests/models + commands: + - pip install -e ./plugins/vllm_add_dummy_model + - pytest -v -s models/test_oot_registration.py # it needs a clean process + - pytest -v -s models/*.py --ignore=models/test_oot_registration.py + +- label: Decoder-only Language Models Test # 1h3min + #mirror_hardwares: [amd] + source_file_dependencies: + - vllm/ + - tests/models/decoder_only/language + commands: + - pytest -v -s models/decoder_only/language + +- label: Decoder-only Multi-Modal Models Test # 56min + #mirror_hardwares: [amd] + source_file_dependencies: + - vllm/ + - tests/models/decoder_only/audio_language + - tests/models/decoder_only/vision_language + commands: + - pytest -v -s models/decoder_only/audio_language + - pytest -v -s models/decoder_only/vision_language + +- label: Other Models Test # 5min + #mirror_hardwares: [amd] + source_file_dependencies: + - vllm/ + - tests/models/embedding/language + - tests/models/encoder_decoder/language + commands: + - pytest -v -s models/embedding/language + - pytest -v -s models/encoder_decoder/language + ##### 1 GPU test ##### ##### multi gpus test ##### @@ -311,11 +332,11 @@ steps: - tests/distributed/ commands: - # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up) - - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py + - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep -q 'Same node test passed' - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py - # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up) - - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py + - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep -q 'Same node test passed' - label: Distributed Tests (2 GPUs) # 28min #mirror_hardwares: [amd] @@ -328,11 +349,10 @@ steps: - vllm/model_executor/models/ - tests/distributed/ commands: - - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py - - TARGET_TEST_SUITE=L4 pytest -v -s distributed/test_basic_distributed_correctness.py - - pytest -v -s distributed/test_basic_distributed_correctness_enc_dec.py - - pytest -v -s distributed/test_chunked_prefill_distributed.py - - pytest -v -s distributed/test_multimodal_broadcast.py + - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed' + - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus + # Avoid importing model tests that cause CUDA reinitialization error + - pytest models/encoder_decoder/language/test_bart.py models/decoder_only/vision_language/test_broadcast.py -v -s -m distributed_2_gpus - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py - pip install -e ./plugins/vllm_add_dummy_model - pytest -v -s distributed/test_distributed_oot.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index faac2b97722b7..6c7f7f7d5d992 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -342,7 +342,7 @@ Note that, as an inference engine, vLLM does not introduce new models. Therefore We have the following levels of testing for models: -1. **Strict Consistency**: We compare the output of the model with the output of the model in the HuggingFace Transformers library under greedy decoding. This is the most stringent test. Please refer to `test_models.py `_ and `test_big_models.py `_ for the models that have passed this test. +1. **Strict Consistency**: We compare the output of the model with the output of the model in the HuggingFace Transformers library under greedy decoding. This is the most stringent test. Please refer to `models tests `_ for the models that have passed this test. 2. **Output Sensibility**: We check if the output of the model is sensible and coherent, by measuring the perplexity of the output and checking for any obvious errors. This is a less stringent test. 3. **Runtime Functionality**: We check if the model can be loaded and run without errors. This is the least stringent test. Please refer to `functionality tests `_ and `examples `_ for the models that have passed this test. 4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category. diff --git a/pyproject.toml b/pyproject.toml index d9e3278db4d19..6b682f5d4dd4d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,5 +85,6 @@ skip_gitignore = true [tool.pytest.ini_options] markers = [ "skip_global_cleanup", - "vlm: run tests for vision language models only", + "core_model: run this model test in each PR instead of just daily", + "distributed_2_gpus: run this test only in distributed tests for 2 GPUs", ] diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index b970cd48f9170..0fe88e792520a 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -15,12 +15,15 @@ from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata from ..models.utils import check_outputs_equal +from ..utils import multi_gpu_test MODELS = [ "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", ] +TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4") + def test_vllm_gc_ed(): """Verify vllm instance is GC'ed when it is deleted""" @@ -70,6 +73,65 @@ def test_models( ) +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize( + "model, distributed_executor_backend, attention_backend, " + "test_suite", [ + ("facebook/opt-125m", "ray", "", "L4"), + ("facebook/opt-125m", "mp", "", "L4"), + ("meta-llama/Llama-2-7b-hf", "ray", "", "L4"), + ("meta-llama/Llama-2-7b-hf", "mp", "", "L4"), + ("facebook/opt-125m", "ray", "", "A100"), + ("facebook/opt-125m", "mp", "", "A100"), + ("facebook/opt-125m", "mp", "FLASHINFER", "A100"), + ("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100"), + ]) +def test_models_distributed( + hf_runner, + vllm_runner, + example_prompts, + model: str, + distributed_executor_backend: str, + attention_backend: str, + test_suite: str, +) -> None: + + if test_suite != TARGET_TEST_SUITE: + pytest.skip(f"Skip test for {test_suite}") + + if model == "meta-llama/Llama-2-7b-hf" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa + # test ray adag + os.environ['VLLM_USE_RAY_SPMD_WORKER'] = "1" + os.environ['VLLM_USE_RAY_COMPILED_DAG'] = "1" + + if attention_backend: + os.environ["VLLM_ATTENTION_BACKEND"] = attention_backend + + dtype = "half" + max_tokens = 5 + + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + with vllm_runner(model, + dtype=dtype, + tensor_parallel_size=2, + distributed_executor_backend=distributed_executor_backend + ) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + + def test_model_with_failure(vllm_runner) -> None: try: with patch("vllm.model_executor.models.opt.OPTForCausalLM.forward", diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index 9c34b2a13fd53..14c5447680729 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -6,11 +6,13 @@ Run `pytest tests/models/test_chunked_prefill.py`. """ +import os from contextlib import nullcontext import pytest from ..models.utils import check_logprobs_close, check_outputs_equal +from ..utils import multi_gpu_test MODELS = [ "facebook/opt-125m", @@ -66,6 +68,59 @@ def test_models( ) +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) +@pytest.mark.parametrize("model", MODELS) +def test_models_distributed( + hf_runner, + vllm_runner, + example_prompts, + model: str, + distributed_executor_backend: str, +) -> None: + if (model == "meta-llama/Llama-2-7b-hf" + and distributed_executor_backend == "ray"): + # test ray adag + os.environ['VLLM_USE_RAY_SPMD_WORKER'] = "1" + os.environ['VLLM_USE_RAY_COMPILED_DAG'] = "1" + + dtype = "half" + max_tokens = 5 + chunked_prefill_token_size = 16 + + # Add a chunked prefill config. + max_num_seqs = min(chunked_prefill_token_size, 256) + assert chunked_prefill_token_size != -1 + enable_chunked_prefill = True + max_num_batched_tokens = chunked_prefill_token_size + + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + + with vllm_runner( + model, + dtype=dtype, + tensor_parallel_size=2, + max_num_seqs=max_num_seqs, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + distributed_executor_backend=distributed_executor_backend, + ) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + + @pytest.mark.parametrize( "kv_cache_dtype,model", [("fp8_e4m3", diff --git a/tests/basic_correctness/test_preemption.py b/tests/basic_correctness/test_preemption.py index 50d399bef1878..00806c3e129b1 100644 --- a/tests/basic_correctness/test_preemption.py +++ b/tests/basic_correctness/test_preemption.py @@ -19,10 +19,13 @@ "facebook/opt-125m", ] -assert ENABLE_ARTIFICIAL_PREEMPT is True, ( - "Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1. " - "`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest " - "tests/basic_correctness/test_preemption.py`") + +@pytest.fixture(scope="module", autouse=True) +def check_settings(): + assert ENABLE_ARTIFICIAL_PREEMPT is True, ( + "Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1. " + "`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest " + "tests/basic_correctness/test_preemption.py`") @pytest.fixture diff --git a/tests/conftest.py b/tests/conftest.py index 620f8b4983517..e4c7b96e82429 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,8 +6,8 @@ import tempfile from collections import UserList from enum import Enum -from typing import (Any, Callable, Dict, List, Optional, Tuple, TypedDict, - TypeVar, Union) +from typing import (Any, Callable, Dict, List, Optional, Tuple, Type, + TypedDict, TypeVar, Union) import numpy as np import pytest @@ -18,6 +18,7 @@ from PIL import Image from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding, BatchFeature) +from transformers.models.auto.auto_factory import _BaseAutoModelClass from vllm import LLM, SamplingParams from vllm.assets.image import ImageAsset @@ -260,7 +261,7 @@ def __init__( *, model_kwargs: Optional[Dict[str, Any]] = None, is_embedding_model: bool = False, - auto_cls=AutoModelForCausalLM, + auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM, postprocess_inputs: Callable[[BatchEncoding], BatchEncoding] = identity, ) -> None: @@ -292,20 +293,14 @@ def __init__( trust_remote_code=True, ) - try: - # don't put this import at the top level - # it will call torch.cuda.device_count() - from transformers import AutoProcessor # noqa: F401 - self.processor = AutoProcessor.from_pretrained( - model_name, - torch_dtype=torch_dtype, - trust_remote_code=True, - ) - except Exception as exc: - logger.warning( - "Unable to auto-load HuggingFace processor for model (%s). " - "Using tokenizer instead. Reason: %s", model_name, exc) - self.processor = self.tokenizer + # don't put this import at the top level + # it will call torch.cuda.device_count() + from transformers import AutoProcessor # noqa: F401 + self.processor = AutoProcessor.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + ) self.postprocess_inputs = postprocess_inputs diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py deleted file mode 100644 index e254686f269b1..0000000000000 --- a/tests/distributed/test_basic_distributed_correctness.py +++ /dev/null @@ -1,80 +0,0 @@ -"""Compare the outputs of HF and distributed vLLM when using greedy sampling. - -Run: -```sh -cd $VLLM_PATH/tests - -pytest distributed/test_basic_distributed_correctness.py -``` -""" -import os - -import pytest - -from vllm.utils import cuda_device_count_stateless - -from ..models.utils import check_outputs_equal -from ..utils import fork_new_process_for_each_test - -TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4") - - -@pytest.mark.skipif(cuda_device_count_stateless() < 2, - reason="Need at least 2 GPUs to run the test.") -@pytest.mark.parametrize( - "model, distributed_executor_backend, attention_backend, " - "test_suite", [ - ("facebook/opt-125m", "ray", "", "L4"), - ("facebook/opt-125m", "mp", "", "L4"), - ("meta-llama/Llama-2-7b-hf", "ray", "", "L4"), - ("meta-llama/Llama-2-7b-hf", "mp", "", "L4"), - ("facebook/opt-125m", "ray", "", "A100"), - ("facebook/opt-125m", "mp", "", "A100"), - ("facebook/opt-125m", "mp", "FLASHINFER", "A100"), - ("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100"), - ]) -@fork_new_process_for_each_test -def test_models( - hf_runner, - vllm_runner, - example_prompts, - model: str, - distributed_executor_backend: str, - attention_backend: str, - test_suite: str, -) -> None: - - if test_suite != TARGET_TEST_SUITE: - pytest.skip(f"Skip test for {test_suite}") - - if model == "meta-llama/Llama-2-7b-hf" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa - # test ray adag - os.environ['VLLM_USE_RAY_SPMD_WORKER'] = "1" - os.environ['VLLM_USE_RAY_COMPILED_DAG'] = "1" - - if attention_backend: - os.environ["VLLM_ATTENTION_BACKEND"] = attention_backend - - dtype = "half" - max_tokens = 5 - - # NOTE: take care of the order. run vLLM first, and then run HF. - # vLLM needs a fresh new process without cuda initialization. - # if we run HF first, the cuda initialization will be done and it - # will hurt multiprocessing backend with fork method (the default method). - with vllm_runner(model, - dtype=dtype, - tensor_parallel_size=2, - distributed_executor_backend=distributed_executor_backend - ) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - - check_outputs_equal( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) diff --git a/tests/distributed/test_basic_distributed_correctness_enc_dec.py b/tests/distributed/test_basic_distributed_correctness_enc_dec.py deleted file mode 100644 index f00d5ef584a2a..0000000000000 --- a/tests/distributed/test_basic_distributed_correctness_enc_dec.py +++ /dev/null @@ -1,102 +0,0 @@ -"""For encoder/decoder models only: -Compare the outputs of HF and distributed vLLM when using greedy sampling. - -Run: -```sh -cd $VLLM_PATH/tests - -pytest distributed/test_basic_distributed_correctness_enc_dec.py -``` -""" - -import pytest -from transformers import AutoModelForSeq2SeqLM - -from vllm.utils import cuda_device_count_stateless - -from ..conftest import DecoderPromptType -from ..models.utils import check_logprobs_close -from ..utils import fork_new_process_for_each_test - - -@pytest.mark.skipif(cuda_device_count_stateless() < 2, - reason="Need at least 2 GPUs to run the test.") -@pytest.mark.parametrize("model, distributed_executor_backend", [ - ("facebook/bart-large-cnn", "ray"), - ("facebook/bart-large-cnn", "mp"), -]) -@fork_new_process_for_each_test -def test_models( - model: str, - distributed_executor_backend: str, - hf_runner, - vllm_runner, - example_encoder_decoder_prompts, -) -> None: - ''' - Test vLLM BART inference on more than one GPU, comparing - outputs against HF as a baseline. - - Fork a new process for each test, to prevent CUDA from - being re-initialized by successive tests within the same - process. - - Arguments: - - * model: the HF ID of the specific BART variant under test - * distributed_executor_backend - * hf_runner: HuggingFace (HF) test model runner - * vllm_runner: vLLM test model runner - * example_encoder_decoder_prompts: test fixture which provides a - dictionary of dummy prompts - ''' - - dtype = "float" - max_tokens = 64 - num_logprobs = 5 - - # Example inputs with non-trivial (i.e. not None/empty) encoder & - # decoder prompts. - test_prompts = example_encoder_decoder_prompts[DecoderPromptType.CUSTOM] - - # NOTE: take care of the order. run vLLM first, and then run HF. - # vLLM needs a fresh new process without cuda initialization. - # if we run HF first, the cuda initialization will be done and it - # will hurt multiprocessing backend with fork method (the default method). - with vllm_runner( - model, - dtype=dtype, - tensor_parallel_size=2, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True, - ) as vllm_model: - vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( - test_prompts, max_tokens, num_logprobs) - - # Configuration settings for HF baseline - hf_kwargs = { - "top_k": None, - "num_beams": 1, - "repetition_penalty": 1.0, - "top_p": 1.0, - "length_penalty": 1.0, - "early_stopping": False, - "no_repeat_ngram_size": None, - "min_length": 0 - } - - with hf_runner(model, dtype=dtype, - auto_cls=AutoModelForSeq2SeqLM) as hf_model: - hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit( - test_prompts, - max_tokens, - num_logprobs, - **hf_kwargs, - )) - - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) diff --git a/tests/distributed/test_chunked_prefill_distributed.py b/tests/distributed/test_chunked_prefill_distributed.py deleted file mode 100644 index 262845f19822f..0000000000000 --- a/tests/distributed/test_chunked_prefill_distributed.py +++ /dev/null @@ -1,75 +0,0 @@ -"""Compare the outputs of HF and distributed vLLM when using greedy sampling. - -Run: -```sh -pytest test_chunked_prefill_distributed.py -``` -""" - -import os - -import pytest - -from vllm.utils import cuda_device_count_stateless - -from ..models.utils import check_outputs_equal -from ..utils import fork_new_process_for_each_test - - -@pytest.mark.skipif(cuda_device_count_stateless() < 2, - reason="Need at least 2 GPUs to run the test.") -@pytest.mark.parametrize("model, distributed_executor_backend", [ - ("facebook/opt-125m", "ray"), - ("meta-llama/Llama-2-7b-hf", "ray"), - ("facebook/opt-125m", "mp"), - ("meta-llama/Llama-2-7b-hf", "mp"), -]) -@fork_new_process_for_each_test -def test_models( - hf_runner, - vllm_runner, - example_prompts, - model: str, - distributed_executor_backend: str, -) -> None: - if model == "meta-llama/Llama-2-7b-hf" and distributed_executor_backend == "ray": # noqa - assert distributed_executor_backend == "ray" - # test ray adag - os.environ['VLLM_USE_RAY_SPMD_WORKER'] = "1" - os.environ['VLLM_USE_RAY_COMPILED_DAG'] = "1" - - dtype = "half" - max_tokens = 5 - chunked_prefill_token_size = 16 - - # Add a chunked prefill config. - max_num_seqs = min(chunked_prefill_token_size, 256) - assert chunked_prefill_token_size != -1 - enable_chunked_prefill = True - max_num_batched_tokens = chunked_prefill_token_size - - # NOTE: take care of the order. run vLLM first, and then run HF. - # vLLM needs a fresh new process without cuda initialization. - # if we run HF first, the cuda initialization will be done and it - # will hurt multiprocessing backend with fork method (the default method). - - with vllm_runner( - model, - dtype=dtype, - tensor_parallel_size=2, - max_num_seqs=max_num_seqs, - enable_chunked_prefill=enable_chunked_prefill, - max_num_batched_tokens=max_num_batched_tokens, - distributed_executor_backend=distributed_executor_backend, - ) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - - check_outputs_equal( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) diff --git a/tests/distributed/test_multimodal_broadcast.py b/tests/distributed/test_multimodal_broadcast.py deleted file mode 100644 index 73ef863c2f193..0000000000000 --- a/tests/distributed/test_multimodal_broadcast.py +++ /dev/null @@ -1,58 +0,0 @@ -"""Compare the outputs of HF and distributed vLLM when using greedy sampling. - -Run: -```sh -pytest -s -v test_multimodal_broadcast.py -``` -""" - -import pytest - -from vllm.utils import cuda_device_count_stateless - -from ..utils import fork_new_process_for_each_test - - -@pytest.mark.skipif(cuda_device_count_stateless() < 2, - reason="Need at least 2 GPUs to run the test.") -@pytest.mark.parametrize("model, distributed_executor_backend", [ - ("llava-hf/llava-1.5-7b-hf", "ray"), - ("llava-hf/llava-v1.6-mistral-7b-hf", "ray"), - ("facebook/chameleon-7b", "ray"), - ("llava-hf/llava-1.5-7b-hf", "mp"), - ("llava-hf/llava-v1.6-mistral-7b-hf", "mp"), - ("facebook/chameleon-7b", "mp"), -]) -@fork_new_process_for_each_test -def test_models(hf_runner, vllm_runner, image_assets, model: str, - distributed_executor_backend: str) -> None: - - dtype = "half" - max_tokens = 5 - num_logprobs = 5 - tensor_parallel_size = 2 - - if model.startswith("llava-hf/llava-1.5"): - from ..models.test_llava import models, run_test - elif model.startswith("llava-hf/llava-v1.6"): - from ..models.test_llava_next import run_test # type: ignore[no-redef] - from ..models.test_llava_next import models - elif model.startswith("facebook/chameleon"): - from ..models.test_chameleon import run_test # type: ignore[no-redef] - from ..models.test_chameleon import models - else: - raise NotImplementedError(f"Unsupported model: {model}") - - run_test( - hf_runner, - vllm_runner, - image_assets, - model=models[0], - # So that LLaVA-NeXT processor may return nested list - size_factors=[0.25, 0.5, 1.0], - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - ) diff --git a/tests/distributed/test_same_node.py b/tests/distributed/test_same_node.py index 07e84d0ad54cd..defc4e23c8ce2 100644 --- a/tests/distributed/test_same_node.py +++ b/tests/distributed/test_same_node.py @@ -1,13 +1,13 @@ import os -import torch +import torch.distributed as dist from vllm.distributed.parallel_state import in_the_same_node_as -torch.distributed.init_process_group(backend="gloo") -test_result = all( - in_the_same_node_as(torch.distributed.group.WORLD, source_rank=0)) +if __name__ == "__main__": + dist.init_process_group(backend="gloo") + test_result = all(in_the_same_node_as(dist.group.WORLD, source_rank=0)) -expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1" -assert test_result == expected, f"Expected {expected}, got {test_result}" -print("Same node test passed!") + expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1" + assert test_result == expected, f"Expected {expected}, got {test_result}" + print("Same node test passed!") diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index dbddd69c07dbc..5746932c30a45 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -10,7 +10,6 @@ import torch from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType -from vllm.attention.backends.xformers import XFormersBackend from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL, make_tensor_with_pad) @@ -521,6 +520,9 @@ def make_backend(backend_name: str) -> AttentionBackend: * Backend instance ''' if backend_name == STR_XFORMERS_ATTN_VAL: + # NOTE: xFormers backend cannot be imported for CPU and AMD GPUs. + from vllm.attention.backends.xformers import XFormersBackend + return XFormersBackend() raise AssertionError( f"Unrecognized backend_name {backend_name} for unit test") diff --git a/tests/models/decoder_only/__init__.py b/tests/models/decoder_only/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/decoder_only/audio_language/__init__.py b/tests/models/decoder_only/audio_language/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py similarity index 98% rename from tests/models/test_ultravox.py rename to tests/models/decoder_only/audio_language/test_ultravox.py index e98db9b65f484..bfffd34d1142c 100644 --- a/tests/models/test_ultravox.py +++ b/tests/models/decoder_only/audio_language/test_ultravox.py @@ -7,10 +7,8 @@ from vllm.sequence import SampleLogprobs from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE -from ..conftest import HfRunner, VllmRunner -from .utils import check_logprobs_close - -pytestmark = pytest.mark.vlm +from ....conftest import HfRunner, VllmRunner +from ...utils import check_logprobs_close MODEL_NAME = "fixie-ai/ultravox-v0_3" diff --git a/tests/models/decoder_only/language/__init__.py b/tests/models/decoder_only/language/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/test_aqlm.py b/tests/models/decoder_only/language/test_aqlm.py similarity index 100% rename from tests/models/test_aqlm.py rename to tests/models/decoder_only/language/test_aqlm.py diff --git a/tests/models/test_big_models.py b/tests/models/decoder_only/language/test_big_models.py similarity index 97% rename from tests/models/test_big_models.py rename to tests/models/decoder_only/language/test_big_models.py index c3e48b56ee58f..c5783fe19dd3f 100644 --- a/tests/models/test_big_models.py +++ b/tests/models/decoder_only/language/test_big_models.py @@ -7,7 +7,7 @@ import pytest import torch -from .utils import check_outputs_equal +from ...utils import check_outputs_equal MODELS = [ "meta-llama/Llama-2-7b-hf", diff --git a/tests/models/test_danube3_4b.py b/tests/models/decoder_only/language/test_danube3_4b.py similarity index 97% rename from tests/models/test_danube3_4b.py rename to tests/models/decoder_only/language/test_danube3_4b.py index bfaa275f73c19..bdd498edc293d 100644 --- a/tests/models/test_danube3_4b.py +++ b/tests/models/decoder_only/language/test_danube3_4b.py @@ -6,7 +6,7 @@ """ import pytest -from .utils import check_outputs_equal +from ...utils import check_outputs_equal MODELS = ["h2oai/h2o-danube3-4b-base"] diff --git a/tests/models/test_fp8.py b/tests/models/decoder_only/language/test_fp8.py similarity index 98% rename from tests/models/test_fp8.py rename to tests/models/decoder_only/language/test_fp8.py index 17acdb52322fd..5a947ce62c785 100644 --- a/tests/models/test_fp8.py +++ b/tests/models/decoder_only/language/test_fp8.py @@ -10,7 +10,7 @@ from tests.kernels.utils import override_backend_env_variable from tests.quantization.utils import is_quant_method_supported -from ..models.utils import check_logprobs_close +from ...utils import check_logprobs_close os.environ["TOKENIZERS_PARALLELISM"] = "true" diff --git a/tests/models/test_gguf.py b/tests/models/decoder_only/language/test_gguf.py similarity index 98% rename from tests/models/test_gguf.py rename to tests/models/decoder_only/language/test_gguf.py index 196cd88e039a1..8fc64a10c84af 100644 --- a/tests/models/test_gguf.py +++ b/tests/models/decoder_only/language/test_gguf.py @@ -11,7 +11,7 @@ from tests.quantization.utils import is_quant_method_supported -from .utils import check_logprobs_close +from ...utils import check_logprobs_close os.environ["TOKENIZERS_PARALLELISM"] = "true" diff --git a/tests/models/test_gptq_marlin.py b/tests/models/decoder_only/language/test_gptq_marlin.py similarity index 98% rename from tests/models/test_gptq_marlin.py rename to tests/models/decoder_only/language/test_gptq_marlin.py index 4abbc41c9c287..2155e83dbe915 100644 --- a/tests/models/test_gptq_marlin.py +++ b/tests/models/decoder_only/language/test_gptq_marlin.py @@ -15,7 +15,7 @@ from tests.quantization.utils import is_quant_method_supported from vllm.model_executor.layers.rotary_embedding import _ROPE_DICT -from .utils import check_logprobs_close +from ...utils import check_logprobs_close os.environ["TOKENIZERS_PARALLELISM"] = "true" diff --git a/tests/models/test_gptq_marlin_24.py b/tests/models/decoder_only/language/test_gptq_marlin_24.py similarity index 97% rename from tests/models/test_gptq_marlin_24.py rename to tests/models/decoder_only/language/test_gptq_marlin_24.py index 60d9ae2f1c629..d65be05f141b4 100644 --- a/tests/models/test_gptq_marlin_24.py +++ b/tests/models/decoder_only/language/test_gptq_marlin_24.py @@ -10,9 +10,10 @@ import pytest -from tests.models.utils import check_logprobs_close from tests.quantization.utils import is_quant_method_supported +from ...utils import check_logprobs_close + @dataclass class ModelPair: diff --git a/tests/models/test_granite.py b/tests/models/decoder_only/language/test_granite.py similarity index 97% rename from tests/models/test_granite.py rename to tests/models/decoder_only/language/test_granite.py index 2435b5dc3ff88..82c753855e714 100644 --- a/tests/models/test_granite.py +++ b/tests/models/decoder_only/language/test_granite.py @@ -6,7 +6,7 @@ import pytest -from .utils import check_logprobs_close +from ...utils import check_logprobs_close TRANSFORMERS_VERSION = tuple( map(int, diff --git a/tests/models/test_jamba.py b/tests/models/decoder_only/language/test_jamba.py similarity index 99% rename from tests/models/test_jamba.py rename to tests/models/decoder_only/language/test_jamba.py index efb7b1c607721..36fa67a22b0f6 100644 --- a/tests/models/test_jamba.py +++ b/tests/models/decoder_only/language/test_jamba.py @@ -1,8 +1,9 @@ import pytest -from tests.models.utils import check_outputs_equal from vllm.worker.model_runner import _get_graph_batch_size +from ...utils import check_outputs_equal + MODELS = ["ai21labs/Jamba-tiny-random"] diff --git a/tests/models/test_marlin.py b/tests/models/decoder_only/language/test_marlin.py similarity index 98% rename from tests/models/test_marlin.py rename to tests/models/decoder_only/language/test_marlin.py index e86f6e29d1567..c802346dee8af 100644 --- a/tests/models/test_marlin.py +++ b/tests/models/decoder_only/language/test_marlin.py @@ -16,7 +16,7 @@ from tests.quantization.utils import is_quant_method_supported -from .utils import check_logprobs_close +from ...utils import check_logprobs_close @dataclass diff --git a/tests/models/test_mistral.py b/tests/models/decoder_only/language/test_mistral.py similarity index 98% rename from tests/models/test_mistral.py rename to tests/models/decoder_only/language/test_mistral.py index 0741174497e32..687ba6a03a691 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/decoder_only/language/test_mistral.py @@ -4,7 +4,7 @@ """ import pytest -from .utils import check_logprobs_close +from ...utils import check_logprobs_close MODELS = [ "mistralai/Mistral-7B-Instruct-v0.1", diff --git a/tests/models/test_modelopt.py b/tests/models/decoder_only/language/test_modelopt.py similarity index 100% rename from tests/models/test_modelopt.py rename to tests/models/decoder_only/language/test_modelopt.py diff --git a/tests/models/test_models.py b/tests/models/decoder_only/language/test_models.py similarity index 97% rename from tests/models/test_models.py rename to tests/models/decoder_only/language/test_models.py index 4cd2cb665c8f0..68055cbe29095 100644 --- a/tests/models/test_models.py +++ b/tests/models/decoder_only/language/test_models.py @@ -7,7 +7,7 @@ """ import pytest -from .utils import check_outputs_equal +from ...utils import check_outputs_equal MODELS = [ "facebook/opt-125m", diff --git a/tests/models/test_phimoe.py b/tests/models/decoder_only/language/test_phimoe.py similarity index 98% rename from tests/models/test_phimoe.py rename to tests/models/decoder_only/language/test_phimoe.py index 2fb2eecc94672..dbdf5a1b934a6 100644 --- a/tests/models/test_phimoe.py +++ b/tests/models/decoder_only/language/test_phimoe.py @@ -7,7 +7,7 @@ from vllm.utils import is_cpu -from .utils import check_logprobs_close +from ...utils import check_logprobs_close MODELS = [ "microsoft/Phi-3.5-MoE-instruct", diff --git a/tests/models/decoder_only/vision_language/__init__.py b/tests/models/decoder_only/vision_language/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/test_blip2.py b/tests/models/decoder_only/vision_language/test_blip2.py similarity index 95% rename from tests/models/test_blip2.py rename to tests/models/decoder_only/vision_language/test_blip2.py index 5d48bad0d7b35..e1e32b96d89ac 100644 --- a/tests/models/test_blip2.py +++ b/tests/models/decoder_only/vision_language/test_blip2.py @@ -6,10 +6,8 @@ from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs -from ..conftest import IMAGE_ASSETS -from .utils import check_logprobs_close - -pytestmark = pytest.mark.vlm +from ....conftest import IMAGE_ASSETS +from ...utils import check_logprobs_close HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": @@ -56,7 +54,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, dtype: str, max_tokens: int, num_logprobs: int) -> None: """Inference result should be the same between hf and vllm. - All the image fixtures for the test is under tests/images. + All the image fixtures for the test are from IMAGE_ASSETS. For huggingface runner, we provide the PIL images as input. For vllm runner, we provide MultiModalData objects and corresponding MultiModalConfig as input. diff --git a/tests/models/decoder_only/vision_language/test_broadcast.py b/tests/models/decoder_only/vision_language/test_broadcast.py new file mode 100644 index 0000000000000..d01490d74bd4d --- /dev/null +++ b/tests/models/decoder_only/vision_language/test_broadcast.py @@ -0,0 +1,42 @@ +import pytest + +from ....utils import multi_gpu_test + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) +@pytest.mark.parametrize("model", [ + "llava-hf/llava-1.5-7b-hf", + "llava-hf/llava-v1.6-mistral-7b-hf", + "facebook/chameleon-7b", +]) +def test_models(hf_runner, vllm_runner, image_assets, + distributed_executor_backend, model) -> None: + + dtype = "half" + max_tokens = 5 + num_logprobs = 5 + tensor_parallel_size = 2 + + if model.startswith("llava-hf/llava-1.5"): + from .test_llava import models, run_test + elif model.startswith("llava-hf/llava-v1.6"): + from .test_llava_next import models, run_test # type: ignore[no-redef] + elif model.startswith("facebook/chameleon"): + from .test_chameleon import models, run_test # type: ignore[no-redef] + else: + raise NotImplementedError(f"Unsupported model: {model}") + + run_test( + hf_runner, + vllm_runner, + image_assets, + model=models[0], + # So that LLaVA-NeXT processor may return nested list + size_factors=[0.25, 0.5, 1.0], + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + ) diff --git a/tests/models/test_chameleon.py b/tests/models/decoder_only/vision_language/test_chameleon.py similarity index 95% rename from tests/models/test_chameleon.py rename to tests/models/decoder_only/vision_language/test_chameleon.py index e02b4b1ed72bd..8334451970a4f 100644 --- a/tests/models/test_chameleon.py +++ b/tests/models/decoder_only/vision_language/test_chameleon.py @@ -6,10 +6,8 @@ from vllm.multimodal.utils import rescale_image_size from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE -from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets -from .utils import check_outputs_equal - -pytestmark = pytest.mark.vlm +from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets +from ...utils import check_outputs_equal HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": @@ -36,7 +34,7 @@ def run_test( ): """Inference result should be the same between hf and vllm. - All the image fixtures for the test is under tests/images. + All the image fixtures for the test are from IMAGE_ASSETS. For huggingface runner, we provide the PIL images as input. For vllm runner, we provide MultiModalDataDict objects and corresponding vision language config as input. diff --git a/tests/models/test_fuyu.py b/tests/models/decoder_only/vision_language/test_fuyu.py similarity index 95% rename from tests/models/test_fuyu.py rename to tests/models/decoder_only/vision_language/test_fuyu.py index 0d666d8f71a92..94b8431424db5 100644 --- a/tests/models/test_fuyu.py +++ b/tests/models/decoder_only/vision_language/test_fuyu.py @@ -6,10 +6,8 @@ from vllm.sequence import SampleLogprobs from vllm.utils import is_cpu -from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets -from .utils import check_logprobs_close - -pytestmark = pytest.mark.vlm +from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets +from ...utils import check_logprobs_close HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": @@ -46,7 +44,7 @@ def run_test( ): """Inference result should be the same between hf and vllm. - All the image fixtures for the test is under tests/images. + All the image fixtures for the test are from IMAGE_ASSETS. For huggingface runner, we provide the PIL images as input. For vllm runner, we provide MultiModalDataDict objects and corresponding MultiModalConfig as input. diff --git a/tests/models/test_intern_vit.py b/tests/models/decoder_only/vision_language/test_intern_vit.py similarity index 97% rename from tests/models/test_intern_vit.py rename to tests/models/decoder_only/vision_language/test_intern_vit.py index 816f846f69bae..3c3b95b38baac 100644 --- a/tests/models/test_intern_vit.py +++ b/tests/models/decoder_only/vision_language/test_intern_vit.py @@ -6,9 +6,7 @@ from huggingface_hub import snapshot_download from transformers import AutoConfig, AutoModel, CLIPImageProcessor -from ..conftest import _ImageAssets, cleanup - -pytestmark = pytest.mark.vlm +from ....conftest import _ImageAssets, cleanup # we use snapshot_download to prevent conflicts between # dynamic_module and trust_remote_code for hf_runner diff --git a/tests/models/test_internvl.py b/tests/models/decoder_only/vision_language/test_internvl.py similarity index 98% rename from tests/models/test_internvl.py rename to tests/models/decoder_only/vision_language/test_internvl.py index 881068b3afe41..a756f8214edee 100644 --- a/tests/models/test_internvl.py +++ b/tests/models/decoder_only/vision_language/test_internvl.py @@ -9,11 +9,9 @@ from vllm.multimodal.utils import rescale_image_size from vllm.utils import is_cpu -from ..conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, - _ImageAssets) -from .utils import check_logprobs_close - -pytestmark = pytest.mark.vlm +from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, + _ImageAssets) +from ...utils import check_logprobs_close HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": @@ -78,7 +76,7 @@ def run_test( ): """Inference result should be the same between hf and vllm. - All the image fixtures for the test is under tests/images. + All the image fixtures for the test are from IMAGE_ASSETS. For huggingface runner, we provide the PIL images as input. For vllm runner, we provide MultiModalDataDict objects and corresponding MultiModalConfig as input. diff --git a/tests/models/test_llava.py b/tests/models/decoder_only/vision_language/test_llava.py similarity index 96% rename from tests/models/test_llava.py rename to tests/models/decoder_only/vision_language/test_llava.py index 84ca23f6222a9..fd28a9367b4b2 100644 --- a/tests/models/test_llava.py +++ b/tests/models/decoder_only/vision_language/test_llava.py @@ -8,11 +8,9 @@ from vllm.sequence import SampleLogprobs from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE -from ..conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, - _ImageAssets) -from .utils import check_logprobs_close - -pytestmark = pytest.mark.vlm +from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, + _ImageAssets) +from ...utils import check_logprobs_close _LIMIT_IMAGE_PER_PROMPT = 4 @@ -143,7 +141,7 @@ def _run_test( ): """Inference result should be the same between hf and vllm. - All the image fixtures for the test is under tests/images. + All the image fixtures for the test are from IMAGE_ASSETS. For huggingface runner, we provide the PIL images as input. For vllm runner, we provide MultiModalDataDict objects and corresponding MultiModalConfig as input. @@ -239,7 +237,7 @@ def process(hf_inputs: BatchEncoding): @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, - dtype: str, max_tokens: int, num_logprobs: int) -> None: + dtype, max_tokens, num_logprobs) -> None: run_test( hf_runner, vllm_runner, diff --git a/tests/models/test_llava_image_embeds.py b/tests/models/decoder_only/vision_language/test_llava_image_embeds.py similarity index 96% rename from tests/models/test_llava_image_embeds.py rename to tests/models/decoder_only/vision_language/test_llava_image_embeds.py index cc444fe32e79b..66414032509ed 100644 --- a/tests/models/test_llava_image_embeds.py +++ b/tests/models/decoder_only/vision_language/test_llava_image_embeds.py @@ -5,10 +5,8 @@ from vllm.sequence import SampleLogprobs -from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets -from .utils import check_logprobs_close - -pytestmark = pytest.mark.vlm +from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets +from ...utils import check_logprobs_close HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": @@ -62,7 +60,7 @@ def run_test( ): """Inference result should be the same between hf and vllm. - All the image fixtures for the test is under tests/images. + All the image fixtures for the test are from IMAGE_ASSETS. For huggingface runner, we provide the PIL images as input. For vllm runner, we provide MultiModalDataDict objects and corresponding vision language config as input. diff --git a/tests/models/test_llava_next.py b/tests/models/decoder_only/vision_language/test_llava_next.py similarity index 97% rename from tests/models/test_llava_next.py rename to tests/models/decoder_only/vision_language/test_llava_next.py index d5fe0cbe32880..f833fe0c8bbb4 100644 --- a/tests/models/test_llava_next.py +++ b/tests/models/decoder_only/vision_language/test_llava_next.py @@ -6,11 +6,9 @@ from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs -from ..conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, - _ImageAssets) -from .utils import check_logprobs_close - -pytestmark = pytest.mark.vlm +from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, + _ImageAssets) +from ...utils import check_logprobs_close _LIMIT_IMAGE_PER_PROMPT = 4 @@ -197,7 +195,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, dtype, max_tokens, num_logprobs) -> None: """Inference result should be the same between hf and vllm. - All the image fixtures for the test is under tests/images. + All the image fixtures for the test are from IMAGE_ASSETS. For huggingface runner, we provide the PIL images as input. For vllm runner, we provide MultiModalDataDict objects and corresponding MultiModalConfig as input. diff --git a/tests/models/test_llava_next_video.py b/tests/models/decoder_only/vision_language/test_llava_next_video.py similarity index 98% rename from tests/models/test_llava_next_video.py rename to tests/models/decoder_only/vision_language/test_llava_next_video.py index 6856b15f22ec3..373c8964054cd 100644 --- a/tests/models/test_llava_next_video.py +++ b/tests/models/decoder_only/vision_language/test_llava_next_video.py @@ -8,10 +8,8 @@ sample_frames_from_video) from vllm.sequence import SampleLogprobs -from ..conftest import VIDEO_ASSETS, HfRunner, VllmRunner, _VideoAssets -from .utils import check_logprobs_close - -pytestmark = pytest.mark.vlm +from ....conftest import VIDEO_ASSETS, HfRunner, VllmRunner, _VideoAssets +from ...utils import check_logprobs_close _PREFACE = ( "A chat between a curious human and an artificial intelligence assistant. " diff --git a/tests/models/test_minicpmv.py b/tests/models/decoder_only/vision_language/test_minicpmv.py similarity index 97% rename from tests/models/test_minicpmv.py rename to tests/models/decoder_only/vision_language/test_minicpmv.py index 99e49c14f1f26..7bf5d75f400f9 100644 --- a/tests/models/test_minicpmv.py +++ b/tests/models/decoder_only/vision_language/test_minicpmv.py @@ -9,10 +9,8 @@ from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs -from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner -from .utils import check_logprobs_close - -pytestmark = pytest.mark.vlm +from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner +from ...utils import check_logprobs_close # The image token is placed before "user" on purpose so that the test can pass HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ @@ -65,7 +63,7 @@ def run_test( ): """Inference result should be the same between hf and vllm. - All the image fixtures for the test is under tests/images. + All the image fixtures for the test are from IMAGE_ASSETS. For huggingface runner, we provide the PIL images as input. For vllm runner, we provide MultiModalDataDict objects and corresponding MultiModalConfig as input. diff --git a/tests/models/test_paligemma.py b/tests/models/decoder_only/vision_language/test_paligemma.py similarity index 96% rename from tests/models/test_paligemma.py rename to tests/models/decoder_only/vision_language/test_paligemma.py index beddaaf608a18..d7e29ea76ba4e 100644 --- a/tests/models/test_paligemma.py +++ b/tests/models/decoder_only/vision_language/test_paligemma.py @@ -8,10 +8,8 @@ from vllm.sequence import SampleLogprobs from vllm.utils import is_hip -from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets -from .utils import check_logprobs_close - -pytestmark = pytest.mark.vlm +from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets +from ...utils import check_logprobs_close HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": @@ -69,7 +67,7 @@ def run_test( ): """Inference result should be the same between hf and vllm. - All the image fixtures for the test is under tests/images. + All the image fixtures for the test are from IMAGE_ASSETS. For huggingface runner, we provide the PIL images as input. For vllm runner, we provide MultiModalDataDict objects and corresponding MultiModalConfig as input. diff --git a/tests/models/test_phi3v.py b/tests/models/decoder_only/vision_language/test_phi3v.py similarity index 97% rename from tests/models/test_phi3v.py rename to tests/models/decoder_only/vision_language/test_phi3v.py index 6ecbf07a08b7c..e248151c40a60 100644 --- a/tests/models/test_phi3v.py +++ b/tests/models/decoder_only/vision_language/test_phi3v.py @@ -9,10 +9,8 @@ from vllm.sequence import SampleLogprobs from vllm.utils import is_cpu, is_hip -from ..conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner -from .utils import check_logprobs_close - -pytestmark = pytest.mark.vlm +from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner +from ...utils import check_logprobs_close HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": @@ -71,7 +69,7 @@ def run_test( ): """Inference result should be the same between hf and vllm. - All the image fixtures for the test is under tests/images. + All the image fixtures for the test are from IMAGE_ASSETS. For huggingface runner, we provide the PIL images as input. For vllm runner, we provide MultiModalDataDict objects and corresponding MultiModalConfig as input. diff --git a/tests/models/test_pixtral.py b/tests/models/decoder_only/vision_language/test_pixtral.py similarity index 90% rename from tests/models/test_pixtral.py rename to tests/models/decoder_only/vision_language/test_pixtral.py index 1fbfd77218ca7..072bedfc01a1f 100644 --- a/tests/models/test_pixtral.py +++ b/tests/models/decoder_only/vision_language/test_pixtral.py @@ -5,7 +5,7 @@ import json import uuid from dataclasses import asdict -from typing import Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import pytest from mistral_common.protocol.instruct.messages import ImageURLChunk @@ -17,9 +17,11 @@ from vllm.multimodal import MultiModalDataBuiltins from vllm.sequence import Logprob, SampleLogprobs -from .utils import check_logprobs_close +from ....utils import VLLM_PATH +from ...utils import check_logprobs_close -pytestmark = pytest.mark.vlm +if TYPE_CHECKING: + from _typeshed import StrPath MODELS = ["mistralai/Pixtral-12B-2409"] IMG_URLS = [ @@ -83,14 +85,21 @@ def _create_engine_inputs(urls: List[str]) -> TokensPrompt: LIMIT_MM_PER_PROMPT = dict(image=4) MAX_MODEL_LEN = [8192, 65536] -FIXTURE_LOGPROBS_CHAT = "tests/models/fixtures/pixtral_chat.json" -FIXTURE_LOGPROBS_ENGINE = "tests/models/fixtures/pixtral_chat_engine.json" + +FIXTURES_PATH = VLLM_PATH / "tests/models/fixtures" +assert FIXTURES_PATH.exists() + +FIXTURE_LOGPROBS_CHAT = FIXTURES_PATH / "pixtral_chat.json" +FIXTURE_LOGPROBS_ENGINE = FIXTURES_PATH / "pixtral_chat_engine.json" OutputsLogprobs = List[Tuple[List[int], str, Optional[SampleLogprobs]]] # For the test author to store golden output in JSON -def _dump_outputs_w_logprobs(outputs: OutputsLogprobs, filename: str) -> None: +def _dump_outputs_w_logprobs( + outputs: OutputsLogprobs, + filename: "StrPath", +) -> None: json_data = [(tokens, text, [{k: asdict(v) for k, v in token_logprobs.items()} @@ -101,7 +110,7 @@ def _dump_outputs_w_logprobs(outputs: OutputsLogprobs, filename: str) -> None: json.dump(json_data, f) -def load_outputs_w_logprobs(filename: str) -> OutputsLogprobs: +def load_outputs_w_logprobs(filename: "StrPath") -> OutputsLogprobs: with open(filename, "rb") as f: json_data = json.load(f) diff --git a/tests/models/test_qwen.py b/tests/models/decoder_only/vision_language/test_qwen.py similarity index 98% rename from tests/models/test_qwen.py rename to tests/models/decoder_only/vision_language/test_qwen.py index 5e7f1de99d6c3..e4f79092b7606 100644 --- a/tests/models/test_qwen.py +++ b/tests/models/decoder_only/vision_language/test_qwen.py @@ -10,11 +10,9 @@ from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.utils import cached_get_tokenizer, rescale_image_size -from ..conftest import (IMAGE_ASSETS, HfRunner, ImageAsset, PromptImageInput, - VllmRunner, _ImageAssets) -from .utils import check_logprobs_close - -pytestmark = pytest.mark.vlm +from ....conftest import (IMAGE_ASSETS, HfRunner, ImageAsset, PromptImageInput, + VllmRunner, _ImageAssets) +from ...utils import check_logprobs_close text_only_models = [ "Qwen/Qwen-7B-Chat" # Has no visual component diff --git a/tests/models/embedding/__init__.py b/tests/models/embedding/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/embedding/language/__init__.py b/tests/models/embedding/language/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/test_embedding.py b/tests/models/embedding/language/test_embedding.py similarity index 100% rename from tests/models/test_embedding.py rename to tests/models/embedding/language/test_embedding.py diff --git a/tests/models/encoder_decoder/__init__.py b/tests/models/encoder_decoder/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/encoder_decoder/language/__init__.py b/tests/models/encoder_decoder/language/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/test_bart.py b/tests/models/encoder_decoder/language/test_bart.py similarity index 69% rename from tests/models/test_bart.py rename to tests/models/encoder_decoder/language/test_bart.py index 660b61d1a7ade..758a9b743b397 100644 --- a/tests/models/test_bart.py +++ b/tests/models/encoder_decoder/language/test_bart.py @@ -1,8 +1,8 @@ """Compare the outputs of HF and vLLM for BART models using greedy sampling. -Run `pytest tests/models/test_bart.py`. +Run `pytest tests/models/encoder_decoder/language/test_bart.py`. """ -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Type from vllm.utils import is_cpu @@ -16,8 +16,10 @@ from vllm.sequence import SampleLogprobs - from ..conftest import DecoderPromptType - from .utils import check_logprobs_close + from ....conftest import (DecoderPromptType, ExplicitEncoderDecoderPrompt, + HfRunner, VllmRunner) + from ....utils import multi_gpu_test + from ...utils import check_logprobs_close MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"] @@ -34,20 +36,18 @@ def vllm_to_hf_output( return output_ids, hf_output_str, out_logprobs - @pytest.mark.parametrize("model", MODELS) - @pytest.mark.parametrize("dtype", ["float", "bfloat16"]) - @pytest.mark.parametrize("max_tokens", [64]) - @pytest.mark.parametrize("num_logprobs", [5]) - @pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) - def test_models( - hf_runner, - vllm_runner, - example_encoder_decoder_prompts, + def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + prompts: List[ExplicitEncoderDecoderPrompt[str, str]], + decoder_prompt_type: DecoderPromptType, model: str, + *, dtype: str, max_tokens: int, num_logprobs: int, - decoder_prompt_type: DecoderPromptType, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, ) -> None: ''' Test the vLLM BART model for a variety of encoder/decoder input prompts, @@ -116,8 +116,29 @@ def test_models( token during the process of validating the vLLM decoded output. ''' - test_case_prompts = example_encoder_decoder_prompts[ - decoder_prompt_type] + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default). + + # Note: currently encoder/decoder models are only compatible with + # enforce_eager=True. Normally this is not a problem because + # for encoder/decoder models vLLM will + # default to enforce_eager=True if enforce_eager + # is left unspecified. However, the + # VllmRunner test fixture (which wraps around the LLM class) defaults to + # enforce_eager=False (a behavior which a number of already-exisitng + # decoder-only unit tests expect), so when testing an encoder/decoder + # model we must explicitly specify enforce_eager=True in the VllmRunner + # constructor. + with vllm_runner( + model, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True) as vllm_model: + vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( + prompts, max_tokens, num_logprobs) # Configuration settings for HF baseline hf_kwargs = { @@ -135,26 +156,12 @@ def test_models( auto_cls=AutoModelForSeq2SeqLM) as hf_model: hf_outputs = ( hf_model.generate_encoder_decoder_greedy_logprobs_limit( - test_case_prompts, + prompts, max_tokens, num_logprobs, **hf_kwargs, )) - # Note: currently encoder/decoder models are only compatible with - # enforce_eager=True. Normally this is not a problem because - # for encoder/decoder models vLLM will - # default to enforce_eager=True if enforce_eager - # is left unspecified. However, the - # VllmRunner test fixture (which wraps around the LLM class) defaults to - # enforce_eager=False (a behavior which a number of already-exisitng - # decoder-only unit tests expect), so when testing an encoder/decoder - # model we must explicitly specify enforce_eager=True in the VllmRunner - # constructor. - with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model: - vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( - test_case_prompts, max_tokens, num_logprobs) - hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE else 0) @@ -168,3 +175,49 @@ def test_models( name_1="vllm", num_outputs_0_skip_tokens=hf_skip_tokens, ) + + @pytest.mark.parametrize("model", MODELS) + @pytest.mark.parametrize("dtype", ["float", "bfloat16"]) + @pytest.mark.parametrize("max_tokens", [64]) + @pytest.mark.parametrize("num_logprobs", [5]) + @pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) + def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, + model, dtype, max_tokens, num_logprobs, + decoder_prompt_type) -> None: + + run_test( + hf_runner, + vllm_runner, + example_encoder_decoder_prompts[decoder_prompt_type], + decoder_prompt_type, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) + + @multi_gpu_test(num_gpus=2) + @pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) + @pytest.mark.parametrize("model", ["facebook/bart-large-cnn"]) + @pytest.mark.parametrize("dtype", ["float"]) + @pytest.mark.parametrize("max_tokens", [64]) + @pytest.mark.parametrize("num_logprobs", [5]) + @pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.CUSTOM]) + def test_models_distributed(hf_runner, vllm_runner, + example_encoder_decoder_prompts, + distributed_executor_backend, model, dtype, + max_tokens, num_logprobs, + decoder_prompt_type) -> None: + run_test( + hf_runner, + vllm_runner, + example_encoder_decoder_prompts[decoder_prompt_type], + decoder_prompt_type, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=2, + distributed_executor_backend=distributed_executor_backend, + ) diff --git a/tests/utils.py b/tests/utils.py index 3c519fb6e50e0..f6c2be17ebdcf 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -10,6 +10,7 @@ from typing import Any, Callable, Dict, List, Optional import openai +import pytest import requests from openai.types.completion import Completion from transformers import AutoTokenizer @@ -22,7 +23,8 @@ from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.model_executor.model_loader.loader import get_model_loader from vllm.platforms import current_platform -from vllm.utils import FlexibleArgumentParser, get_open_port, is_hip +from vllm.utils import (FlexibleArgumentParser, cuda_device_count_stateless, + get_open_port, is_hip) if current_platform.is_rocm(): from amdsmi import (amdsmi_get_gpu_vram_usage, @@ -452,6 +454,22 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: return wrapper +def multi_gpu_test(*, num_gpus: int): + """ + Decorate a test to be run only when multiple GPUs are available. + """ + test_selector = getattr(pytest.mark, f"distributed_{num_gpus}_gpus") + test_skipif = pytest.mark.skipif( + cuda_device_count_stateless() < num_gpus, + reason=f"Need at least {num_gpus} GPUs to run the test.", + ) + + def wrapper(f: Callable[_P, None]) -> Callable[_P, None]: + return test_selector(test_skipif(fork_new_process_for_each_test(f))) + + return wrapper + + async def completions_with_server_args( prompts: List[str], model_name: str, From f57092c00b53d6da887f2b8071af332d42ccb6d4 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sat, 14 Sep 2024 02:06:30 +0800 Subject: [PATCH 37/98] [Doc] Add oneDNN installation to CPU backend documentation (#8467) --- docs/source/getting_started/cpu-installation.rst | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/docs/source/getting_started/cpu-installation.rst b/docs/source/getting_started/cpu-installation.rst index 7fc469e06844f..816e0a29ef28b 100644 --- a/docs/source/getting_started/cpu-installation.rst +++ b/docs/source/getting_started/cpu-installation.rst @@ -59,6 +59,20 @@ Build from source $ pip install wheel packaging ninja "setuptools>=49.4.0" numpy $ pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu +- Third, build and install oneDNN library from source: + +.. code-block:: console + + $ git clone -b rls-v3.5 https://github.com/oneapi-src/oneDNN.git + $ cmake -B ./oneDNN/build -S ./oneDNN -G Ninja -DONEDNN_LIBRARY_TYPE=STATIC \ + -DONEDNN_BUILD_DOC=OFF \ + -DONEDNN_BUILD_EXAMPLES=OFF \ + -DONEDNN_BUILD_TESTS=OFF \ + -DONEDNN_BUILD_GRAPH=OFF \ + -DONEDNN_ENABLE_WORKLOAD=INFERENCE \ + -DONEDNN_ENABLE_PRIMITIVE=MATMUL + $ cmake --build ./oneDNN/build --target install --config Release + - Finally, build and install vLLM CPU backend: .. code-block:: console From 18e9e1f7b34c46857466fe24e9f9bdee17542f2c Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 13 Sep 2024 19:31:12 +0100 Subject: [PATCH 38/98] [HotFix] Fix final output truncation with stop string + streaming (#8468) --- tests/async_engine/test_async_llm_engine.py | 26 +++++++++++++++++---- vllm/sequence.py | 4 +++- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index bab42942d311f..a093a2b29278a 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -159,7 +159,8 @@ def should_do_global_cleanup_after_test(request) -> bool: @pytest.mark.asyncio(scope="module") -async def test_asyncio_run(async_engine): +@pytest.mark.parametrize("stop", [None, ["a stop string"]]) +async def test_asyncio_run(async_engine, stop): scheduler_config = await async_engine.get_scheduler_config() num_scheduler_steps = scheduler_config.num_scheduler_steps @@ -169,6 +170,7 @@ async def run(prompt: str): temperature=0, max_tokens=32, min_tokens=32, + stop=stop, ) output_count = 0 @@ -203,7 +205,8 @@ async def run(prompt: str): @pytest.mark.asyncio(scope="module") -async def test_output_kinds(async_engine): +@pytest.mark.parametrize("stop", [None, ["a stop string"]]) +async def test_output_kinds(async_engine, stop): """Test that output_kind works as expected and that results are equivalent across different kinds.""" @@ -214,6 +217,7 @@ async def test_output_kinds(async_engine): temperature=0, max_tokens=32, min_tokens=32, + stop=stop, ) async def run(prompt: str, kind: RequestOutputKind): @@ -229,6 +233,8 @@ async def run(prompt: str, kind: RequestOutputKind): final_output = output assert final_output is not None + assert final_output.finished + return (final_output.prompt_token_ids, final_output.outputs[0].token_ids, final_output.outputs[0].text, output_count) @@ -241,16 +247,18 @@ async def run_deltas(prompt: str): output_tokens: List[int] = [] output_text = "" output_count = 0 + final_output = None async for output in async_engine.generate(prompt, params, request_id=uid()): token_ids = output.outputs[0].token_ids text = output.outputs[0].text + final_output = output # Ensure we get prompt ids iff we haven't yet received output tokens if output_tokens: assert 1 <= len(token_ids) <= num_scheduler_steps - assert text + assert stop or text assert not output.prompt_token_ids else: assert output.prompt_token_ids @@ -260,6 +268,10 @@ async def run_deltas(prompt: str): output_text += text output_count += 1 + + assert final_output is not None + assert final_output.finished + return prompt_tokens, output_tokens, output_text, output_count results = await asyncio.gather( @@ -291,7 +303,8 @@ async def run_deltas(prompt: str): @pytest.mark.asyncio(scope="module") -async def test_cancellation(async_engine): +@pytest.mark.parametrize("stop", [None, ["a stop string"]]) +async def test_cancellation(async_engine, stop): scheduler_config = await async_engine.get_scheduler_config() num_scheduler_steps = scheduler_config.num_scheduler_steps @@ -299,6 +312,7 @@ async def test_cancellation(async_engine): temperature=0, min_tokens=13, max_tokens=13, + stop=stop, ) stop_at = 5 if num_scheduler_steps == 1 else 1 @@ -319,7 +333,8 @@ async def test_cancellation(async_engine): @pytest.mark.asyncio(scope="module") -async def test_delayed_generator(async_engine): +@pytest.mark.parametrize("stop", [None, ["a stop string"]]) +async def test_delayed_generator(async_engine, stop): scheduler_config = await async_engine.get_scheduler_config() if scheduler_config.num_scheduler_steps != 1: @@ -329,6 +344,7 @@ async def test_delayed_generator(async_engine): temperature=0, min_tokens=10, max_tokens=10, + stop=stop, ) stream = async_engine.generate("test3", sampling_params, request_id=uid()) diff --git a/vllm/sequence.py b/vllm/sequence.py index 98a8b73586062..07ceccf123541 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -477,7 +477,9 @@ def get_output_text_to_return(self, buffer_length: int, if not delta: return self.output_text[:-buffer_length] if truncate else ( self.output_text) - length = len(self.output_text) - buffer_length + length = len(self.output_text) + if truncate: + length -= buffer_length last_offset = self._last_output_text_offset if last_offset < length: self._last_output_text_offset = length From 9ba0817ff1eb514f51cc6de9cb8e16c98d6ee44f Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Fri, 13 Sep 2024 11:35:00 -0700 Subject: [PATCH 39/98] bump version to v0.6.1.post2 (#8473) --- vllm/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/version.py b/vllm/version.py index 975e695ac7821..0ddc7fb99ad45 100644 --- a/vllm/version.py +++ b/vllm/version.py @@ -10,4 +10,4 @@ stacklevel=2) __commit__ = "COMMIT_HASH_PLACEHOLDER" -__version__ = "0.6.1.post1" +__version__ = "0.6.1.post2" From 851725202af36dafecd47af802db1d465b25b815 Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Sat, 14 Sep 2024 07:54:34 +0800 Subject: [PATCH 40/98] [Hardware][intel GPU] bump up ipex version to 2.3 (#8365) Co-authored-by: Yan Ma --- Dockerfile.xpu | 12 ++- requirements-xpu.txt | 9 ++- vllm/_ipex_ops.py | 98 +++++++----------------- vllm/attention/backends/ipex_attn.py | 8 +- vllm/model_executor/layers/activation.py | 15 ++-- vllm/model_executor/layers/layernorm.py | 5 +- 6 files changed, 60 insertions(+), 87 deletions(-) diff --git a/Dockerfile.xpu b/Dockerfile.xpu index 321da98cf6c89..50bbd8f7dad87 100644 --- a/Dockerfile.xpu +++ b/Dockerfile.xpu @@ -1,15 +1,23 @@ -FROM intel/oneapi-basekit:2024.1.0-devel-ubuntu20.04 +FROM intel/oneapi-basekit:2024.2.1-0-devel-ubuntu22.04 RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/intel-oneapi-archive-keyring.gpg > /dev/null && \ echo "deb [signed-by=/usr/share/keyrings/intel-oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main " | tee /etc/apt/sources.list.d/oneAPI.list && \ chmod 644 /usr/share/keyrings/intel-oneapi-archive-keyring.gpg && \ - rm /etc/apt/sources.list.d/intel-graphics.list && \ wget -O- https://repositories.intel.com/graphics/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null && \ echo "deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/graphics/ubuntu jammy arc" | tee /etc/apt/sources.list.d/intel.gpu.jammy.list && \ chmod 644 /usr/share/keyrings/intel-graphics.gpg RUN apt-get update -y \ && apt-get install -y curl libicu70 lsb-release git wget vim numactl python3 python3-pip ffmpeg libsm6 libxext6 libgl1 + +RUN git clone https://github.com/intel/pti-gpu && \ + cd pti-gpu/sdk && \ + mkdir build && \ + cd build && \ + cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=../cmake/toolchains/icpx_toolchain.cmake -DBUILD_TESTING=OFF .. && \ + make -j && \ + cmake --install . --config Release --prefix "/usr/local" + COPY ./ /workspace/vllm WORKDIR /workspace/vllm diff --git a/requirements-xpu.txt b/requirements-xpu.txt index 48d899ec70eda..f07211b48b68d 100644 --- a/requirements-xpu.txt +++ b/requirements-xpu.txt @@ -3,9 +3,10 @@ setuptools < 70.0.0 # IPEX's torch have some dependency. to be removed. -torch @ https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl -intel_extension_for_pytorch @ https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.1.30a0-cp310-cp310-linux_x86_64.whl -oneccl_bind_pt @ https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/oneccl_bind_pt-2.1.200%2Bxpu-cp310-cp310-linux_x86_64.whl +torch == 2.3.1+cxx11.abi +intel-extension-for-pytorch == 2.3.110+xpu +oneccl_bind_pt == 2.3.100+xpu -triton @ https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v2.1.0/triton-2.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl +triton-xpu == 3.0.0b2 +--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py index 2156f6b18adb6..31fcc4c3256a8 100644 --- a/vllm/_ipex_ops.py +++ b/vllm/_ipex_ops.py @@ -27,29 +27,27 @@ def _reshape_activation_tensor( @staticmethod def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: - x1, x2 = ipex_ops._reshape_activation_tensor(x) - ipex.llm.functional.silu_mul(x1, x2, out) + ipex.llm.functional.silu_and_mul(x, out) @staticmethod def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: - x1, x2 = ipex_ops._reshape_activation_tensor(x) - ipex.llm.functional.gelu_mul(x1, x2, out, "none") + ipex.llm.functional.gelu_and_mul(x, out) @staticmethod def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: - x1, x2 = ipex_ops._reshape_activation_tensor(x) - ipex.llm.functional.gelu_mul(x1, x2, out, "tanh") + ipex.llm.functional.gelu_and_mul(x, out) @staticmethod - def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: - out.copy_(torch.nn.functional.gelu(x)) + def gelu_fast(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x) @staticmethod - def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None: - out.copy_(torch.nn.functional.gelu(x)) + def gelu_new(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x) - # TODO add implementation of gelu_quick here - # def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None: + @staticmethod + def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None: + ipex.llm.functional.gelu_quick(x, out) @staticmethod def paged_attention_v1( @@ -160,29 +158,10 @@ def rotary_embedding( cos_sin_cache: torch.Tensor, # [cos_sin_dim, rot_dim] is_neox: bool, ) -> None: - if positions.dim() == 1: - positions = positions.unsqueeze(0) - query = query.unsqueeze(0) - key = key.unsqueeze(0) - - rotary_dim = cos_sin_cache.size(1) - query = query.view(*query.shape[:-1], -1, head_size) - key = key.view(*key.shape[:-1], -1, head_size) - - query_rot = query[..., :rotary_dim] - key_rot = key[..., :rotary_dim] - - cos_sin = cos_sin_cache[positions.long()] - cos, sin = cos_sin.chunk(2, dim=-1) - - if is_neox: - cos = cos.repeat(1, 1, 2).unsqueeze(-2) - sin = sin.repeat(1, 1, 2).unsqueeze(-2) - else: - cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) - sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) - ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos, - rotary_dim, is_neox, positions) + rot_dim = cos_sin_cache.size(1) + ipex.llm.functional.rotary_embedding_batched(positions, query, key, + head_size, cos_sin_cache, + is_neox, rot_dim) @staticmethod def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, @@ -190,37 +169,15 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, cos_sin_cache: torch.Tensor, is_neox: bool, rot_dim: int, cos_sin_cache_offsets: torch.Tensor) -> None: - if positions.dim() == 1: - positions = positions.unsqueeze(0) - query = query.unsqueeze(0) - key = key.unsqueeze(0) - cos_sin_cache_offsets = cos_sin_cache_offsets.view_as(positions) - rotary_dim = cos_sin_cache.size(1) - query = query.view(*query.shape[:-1], -1, head_size) - key = key.view(*key.shape[:-1], -1, head_size) - - query_rot = query[..., :rotary_dim] - key_rot = key[..., :rotary_dim] - - cos_sin = cos_sin_cache[torch.add(positions, - cos_sin_cache_offsets).long()] - cos, sin = cos_sin.chunk(2, dim=-1) - - if is_neox: - cos = cos.repeat(1, 1, 2).unsqueeze(-2) - sin = sin.repeat(1, 1, 2).unsqueeze(-2) - else: - cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) - sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) - - ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos, - rotary_dim, is_neox, positions) + ipex.llm.functional.rotary_embedding_batched(positions, query, key, + head_size, cos_sin_cache, + is_neox, rot_dim, + cos_sin_cache_offsets) @staticmethod - def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, - epsilon: float) -> None: - tmp = ipex.llm.functional.rms_norm(input, weight, epsilon) - out.copy_(tmp) + def rms_norm(input: torch.Tensor, weight: torch.Tensor, + epsilon: float) -> torch.Tensor: + return ipex.llm.functional.rms_norm(input, weight, epsilon) @staticmethod def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, @@ -246,11 +203,14 @@ def varlen_attention( return_softmax: bool, gen_: torch.Generator, ) -> None: - ipex.llm.functional.varlen_attention(query, key, value, out, seqlen_q, - seqlen_k, max_seqlen_q, - max_seqlen_k, pdropout, - softmax_scale, zero_tensors, - is_causal, return_softmax, gen_) + ipex.llm.functional.varlen_attention(query.contiguous(), + key.contiguous(), + value.contiguous(), out, + seqlen_q.int(), seqlen_k.int(), + max_seqlen_q, max_seqlen_k, + pdropout, softmax_scale, + zero_tensors, is_causal, + return_softmax, gen_) @staticmethod def reshape_and_cache( diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 64d60e4e47e48..113a2788eacd3 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -49,14 +49,18 @@ def swap_blocks( dst_kv_cache: torch.Tensor, src_to_dst: torch.Tensor, ) -> None: - PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + from vllm._ipex_ops import ipex_ops as ops + ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: - PagedAttention.copy_blocks(kv_caches, src_to_dists) + from vllm._ipex_ops import ipex_ops as ops + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) @dataclass diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 4c14fe476ee4a..43056786d35c9 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -114,9 +114,7 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: from vllm._ipex_ops import ipex_ops as ops - out = torch.empty_like(x) - ops.gelu_new(out, x) - return out + return ops.gelu_new(x) class FastGELU(CustomOp): @@ -136,9 +134,7 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: from vllm._ipex_ops import ipex_ops as ops - out = torch.empty_like(x) - ops.gelu_fast(out, x) - return out + return ops.gelu_fast(x) class QuickGELU(CustomOp): @@ -155,6 +151,13 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: ops.gelu_quick(out, x) return out + def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + from vllm._ipex_ops import ipex_ops as ops + + out = torch.empty_like(x) + ops.gelu_quick(out, x) + return out + # TODO implement forward_xpu for QuickGELU # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index e3d588efd9b6d..14f60e9172f29 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -82,14 +82,11 @@ def forward_xpu( self.variance_epsilon, ) return x, residual - out = torch.empty_like(x) - ops.rms_norm( - out, + return ops.rms_norm( x, self.weight.data, self.variance_epsilon, ) - return out def extra_repr(self) -> str: s = f"hidden_size={self.weight.data.size(0)}" From 1ef0d2efd07f93bc7b0cfb597d8947b49e2fdaac Mon Sep 17 00:00:00 2001 From: Charlie Fu Date: Fri, 13 Sep 2024 19:01:11 -0500 Subject: [PATCH 41/98] [Kernel][Hardware][Amd]Custom paged attention kernel for rocm (#8310) --- CMakeLists.txt | 23 + csrc/rocm/attention.cu | 1038 ++++++++++++++++++++ csrc/rocm/ops.h | 13 + csrc/rocm/torch_bindings.cpp | 33 + setup.py | 3 + tests/kernels/test_attention.py | 166 +++- vllm/_custom_ops.py | 27 + vllm/attention/backends/rocm_flash_attn.py | 84 +- 8 files changed, 1371 insertions(+), 16 deletions(-) create mode 100644 csrc/rocm/attention.cu create mode 100644 csrc/rocm/ops.h create mode 100644 csrc/rocm/torch_bindings.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index f8d6a2be9feae..c8f19de94e59b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -324,6 +324,25 @@ define_gpu_extension_target( WITH_SOABI) +if(VLLM_GPU_LANG STREQUAL "HIP") + # + # _rocm_C extension + # + set(VLLM_ROCM_EXT_SRC + "csrc/rocm/torch_bindings.cpp" + "csrc/rocm/attention.cu") + + define_gpu_extension_target( + _rocm_C + DESTINATION vllm + LANGUAGE ${VLLM_GPU_LANG} + SOURCES ${VLLM_ROCM_EXT_SRC} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + USE_SABI 3 + WITH_SOABI) +endif() + if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") message(STATUS "Enabling C extension.") @@ -331,5 +350,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") message(STATUS "Enabling moe extension.") add_dependencies(default _moe_C) +endif() +if(VLLM_GPU_LANG STREQUAL "HIP") + message(STATUS "Enabling rocm extension.") + add_dependencies(default _rocm_C) endif() diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu new file mode 100644 index 0000000000000..8fa7c862fbfa8 --- /dev/null +++ b/csrc/rocm/attention.cu @@ -0,0 +1,1038 @@ +/* + * Copyright (c) 2024, The vLLM team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include + +#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \ + defined(__gfx941__) || defined(__gfx942__)) + #define __HIP__MI300_MI250__ +#endif + +#if defined(NDEBUG) + #undef NDEBUG + #include + #define UNREACHABLE_CODE assert(false); + #define NDEBUG +#else + #define UNREACHABLE_CODE assert(false); +#endif + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) +#define WARP_SIZE 64 + +#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support + + #define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32 + #define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16 + +using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; +using float16x4 = + __attribute__((__vector_size__(4 * sizeof(_Float16)))) _Float16; +typedef float16x4 _Half4; +typedef struct _Half8 { + _Half4 xy[2]; +} _Half8; + +using bit16_t = uint16_t; +using bit16x4 = __attribute__((__vector_size__(4 * sizeof(uint16_t)))) uint16_t; +typedef bit16x4 _B16x4; +typedef struct _B16x8 { + _B16x4 xy[2]; +} _B16x8; + +////// Non temporal load stores /////// + +template +__device__ __forceinline__ T load(T* addr) { + return addr[0]; +} + +template +__device__ __forceinline__ void store(T value, T* addr) { + addr[0] = value; +} + +template +__device__ __forceinline__ floatx4 gcn_mfma_instr(const _B16x4& inpA, + const _B16x4& inpB, + const floatx4& inpC) { + if constexpr (std::is_same::value) { + return __builtin_amdgcn_mfma_f32_4x4x4f16(inpA, inpB, inpC, absz, cbid, + blgp); + } else if constexpr (std::is_same::value) { + return __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(inpA, inpB, inpC, absz, cbid, + blgp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ float to_float(const T& inp) { + if constexpr (std::is_same::value) { + return (float)inp; + } else if constexpr (std::is_same::value) { + return __bfloat162float(inp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ T from_float(const float& inp) { + if constexpr (std::is_same::value) { + return (_Float16)inp; + } else if constexpr (std::is_same::value) { + return __float2bfloat16(inp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) { + union tmpcvt { + uint16_t u; + _Float16 f; + __hip_bfloat16 b; + } t16; + _B16x4 ret; + if constexpr (std::is_same::value) { + #pragma unroll + for (int i = 0; i < 4; i++) { + t16.f = (_Float16)inp[i]; + ret[i] = t16.u; + } + return ret; + } else if constexpr (std::is_same::value) { + #pragma unroll + for (int i = 0; i < 4; i++) { + t16.b = __float2bfloat16(inp[i]); + ret[i] = t16.u; + } + return ret; + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, + const _B16x4& inp2) { + union tmpcvt { + uint16_t u; + _Float16 f; + __hip_bfloat16 b; + } t1, t2, res; + _B16x4 ret; + if constexpr (std::is_same::value) { + #pragma unroll + for (int i = 0; i < 4; i++) { + t1.u = inp1[i]; + t2.u = inp2[i]; + res.f = t1.f + t2.f; + ret[i] = res.u; + } + return ret; + } else if constexpr (std::is_same::value) { + #pragma unroll + for (int i = 0; i < 4; i++) { + t1.u = inp1[i]; + t2.u = inp2[i]; + res.b = t1.b + t2.b; + ret[i] = res.u; + } + return ret; + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +/////////////////////////////////////// + +// grid (num_seqs, num_partitions,num_heads/gqa_ratio) +// block (partition size) +template +__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] + #if 0 + scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size] + #endif + int max_ctx_blocks) { + constexpr int NWARPS = NUM_THREADS / WARP_SIZE; + const int warpid = threadIdx.x / WARP_SIZE; + const int laneid = threadIdx.x % WARP_SIZE; + const int lane4id = laneid % 4; + + const int seq_idx = blockIdx.x; + const int partition_idx = blockIdx.y; + const int partition_size = blockDim.x; + const int max_num_partitions = gridDim.y; + + const int context_len = context_lens[seq_idx]; + const int partition_start_token_idx = partition_idx * partition_size; + // exit if partition is out of context for seq + if (partition_start_token_idx >= context_len) { + return; + } + constexpr int QHLOOP = + DIVIDE_ROUND_UP(GQA_RATIO, 4); // each 4 lanes fetch 4 different qheads, + // total qheads =8, so qhloop is 2 + constexpr int GQA_RATIO4 = 4 * QHLOOP; + __shared__ float shared_qk_max[NWARPS][GQA_RATIO4 + 1]; + __shared__ float shared_exp_sum[NWARPS][GQA_RATIO4 + 1]; + _B16x8 Qlocal[QHLOOP]; + constexpr int x = 16 / sizeof(scalar_t); + constexpr int KHELOOP = HEAD_SIZE / x; + _B16x8 Klocal[KHELOOP]; + constexpr int VHELOOP = + HEAD_SIZE / + WARP_SIZE; // v head_size dimension is distributed across lanes + constexpr int VTLOOP = 8; // 16 separate 4xtokens across warp -> 16/2 + // 8xtokens + _B16x8 Vlocal[VHELOOP][VTLOOP]; + floatx4 dout[QHLOOP]; + float qk_max[QHLOOP]; + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + dout[h] = {0}; + qk_max[h] = -FLT_MAX; + } + + const int wg_start_head_idx = blockIdx.z * GQA_RATIO; + const int wg_start_kv_head_idx = blockIdx.z; + + const int warp_start_token_idx = + partition_start_token_idx + warpid * WARP_SIZE; + + if (warp_start_token_idx >= context_len) { // warp out of context + #pragma unroll + for (int h = 0; h < GQA_RATIO4; h++) { + shared_qk_max[warpid][h] = -FLT_MAX; + shared_exp_sum[warpid][h] = 0.0f; + } + } else { // warp within context + + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int last_ctx_block = num_context_blocks - 1; + + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + + const int local_token_idx = threadIdx.x; + const int global_token_idx = partition_start_token_idx + local_token_idx; + + const int block_idx = (global_token_idx < context_len) + ? global_token_idx / BLOCK_SIZE + : last_ctx_block; + // fetch block number for q and k + // int32 physical_block_number leads to overflow when multiplied with + // kv_block_stride + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + + // fetch vphysical block numbers up front + constexpr int VBLOCKS = 8 * VTLOOP / BLOCK_SIZE; + int vphysical_blocks[VBLOCKS]; + + const int warp_start_block_idx = warp_start_token_idx / BLOCK_SIZE; + #pragma unroll + for (int b = 0; b < VBLOCKS; b++) { + const int vblock_idx = warp_start_block_idx + b; + const int vblock_idx_ctx = + (vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block; + vphysical_blocks[b] = block_table[vblock_idx_ctx]; + } + // each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems + const scalar_t* q_ptr = + q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE; + const _B16x8* q_ptrh8 = reinterpret_cast(q_ptr); + const int qhead_elemh8 = laneid / 4; + #pragma unroll + for (int h = 0; h < QHLOOP - 1; h++) { + const int qhead_idx = h * 4 + lane4id; + Qlocal[h] = q_ptrh8[qhead_idx * HEAD_SIZE / 8 + qhead_elemh8]; + } + const int final_qhead_idx = 4 * (QHLOOP - 1) + lane4id; + if (final_qhead_idx < GQA_RATIO) { + Qlocal[QHLOOP - 1] = + q_ptrh8[final_qhead_idx * HEAD_SIZE / 8 + qhead_elemh8]; + } else { + Qlocal[QHLOOP - 1].xy[0] = {0}; + Qlocal[QHLOOP - 1].xy[1] = {0}; + } + + const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride + + wg_start_kv_head_idx * kv_head_stride; + + const int physical_block_offset = + local_token_idx % BLOCK_SIZE; // since x=half8, physical_block_offset + // is already cast as _H8 + + const _B16x8* k_ptrh8 = reinterpret_cast(k_ptr); + #pragma unroll + for (int d = 0; d < KHELOOP; d++) { + Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset]; + } + + float alibi_slope[QHLOOP]; + if (alibi_slopes != nullptr) { + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + const int qhead_idx = h * 4 + lane4id; + alibi_slope[h] = (qhead_idx < GQA_RATIO) + ? alibi_slopes[wg_start_head_idx + qhead_idx] + : 0.f; + } + } + + const scalar_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; + const _B16x8* v_ptrh8 = reinterpret_cast(v_ptr); + // iterate over each v block + #pragma unroll + for (int b = 0; b < VBLOCKS; b++) { + // int32 physical_block_number leads to overflow when multiplied with + // kv_block_stride + const int64_t vphysical_block_number = + static_cast(vphysical_blocks[b]); + const _B16x8* v_ptrh8b = + v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; + // iterate over each head elem (within head_size) + #pragma unroll + for (int h = 0; h < VHELOOP; h++) { + const int head_size_elem = h * WARP_SIZE + laneid; + const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; + // iterate over all velems within block + #pragma unroll + for (int d = 0; d < BLOCK_SIZE / 8; d++) { + Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; + } + } + } + + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[0].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[0].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[1].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[1].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[2].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[2].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[3].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[3].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[4].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[4].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[5].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[5].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[6].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[6].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[7].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[7].xy[1], dout[h]); + if constexpr (KHELOOP > 8) { + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[8].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[8].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[9].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[9].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[10].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[10].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[11].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[11].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[12].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[12].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[13].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[13].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[14].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[14].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[15].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[15].xy[1], dout[h]); + } // KHELOOP>8 + dout[h] *= scale; + } + // transpose dout so that 4 token ids are in each lane, and 4 heads are across + // 4 lanes + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + floatx4 tmp = {0}; + #pragma unroll + for (int i = 0; i < 4; i++) { + const float B = (lane4id == i) ? 1.0f : 0.0f; + // const float A = (global_token_idx < context_len) ? dout[h][i] : 0.0f; + tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(dout[h][i], B, tmp, 0, 0, 0); + // tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(A, B, tmp, 0, 0, 0); + } + dout[h] = tmp; + } + + const int lane4_token_idx = 4 * (global_token_idx >> 2); + const int alibi_offset = lane4_token_idx - context_len + 1; + if (alibi_slopes != nullptr) { + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + #pragma unroll + for (int i = 0; i < 4; i++) { + dout[h][i] += alibi_slope[h] * (alibi_offset + i); + } + } + } + + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + qk_max[h] = -FLT_MAX; + #pragma unroll + for (int i = 0; i < 4; i++) { + qk_max[h] = (lane4_token_idx + i < context_len) + ? fmaxf(qk_max[h], dout[h][i]) + : qk_max[h]; + } + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { + qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h], mask)); + } + } + + float exp_sum[QHLOOP]; + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + exp_sum[h] = 0.0f; + #pragma unroll + for (int i = 0; i < 4; i++) { + dout[h][i] = (lane4_token_idx + i < context_len) + ? __expf(dout[h][i] - qk_max[h]) + : 0.0f; + exp_sum[h] += dout[h][i]; + } + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { + exp_sum[h] += __shfl_xor(exp_sum[h], mask); + } + } + + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + const int head_idx = 4 * h + lane4id; + shared_qk_max[warpid][head_idx] = qk_max[h]; + shared_exp_sum[warpid][head_idx] = exp_sum[h]; + } + } // warp within context + + __syncthreads(); + + const int num_heads = gridDim.z * GQA_RATIO; + float* max_logits_ptr = + max_logits + seq_idx * num_heads * max_num_partitions + partition_idx; + float* exp_sums_ptr = + exp_sums + seq_idx * num_heads * max_num_partitions + partition_idx; + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + float global_qk_max = -FLT_MAX; + float warp_qk_max[NWARPS]; + const int head_idx = 4 * h + lane4id; + #pragma unroll + for (int w = 0; w < NWARPS; w++) { + warp_qk_max[w] = shared_qk_max[w][head_idx]; + global_qk_max = fmaxf(global_qk_max, warp_qk_max[w]); + } + float global_exp_sum = 0.0f; + #pragma unroll + for (int w = 0; w < NWARPS; w++) { + global_exp_sum += + shared_exp_sum[w][head_idx] * __expf(warp_qk_max[w] - global_qk_max); + } + if (head_idx < GQA_RATIO) { + max_logits_ptr[(wg_start_head_idx + head_idx) * max_num_partitions] = + global_qk_max; + exp_sums_ptr[(wg_start_head_idx + head_idx) * max_num_partitions] = + global_exp_sum; + } + const float global_inv_sum_scale = __fdividef(1.f, global_exp_sum + 1e-6f) * + __expf(qk_max[h] - global_qk_max); + dout[h] *= global_inv_sum_scale; + } + // logits[h] -> every 4 lanes hold 4 heads, each lane holds 4 tokens, there + // are 4x16 tokens across warp + _B16x4 logits[QHLOOP]; + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + logits[h] = from_floatx4(dout[h]); + } + + __shared__ _B16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1]; + + if (warp_start_token_idx >= context_len) { // warp out of context + #pragma unroll + for (int qh = 0; qh < QHLOOP; qh++) { + #pragma unroll + for (int vh = 0; vh < VHELOOP; vh++) { + vout_shared[qh][vh][laneid][warpid] = {0}; + } + } + } else { // warp in context + // iterate across heads + #pragma unroll + for (int qh = 0; qh < QHLOOP; qh++) { + // iterate over each v head elem (within head_size) + #pragma unroll + for (int vh = 0; vh < VHELOOP; vh++) { + floatx4 acc = {0}; + // iterate over tokens + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][0].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][0].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][1].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][1].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][2].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][2].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][3].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][3].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][4].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][4].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][5].xy[0], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][5].xy[1], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][6].xy[0], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][6].xy[1], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][7].xy[0], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][7].xy[1], acc); + vout_shared[qh][vh][laneid][warpid] = from_floatx4(acc); + } + } + } // warp in context + + __syncthreads(); + + if (warpid == 0) { + _B16x4 vout[QHLOOP][VHELOOP]; + // iterate across heads + scalar_t* out_ptr; + int out_num_partitions; + if (context_len > partition_size) { + out_num_partitions = max_num_partitions; + out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + partition_idx * HEAD_SIZE; + } else { + out_num_partitions = 1; + out_ptr = final_out + seq_idx * num_heads * HEAD_SIZE; + } + #pragma unroll + for (int qh = 0; qh < QHLOOP; qh++) { + // iterate over each v head elem (within head_size) + #pragma unroll + for (int vh = 0; vh < VHELOOP; vh++) { + vout[qh][vh] = {0}; + #pragma unroll + for (int w = 0; w < NWARPS; w++) { + vout[qh][vh] = + addx4(vout[qh][vh], vout_shared[qh][vh][laneid][w]); + } + const int head_size_elem = vh * WARP_SIZE + laneid; + bit16_t* out_ptr_b16 = reinterpret_cast(out_ptr); + #pragma unroll + for (int i = 0; i < 4; i++) { + const int head_idx = 4 * qh + i; + if (head_idx < GQA_RATIO) { + out_ptr_b16[(wg_start_head_idx + head_idx) * out_num_partitions * + HEAD_SIZE + + head_size_elem] = vout[qh][vh][i]; + } + } + } + } + } +} + +// Grid: (num_heads, num_seqs). +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_partitions) { + const int num_heads = gridDim.x; + const int head_idx = blockIdx.x; + const int seq_idx = blockIdx.y; + const int context_len = context_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + if (num_partitions == 1) { + // if num_partitions==1, main kernel will write to out directly, no work in + // reduction kernel + return; + } + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int warpid = threadIdx.x / WARP_SIZE; + const int laneid = threadIdx.x % WARP_SIZE; + + __shared__ float shared_global_exp_sum; + __shared__ float shared_exp_sums[2 * WARP_SIZE]; + + if (warpid == 0) { + const float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + + // valid partition is the last valid partition in case threadid > num + // partitions + const int valid_partition = + (threadIdx.x < num_partitions) ? threadIdx.x : num_partitions - 1; + const int valid_partition2 = (WARP_SIZE + threadIdx.x < num_partitions) + ? WARP_SIZE + threadIdx.x + : num_partitions - 1; + float reg_max_logit = max_logits_ptr[valid_partition]; + float reg_max_logit2 = max_logits_ptr[valid_partition2]; + float max_logit = fmaxf(reg_max_logit, reg_max_logit2); + + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask)); + } + + const float* exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + + float global_exp_sum = 0.0f; + float rescaled_exp_sum = exp_sums_ptr[valid_partition]; + float rescaled_exp_sum2 = exp_sums_ptr[valid_partition2]; + rescaled_exp_sum *= + (threadIdx.x < num_partitions) ? expf(reg_max_logit - max_logit) : 0.0f; + rescaled_exp_sum2 *= (threadIdx.x + WARP_SIZE < num_partitions) + ? expf(reg_max_logit2 - max_logit) + : 0.0f; + global_exp_sum += rescaled_exp_sum + rescaled_exp_sum2; + shared_exp_sums[threadIdx.x] = rescaled_exp_sum; + shared_exp_sums[threadIdx.x + WARP_SIZE] = rescaled_exp_sum2; + + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + global_exp_sum += __shfl_xor(global_exp_sum, mask); + } + if (threadIdx.x == 0) { + shared_global_exp_sum = global_exp_sum; + } + } // warpid == 0 + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + threadIdx.x; + constexpr int MAX_NPAR = 64; + scalar_t tmps[MAX_NPAR]; + const float dzero = 0.0f; + #pragma unroll + for (int j = 0; j < MAX_NPAR; j++) { + tmps[j] = from_float(dzero); + } + const int last_partition_offset = (num_partitions - 1) * HEAD_SIZE; + const int num_partition_offset = (num_partitions)*HEAD_SIZE; + int idx = 0; + + constexpr int JCHUNK = 16; + + #pragma unroll + for (int j = 0; j < JCHUNK * HEAD_SIZE; j += HEAD_SIZE) { + // lastj is last valid partition + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + __syncthreads(); + + if (num_partitions > JCHUNK) { + #pragma unroll + for (int j = JCHUNK * HEAD_SIZE; j < 2 * JCHUNK * HEAD_SIZE; + j += HEAD_SIZE) { + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + + if (num_partitions > 2 * JCHUNK) { + #pragma unroll + for (int j = 2 * JCHUNK * HEAD_SIZE; j < MAX_NPAR * HEAD_SIZE; + j += HEAD_SIZE) { + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + } + } // num_partitions > JCHUNK + + // Aggregate tmp_out to out. + float acc = 0.0f; + #pragma unroll + for (int j = 0; j < JCHUNK; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j]; + } + if (num_partitions > JCHUNK) { + #pragma unroll + for (int j = JCHUNK; j < 2 * JCHUNK; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j]; + } + if (num_partitions > 2 * JCHUNK) { + #pragma unroll + for (int j = 2 * JCHUNK; j < MAX_NPAR; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j]; + } + } + } + + if (num_partitions > MAX_NPAR) { + idx = 0; + #pragma unroll + for (int j = MAX_NPAR * HEAD_SIZE; j < 2 * MAX_NPAR * HEAD_SIZE; + j += HEAD_SIZE) { + // lastj is last valid partition + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + + #pragma unroll + for (int j = 0; j < MAX_NPAR; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j + MAX_NPAR]; + } + } + + const float inv_global_exp_sum = + __fdividef(1.0f, shared_global_exp_sum + 1e-6f); + acc *= inv_global_exp_sum; + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + out_ptr[threadIdx.x] = from_float(acc); +} + +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support + +template +__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] + #if 0 + scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size] + #endif + int max_ctx_blocks) { + UNREACHABLE_CODE +} + +// Grid: (num_heads, num_seqs). +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_partitions){UNREACHABLE_CODE} + +#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support + +#define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \ + paged_attention_ll4mi_QKV_kernel \ + <<>>( \ + query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ + block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ + exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks); + +template +void paged_attention_custom_launcher( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, const int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& context_lens, + int max_context_len, +#if 0 + torch::Tensor& qk_out, + torch::Tensor& softmax_out, +#endif + const c10::optional& alibi_slopes) { + + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* context_lens_ptr = context_lens.data_ptr(); +#if 0 + T* qk_out_ptr = reinterpret_cast(qk_out.data_ptr()); + T* softmax_out_ptr = reinterpret_cast(softmax_out.data_ptr()); +#endif + + const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); + const int max_num_partitions = + DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + const int gqa_ratio = num_heads / num_kv_heads; + assert(num_heads % num_kv_heads == 0); + assert(head_size == HEAD_SIZE); + assert(max_num_partitions <= 128); + + constexpr int NTHR = PARTITION_SIZE; + dim3 grid(num_seqs, max_num_partitions, num_kv_heads); + dim3 block(NTHR); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (gqa_ratio) { + case 1: + LAUNCH_CUSTOM_ATTENTION(1); + break; + case 2: + LAUNCH_CUSTOM_ATTENTION(2); + break; + case 3: + LAUNCH_CUSTOM_ATTENTION(3); + break; + case 4: + LAUNCH_CUSTOM_ATTENTION(4); + break; + case 5: + LAUNCH_CUSTOM_ATTENTION(5); + break; + case 6: + LAUNCH_CUSTOM_ATTENTION(6); + break; + case 7: + LAUNCH_CUSTOM_ATTENTION(7); + break; + case 8: + LAUNCH_CUSTOM_ATTENTION(8); + break; + case 9: + LAUNCH_CUSTOM_ATTENTION(9); + break; + case 10: + LAUNCH_CUSTOM_ATTENTION(10); + break; + case 11: + LAUNCH_CUSTOM_ATTENTION(11); + break; + case 12: + LAUNCH_CUSTOM_ATTENTION(12); + break; + case 13: + LAUNCH_CUSTOM_ATTENTION(13); + break; + case 14: + LAUNCH_CUSTOM_ATTENTION(14); + break; + case 15: + LAUNCH_CUSTOM_ATTENTION(15); + break; + case 16: + LAUNCH_CUSTOM_ATTENTION(16); + break; + default: + TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio); + break; + } + // dim3 grid2(num_heads,num_seqs,head_size/HEAD_ELEMS_PER_WG); + // dim3 block2(1024); + // LAUNCH_CUSTOM_ATTENTION2; + + // reduction kernel is only required if max_context_len > partition size, + // otherwise main kernel writes directly to final output + // note there are cases with graphing where max_context_len is the max + // supported by graphing, not the actual max among all the sequences: in that + // case reduction kernel will still run but return immediately + if (max_context_len > PARTITION_SIZE) { + dim3 reduce_grid(num_heads, num_seqs); + dim3 reduce_block(head_size); + paged_attention_ll4mi_reduce_kernel + <<>>( + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, + context_lens_ptr, max_num_partitions); + } +} + +#define CALL_CUSTOM_LAUNCHER(T, BLK_SIZE, HEAD_SIZE) \ + paged_attention_custom_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, context_lens, max_context_len, \ + alibi_slopes); + +#define CALL_CUSTOM_LAUNCHER_BLK(T, HEAD_SIZE) \ + switch (block_size) { \ + case 16: \ + CALL_CUSTOM_LAUNCHER(T, 16, HEAD_SIZE); \ + break; \ + case 32: \ + CALL_CUSTOM_LAUNCHER(T, 32, HEAD_SIZE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + +#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T) \ + switch (head_size) { \ + case 64: \ + CALL_CUSTOM_LAUNCHER_BLK(T, 64); \ + break; \ + case 128: \ + CALL_CUSTOM_LAUNCHER_BLK(T, 128); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported head size: ", head_size); \ + break; \ + } + +void paged_attention( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& + tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& context_lens, // [num_seqs] + int64_t block_size, int64_t max_context_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype) { + assert(kv_cache_dtype == "auto"); + const int head_size = query.size(2); + if (query.dtype() == at::ScalarType::Half) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } +} + +#undef WARP_SIZE +#undef MAX +#undef MIN +#undef DIVIDE_ROUND_UP diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h new file mode 100644 index 0000000000000..4a07a3f1775bd --- /dev/null +++ b/csrc/rocm/ops.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, + torch::Tensor& max_logits, torch::Tensor& tmp_out, + torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int64_t num_kv_heads, + double scale, torch::Tensor& block_tables, + torch::Tensor& context_lens, int64_t block_size, + int64_t max_context_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype); diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp new file mode 100644 index 0000000000000..082e314587908 --- /dev/null +++ b/csrc/rocm/torch_bindings.cpp @@ -0,0 +1,33 @@ +#include "core/registration.h" +#include "rocm/ops.h" + +// Note on op signatures: +// The X_meta signatures are for the meta functions corresponding to op X. +// They must be kept in sync with the signature for X. Generally, only +// functions that return Tensors require a meta function. +// +// See the following links for detailed docs on op registration and function +// schemas. +// https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9 +// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { + // vLLM custom ops for rocm + + // Custom attention op + // Compute the attention between an input query and the cached + // keys/values using PagedAttention. + rocm_ops.def( + "paged_attention(Tensor! out, Tensor exp_sums," + " Tensor max_logits, Tensor tmp_out," + " Tensor query, Tensor key_cache," + " Tensor value_cache, int num_kv_heads," + " float scale, Tensor block_tables," + " Tensor context_lens, int block_size," + " int max_context_len," + " Tensor? alibi_slopes," + " str kv_cache_dtype) -> ()"); + rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/setup.py b/setup.py index 10770b8c9aa4a..8930ea7239dc9 100644 --- a/setup.py +++ b/setup.py @@ -462,6 +462,9 @@ def _read_requirements(filename: str) -> List[str]: if _is_cuda() or _is_hip(): ext_modules.append(CMakeExtension(name="vllm._moe_C")) +if _is_hip(): + ext_modules.append(CMakeExtension(name="vllm._rocm_C")) + if _build_custom_ops(): ext_modules.append(CMakeExtension(name="vllm._C")) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 7995f11f19e98..46831b506aff3 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -3,8 +3,6 @@ import pytest import torch -from xformers import ops as xops -from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from tests.kernels.utils import opcheck from vllm import _custom_ops as ops @@ -12,6 +10,10 @@ from .allclose_default import get_default_atol, get_default_rtol +if not is_hip(): + from xformers import ops as xops + from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask + FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. # - 512 as a buffer @@ -328,6 +330,165 @@ def ref_multi_query_kv_attention( return torch.cat(ref_outputs, dim=0) +@pytest.mark.parametrize("version", ["rocm"]) +@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", [64, 128]) # only test 64 128 +@pytest.mark.parametrize("use_alibi", USE_ALIBI) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("kv_cache_dtype", ["auto"]) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.skipif(not is_hip(), reason="only for rocm") +def test_paged_attention_rocm( + kv_cache_factory, + version: str, + num_seqs: int, + num_heads: Tuple[int, int], + head_size: int, + use_alibi: bool, + block_size: int, + dtype: torch.dtype, + kv_cache_dtype: str, + seed: int, + device: str, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype) + query.uniform_(-scale, scale) + + assert num_query_heads % num_kv_heads == 0 + num_queries_per_kv = num_query_heads // num_kv_heads + alibi_slopes = None + if use_alibi: + alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) + + context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] + context_lens[-1] = MAX_SEQ_LEN + #context_lens = [8192 for _ in range(num_seqs)] + max_context_len = max(context_lens) + context_lens = torch.tensor(context_lens, dtype=torch.int) + #print('>>> ctx lens', context_lens) + + # Create the block tables. + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + block_tables = [] + for _ in range(num_seqs): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) + for _ in range(max_num_blocks_per_seq) + ] + block_tables.append(block_table) + block_tables = torch.tensor(block_tables, dtype=torch.int) + + # Create the KV caches. + key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, + num_kv_heads, head_size, + kv_cache_dtype, dtype, seed, + device) + key_cache, value_cache = key_caches[0], value_caches[0] + + # TODO(charlifu) enable fp8 kv cache + # Using default kv_scale + # kv_scale = 1.0 + + # Call the paged attention kernel. + output = torch.empty_like(query) + PARTITION_SIZE_ROCM = 256 + num_partitions = ((max_context_len + PARTITION_SIZE_ROCM - 1) // + PARTITION_SIZE_ROCM) + assert PARTITION_SIZE_ROCM % block_size == 0 + num_seqs, num_heads, head_size = output.shape + tmp_output = torch.empty( + size=(num_seqs, num_heads, num_partitions, head_size), + dtype=output.dtype, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, num_partitions), + dtype=torch.float32, + ) + max_logits = torch.empty_like(exp_sums) + if version == "rocm": + ops.paged_attention_rocm( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + kv_cache_dtype, + ) + else: + raise AssertionError(f"Unknown version: {version}") + + # Run the reference implementation. + if kv_cache_dtype == "fp8": + # Convert cache data back to dtype. + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, + block_size, x) + dequantized_key_cache = torch.empty(size=key_cache_shape, + dtype=dtype, + device=device) + ops.convert_fp8(key_cache, dequantized_key_cache) + key_cache = dequantized_key_cache + + value_cache_shape = value_cache.shape + dequantized_value_cache = torch.empty(size=value_cache_shape, + dtype=dtype, + device=device) + ops.convert_fp8(value_cache, dequantized_value_cache) + value_cache = dequantized_value_cache + + ref_output = torch.empty_like(query) + ref_single_query_cached_kv_attention( + ref_output, + query, + num_queries_per_kv, + key_cache, + value_cache, + block_tables, + context_lens, + scale, + alibi_slopes, + ) + + # NOTE(woosuk): Due to the kernel-level differences in the two + # implementations, there is a small numerical difference in the two + # outputs. Thus, we use a relaxed tolerance for the test. + atol = get_default_atol(output) if is_hip() else 1e-3 + rtol = get_default_rtol(output) if is_hip() else 1e-5 + + # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, + # so we use a relaxed tolerance for the test. + atol, rtol = 1e-4, 1e-5 + if dtype == torch.bfloat16: + atol, rtol = 2e-4, 1e-5 + if use_alibi: + if dtype == torch.half: + atol, rtol = 5e-4, 1e-5 + if dtype == torch.bfloat16: + atol, rtol = 1e-3, 1e-5 + if kv_cache_dtype == "fp8": + atol, rtol = 1e-2, 1e-5 + assert torch.allclose(output, ref_output, atol=atol, rtol=rtol) + + # TODO(woosuk): Add tests for USE_ALIBI=True. @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -335,6 +496,7 @@ def ref_multi_query_kv_attention( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.skipif(is_hip(), reason="skip for rocm") @torch.inference_mode() def test_multi_query_kv_attention( num_seqs: int, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index efa02d36c4acd..ed08878f14875 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -17,6 +17,9 @@ except ImportError as e: logger.warning("Failed to import from vllm._C with %r", e) +if current_platform.is_rocm(): + import vllm._rocm_C # noqa: F401 + with contextlib.suppress(ImportError): import vllm._moe_C # noqa: F401 @@ -127,6 +130,30 @@ def paged_attention_v2( blocksparse_block_size, blocksparse_head_sliding_step) +def paged_attention_rocm( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, +) -> None: + torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, + key_cache, value_cache, num_kv_heads, + scale, block_tables, seq_lens, + block_size, max_seq_len, alibi_slopes, + kv_cache_dtype) + + # pos encoding ops def rotary_embedding( positions: torch.Tensor, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index b0f4d0530b7f0..f1404b8b6bfe7 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -5,6 +5,7 @@ import torch import vllm.envs as envs +from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.attention.backends.utils import (CommonAttentionState, @@ -15,6 +16,9 @@ logger = init_logger(__name__) +_PARTITION_SIZE = 256 +ON_NAVI = "gfx1" in torch.cuda.get_device_properties("cuda").gcnArchName + class ROCmFlashAttentionBackend(AttentionBackend): @@ -480,20 +484,61 @@ def forward( if decode_meta := attn_metadata.decode_metadata: # Decoding run. - output[num_prefill_tokens:] = PagedAttention.forward_decode( - decode_query, - key_cache, - value_cache, - decode_meta.block_tables, - decode_meta.seq_lens_tensor, - decode_meta.max_decode_seq_len, - self.kv_cache_dtype, - self.num_kv_heads, - self.scale, - self.alibi_slopes, - k_scale, - v_scale, - ) + # Whether to use rocm custom paged attention or not + num_seqs, num_heads, head_size = decode_query.shape + block_size = value_cache.shape[3] + gqa_ratio = num_heads // self.num_kv_heads + use_custom = use_rocm_custom_paged_attention( + decode_query.dtype, head_size, block_size, self.kv_cache_dtype, + gqa_ratio, decode_meta.max_decode_seq_len) + if use_custom: + max_seq_len = decode_meta.max_decode_seq_len + max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // + _PARTITION_SIZE) + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + ops.paged_attention_rocm( + output[num_prefill_tokens:], + exp_sums, + max_logits, + tmp_output, + decode_query, + key_cache, + value_cache, + self.num_kv_heads, + self.scale, + decode_meta.block_tables, + decode_meta.seq_lens_tensor, + block_size, + max_seq_len, + self.alibi_slopes, + self.kv_cache_dtype, + ) + else: + output[num_prefill_tokens:] = PagedAttention.forward_decode( + decode_query, + key_cache, + value_cache, + decode_meta.block_tables, + decode_meta.seq_lens_tensor, + decode_meta.max_decode_seq_len, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + k_scale, + v_scale, + ) # Reshape the output tensor. return output.view(num_tokens, hidden_size) @@ -532,3 +577,14 @@ def _sdpa_attention( start = end return output + + +def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, + block_size: int, kv_cache_dtype: str, + gqa_ratio: int, max_seq_len: int) -> bool: + # rocm custom page attention not support on navi (gfx1*) + return (not ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16) + and (head_size == 64 or head_size == 128) + and (block_size == 16 or block_size == 32) + and kv_cache_dtype == "auto" + and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768) From 8a0cf1ddc323a571c9f46a85da067d44af5d2453 Mon Sep 17 00:00:00 2001 From: ywfang <47963924+SUDA-HLT-ywfang@users.noreply.github.com> Date: Sat, 14 Sep 2024 22:50:26 +0800 Subject: [PATCH 42/98] [Model] support minicpm3 (#8297) Co-authored-by: DarkLight1337 --- .buildkite/run-cpu-test.sh | 2 +- docs/source/models/supported_models.rst | 4 + requirements-test.txt | 1 + .../decoder_only/language/test_big_models.py | 15 +- vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/minicpm.py | 79 ++++--- vllm/model_executor/models/minicpm3.py | 216 ++++++++++++++++++ 7 files changed, 281 insertions(+), 37 deletions(-) create mode 100644 vllm/model_executor/models/minicpm3.py diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index f4ead8d277736..73ce82c5857ab 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -22,7 +22,7 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py" # Run basic model test docker exec cpu-test bash -c " - pip install pytest matplotlib einops transformers_stream_generator + pip install pytest matplotlib einops transformers_stream_generator datamodel_code_generator pytest -v -s tests/models/decoder_only/language \ --ignore=tests/models/test_fp8.py \ --ignore=tests/models/decoder_only/language/test_jamba.py \ diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 6c7f7f7d5d992..3dcc242803752 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -107,6 +107,10 @@ Decoder-only Language Models - MiniCPM - :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc. - + * - :code:`MiniCPM3ForCausalLM` + - MiniCPM3 + - :code:`openbmb/MiniCPM3-4B`, etc. + - * - :code:`MistralForCausalLM` - Mistral, Mistral-Instruct - :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc. diff --git a/requirements-test.txt b/requirements-test.txt index ca3bfa7aff629..16a883b81ce50 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -21,6 +21,7 @@ compressed-tensors==0.4.0 # required for compressed-tensors timm # required for internvl test transformers_stream_generator # required for qwen-vl test matplotlib # required for qwen-vl test +datamodel_code_generator # required for minicpm3 test # TODO: Add this after fully implementing llava(mantis) # git+https://github.com/TIGER-AI-Lab/Mantis.git # required for llava(mantis) test diff --git a/tests/models/decoder_only/language/test_big_models.py b/tests/models/decoder_only/language/test_big_models.py index c5783fe19dd3f..fcc158639748d 100644 --- a/tests/models/decoder_only/language/test_big_models.py +++ b/tests/models/decoder_only/language/test_big_models.py @@ -5,7 +5,8 @@ Run `pytest tests/models/test_big_models.py`. """ import pytest -import torch + +from vllm.platforms import current_platform from ...utils import check_outputs_equal @@ -19,10 +20,12 @@ # "Qwen/Qwen1.5-0.5B" # Broken, ] +if not current_platform.is_cpu(): + # MiniCPM requires fused_moe which is not supported by CPU + MODELS.append("openbmb/MiniCPM3-4B") + #TODO: remove this after CPU float16 support ready -target_dtype = "float" -if torch.cuda.is_available(): - target_dtype = "half" +target_dtype = "float" if current_platform.is_cpu() else "half" @pytest.mark.parametrize("model", MODELS) @@ -39,7 +42,7 @@ def test_models( with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - with vllm_runner(model, dtype=dtype) as vllm_model: + with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) check_outputs_equal( @@ -57,7 +60,7 @@ def test_model_print( model: str, dtype: str, ) -> None: - with vllm_runner(model, dtype=dtype) as vllm_model: + with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model: # This test is for verifying whether the model's extra_repr # can be printed correctly. print(vllm_model.model.llm_engine.model_executor.driver_worker. diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 250f75b639a5b..41c8e754377c7 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -43,6 +43,7 @@ "MptForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"), "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), + "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"), "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"), "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"), diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index a135118bc748e..963ad7553fe1d 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -270,38 +270,47 @@ def __init__( ) -> None: super().__init__() self.config = config + self.cache_config = cache_config + self.quant_config = quant_config self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + self.rope_theta = getattr(config, "rope_theta", 10000) + self.rope_scaling = getattr(config, "rope_scaling", None) + self.max_position_embeddings = getattr(config, + "max_position_embeddings", 8192) + self._init_attn_block() + self._init_ffn_block() + + def _init_attn_block(self): + self.input_layernorm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) self.self_attn = MiniCPMAttention( hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - num_kv_heads=config.num_key_value_heads, - rope_theta=rope_theta, - rope_scaling=rope_scaling, - max_position_embeddings=max_position_embeddings, - cache_config=cache_config, - quant_config=quant_config, + num_heads=self.config.num_attention_heads, + num_kv_heads=self.config.num_key_value_heads, + rope_theta=self.rope_theta, + rope_scaling=self.rope_scaling, + max_position_embeddings=self.max_position_embeddings, + cache_config=self.cache_config, + quant_config=self.quant_config, ) + + def _init_ffn_block(self): + self.post_attention_layernorm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) self.num_experts = getattr(self.config, "num_experts", 0) if self.num_experts == 0: self.mlp = MiniCPMMLP( hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, + intermediate_size=self.config.intermediate_size, + hidden_act=self.config.hidden_act, + quant_config=self.quant_config, ) else: - self.mlp = MiniCPMMoE(num_experts=config.num_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.mlp = MiniCPMMoE( + num_experts=self.config.num_experts, + top_k=self.config.num_experts_per_tok, + hidden_size=self.config.hidden_size, + intermediate_size=self.config.intermediate_size) def forward( self, @@ -344,6 +353,8 @@ def __init__( ) -> None: super().__init__() self.config = config + self.cache_config = cache_config + self.quant_config = quant_config self.padding_idx = config.pad_token_id lora_vocab = (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0 @@ -354,11 +365,15 @@ def __init__( config.hidden_size, org_num_embeddings=config.vocab_size, ) + self._init_layers() + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def _init_layers(self): self.layers = nn.ModuleList([ - MiniCPMDecoderLayer(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) + MiniCPMDecoderLayer(self.config, self.cache_config, + self.quant_config) + for _ in range(self.config.num_hidden_layers) ]) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: embedding = self.embed_tokens(input_ids) @@ -431,13 +446,11 @@ def __init__( self.config = config self.lora_config = lora_config + self.cache_config = cache_config + self.quant_config = quant_config self.num_experts = getattr(self.config, "num_experts", 0) - self.quant_config = quant_config - self.model = MiniCPMModel(config, - cache_config, - quant_config, - lora_config=lora_config) + self._init_model() unpadded_vocab_size = config.vocab_size if lora_config: unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -458,6 +471,12 @@ def __init__( config.vocab_size) self.sampler = Sampler() + def _init_model(self): + self.model = MiniCPMModel(config=self.config, + cache_config=self.cache_config, + quant_config=self.quant_config, + lora_config=self.lora_config) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/minicpm3.py b/vllm/model_executor/models/minicpm3.py new file mode 100644 index 0000000000000..a048a3dba0415 --- /dev/null +++ b/vllm/model_executor/models/minicpm3.py @@ -0,0 +1,216 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2024 The ModelBest team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only MiniCPM3 model compatible with HuggingFace weights.""" +from typing import Any, Dict, Optional + +import torch +from torch import nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.models.minicpm import (MiniCPMDecoderLayer, + MiniCPMForCausalLM, + MiniCPMModel) + + +class MiniCPM3Attention(nn.Module): + + def __init__( + self, + config, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: int, + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.num_heads = num_heads + + tp_size = get_tensor_model_parallel_world_size() + assert self.num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.q_a_proj = ReplicatedLinear(self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config) + self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear(q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config) + + self.kv_a_proj_with_mqa = ReplicatedLinear(self.hidden_size, + self.kv_lora_rank + + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config) + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config) + # O projection. + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config) + + self.rotary_emb = get_rope( + self.qk_rope_head_dim, + rotary_dim=self.qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_local_heads, + self.qk_head_dim, + self.scaling, + num_kv_heads=self.num_local_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + q, _ = self.q_a_proj(hidden_states) + q = self.q_a_layernorm(q) + q, _ = self.q_b_proj(q) + q = q.view(-1, self.num_local_heads, self.qk_head_dim) + _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], + dim=-1) + latent_cache, _ = self.kv_a_proj_with_mqa(hidden_states) + kv_a, _ = latent_cache.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + latent_cache = latent_cache.unsqueeze(1) + kv_a = self.kv_a_layernorm(kv_a.contiguous()) + kv, _ = self.kv_b_proj(kv_a) + kv = kv.view(-1, self.num_local_heads, + self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k_pe = latent_cache[:, :, self.kv_lora_rank:] + + q_pe, k_pe = self.rotary_emb( + positions, + q_pe.reshape(-1, self.num_local_heads * self.qk_rope_head_dim), + k_pe.reshape(-1, self.qk_rope_head_dim)) + q_pe = q_pe.view(-1, self.num_local_heads, self.qk_rope_head_dim) + k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim) + + q[..., self.qk_nope_head_dim:] = q_pe + + k = torch.empty_like(q) + + k[..., :self.qk_nope_head_dim] = k_nope + k[..., self.qk_nope_head_dim:] = k_pe + + q = q.reshape(-1, self.num_local_heads * self.qk_head_dim) + k = k.view(-1, self.num_local_heads * self.qk_head_dim) + v = torch.nn.functional.pad( + v, [0, self.qk_head_dim - self.v_head_dim], + value=0).view(-1, self.num_local_heads * self.qk_head_dim) + + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = attn_output.view( + -1, self.num_local_heads, + self.qk_head_dim)[..., :self.v_head_dim].reshape( + -1, self.num_local_heads * self.v_head_dim) + + output, _ = self.o_proj(attn_output) + return output + + +class MiniCPM3DecoderLayer(MiniCPMDecoderLayer): + + def _init_attn_block(self): + self.input_layernorm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + self.self_attn = MiniCPM3Attention( + config=self.config, + hidden_size=self.hidden_size, + num_heads=self.config.num_attention_heads, + qk_nope_head_dim=self.config.qk_nope_head_dim, + qk_rope_head_dim=self.config.qk_rope_head_dim, + v_head_dim=self.config.v_head_dim, + q_lora_rank=self.config.q_lora_rank, + kv_lora_rank=self.config.kv_lora_rank, + rope_theta=self.rope_theta, + rope_scaling=self.rope_scaling, + max_position_embeddings=self.max_position_embeddings, + cache_config=self.cache_config, + quant_config=self.quant_config, + ) + + +class MiniCPM3Model(MiniCPMModel): + + def _init_layers(self): + self.layers = nn.ModuleList([ + MiniCPM3DecoderLayer(self.config, self.cache_config, + self.quant_config) + for _ in range(self.config.num_hidden_layers) + ]) + + +class MiniCPM3ForCausalLM(MiniCPMForCausalLM): + + def _init_model(self): + self.model = MiniCPM3Model(config=self.config, + cache_config=self.cache_config, + quant_config=self.quant_config, + lora_config=self.lora_config) From a36e070dad7d7098f69324b8275a533140221809 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 14 Sep 2024 09:46:04 -0700 Subject: [PATCH 43/98] [torch.compile] fix functionalization (#8480) --- tests/compile/test_full_graph.py | 13 ++- vllm/compilation/backends.py | 156 +++++++++++++++++++++++++++++++ vllm/worker/model_runner.py | 3 +- 3 files changed, 167 insertions(+), 5 deletions(-) create mode 100644 vllm/compilation/backends.py diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 0a6e781e18834..43905082b7caf 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -16,7 +16,12 @@ def test_full_graph(model): "The future of AI is", ] sampling_params = SamplingParams(temperature=0) - llm = LLM(model="meta-llama/Meta-Llama-3-8B", - enforce_eager=True, - load_format="dummy") - llm.generate(prompts, sampling_params) + llm = LLM(model=model, enforce_eager=True) + + outputs = llm.generate(prompts, sampling_params) + + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py new file mode 100644 index 0000000000000..de0b1d8a75757 --- /dev/null +++ b/vllm/compilation/backends.py @@ -0,0 +1,156 @@ +import operator + +import torch +import torch.fx as fx + + +def fix_functionalization(graph: fx.Graph): + """ + Rewrite the graph module to replace the pattern involving + torch._higher_order_ops.auto_functionalize.auto_functionalized + with a direct call to the inplace custom op. + + # TODO: check if PyTorch nightly has fixed this issue + """ + + # debug code, if we want to see the graph before the transformation + # with open("before.py", "w") as f: + # print(graph.python_code(root_module="self", verbose=True).src, file=f) + + nodes_to_remove = [] + + for node in graph.nodes: + # Identify the auto_functionalized node + if node.op == 'call_function' and node.target == torch._higher_order_ops.auto_functionalize.auto_functionalized: # noqa + if node.args[0] == torch.ops._C.rotary_embedding.default: + # manual replace for rotary_embedding + + # Now, collect the arguments + kwargs = node.kwargs + + query = kwargs['query'] + mm_node = query.args[0].args[0] + + # Create a new call to torch.ops._C.rotary_embedding.default + with graph.inserting_before(node): + # just insert the call to the custom op + # NOTE: don't run dead code elimination, + # otherwise this op will be removed + graph.call_function(torch.ops._C.rotary_embedding.default, + kwargs=kwargs) + + # Remove the auto_functionalized node + # Since the node may have outputs, we need to handle its users + # Replace uses of the outputs (getitem nodes) with mm_node + for user in list(node.users): + if user.op == 'call_function' and user.target == operator.getitem: # noqa + # Remove the getitem node + for getitem_user in list(user.users): + if (getitem_user.op == 'call_function' + and getitem_user.target + == torch.ops.aten.slice_scatter.default): + # Replace the uses of slice_scatter node + # with mm_node + getitem_user.replace_all_uses_with(mm_node) + nodes_to_remove.append(getitem_user) + nodes_to_remove.append(user) + nodes_to_remove.append(node) + + elif node.args[0] == torch.ops._C.fused_add_rms_norm.default: + # manual replace for fused_add_rms_norm + # this is the most effective optimization for llama + # failing to do this will result in many unnecessary copies + + kwargs = node.kwargs + + input = kwargs['input'] + residual = kwargs['residual'] + + # Create a new call to torch.ops._C.rotary_embedding.default + with graph.inserting_before(node): + # just insert the call to the custom op + # NOTE: don't run dead code elimination, + # otherwise this op will be removed + graph.call_function( + torch.ops._C.fused_add_rms_norm.default, kwargs=kwargs) + + for user in list(node.users): + if user.op == 'call_function' and user.target == operator.getitem: # noqa + # Remove the getitem node + if user.args[1] == 1: + replace_node = input + elif user.args[1] == 2: + replace_node = residual + user.replace_all_uses_with(replace_node) + nodes_to_remove.append(user) + nodes_to_remove.append(node) + + elif node.args[0] == torch.ops._C.rms_norm.default: + # manual replace for rms_norm + + kwargs = node.kwargs + + input = kwargs['input'] + out = kwargs['out'] + weight = kwargs['weight'] + epsilon = kwargs['epsilon'] + # Create a new call to torch.ops._C.rotary_embedding.default + # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa + with graph.inserting_before(node): + # just insert the call to the custom op + # NOTE: don't run dead code elimination, + # otherwise this op will be removed + graph.call_function( + torch.ops._C.rms_norm.default, + args=(out, input, weight, epsilon), + ) + + replace_node = out + + for user in list(node.users): + if user.op == 'call_function' and user.target == operator.getitem: # noqa + user.replace_all_uses_with(replace_node) + nodes_to_remove.append(user) + nodes_to_remove.append(node) + + elif node.args[0] == torch.ops._C.silu_and_mul.default: + # manual replace for silu_and_mul + + kwargs = node.kwargs + + input = kwargs['input'] + out = kwargs['out'] + + # Create a new call to torch.ops._C.rotary_embedding.default + # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa + with graph.inserting_before(node): + # just insert the call to the custom op + # NOTE: don't run dead code elimination, + # otherwise this op will be removed + graph.call_function( + torch.ops._C.silu_and_mul.default, + args=(out, input), + ) + replace_node = out + + for user in list(node.users): + if user.op == 'call_function' and user.target == operator.getitem: # noqa + user.replace_all_uses_with(replace_node) + nodes_to_remove.append(user) + nodes_to_remove.append(node) + + # Remove the nodes all at once + for node in nodes_to_remove: + graph.erase_node(node) + + # debug code, if we want to see the graph after the transformation + # with open("after.py", "w") as f: + # print(graph.python_code(root_module="self", verbose=True).src, file=f) + + +def vllm_backend(graph, example_inputs): + from torch._inductor import config + current_config = config.shallow_copy_dict() + from torch._inductor.compile_fx import compile_fx + current_config['post_grad_custom_post_pass'] = fix_functionalization + return compile_fx(graph, example_inputs, config_patches=current_config) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index bff789c429710..9df9ae783b9fa 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1064,8 +1064,9 @@ def load_model(self) -> None: "This may lead to less accurate results!") if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo(): + from vllm.compilation.backends import vllm_backend from vllm.plugins import get_torch_compile_backend - backend = get_torch_compile_backend() or "eager" + backend = get_torch_compile_backend() or vllm_backend self.model = torch.compile( self.model, fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, From 47790f3e328f1fbf250d8f858b6390496c1e61c0 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 14 Sep 2024 13:07:16 -0700 Subject: [PATCH 44/98] [torch.compile] add a flag to disable custom op (#8488) --- tests/compile/test_full_graph.py | 3 ++- vllm/envs.py | 5 +++++ vllm/model_executor/custom_op.py | 5 +++++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 43905082b7caf..5452ce6be8110 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -6,7 +6,8 @@ @pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) def test_full_graph(model): # make sure these models can be captured in full graph mode - os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1" + if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ: + os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1" from vllm import LLM, SamplingParams prompts = [ diff --git a/vllm/envs.py b/vllm/envs.py index b3678399fe207..2003ede95d2d8 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -202,6 +202,11 @@ def get_default_config_root(): (os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in ("true", "1")), + # Internal flag to control whether we use custom op, + # or use the native pytorch implementation + "VLLM_TEST_COMPILE_NO_CUSTOM_OPS": + lambda: int(os.environ.get("VLLM_TEST_COMPILE_NO_CUSTOM_OPS", "0")), + # Internal flag to enable Dynamo fullgraph capture "VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE": lambda: bool( diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 49247cd5de42a..9102b5e19ebec 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -1,5 +1,6 @@ import torch.nn as nn +import vllm.envs as envs from vllm.platforms import current_platform from vllm.utils import is_cpu, is_hip, is_xpu @@ -53,6 +54,10 @@ def forward_gaudi(self, *args, **kwargs): def dispatch_forward(self): # NOTE(woosuk): Here we assume that vLLM was built for only one # specific backend. Currently, we do not support dynamic dispatching. + + if envs.VLLM_TEST_COMPILE_NO_CUSTOM_OPS: + return self.forward_native + if is_hip(): return self.forward_hip elif is_cpu(): From 50e9ec41fc2dbd1b80e7ec488650c327bdf81798 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 14 Sep 2024 16:58:31 -0700 Subject: [PATCH 45/98] [TPU] Implement multi-step scheduling (#8489) --- vllm/config.py | 2 +- vllm/executor/ray_tpu_executor.py | 8 +- vllm/executor/tpu_executor.py | 16 +- vllm/worker/multi_step_tpu_worker.py | 105 +++++++++++++ vllm/worker/tpu_model_runner.py | 224 +++++++++++++++++++-------- 5 files changed, 279 insertions(+), 76 deletions(-) create mode 100644 vllm/worker/multi_step_tpu_worker.py diff --git a/vllm/config.py b/vllm/config.py index 9684cea813134..89cffc8b306b2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -379,7 +379,7 @@ def verify_async_output_proc(self, parallel_config, speculative_config, self.use_async_output_proc = False return - if self.enforce_eager: + if device_config.device_type == "cuda" and self.enforce_eager: logger.warning( "To see benefits of async output processing, enable CUDA " "graph. Since, enforce-eager is enabled, async output " diff --git a/vllm/executor/ray_tpu_executor.py b/vllm/executor/ray_tpu_executor.py index 8c8b5f741488b..732b69d6e5954 100644 --- a/vllm/executor/ray_tpu_executor.py +++ b/vllm/executor/ray_tpu_executor.py @@ -68,8 +68,12 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", ) assert self.speculative_config is None - worker_module_name = "vllm.worker.tpu_worker" - worker_class_name = "TPUWorker" + if self.scheduler_config.is_multi_step: + worker_module_name = "vllm.worker.multi_step_tpu_worker" + worker_class_name = "MultiStepTPUWorker" + else: + worker_module_name = "vllm.worker.tpu_worker" + worker_class_name = "TPUWorker" # GKE does not fetch environment information from metadata server # and instead sets these from within the Ray process. Therefore we diff --git a/vllm/executor/tpu_executor.py b/vllm/executor/tpu_executor.py index 0af8ba41e24d5..972649dedf33e 100644 --- a/vllm/executor/tpu_executor.py +++ b/vllm/executor/tpu_executor.py @@ -62,11 +62,17 @@ def _create_worker( rank: int = 0, distributed_init_method: Optional[str] = None, ): - from vllm.worker.tpu_worker import TPUWorker - - worker = TPUWorker(**self._get_worker_kwargs(local_rank, rank, - distributed_init_method)) - return worker + if self.scheduler_config.is_multi_step: + from vllm.worker.multi_step_tpu_worker import MultiStepTPUWorker + worker = MultiStepTPUWorker(**self._get_worker_kwargs( + local_rank, rank, distributed_init_method)) + return worker + else: + from vllm.worker.tpu_worker import TPUWorker + + worker = TPUWorker(**self._get_worker_kwargs( + local_rank, rank, distributed_init_method)) + return worker def initialize_cache( self, diff --git a/vllm/worker/multi_step_tpu_worker.py b/vllm/worker/multi_step_tpu_worker.py new file mode 100644 index 0000000000000..e654f7172b266 --- /dev/null +++ b/vllm/worker/multi_step_tpu_worker.py @@ -0,0 +1,105 @@ +import dataclasses +from typing import Dict, Optional, Tuple + +import torch + +from vllm.distributed import broadcast_tensor_dict +from vllm.sequence import ExecuteModelRequest +from vllm.worker.tpu_model_runner import ModelInputForTPU +from vllm.worker.tpu_worker import TPUWorker +from vllm.worker.worker_base import WorkerInput + + +class MultiStepTPUWorker(TPUWorker): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.cached_model_input: Optional[ModelInputForTPU] = None + + def _get_driver_input_and_broadcast( + self, execute_model_req: ExecuteModelRequest + ) -> Tuple[ModelInputForTPU, WorkerInput, Dict[str, torch.Tensor]]: + assert self.is_driver_worker + assert execute_model_req.virtual_engine == 0 + + is_first_multi_step = execute_model_req.is_first_multi_step + is_last_step = execute_model_req.is_last_step + if is_first_multi_step: + worker_input: WorkerInput = self.prepare_worker_input( + execute_model_req=execute_model_req) + worker_input = dataclasses.replace( + worker_input, + num_steps=execute_model_req.num_lookahead_slots + 1) + model_input: ModelInputForTPU = ( + self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list, + execute_model_req.virtual_engine, + execute_model_req.finished_requests_ids)) + + if execute_model_req.async_callback: + model_input = dataclasses.replace( + model_input, + async_callback=execute_model_req.async_callback) + else: + assert self.cached_model_input is not None + model_input = self.cached_model_input + worker_input = WorkerInput() + model_input = dataclasses.replace( + model_input, + is_first_multi_step=is_first_multi_step, + is_last_step=is_last_step) + + if self.do_metadata_broadcast: + if is_first_multi_step: + broadcast_data = worker_input.as_broadcastable_tensor_dict() + broadcast_data.update( + model_input.as_broadcastable_tensor_dict()) + broadcast_tensor_dict(broadcast_data, src=0) + else: + broadcast_data = { + "is_first_multi_step": is_first_multi_step, + "is_last_step": is_last_step, + } + broadcast_tensor_dict(broadcast_data, src=0) + + # Retuning empty dict here to keep this compatible with + # `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast` + return model_input, worker_input, {} + + def prepare_input( + self, + execute_model_req: Optional[ExecuteModelRequest] = None, + ) -> Optional[Tuple[ModelInputForTPU, WorkerInput, Dict[str, + torch.Tensor]]]: + if self.is_driver_worker: + if execute_model_req is None: + if self.do_metadata_broadcast: + broadcast_tensor_dict({}, src=0) + return None + + model_input, worker_input, _ = self._get_driver_input_and_broadcast( + execute_model_req) + if model_input.is_first_multi_step: + self.cached_model_input = model_input + return model_input, worker_input, {} + else: + broadcast_data = broadcast_tensor_dict(src=0) + if not broadcast_data: + return None + + if len(broadcast_data) == 2: + assert self.cached_model_input is not None + self.cached_model_input = dataclasses.replace( + self.cached_model_input, + is_first_multi_step=broadcast_data["is_first_multi_step"], + is_last_step=broadcast_data["is_last_step"]) + empty_worker_input = WorkerInput() + return self.cached_model_input, empty_worker_input, {} + + worker_input = WorkerInput.from_broadcasted_tensor_dict( + broadcast_data) + model_input = ( + self.model_runner. + make_model_input_from_broadcasted_tensor_dict(broadcast_data)) + self.cached_model_input = model_input + return model_input, worker_input, {} diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index db306bc743d3a..575769ca1aa4a 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -51,6 +51,8 @@ class ModelInputForTPU(ModelRunnerInputBase): num_samples: int best_of: List[int] seq_groups: List[List[int]] + is_first_multi_step: bool = True + is_last_step: bool = True virtual_engine: int = 0 async_callback: Optional[Callable] = None @@ -65,6 +67,8 @@ def as_broadcastable_tensor_dict( "num_samples": self.num_samples, "best_of": self.best_of, "seq_groups": self.seq_groups, + "is_first_multi_step": self.is_first_multi_step, + "is_last_step": self.is_last_step, "virtual_engine": self.virtual_engine, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) @@ -118,6 +122,7 @@ def __init__( self.block_size, False, ) + self.cached_step_outputs: List[torch.Tensor] = [] def load_model(self) -> None: self.device = self.device_config.device @@ -518,97 +523,159 @@ def execute_model( num_steps: int = 1, ) -> List[SamplerOutput]: assert intermediate_tensors is None - if num_steps > 1: - raise ValueError( - "TPUModelRunner does not support multi-step execution.") - - def _execute_model(*args): - """Move input args from CPU to device and execute the model.""" - - new_args = [] - for arg in args: - if isinstance(arg, torch.Tensor): - arg = arg.to(self.device) - elif isinstance(arg, AttentionMetadata): - arg.slot_mapping = arg.slot_mapping.to(self.device) - if getattr(arg, "block_tables", None) is not None: - arg.block_tables = arg.block_tables.to(self.device) - if getattr(arg, "context_lens", None) is not None: - arg.context_lens = arg.context_lens.to(self.device) - new_args.append(arg) - return self.model(*new_args, is_prompt=is_prompt) - - num_prefills = model_input.attn_metadata.num_prefills - is_prompt = num_prefills > 0 + if not model_input.is_first_multi_step: + if not model_input.is_last_step: + return [] + + use_async_out_proc = model_input.async_callback is not None + sampler_outputs = [] + num_outputs = len(self.cached_step_outputs) + for i in range(num_outputs): + next_token_ids = self.cached_step_outputs.pop(0) + next_token_ids = next_token_ids.cpu().tolist() + sampler_output = _make_decode_output(next_token_ids, + model_input.seq_groups) + sampler_outputs.append(sampler_output) + + if i < num_outputs - 1 and use_async_out_proc: + assert model_input.async_callback is not None + ctx = model_input.async_callback.keywords[ # type: ignore + "ctx"] + ctx.append_output( + outputs=[sampler_output], + seq_group_metadata_list=ctx.seq_group_metadata_list, + scheduler_outputs=ctx.scheduler_outputs, + is_async=False, + is_last_step=False) + model_input.async_callback() + if use_async_out_proc: + return [sampler_outputs[-1]] + else: + return sampler_outputs + + is_prompt = model_input.attn_metadata.num_prefills > 0 if is_prompt: + assert num_steps == 1 # NOTE(woosuk): Since the FlashAttention kernel does not support # ragged inputs, we split the prompts into different batches and # process them separately. This is a temporary hack that should be # optimized by using SplashAttention. - next_token_ids = [] orig_slot_mapping = model_input.attn_metadata.slot_mapping batch_size = model_input.input_lens.shape[0] start_idx = 0 + next_token_ids = [] for i in range(batch_size): # Get the actual prefill_len. prefill_len = model_input.input_lens[i:i + 1].item() prefill_len = _get_padded_prefill_len(prefill_len) end_idx = start_idx + prefill_len - model_input.attn_metadata.slot_mapping = orig_slot_mapping[ - None, start_idx:end_idx] - model_input.attn_metadata.num_prefills = 1 - output_token_ids = _execute_model( - model_input.token_ids[None, start_idx:end_idx], - model_input.position_ids[None, start_idx:end_idx], - model_input.attn_metadata, model_input.input_lens[i:i + 1], - model_input.t[i:i + 1], model_input.p[i:i + 1], - model_input.num_samples, kv_caches) - if i == 0 and model_input.async_callback is not None: - model_input.async_callback() - # Retrieve the outputs to CPU. - next_token_ids += output_token_ids.cpu().tolist() + token_ids = model_input.token_ids[None, start_idx:end_idx].to( + self.device) + position_ids = model_input.position_ids[None, + start_idx:end_idx].to( + self.device) + attn_metadata = model_input.attn_metadata + attn_metadata.num_prefills = 1 + attn_metadata.slot_mapping = orig_slot_mapping[ + None, start_idx:end_idx].to(self.device) + input_lens = model_input.input_lens[i:i + 1].to(self.device) + t = model_input.t[i:i + 1].to(self.device) + p = model_input.p[i:i + 1].to(self.device) + output_token_ids = self.model(token_ids, + position_ids, + attn_metadata, + input_lens, + t, + p, + model_input.num_samples, + kv_caches, + is_prompt=True) + next_token_ids.append(output_token_ids[0]) start_idx = end_idx - else: - # Execute the model. - output_token_ids = _execute_model( - model_input.token_ids, model_input.position_ids, - model_input.attn_metadata, model_input.input_lens, - model_input.t, model_input.p, model_input.num_samples, - kv_caches) + if model_input.async_callback is not None: model_input.async_callback() # Retrieve the outputs to CPU. - next_token_ids = output_token_ids.cpu().tolist() - - # NOTE(woosuk): Minimal code to construct the sampler outputs. - # The TPU backend does not reuse the sampler, since the TPU backend - # does not support the advanced sampling parameters such as logprobs. - zero_logprob = Logprob(0.0) - batch_idx = 0 - sampler_outputs = [] - for seq_group in model_input.seq_groups: - seq_ids = seq_group - seq_outputs = [] - if is_prompt: + next_token_ids = [ + output_token_ids.cpu().tolist() + for output_token_ids in next_token_ids + ] + + # NOTE(woosuk): Minimal code to construct the sampler outputs. + # The TPU backend does not reuse the sampler, since the TPU backend + # does not support advanced sampling parameters such as logprobs. + zero_logprob = Logprob(0.0) + sampler_outputs = [] + for i, seq_group in enumerate(model_input.seq_groups): + seq_ids = seq_group assert len(seq_ids) == 1 seq_id = seq_ids[0] - for i in range(model_input.best_of[batch_idx]): - next_token_id = next_token_ids[batch_idx][i] + seq_outputs = [] + for j in range(model_input.best_of[i]): + next_token_id = next_token_ids[i][j] seq_outputs.append( SequenceOutput(seq_id, next_token_id, {next_token_id: zero_logprob})) - batch_idx += 1 - else: - for seq_id in seq_ids: - next_token_id = next_token_ids[batch_idx] - seq_outputs.append( - SequenceOutput(seq_id, next_token_id, - {next_token_id: zero_logprob})) - batch_idx += 1 - sampler_outputs.append( - CompletionSequenceGroupOutput(seq_outputs, None)) - return [SamplerOutput(sampler_outputs)] + sampler_outputs.append( + CompletionSequenceGroupOutput(seq_outputs, None)) + return [SamplerOutput(sampler_outputs)] + else: + token_ids = model_input.token_ids.to(self.device) + position_ids = model_input.position_ids.to(self.device) + attn_metadata = model_input.attn_metadata + attn_metadata.slot_mapping = attn_metadata.slot_mapping.to( + self.device) + attn_metadata.block_tables = attn_metadata.block_tables.to( + self.device) + attn_metadata.context_lens = attn_metadata.context_lens.to( + self.device) + t = model_input.t.to(self.device) + p = model_input.p.to(self.device) + input_lens = model_input.input_lens.to(self.device) + for i in range(num_steps): + slot_mapping = attn_metadata.slot_mapping + output_token_ids = self.model(token_ids, + position_ids, + attn_metadata, + input_lens, + t, + p, + model_input.num_samples, + kv_caches, + is_prompt=False) + self.cached_step_outputs.append(output_token_ids) + + if i < num_steps - 1: + # Prepare the inputs for the next step. + token_ids = output_token_ids.unsqueeze(dim=1).int() + position_ids = position_ids + 1 + attn_metadata.context_lens = attn_metadata.context_lens + 1 + + block_tables = attn_metadata.block_tables + block_number = block_tables.gather( + 1, + position_ids.long() // self.block_size) + block_offset = position_ids % self.block_size + + is_padding = slot_mapping == _PAD_SLOT_ID + slot_mapping = block_number * self.block_size + block_offset + slot_mapping = slot_mapping.long() + slot_mapping = torch.where(is_padding, _PAD_SLOT_ID, + slot_mapping) + attn_metadata.slot_mapping = slot_mapping + + if model_input.async_callback is not None: + model_input.async_callback() + + if num_steps > 1: + return [] + # Retrieve the outputs to CPU. + next_token_ids = self.cached_step_outputs.pop(0) + next_token_ids = next_token_ids.cpu().tolist() + sampler_output = _make_decode_output(next_token_ids, + model_input.seq_groups) + return [sampler_output] class ModelWrapper(TorchCompileWrapperWithCustomDispatcher): @@ -756,3 +823,24 @@ def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor: cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index) logits = logits.masked_fill_(logits < cutoff_logit, -float("inf")) return logits + + +def _make_decode_output( + next_token_ids: List[int], + seq_groups: List[List[int]], +) -> SamplerOutput: + zero_logprob = Logprob(0.0) + sampler_outputs = [] + batch_idx = 0 + for seq_group in seq_groups: + seq_ids = seq_group + seq_outputs = [] + for seq_id in seq_ids: + next_token_id = next_token_ids[batch_idx] + seq_outputs.append( + SequenceOutput(seq_id, next_token_id, + {next_token_id: zero_logprob})) + batch_idx += 1 + sampler_outputs.append(CompletionSequenceGroupOutput( + seq_outputs, None)) + return SamplerOutput(sampler_outputs) From 3724d5f6b59d9859e5b47c047535bb8edc124eab Mon Sep 17 00:00:00 2001 From: Chris <34248815+chrisociepa@users.noreply.github.com> Date: Sun, 15 Sep 2024 06:20:05 +0200 Subject: [PATCH 46/98] [Bugfix][Model] Fix Python 3.8 compatibility in Pixtral model by updating type annotations (#8490) --- vllm/model_executor/models/pixtral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index b26fd558fa1ea..682b78bbed093 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -454,7 +454,7 @@ def forward( return x -def position_meshgrid(patch_embeds_list: list[torch.Tensor], ) -> torch.Tensor: +def position_meshgrid(patch_embeds_list: List[torch.Tensor], ) -> torch.Tensor: positions = torch.cat([ torch.stack( torch.meshgrid( From fc990f97958636ce25e4471acfd5651b096b0311 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 16 Sep 2024 06:51:44 +0800 Subject: [PATCH 47/98] [Bugfix][Kernel] Add `IQ1_M` quantization implementation to GGUF kernel (#8357) --- csrc/quantization/gguf/dequantize.cuh | 55 ++- csrc/quantization/gguf/ggml-common.h | 408 ++++++++++++------ csrc/quantization/gguf/gguf_kernel.cu | 5 + csrc/quantization/gguf/mmvq.cuh | 8 + csrc/quantization/gguf/vecdotq.cuh | 101 ++++- requirements-common.txt | 2 +- tests/kernels/test_gguf.py | 126 ++++++ .../layers/quantization/gguf.py | 5 +- 8 files changed, 548 insertions(+), 162 deletions(-) create mode 100644 tests/kernels/test_gguf.py diff --git a/csrc/quantization/gguf/dequantize.cuh b/csrc/quantization/gguf/dequantize.cuh index 2069fba759ea0..c012262e49015 100644 --- a/csrc/quantization/gguf/dequantize.cuh +++ b/csrc/quantization/gguf/dequantize.cuh @@ -353,18 +353,47 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_ template static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) { - const int i = blockIdx.x; + const int64_t i = blockIdx.x; const block_iq1_s * x = (const block_iq1_s *) vx; - const int tid = threadIdx.x; - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 8*il; + const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA; + const float d = __half2float(x[i].d) * (2*((x[i].qh[ib] >> 12) & 7) + 1); + uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32; + grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)]; + grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f; + grid32[0] &= 0x0f0f0f0f; + for (int j = 0; j < 8; ++j) { + y[j] = __float2half(d * (q[j] + delta)); + } +} + +template +static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int64_t i = blockIdx.x; + const block_iq1_m * x = (const block_iq1_m *) vx; + + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; - const int i8 = 4*ib+il; - uint8_t h = x[i].scales[i8/2] >> 4*(i8%2); - const int8_t * grid = (const int8_t *)(iq1s_grid + (x[i].qs[i8] | ((h & 8) << 5))); - const float d = __half2float(x[i].d) * (2*(h & 7) + 1); - for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j]); + const uint16_t * sc = (const uint16_t *)x[i].scales; + iq1m_scale_t scale; + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + const int64_t ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4); + const float d = __half2float(scale.f16) * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1); + const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA; + uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32; + grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)]; + grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f; + grid32[0] &= 0x0f0f0f0f; + for (int j = 0; j < 8; ++j) { + y[j] = __float2half(d * (q[j] + delta)); + } } template @@ -475,6 +504,12 @@ static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, c dequantize_block_iq1_s<<>>(vx, y); } +template +static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_iq1_m<<>>(vx, y); +} + template static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int nb = (k + QK_K - 1) / QK_K; @@ -525,6 +560,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(int64_t type) { return dequantize_row_iq2_s_cuda; case 23: return dequantize_row_iq4_xs_cuda; + case 29: + return dequantize_row_iq1_m_cuda; default: return nullptr; } diff --git a/csrc/quantization/gguf/ggml-common.h b/csrc/quantization/gguf/ggml-common.h index d7989d84bf68e..fba94fd1d157b 100644 --- a/csrc/quantization/gguf/ggml-common.h +++ b/csrc/quantization/gguf/ggml-common.h @@ -149,14 +149,30 @@ typedef struct { uint8_t scales[IQ3S_N_SCALE]; } block_iq3_s; +// 1.5625 bpw #define QR1_S 8 #define QI1_S (QK_K / (4*QR1_S)) typedef struct { half d; - uint8_t qs[QK_K/8]; - uint8_t scales[QK_K/16]; + uint8_t qs[QK_K/8]; + uint16_t qh[QK_K/32]; } block_iq1_s; +// 1.75 bpw +#define QR1_M 8 +#define QI1_M (QK_K / (4*QR1_M)) +typedef struct { + uint8_t qs[QK_K/8]; // grid index, low 8 bits + uint8_t qh[QK_K/16]; // grid index, high 3 bits + grid shift bit (for two groups of 8) + uint8_t scales[QK_K/32]; // 3-bit block scales (4-bit if QK_K == 64) +} block_iq1_m; + +// Used by IQ1_M quants +typedef union { + half f16; + uint16_t u16; +} iq1m_scale_t; + #define QK4_NL 32 #define QR4_NL 2 #define QI4_NL (QK4_NL / (4*QR4_NL)) @@ -733,135 +749,265 @@ static const __device__ uint32_t iq3xs_grid[512] = { 0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404, }; -static const __device__ uint64_t iq1s_grid[512] = { - 0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000, - 0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01, - 0xffffff00ff000000, 0xffffff000000ff00, 0xffffff00000000ff, 0xffffff0000000100, - 0xffffff0000010000, 0xffffff0001000000, 0xffffff01ffff00ff, 0xffffff01ff01ff00, - 0xffffff01ff010100, 0xffffff0100000001, 0xffffff0101ffff00, 0xffffff0101ff0101, - 0xffffff0101010100, 0xffff00ffff00ff01, 0xffff00ffff0000ff, 0xffff00ff00ff0100, - 0xffff00ff0100ff00, 0xffff00ff010001ff, 0xffff0000ff0101ff, 0xffff000000ffff00, - 0xffff000000000000, 0xffff00000001ff01, 0xffff000001000101, 0xffff0000010100ff, - 0xffff0001ffff0100, 0xffff00010000ff00, 0xffff000100010101, 0xffff000101000000, - 0xffff01ffffff0000, 0xffff01ffff01ffff, 0xffff01ffff010100, 0xffff01ff00000000, - 0xffff01ff01ffffff, 0xffff01ff01ff0001, 0xffff01ff0101ffff, 0xffff01ff01010001, - 0xffff0100ffffff01, 0xffff01000000ffff, 0xffff010000000100, 0xffff010001ff01ff, - 0xffff010001000000, 0xffff0101ff000000, 0xffff0101000101ff, 0xffff010101ffff01, - 0xffff01010101ff00, 0xff00ffffff000000, 0xff00ffff00ffff00, 0xff00ffff00000001, - 0xff00ffff000001ff, 0xff00ffff01010000, 0xff00ff00ffff0000, 0xff00ff00ff00ff00, - 0xff00ff00ff0000ff, 0xff00ff00ff000100, 0xff00ff00ff010001, 0xff00ff0000ff0001, - 0xff00ff000000ffff, 0xff00ff0000000000, 0xff00ff000001ff00, 0xff00ff0000010100, - 0xff00ff0001ff0000, 0xff00ff000100ff00, 0xff00ff0001000100, 0xff00ff01ff000000, - 0xff00ff0100ff0000, 0xff00ff01000001ff, 0xff00ff0101010001, 0xff0000ff00000000, - 0xff0000ff0001ff00, 0xff0000ff00010100, 0xff000000ffff0101, 0xff000000ff000000, - 0xff000000ff01ff00, 0xff00000000ff0000, 0xff0000000000ff00, 0xff000000000000ff, - 0xff00000000000000, 0xff00000000000001, 0xff00000000000100, 0xff0000000001ffff, - 0xff00000000010000, 0xff00000001000000, 0xff00000001010100, 0xff000001ff00ff01, - 0xff000001ff0100ff, 0xff00000100000000, 0xff0000010001ff00, 0xff00000101ff0100, - 0xff0000010100ff00, 0xff0001ff00ff00ff, 0xff0001ff00000101, 0xff0001ff000100ff, - 0xff0001ff01000000, 0xff000100ff0001ff, 0xff0001000000ff01, 0xff00010000000000, - 0xff00010000010001, 0xff00010000010100, 0xff00010001ffff00, 0xff00010001ff0101, - 0xff00010001010000, 0xff000101ffffffff, 0xff000101ff000101, 0xff00010101ff00ff, - 0xff00010101000001, 0xff000101010100ff, 0xff01ffffff000101, 0xff01ffffff01ffff, - 0xff01ffffff01ff01, 0xff01ffffff0101ff, 0xff01ffff00000000, 0xff01ffff01ff0001, - 0xff01ffff0101ff01, 0xff01ff00ff000000, 0xff01ff0000ff0100, 0xff01ff000000ff01, - 0xff01ff0000010000, 0xff01ff00010000ff, 0xff01ff01ff01ff00, 0xff01ff0100000101, - 0xff0100ffffff0000, 0xff0100ffff010000, 0xff0100ff01ff00ff, 0xff0100ff01000100, - 0xff0100ff010100ff, 0xff010000ffffff01, 0xff01000000000000, 0xff0100000101ff00, - 0xff010001ffff00ff, 0xff010001ff000100, 0xff01000100ffff00, 0xff01000100010001, - 0xff01000101ff0001, 0xff010001010001ff, 0xff0101ffffffffff, 0xff0101ffff01ffff, - 0xff0101ffff010101, 0xff0101ff0000ff00, 0xff0101ff01010001, 0xff010100ff000000, - 0xff010100ff01ff01, 0xff01010000ff0001, 0xff01010000000100, 0xff01010001000000, - 0xff0101010100ffff, 0x00ffffff0000ff01, 0x00ffffff000000ff, 0x00ffffff00000100, - 0x00ffffff00010000, 0x00ffff00ffff0001, 0x00ffff00ff0000ff, 0x00ffff00ff000100, - 0x00ffff0000000000, 0x00ffff0001000100, 0x00ffff0001010001, 0x00ffff01ff00ff01, - 0x00ffff0100ff0100, 0x00ffff010000ff00, 0x00ffff01000100ff, 0x00ffff0101ff00ff, - 0x00ffff010101ff00, 0x00ff00ffffffffff, 0x00ff00ffffff01ff, 0x00ff00ffff000101, - 0x00ff00ff00000000, 0x00ff00ff000101ff, 0x00ff00ff01010101, 0x00ff0000ff000000, - 0x00ff0000ff01ffff, 0x00ff000000ff0000, 0x00ff00000000ff00, 0x00ff0000000000ff, - 0x00ff000000000000, 0x00ff000000000001, 0x00ff000000000100, 0x00ff000000010000, - 0x00ff000001ffff01, 0x00ff000001000000, 0x00ff0001ff000101, 0x00ff000100ffffff, - 0x00ff000100000000, 0x00ff0001010001ff, 0x00ff01ffff000000, 0x00ff01ff0001ff00, - 0x00ff01ff01ff0100, 0x00ff0100ff01ff01, 0x00ff010000ff00ff, 0x00ff010000ff0101, - 0x00ff010000000000, 0x00ff010000010101, 0x00ff01000100ff00, 0x00ff010001010000, - 0x00ff0101ffffff00, 0x00ff01010000ff01, 0x00ff010100000100, 0x00ff010101ff0000, - 0x0000ffffffff0100, 0x0000ffffff00ff00, 0x0000ffffff0000ff, 0x0000ffffff010000, - 0x0000ffff00000000, 0x0000ffff00010101, 0x0000ffff01ffff01, 0x0000ffff01000100, - 0x0000ff00ff000000, 0x0000ff00ff01ff00, 0x0000ff00ff0101ff, 0x0000ff0000ff0000, - 0x0000ff000000ff00, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001, - 0x0000ff0000000100, 0x0000ff0000010000, 0x0000ff0001ffffff, 0x0000ff0001ff01ff, - 0x0000ff0001000000, 0x0000ff000101ffff, 0x0000ff01ffff0101, 0x0000ff01ff010000, - 0x0000ff0100000000, 0x0000ff0101000101, 0x000000ffffff0001, 0x000000ffff000000, - 0x000000ff00ff0000, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000, - 0x000000ff00000001, 0x000000ff00000100, 0x000000ff00010000, 0x000000ff01000000, - 0x000000ff0101ff00, 0x00000000ffff0000, 0x00000000ff00ff00, 0x00000000ff0000ff, - 0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff010000, - 0x0000000000ffff00, 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001, - 0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01, - 0x00000000000000ff, 0x0000000000000001, 0x00000000000001ff, 0x0000000000000100, - 0x0000000000000101, 0x000000000001ff00, 0x00000000000100ff, 0x0000000000010000, - 0x0000000000010001, 0x0000000000010100, 0x0000000001ff0000, 0x000000000100ff00, - 0x00000000010000ff, 0x0000000001000000, 0x0000000001000001, 0x0000000001000100, - 0x0000000001010000, 0x00000001ffff01ff, 0x00000001ff000000, 0x0000000100ff0000, - 0x000000010000ff00, 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001, - 0x0000000100000100, 0x0000000100010000, 0x0000000101000000, 0x000001ffff00ff00, - 0x000001ffff010001, 0x000001ffff0101ff, 0x000001ff00ffff01, 0x000001ff0000ffff, - 0x000001ff00000000, 0x000001ff010000ff, 0x000001ff01010100, 0x00000100ffff0100, - 0x00000100ff000000, 0x0000010000ff0000, 0x000001000000ff00, 0x00000100000000ff, - 0x0000010000000000, 0x0000010000000001, 0x0000010000000100, 0x0000010000010000, - 0x0000010001000000, 0x000001000101ff01, 0x00000101ffff0001, 0x00000101ff01ffff, - 0x0000010100000000, 0x0000010101010100, 0x0001ffffff000000, 0x0001ffff00ffffff, - 0x0001ffff00000100, 0x0001ffff0001ff00, 0x0001ffff01000000, 0x0001ff00ffffff00, - 0x0001ff00ffff01ff, 0x0001ff00ff010000, 0x0001ff0000000000, 0x0001ff0000010001, - 0x0001ff0001ff0000, 0x0001ff0001010100, 0x0001ff01ff0000ff, 0x0001ff01ff000001, - 0x0001ff0100ffffff, 0x0001ff010001ffff, 0x0001ff01000101ff, 0x0001ff010100ff01, - 0x000100ffff00ffff, 0x000100ffff00ff01, 0x000100ffff000100, 0x000100ff00000000, - 0x000100ff000101ff, 0x000100ff01ff0101, 0x000100ff0100ffff, 0x000100ff01010101, - 0x00010000ff000000, 0x00010000ff010100, 0x0001000000ff0000, 0x000100000000ff00, - 0x00010000000000ff, 0x0001000000000000, 0x0001000000000001, 0x0001000000000100, - 0x0001000000010000, 0x0001000001ffff01, 0x0001000001000000, 0x0001000100ff0101, - 0x0001000100000000, 0x00010001010100ff, 0x000101ffffff01ff, 0x000101ffffff0101, - 0x000101ff00010000, 0x000101ff01ff0000, 0x000101ff0100ff01, 0x00010100ffff0000, - 0x0001010000000000, 0x000101000001ffff, 0x0001010000010101, 0x00010100010001ff, - 0x00010101ff00ff00, 0x00010101ff010001, 0x0001010100ffffff, 0x0001010100ff01ff, - 0x00010101000101ff, 0x0001010101ff0000, 0x000101010100ff01, 0x0001010101000101, - 0x01ffffffffff0101, 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff, - 0x01ffffffff010101, 0x01ffffff00000000, 0x01ffffff01ff01ff, 0x01ffffff01000101, - 0x01ffffff0101ff01, 0x01ffffff010100ff, 0x01ffff000000ff00, 0x01ffff0000000001, - 0x01ffff00000001ff, 0x01ffff0000010000, 0x01ffff0001ff0000, 0x01ffff01ffffffff, - 0x01ffff01ffff01ff, 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff0101ff, - 0x01ffff010100ffff, 0x01ff00ffffff0000, 0x01ff00ffff010000, 0x01ff00ff00ffff01, - 0x01ff0000ff0000ff, 0x01ff000000000000, 0x01ff00000001ff01, 0x01ff000001ffffff, - 0x01ff000001010100, 0x01ff0001ffffff01, 0x01ff0001ff010001, 0x01ff000101ff0100, - 0x01ff000101000001, 0x01ff0001010100ff, 0x01ff01ffff00ffff, 0x01ff01ff00010001, - 0x01ff01ff01000000, 0x01ff01ff010101ff, 0x01ff0100ff000001, 0x01ff010000ffff00, - 0x01ff010000000100, 0x01ff010001ff01ff, 0x01ff01000101ffff, 0x01ff0101ffff00ff, - 0x01ff0101ffff0101, 0x01ff0101ff0101ff, 0x01ff010100010000, 0x0100ffff00ff00ff, - 0x0100ffff00ff0001, 0x0100ffff00000100, 0x0100ffff0100ff00, 0x0100ff00ffff0000, - 0x0100ff00ff00ffff, 0x0100ff00ff00ff01, 0x0100ff00ff000100, 0x0100ff00ff010000, - 0x0100ff0000000000, 0x0100ff00000100ff, 0x0100ff0001ff0101, 0x0100ff0001010101, - 0x0100ff0100ff00ff, 0x0100ff0100ff0001, 0x0100ff0100000100, 0x0100ff0100010001, - 0x0100ff0101000000, 0x010000ffff00ff00, 0x010000ff0000ffff, 0x010000ff00000000, - 0x010000ff010001ff, 0x010000ff01010001, 0x01000000ffffff00, 0x01000000ffff0101, - 0x01000000ff000000, 0x01000000ff0100ff, 0x01000000ff010101, 0x0100000000ff0000, - 0x010000000000ff00, 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001, - 0x0100000000000100, 0x0100000000010000, 0x0100000001000000, 0x0100000100000000, - 0x01000001000101ff, 0x0100000101ffff01, 0x010001ffff000101, 0x010001ff00ff0100, - 0x010001ff0000ff00, 0x010001ff000100ff, 0x010001ff01ffffff, 0x01000100ffff0000, - 0x01000100ff0001ff, 0x0100010000000000, 0x010001000001ff00, 0x0100010001ff0000, - 0x01000100010000ff, 0x0100010001000101, 0x01000101ff00ff01, 0x0100010100ff0100, - 0x010001010000ffff, 0x0100010101010001, 0x0101ffffffff0101, 0x0101ffffff0001ff, - 0x0101ffffff01ffff, 0x0101ffffff010101, 0x0101ffff00000000, 0x0101ffff0101ffff, - 0x0101ffff010101ff, 0x0101ff00ff000000, 0x0101ff0000ff0100, 0x0101ff000000ff00, - 0x0101ff0000010000, 0x0101ff00010000ff, 0x0101ff0001000001, 0x0101ff01ff010101, - 0x0101ff0100000000, 0x0101ff010101ff00, 0x010100ffffff0000, 0x010100ffff010000, - 0x010100ff00ff01ff, 0x010100ff000000ff, 0x010100ff00000101, 0x010100ff01ffff00, - 0x01010000ffffff01, 0x01010000ff000100, 0x01010000ff01ff01, 0x0101000000000000, - 0x01010000000100ff, 0x010100000101ff01, 0x01010001ffff0000, 0x01010001ff00ffff, - 0x01010001ff010000, 0x0101000101ffffff, 0x0101000101ff01ff, 0x0101000101010101, - 0x010101ffff01ffff, 0x010101ff00000000, 0x010101ff0001ff01, 0x010101ff0101ffff, - 0x010101ff010101ff, 0x01010100ffffffff, 0x01010100ff000001, 0x010101000000ff00, - 0x0101010001010000, 0x0101010100ff0001, 0x010101010001ff01, 0x010101010101ffff, +#define IQ1S_DELTA 0.125f +#define IQ1M_DELTA 0.125f +static const __device__ uint64_t iq1s_grid_gpu[2048] = { + 0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000, + 0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101, + 0x02000000, 0x02000002, 0x02000200, 0x02000202, 0x02010101, 0x02020000, 0x02020002, 0x02020200, + 0x02020202, 0x00000110, 0x00000111, 0x00010011, 0x00010110, 0x00010112, 0x00010211, 0x00010212, + 0x00020111, 0x01000011, 0x01000112, 0x01000211, 0x01010012, 0x01010111, 0x01010212, 0x01020011, + 0x01020110, 0x01020112, 0x01020210, 0x02000111, 0x02010011, 0x02010110, 0x02010112, 0x02020111, + 0x00000020, 0x00000022, 0x00000220, 0x00000222, 0x00010121, 0x00020020, 0x00020022, 0x00020220, + 0x00020222, 0x01000121, 0x01010021, 0x01010221, 0x01020120, 0x01020221, 0x02000020, 0x02000022, + 0x02000220, 0x02000222, 0x02010021, 0x02010121, 0x02010221, 0x02020020, 0x02020022, 0x02020220, + 0x02020222, 0x00011001, 0x00011100, 0x00011102, 0x00021101, 0x01001001, 0x01001201, 0x01011101, + 0x01011202, 0x01021100, 0x01021101, 0x02011001, 0x02011201, 0x02021101, 0x00001011, 0x00001110, + 0x00001111, 0x00001112, 0x00011111, 0x00011210, 0x00011212, 0x00021211, 0x01001010, 0x01001111, + 0x01001212, 0x01011010, 0x01011011, 0x01011110, 0x01011111, 0x01011112, 0x01011211, 0x01021010, + 0x01021012, 0x01021111, 0x01021210, 0x01021212, 0x02001011, 0x02011011, 0x02011111, 0x02011210, + 0x02011212, 0x02021011, 0x02021110, 0x02021111, 0x02021112, 0x02021211, 0x00011120, 0x00011221, + 0x01001021, 0x01001120, 0x01011020, 0x01011022, 0x01011121, 0x01011220, 0x01021020, 0x01021021, + 0x01021122, 0x01021221, 0x02001121, 0x02011021, 0x02011120, 0x02011221, 0x00002000, 0x00002002, + 0x00002200, 0x00002202, 0x00012101, 0x00022000, 0x00022002, 0x00022200, 0x00022202, 0x01002101, + 0x01012001, 0x01012102, 0x01022101, 0x02002000, 0x02002002, 0x02002200, 0x02002202, 0x02012101, + 0x02022000, 0x02022002, 0x02022200, 0x02022202, 0x00002111, 0x00012011, 0x00012110, 0x00012211, + 0x00022110, 0x00022111, 0x01002011, 0x01012010, 0x01012011, 0x01012111, 0x01022011, 0x01022110, + 0x01022211, 0x02012011, 0x02012110, 0x02012112, 0x02012211, 0x02022111, 0x00002020, 0x00002022, + 0x00002220, 0x00002222, 0x00012121, 0x00022020, 0x00022022, 0x00022220, 0x00022222, 0x01002121, + 0x01012021, 0x01012221, 0x01022021, 0x01022121, 0x02002020, 0x02002022, 0x02002121, 0x02002220, + 0x02002222, 0x02012121, 0x02022020, 0x02022022, 0x02022220, 0x02022222, 0x00110000, 0x00110001, + 0x00110100, 0x00110201, 0x00120100, 0x00120101, 0x01100001, 0x01100100, 0x01110000, 0x01110101, + 0x01110200, 0x01120001, 0x01120100, 0x01120101, 0x01120201, 0x02110001, 0x02110100, 0x02110102, + 0x02120001, 0x02120101, 0x00100011, 0x00100110, 0x00100112, 0x00100211, 0x00110010, 0x00110012, + 0x00110111, 0x00110210, 0x00120011, 0x00120110, 0x00120211, 0x01100111, 0x01100212, 0x01110010, + 0x01110011, 0x01110012, 0x01110110, 0x01110111, 0x01110112, 0x01110211, 0x01120010, 0x01120111, + 0x02100110, 0x02110012, 0x02110111, 0x02120011, 0x02120110, 0x00110021, 0x00110120, 0x00110122, + 0x00120121, 0x01100020, 0x01100122, 0x01100221, 0x01110022, 0x01110121, 0x01110220, 0x01110222, + 0x01120120, 0x01120122, 0x02100121, 0x02110021, 0x02110120, 0x02110122, 0x02120121, 0x00101001, + 0x00101102, 0x00101201, 0x00111100, 0x00111101, 0x00111200, 0x00111201, 0x00121001, 0x00121102, + 0x01101001, 0x01101101, 0x01101102, 0x01101200, 0x01101202, 0x01111001, 0x01111100, 0x01111101, + 0x01111102, 0x01111201, 0x01121002, 0x01121101, 0x01121200, 0x02101100, 0x02101201, 0x02111000, + 0x02111100, 0x02111101, 0x02111200, 0x02111201, 0x02111202, 0x02121001, 0x02121100, 0x02121101, + 0x02121201, 0x00101012, 0x00101111, 0x00101212, 0x00111011, 0x00111110, 0x00111111, 0x00111112, + 0x00111211, 0x00121010, 0x00121012, 0x00121111, 0x00121210, 0x00121212, 0x01101011, 0x01101110, + 0x01101111, 0x01101112, 0x01111011, 0x01111012, 0x01111110, 0x01111111, 0x01111112, 0x01111211, + 0x01111212, 0x01121011, 0x01121110, 0x01121111, 0x01121112, 0x01121211, 0x02101010, 0x02101012, + 0x02101110, 0x02101111, 0x02101210, 0x02101212, 0x02111010, 0x02111011, 0x02111110, 0x02111111, + 0x02111112, 0x02111211, 0x02111212, 0x02121010, 0x02121012, 0x02121111, 0x00101021, 0x00101120, + 0x00101121, 0x00101122, 0x00111121, 0x00111122, 0x00111220, 0x00111222, 0x00121021, 0x00121122, + 0x01101020, 0x01101022, 0x01101120, 0x01101121, 0x01101220, 0x01101222, 0x01111021, 0x01111121, + 0x01111122, 0x01111220, 0x01111221, 0x01121021, 0x01121120, 0x01121121, 0x01121220, 0x01121221, + 0x01121222, 0x02101122, 0x02101222, 0x02111022, 0x02111121, 0x02121120, 0x02121221, 0x00112001, + 0x00112102, 0x00122101, 0x01102001, 0x01102100, 0x01102102, 0x01102201, 0x01112000, 0x01112101, + 0x01112200, 0x01112202, 0x01122000, 0x01122001, 0x01122100, 0x01122102, 0x01122201, 0x02102101, + 0x02112001, 0x02112100, 0x02122101, 0x00112010, 0x00112012, 0x00112111, 0x00112212, 0x00122011, + 0x00122111, 0x01102012, 0x01102110, 0x01102111, 0x01102210, 0x01112011, 0x01112110, 0x01112111, + 0x01112112, 0x01112211, 0x01112212, 0x01122010, 0x01122111, 0x01122212, 0x02102211, 0x02112011, + 0x02112012, 0x02112111, 0x02112210, 0x02122011, 0x02122112, 0x02122211, 0x00102221, 0x00112122, + 0x00122120, 0x00122122, 0x01102120, 0x01102122, 0x01102221, 0x01112020, 0x01112022, 0x01112121, + 0x01112220, 0x01122021, 0x01122122, 0x01122221, 0x02102121, 0x02112021, 0x02112122, 0x02112222, + 0x00200000, 0x00200002, 0x00200200, 0x00200202, 0x00210101, 0x00220000, 0x00220002, 0x00220101, + 0x00220200, 0x00220202, 0x01200101, 0x01210001, 0x01210201, 0x01220001, 0x01220101, 0x02200000, + 0x02200002, 0x02200200, 0x02200202, 0x02210101, 0x02220000, 0x02220002, 0x02220101, 0x02220200, + 0x02220202, 0x00200111, 0x00210011, 0x00210110, 0x00210211, 0x00220111, 0x01200012, 0x01200110, + 0x01200211, 0x01210111, 0x01210210, 0x01210212, 0x01220011, 0x01220110, 0x01220111, 0x01220112, + 0x02200111, 0x02210010, 0x02210112, 0x02210211, 0x02220111, 0x00200021, 0x00200220, 0x00200222, + 0x00210021, 0x00210121, 0x00220020, 0x00220022, 0x00220220, 0x00220222, 0x01200121, 0x01210021, + 0x01210122, 0x01210221, 0x01220121, 0x02200021, 0x02200220, 0x02200222, 0x02210021, 0x02210121, + 0x02220020, 0x02220022, 0x02220220, 0x02220222, 0x00201101, 0x00211100, 0x00211102, 0x00211201, + 0x00221101, 0x01201100, 0x01201101, 0x01201102, 0x01201201, 0x01211002, 0x01211101, 0x01211200, + 0x01211202, 0x01221102, 0x02201101, 0x02211001, 0x02211100, 0x02211201, 0x02221001, 0x02221101, + 0x00201211, 0x00211111, 0x00221011, 0x00221211, 0x01201010, 0x01201111, 0x01201210, 0x01211011, + 0x01211110, 0x01211111, 0x01211211, 0x01221012, 0x01221111, 0x01221210, 0x02201211, 0x02211010, + 0x02211110, 0x02211111, 0x02211210, 0x02211212, 0x02221011, 0x02221110, 0x02221112, 0x02221211, + 0x00201121, 0x00211020, 0x00211022, 0x00211221, 0x00221121, 0x01201021, 0x01201221, 0x01211121, + 0x01221020, 0x01221021, 0x01221221, 0x02201120, 0x02201122, 0x02211020, 0x02211222, 0x00202000, + 0x00202002, 0x00202200, 0x00202202, 0x00212101, 0x00222000, 0x00222002, 0x00222200, 0x00222202, + 0x01202101, 0x01212001, 0x01212100, 0x01222101, 0x02202000, 0x02202002, 0x02202200, 0x02202202, + 0x02222000, 0x02222002, 0x02222200, 0x02222202, 0x00202211, 0x00212011, 0x00212110, 0x00212211, + 0x00222111, 0x01202112, 0x01202211, 0x01212012, 0x01212111, 0x01222011, 0x01222110, 0x01222112, + 0x01222211, 0x02202111, 0x02212010, 0x02212112, 0x02212211, 0x02222110, 0x02222111, 0x00202020, + 0x00202022, 0x00202220, 0x00202222, 0x00222020, 0x00222022, 0x00222220, 0x00222222, 0x01202121, + 0x01212021, 0x01212122, 0x01212221, 0x01222121, 0x02202020, 0x02202022, 0x02202220, 0x02202222, + 0x02212121, 0x02222020, 0x02222022, 0x02222220, 0x02222222, 0x10000101, 0x10010001, 0x10010102, + 0x10020101, 0x11000201, 0x11010002, 0x11010101, 0x11010200, 0x11010202, 0x11020001, 0x11020100, + 0x11020102, 0x12010100, 0x12010201, 0x12020001, 0x12020102, 0x10000010, 0x10000011, 0x10000110, + 0x10000112, 0x10000211, 0x10010012, 0x10010111, 0x10010112, 0x10010210, 0x10010212, 0x10020011, + 0x10020112, 0x10020211, 0x11000111, 0x11000210, 0x11000212, 0x11010011, 0x11010110, 0x11010111, + 0x11010112, 0x11010211, 0x11010212, 0x11020111, 0x11020210, 0x11020212, 0x12000011, 0x12000110, + 0x12000112, 0x12010010, 0x12010012, 0x12010111, 0x12020010, 0x12020011, 0x12020012, 0x10000121, + 0x10010021, 0x10010120, 0x10010122, 0x10020121, 0x11000021, 0x11010022, 0x11010121, 0x11010222, + 0x11020120, 0x11020221, 0x12000221, 0x12010120, 0x12020121, 0x10001001, 0x10011101, 0x10011201, + 0x10021201, 0x11001101, 0x11001200, 0x11001202, 0x11011001, 0x11011100, 0x11011101, 0x11011102, + 0x11021001, 0x11021002, 0x11021101, 0x11021200, 0x11021202, 0x12001001, 0x12001102, 0x12001201, + 0x12011000, 0x12011002, 0x12011101, 0x12021000, 0x12021001, 0x12021201, 0x10001011, 0x10001012, + 0x10001111, 0x10001212, 0x10011011, 0x10011110, 0x10011111, 0x10011112, 0x10011211, 0x10021010, + 0x10021111, 0x10021212, 0x11001011, 0x11001110, 0x11001111, 0x11001112, 0x11001211, 0x11011010, + 0x11011011, 0x11011110, 0x11011111, 0x11011112, 0x11011210, 0x11011211, 0x11021011, 0x11021110, + 0x11021111, 0x11021112, 0x11021211, 0x12001012, 0x12001110, 0x12001111, 0x12001210, 0x12011011, + 0x12011110, 0x12011111, 0x12011112, 0x12011211, 0x12011212, 0x12021111, 0x12021210, 0x12021212, + 0x10001021, 0x10001121, 0x10001221, 0x10011120, 0x10011121, 0x10011220, 0x10011222, 0x10021021, + 0x10021120, 0x10021221, 0x11001020, 0x11001022, 0x11001121, 0x11001220, 0x11011020, 0x11011021, + 0x11011022, 0x11011121, 0x11011122, 0x11011221, 0x11021022, 0x11021121, 0x11021220, 0x12001021, + 0x12001121, 0x12001222, 0x12011120, 0x12011121, 0x12021021, 0x12021120, 0x12021122, 0x10002101, + 0x10012001, 0x10012101, 0x10012202, 0x10022101, 0x11002002, 0x11002201, 0x11012000, 0x11012101, + 0x11012200, 0x11022001, 0x11022100, 0x11022102, 0x11022201, 0x12002101, 0x12012001, 0x12012100, + 0x12012102, 0x12012201, 0x12022101, 0x10002011, 0x10002111, 0x10002112, 0x10002212, 0x10012010, + 0x10012110, 0x10012111, 0x10012210, 0x10022011, 0x10022110, 0x10022112, 0x11002010, 0x11002111, + 0x11002212, 0x11012011, 0x11012012, 0x11012110, 0x11012111, 0x11012112, 0x11012211, 0x11022010, + 0x11022012, 0x11022111, 0x11022112, 0x11022212, 0x12002112, 0x12002211, 0x12012012, 0x12012111, + 0x12012112, 0x12012210, 0x12022011, 0x12022110, 0x12022112, 0x12022211, 0x10012122, 0x11002120, + 0x11002122, 0x11002221, 0x11012121, 0x11012220, 0x11012222, 0x11022120, 0x11022221, 0x12012120, + 0x12022121, 0x10100001, 0x10100100, 0x10100101, 0x10100102, 0x10100201, 0x10110002, 0x10110101, + 0x10110202, 0x10120001, 0x10120100, 0x10120201, 0x11100000, 0x11100101, 0x11100200, 0x11110001, + 0x11110100, 0x11110101, 0x11110102, 0x11110201, 0x11120101, 0x11120200, 0x12100102, 0x12100201, + 0x12110101, 0x12110200, 0x12120000, 0x12120001, 0x12120102, 0x12120201, 0x10100111, 0x10100210, + 0x10100211, 0x10100212, 0x10110011, 0x10110110, 0x10110111, 0x10110112, 0x10110210, 0x10110211, + 0x10120010, 0x10120111, 0x10120112, 0x10120210, 0x10120212, 0x11100011, 0x11100110, 0x11100111, + 0x11100112, 0x11100211, 0x11110010, 0x11110011, 0x11110012, 0x11110110, 0x11110111, 0x11110112, + 0x11110210, 0x11110211, 0x11110212, 0x11120011, 0x11120110, 0x11120111, 0x11120112, 0x11120211, + 0x12100012, 0x12100111, 0x12110011, 0x12110110, 0x12110111, 0x12110112, 0x12110211, 0x12120010, + 0x12120111, 0x12120212, 0x10100021, 0x10100122, 0x10110022, 0x10110121, 0x10110222, 0x10120021, + 0x10120120, 0x11100022, 0x11100121, 0x11100222, 0x11110021, 0x11110120, 0x11110121, 0x11110122, + 0x11110221, 0x11120022, 0x11120121, 0x12100121, 0x12110020, 0x12110022, 0x12110121, 0x12110221, + 0x12110222, 0x12120120, 0x10101100, 0x10101101, 0x10111001, 0x10111100, 0x10111101, 0x10111102, + 0x10111200, 0x10111201, 0x10121001, 0x10121101, 0x10121200, 0x10121202, 0x11101001, 0x11101100, + 0x11101101, 0x11101102, 0x11101201, 0x11101202, 0x11111000, 0x11111001, 0x11111100, 0x11111101, + 0x11111102, 0x11111200, 0x11111201, 0x11111202, 0x11121001, 0x11121002, 0x11121100, 0x11121101, + 0x11121102, 0x11121201, 0x12101000, 0x12101200, 0x12101202, 0x12111001, 0x12111100, 0x12111101, + 0x12111102, 0x12111201, 0x12121001, 0x12121100, 0x12121101, 0x12121202, 0x10101011, 0x10101012, + 0x10101110, 0x10101111, 0x10101112, 0x10101211, 0x10111010, 0x10111011, 0x10111012, 0x10111110, + 0x10111111, 0x10111112, 0x10111211, 0x10111212, 0x10121011, 0x10121110, 0x10121111, 0x10121112, + 0x10121211, 0x11101010, 0x11101011, 0x11101012, 0x11101110, 0x11101111, 0x11101112, 0x11101210, + 0x11101211, 0x11111010, 0x11111011, 0x11111012, 0x11111110, 0x11111111, 0x11111112, 0x11111210, + 0x11111211, 0x11111212, 0x11121010, 0x11121011, 0x11121110, 0x11121111, 0x11121112, 0x11121210, + 0x11121211, 0x11121212, 0x12101011, 0x12101110, 0x12101111, 0x12101211, 0x12101212, 0x12111010, + 0x12111011, 0x12111110, 0x12111111, 0x12111112, 0x12111210, 0x12111211, 0x12121011, 0x12121110, + 0x12121111, 0x12121112, 0x12121211, 0x10101020, 0x10101021, 0x10101022, 0x10101120, 0x10101122, + 0x10101220, 0x10101221, 0x10111021, 0x10111120, 0x10111121, 0x10111220, 0x10111221, 0x10121020, + 0x10121021, 0x10121022, 0x10121120, 0x10121121, 0x10121122, 0x10121220, 0x10121221, 0x11101021, + 0x11101121, 0x11101122, 0x11101220, 0x11101221, 0x11101222, 0x11111020, 0x11111021, 0x11111022, + 0x11111120, 0x11111121, 0x11111122, 0x11111220, 0x11111221, 0x11111222, 0x11121021, 0x11121120, + 0x11121121, 0x11121221, 0x12101022, 0x12101121, 0x12101122, 0x12101220, 0x12101221, 0x12101222, + 0x12111021, 0x12111121, 0x12111222, 0x12121022, 0x12121121, 0x12121122, 0x12121220, 0x12121221, + 0x10102100, 0x10102101, 0x10102102, 0x10102201, 0x10112000, 0x10112101, 0x10112200, 0x10122001, + 0x10122202, 0x11102101, 0x11102200, 0x11102202, 0x11112001, 0x11112100, 0x11112101, 0x11112102, + 0x11112200, 0x11112201, 0x11122000, 0x11122002, 0x11122100, 0x11122101, 0x12102002, 0x12102201, + 0x12112000, 0x12112002, 0x12112101, 0x12112200, 0x12122001, 0x12122201, 0x10102011, 0x10102012, + 0x10102111, 0x10102212, 0x10112011, 0x10112110, 0x10112111, 0x10112112, 0x10112211, 0x10122111, + 0x11102011, 0x11102110, 0x11102111, 0x11102112, 0x11102211, 0x11112010, 0x11112011, 0x11112012, + 0x11112110, 0x11112111, 0x11112112, 0x11112210, 0x11112211, 0x11112212, 0x11122011, 0x11122110, + 0x11122111, 0x11122112, 0x11122211, 0x12102011, 0x12102111, 0x12102211, 0x12112011, 0x12112110, + 0x12112111, 0x12112112, 0x12112210, 0x12112211, 0x12122111, 0x10102120, 0x10102220, 0x10112121, + 0x10112222, 0x10122020, 0x10122121, 0x10122122, 0x10122221, 0x11102121, 0x11102220, 0x11102221, + 0x11112021, 0x11112121, 0x11112122, 0x11112220, 0x11112221, 0x11122022, 0x11122121, 0x11122220, + 0x11122222, 0x12102021, 0x12102222, 0x12112022, 0x12112121, 0x12112122, 0x12112220, 0x12112222, + 0x12122021, 0x10200101, 0x10210100, 0x10210102, 0x10210201, 0x10220101, 0x11200100, 0x11210000, + 0x11210101, 0x11210102, 0x11210200, 0x11210202, 0x11220001, 0x11220100, 0x11220102, 0x11220201, + 0x12200001, 0x12210102, 0x12220101, 0x10200011, 0x10200110, 0x10200112, 0x10200211, 0x10210012, + 0x10210111, 0x10220011, 0x10220012, 0x10220112, 0x10220211, 0x11200111, 0x11200211, 0x11210011, + 0x11210111, 0x11210112, 0x11210211, 0x11220111, 0x11220112, 0x11220212, 0x12200110, 0x12200212, + 0x12210012, 0x12210111, 0x12220011, 0x12220112, 0x12220211, 0x10210021, 0x10210122, 0x10210221, + 0x11200020, 0x11200021, 0x11200122, 0x11210121, 0x11210122, 0x11210220, 0x11220020, 0x12200121, + 0x12210021, 0x12210122, 0x12220121, 0x10211001, 0x10211002, 0x10211101, 0x10211102, 0x10211202, + 0x10221001, 0x10221102, 0x10221201, 0x11201000, 0x11201002, 0x11201101, 0x11201200, 0x11201202, + 0x11211001, 0x11211100, 0x11211101, 0x11211102, 0x11211201, 0x11211202, 0x11221000, 0x11221002, + 0x11221101, 0x12201100, 0x12201101, 0x12201201, 0x12211000, 0x12211002, 0x12211100, 0x12211101, + 0x12211102, 0x12211200, 0x12211202, 0x12221001, 0x12221100, 0x12221201, 0x10201111, 0x10201210, + 0x10201212, 0x10211011, 0x10211111, 0x10211112, 0x10211211, 0x11201110, 0x11201111, 0x11201112, + 0x11201211, 0x11211010, 0x11211011, 0x11211110, 0x11211111, 0x11211112, 0x11211211, 0x11221011, + 0x11221110, 0x11221111, 0x11221112, 0x11221211, 0x12201112, 0x12201211, 0x12201212, 0x12211011, + 0x12211111, 0x12211112, 0x12211211, 0x12211212, 0x12221012, 0x12221111, 0x12221112, 0x12221210, + 0x10201022, 0x10201221, 0x10211121, 0x10221020, 0x10221122, 0x10221220, 0x10221221, 0x11201020, + 0x11201121, 0x11201220, 0x11201222, 0x11211021, 0x11211120, 0x11211121, 0x11211122, 0x11211220, + 0x11211222, 0x11221020, 0x11221121, 0x11221220, 0x12201020, 0x12201022, 0x12201121, 0x12201222, + 0x12211120, 0x12211122, 0x12211220, 0x12211221, 0x12221020, 0x12221120, 0x12221122, 0x12221222, + 0x10212102, 0x10212201, 0x10222101, 0x11202001, 0x11212002, 0x11212101, 0x11212202, 0x11222001, + 0x11222201, 0x12202101, 0x12212001, 0x12212200, 0x12222102, 0x10202011, 0x10202110, 0x10212010, + 0x10212111, 0x10222011, 0x10222110, 0x10222112, 0x10222211, 0x11202010, 0x11202011, 0x11202111, + 0x11202112, 0x11202210, 0x11212011, 0x11212110, 0x11212111, 0x11212112, 0x11212211, 0x11222010, + 0x11222111, 0x11222212, 0x12202012, 0x12202110, 0x12202212, 0x12212111, 0x12222011, 0x12222110, + 0x12222111, 0x12222211, 0x10212021, 0x10212122, 0x10212220, 0x11202021, 0x11202120, 0x11202221, + 0x11212020, 0x11212121, 0x11212220, 0x11212222, 0x11222120, 0x11222121, 0x11222221, 0x12202122, + 0x12212120, 0x12212220, 0x12212222, 0x12222122, 0x20000000, 0x20000002, 0x20000200, 0x20000202, + 0x20020000, 0x20020002, 0x20020200, 0x20020202, 0x21000101, 0x21010000, 0x21010001, 0x21010100, + 0x21010102, 0x21010201, 0x21020101, 0x22000000, 0x22000002, 0x22000200, 0x22000202, 0x22010101, + 0x22020000, 0x22020002, 0x22020200, 0x22020202, 0x20000111, 0x20010011, 0x20010110, 0x20010112, + 0x20010211, 0x20020111, 0x21000011, 0x21000110, 0x21000211, 0x21010010, 0x21010012, 0x21010111, + 0x21010112, 0x21010210, 0x21010211, 0x21020110, 0x21020112, 0x21020211, 0x22000111, 0x22000211, + 0x22010110, 0x22010112, 0x22010211, 0x22020111, 0x20000020, 0x20000022, 0x20000220, 0x20000222, + 0x20010121, 0x20020020, 0x20020022, 0x20020220, 0x20020222, 0x21010021, 0x21010120, 0x21010221, + 0x21020121, 0x22000020, 0x22000022, 0x22000220, 0x22000222, 0x22010121, 0x22020020, 0x22020022, + 0x22020220, 0x22020222, 0x20011100, 0x20011201, 0x21001001, 0x21001100, 0x21011001, 0x21011101, + 0x21011202, 0x21021001, 0x21021100, 0x21021201, 0x22011100, 0x22011201, 0x20001011, 0x20001211, + 0x20011012, 0x20011111, 0x20011212, 0x20021112, 0x20021211, 0x21001010, 0x21001011, 0x21001111, + 0x21001210, 0x21011011, 0x21011110, 0x21011111, 0x21011112, 0x21011211, 0x21011212, 0x21021111, + 0x21021112, 0x21021210, 0x21021212, 0x22001011, 0x22001110, 0x22001112, 0x22001211, 0x22011010, + 0x22011012, 0x22011111, 0x22011210, 0x22021112, 0x20011021, 0x20011122, 0x20011221, 0x20021121, + 0x21001021, 0x21001120, 0x21001221, 0x21001222, 0x21011020, 0x21011121, 0x21011221, 0x21011222, + 0x21021021, 0x21021122, 0x21021222, 0x22001121, 0x22011021, 0x22011222, 0x22021120, 0x20002000, + 0x20002002, 0x20002200, 0x20002202, 0x20012101, 0x20022000, 0x20022002, 0x20022200, 0x20022202, + 0x21002001, 0x21002101, 0x21012001, 0x21012100, 0x21012201, 0x21022101, 0x21022201, 0x22002000, + 0x22002002, 0x22002200, 0x22002202, 0x22012101, 0x22022000, 0x22022002, 0x22022200, 0x22022202, + 0x20002111, 0x20002112, 0x20012011, 0x20012110, 0x20012112, 0x20022111, 0x21002011, 0x21002110, + 0x21002112, 0x21002211, 0x21012010, 0x21012012, 0x21012111, 0x21012212, 0x21022011, 0x21022110, + 0x22002111, 0x22012112, 0x22012211, 0x22022111, 0x20002020, 0x20002022, 0x20002220, 0x20002222, + 0x20012121, 0x20022020, 0x20022022, 0x20022220, 0x20022222, 0x21002121, 0x21012021, 0x21012120, + 0x21012122, 0x22002020, 0x22002022, 0x22002220, 0x22002222, 0x22012121, 0x22022020, 0x22022022, + 0x22022220, 0x22022222, 0x20100101, 0x20110001, 0x20110102, 0x20110200, 0x20110201, 0x20120101, + 0x21100001, 0x21100102, 0x21100201, 0x21110101, 0x21110200, 0x21110202, 0x21120201, 0x21120202, + 0x22100101, 0x22110001, 0x22110100, 0x22110102, 0x22110201, 0x22120101, 0x20100011, 0x20100110, + 0x20100112, 0x20100211, 0x20110010, 0x20110111, 0x20110210, 0x20110212, 0x20120011, 0x20120110, + 0x20120112, 0x20120211, 0x21100010, 0x21100111, 0x21110010, 0x21110011, 0x21110110, 0x21110111, + 0x21110112, 0x21110211, 0x21120012, 0x21120111, 0x22100110, 0x22100112, 0x22110012, 0x22110111, + 0x22110210, 0x22120011, 0x22120110, 0x22120112, 0x22120211, 0x20100121, 0x20110021, 0x20110120, + 0x20110221, 0x20120121, 0x21100120, 0x21100122, 0x21100221, 0x21110020, 0x21110022, 0x21110121, + 0x21110220, 0x21120122, 0x21120221, 0x22100121, 0x22110120, 0x22110122, 0x22120221, 0x20101001, + 0x20101100, 0x20101102, 0x20111000, 0x20111101, 0x20111200, 0x20121102, 0x21101000, 0x21101202, + 0x21111001, 0x21111100, 0x21111101, 0x21111102, 0x21111200, 0x21111201, 0x21121000, 0x21121001, + 0x21121002, 0x21121101, 0x22101100, 0x22101102, 0x22111002, 0x22111100, 0x22111101, 0x22111200, + 0x22121001, 0x22121201, 0x20101010, 0x20101111, 0x20101210, 0x20101212, 0x20111010, 0x20111011, + 0x20111110, 0x20111111, 0x20111112, 0x20111211, 0x20121011, 0x20121111, 0x20121211, 0x20121212, + 0x21101011, 0x21101110, 0x21101111, 0x21101112, 0x21101211, 0x21111010, 0x21111011, 0x21111012, + 0x21111110, 0x21111111, 0x21111112, 0x21111210, 0x21111211, 0x21111212, 0x21121011, 0x21121110, + 0x21121111, 0x21121112, 0x21121211, 0x22101011, 0x22101111, 0x22101210, 0x22111011, 0x22111012, + 0x22111110, 0x22111111, 0x22111112, 0x22111211, 0x22111212, 0x22121010, 0x22121012, 0x22121111, + 0x22121210, 0x22121212, 0x20101021, 0x20101120, 0x20111020, 0x20111121, 0x20111221, 0x20121020, + 0x20121122, 0x20121221, 0x21101121, 0x21101220, 0x21101221, 0x21111021, 0x21111022, 0x21111121, + 0x21111122, 0x21111221, 0x21121121, 0x21121220, 0x22101022, 0x22101120, 0x22101221, 0x22101222, + 0x22111022, 0x22111120, 0x22111121, 0x22121120, 0x22121122, 0x22121221, 0x20102101, 0x20112102, + 0x20112201, 0x20122101, 0x21102001, 0x21102102, 0x21112000, 0x21112002, 0x21112101, 0x21112102, + 0x21112202, 0x21122100, 0x21122101, 0x22102101, 0x22112001, 0x22112102, 0x22112201, 0x22122101, + 0x20102110, 0x20102112, 0x20102211, 0x20112010, 0x20112012, 0x20112111, 0x20112210, 0x20112212, + 0x20122010, 0x20122011, 0x20122110, 0x20122112, 0x21102010, 0x21102012, 0x21102111, 0x21102210, + 0x21102212, 0x21112011, 0x21112110, 0x21112111, 0x21112112, 0x21112211, 0x21122012, 0x21122111, + 0x21122112, 0x21122212, 0x22102011, 0x22102110, 0x22112010, 0x22112012, 0x22112111, 0x22112212, + 0x22122011, 0x22122112, 0x20102121, 0x20112121, 0x20122121, 0x21102120, 0x21102122, 0x21102221, + 0x21112020, 0x21112121, 0x21112220, 0x21122021, 0x22102121, 0x22112021, 0x22112120, 0x22112121, + 0x22112122, 0x20200000, 0x20200002, 0x20200200, 0x20200202, 0x20210101, 0x20220000, 0x20220002, + 0x20220200, 0x20220202, 0x21200101, 0x21210001, 0x21210100, 0x21210102, 0x21210201, 0x22200000, + 0x22200002, 0x22200200, 0x22200202, 0x22210101, 0x22220000, 0x22220002, 0x22220200, 0x22220202, + 0x20200111, 0x20200211, 0x20210011, 0x20210110, 0x20210112, 0x20210211, 0x20210212, 0x21200112, + 0x21200211, 0x21210011, 0x21210111, 0x21210210, 0x21210212, 0x21220011, 0x21220110, 0x22200111, + 0x22210010, 0x22210012, 0x22210112, 0x22210211, 0x20200022, 0x20200220, 0x20200222, 0x20210020, + 0x20210221, 0x20220022, 0x20220220, 0x20220222, 0x21200121, 0x21210021, 0x21210122, 0x21210221, + 0x21220121, 0x22200020, 0x22200022, 0x22200220, 0x22200222, 0x22210121, 0x22220020, 0x22220022, + 0x22220220, 0x22220222, 0x20211201, 0x20221101, 0x21201001, 0x21201100, 0x21211000, 0x21211100, + 0x21211101, 0x21211200, 0x21211202, 0x21221001, 0x21221101, 0x21221102, 0x21221200, 0x21221201, + 0x22201101, 0x20201112, 0x20201211, 0x20211010, 0x20211012, 0x20211111, 0x20211210, 0x20221112, + 0x20221211, 0x21201012, 0x21201111, 0x21211011, 0x21211110, 0x21211111, 0x21211112, 0x21211211, + 0x21221111, 0x21221212, 0x22201011, 0x22201110, 0x22201111, 0x22201112, 0x22201211, 0x22211012, + 0x22211111, 0x22211210, 0x20201121, 0x20211021, 0x20211122, 0x20211222, 0x20221021, 0x20221121, + 0x21201120, 0x21201122, 0x21201222, 0x21211022, 0x21211121, 0x21211122, 0x21211220, 0x21221020, + 0x21221022, 0x22201122, 0x22211020, 0x22211121, 0x22211122, 0x22211221, 0x22221021, 0x22221120, + 0x22221122, 0x20202000, 0x20202002, 0x20202200, 0x20202202, 0x20222000, 0x20222002, 0x20222200, + 0x20222202, 0x21212001, 0x21212100, 0x21212102, 0x21212201, 0x22202000, 0x22202002, 0x22202200, + 0x22202202, 0x22212101, 0x22222000, 0x22222002, 0x22222200, 0x22222202, 0x20202111, 0x20212110, + 0x20212211, 0x20222011, 0x20222111, 0x21202011, 0x21212010, 0x21212111, 0x21212212, 0x21222011, + 0x21222112, 0x21222211, 0x22212010, 0x22212112, 0x20202020, 0x20202022, 0x20202220, 0x20202222, + 0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020, + 0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222, }; static const __device__ uint8_t ksigns_iq2xs[128] = { diff --git a/csrc/quantization/gguf/gguf_kernel.cu b/csrc/quantization/gguf/gguf_kernel.cu index 966d9992b25fd..37e4de4e14dd3 100644 --- a/csrc/quantization/gguf/gguf_kernel.cu +++ b/csrc/quantization/gguf/gguf_kernel.cu @@ -166,6 +166,11 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), col, row, stream); break; + case 29: + mul_mat_vec_iq1_m_q8_1_cuda((void*)W.data_ptr(), + (void*)quant_X.data_ptr(), + (half*)Y.data_ptr(), col, row, stream); + break; } return Y; } diff --git a/csrc/quantization/gguf/mmvq.cuh b/csrc/quantization/gguf/mmvq.cuh index ef2ea072392d2..b221ae7896138 100644 --- a/csrc/quantization/gguf/mmvq.cuh +++ b/csrc/quantization/gguf/mmvq.cuh @@ -157,6 +157,14 @@ static void mul_mat_vec_iq1_s_q8_1_cuda(const void * vx, const void * vy, half * <<>>(vx, vy, dst, ncols, nrows); } +static void mul_mat_vec_iq1_m_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + static void mul_mat_vec_iq4_nl_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); diff --git a/csrc/quantization/gguf/vecdotq.cuh b/csrc/quantization/gguf/vecdotq.cuh index 78c749d3f3bc1..ff339753bcbb5 100644 --- a/csrc/quantization/gguf/vecdotq.cuh +++ b/csrc/quantization/gguf/vecdotq.cuh @@ -1,5 +1,18 @@ // copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/vecdotq.cuh // and https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmq.cu +static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) { + const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment + + int x32 = x16[2*i32 + 0] << 0; + x32 |= x16[2*i32 + 1] << 16; + + return x32; +} + +static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32) { + return ((const int *) x)[i32]; // assume at least 4 byte alignment +} + static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) { const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment int x32 = 0; @@ -1658,28 +1671,76 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( static __device__ __forceinline__ float vec_dot_iq1_s_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 const block_iq1_s * bq1 = (const block_iq1_s *) vbq; - const int ib32 = iqs; - int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0; - const uint8_t h1 = bq1->scales[2*ib32+0]; - const uint8_t h2 = bq1->scales[2*ib32+1]; - const int * q8 = (const int *)bq8_1[ib32].qs; - const int * grid1 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5))); - const int * grid2 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1))); - const int * grid3 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5))); - const int * grid4 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1))); - for (int j = 0; j < 2; ++j) { - sumi1 = __dp4a(q8[j+0], grid1[j], sumi1); - sumi2 = __dp4a(q8[j+2], grid2[j], sumi2); - sumi3 = __dp4a(q8[j+4], grid3[j], sumi3); - sumi4 = __dp4a(q8[j+6], grid4[j], sumi4); - } - const float d = __half2float(bq1->d) * __low2float(bq8_1[ib32].ds); - return d * (sumi1 * (2*(h1 & 7) + 1) + sumi2 * (2*((h1 >> 4) & 7) + 1) + - sumi3 * (2*(h2 & 7) + 1) + sumi4 * (2*((h2 >> 4) & 7) + 1)); -#endif + const int qs_packed = get_int_b2(bq1->qs, iqs); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + const int qh = bq1->qh[iqs]; + + int sumi = 0; +#pragma unroll + for (int l0 = 0; l0 < 8; l0 += 2) { + const int grid = iq1s_grid_gpu[qs[l0/2] | (((qh >> 3*(l0/2)) & 0x07) << 8)]; + + const int grid0 = (grid >> 0) & 0x0F0F0F0F; + const int grid1 = (grid >> 4) & 0x0F0F0F0F; + + const int u0 = get_int_b4(bq8_1[iqs].qs, l0 + 0); + const int u1 = get_int_b4(bq8_1[iqs].qs, l0 + 1); + + sumi = __dp4a(grid0, u0, sumi); + sumi = __dp4a(grid1, u1, sumi); + } + + const float d1q = __half2float(bq1->d) * (((qh >> 11) & 0x0E) + 1); + const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000); + const float2 ds = __half22float2(bq8_1[iqs].ds); + return d1q * (ds.x*sumi + ds.y*delta); +} + +static __device__ __forceinline__ float vec_dot_iq1_m_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_iq1_m * bq1 = (const block_iq1_m *) vbq; + + const int qs_packed = get_int_b4(bq1->qs, iqs); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + int sumi[2] = {0}; + float sumf[2] = {0.0f}; +#pragma unroll + for (int l0 = 0; l0 < 8; l0 += 2) { + const int qhl = bq1->qh[2*iqs + l0/4] >> (4 * ((l0/2) % 2)); + + const int grid = iq1s_grid_gpu[qs[l0/2] | ((qhl & 0x07) << 8)]; + + const int grid0 = (grid >> 0) & 0x0F0F0F0F; + const int grid1 = (grid >> 4) & 0x0F0F0F0F; + + const int u0 = get_int_b4(bq8_1[iqs].qs, l0 + 0); + const int u1 = get_int_b4(bq8_1[iqs].qs, l0 + 1); + + sumi[l0/4] = __dp4a(grid0, u0, sumi[l0/4]); + sumi[l0/4] = __dp4a(grid1, u1, sumi[l0/4]); + + const float delta = -1.0f + IQ1M_DELTA - (qhl & 0x08) * (2.0f*IQ1M_DELTA/0x08); + int sumy = 0; + sumy = __dp4a(u0, 0x01010101, sumy); + sumy = __dp4a(u1, 0x01010101, sumy); + sumf[l0/4] += delta*sumy; + } + + const uint16_t * sc = (const uint16_t *) bq1->scales; + + iq1m_scale_t scale; + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00F0) | ((sc[2] >> 4) & 0x0F00) | (sc[3] & 0xF000); + const float d = __half2float(scale.f16) * __low2float(bq8_1[iqs].ds); + + const int tmp = sc[iqs/2] >> (6*(iqs%2)); + const int sc0 = 2*((tmp >> 0) & 0x07) + 1; + const int sc1 = 2*((tmp >> 3) & 0x07) + 1; + return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1); } static __device__ __forceinline__ void get_int_from_table_16(const uint32_t & q4, const uint8_t * values, diff --git a/requirements-common.txt b/requirements-common.txt index c5f003c3c7ddc..ad950d0313454 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -24,7 +24,7 @@ filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 partial-json-parser # used for parsing partial JSON outputs pyzmq msgspec -gguf == 0.9.1 +gguf == 0.10.0 importlib_metadata mistral_common >= 1.4.0 pyyaml diff --git a/tests/kernels/test_gguf.py b/tests/kernels/test_gguf.py new file mode 100644 index 0000000000000..ee29ed93b61fc --- /dev/null +++ b/tests/kernels/test_gguf.py @@ -0,0 +1,126 @@ +from pathlib import Path +from typing import List + +import pytest +import torch +from gguf import GGMLQuantizationType, GGUFReader, ReaderTensor, dequantize +from huggingface_hub import snapshot_download + +import vllm._custom_ops as ops + +GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample") + + +def get_gguf_sample_tensors( + hidden_size: int, + quant_type: GGMLQuantizationType) -> List[ReaderTensor]: + sample_dir = GGUF_SAMPLE + filename = f"Quant_{quant_type.name}_{hidden_size}.gguf" + sample_file = Path(sample_dir) / filename + return GGUFReader(sample_file).tensors + + +DTYPES = [torch.half] +# Hidden_size for testing, must match the sample file in HF repo, +# we have `hidden_size = 256, 1024` for test in HF repo currently. +HIDDEN_SIZES = [256, 1024] +NUM_TOKENS = [7, 83, 128, 2048] # Arbitrary values for testing +SEEDS = [0] +QUANT_TYPES = [ + # i-matrix + GGMLQuantizationType.IQ1_M, + GGMLQuantizationType.IQ1_S, + GGMLQuantizationType.IQ2_S, + GGMLQuantizationType.IQ2_XS, + GGMLQuantizationType.IQ3_S, + GGMLQuantizationType.IQ3_XXS, + GGMLQuantizationType.IQ4_NL, + GGMLQuantizationType.IQ4_XS, + # k-quants + GGMLQuantizationType.Q2_K, + GGMLQuantizationType.Q3_K, + GGMLQuantizationType.Q4_K, + GGMLQuantizationType.Q5_K, + GGMLQuantizationType.Q6_K, + # standard quantization + GGMLQuantizationType.Q4_0, + GGMLQuantizationType.Q5_0, + GGMLQuantizationType.Q8_0, +] + + +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("quant_type", QUANT_TYPES) +@torch.inference_mode() +def test_dequantize(hidden_size: int, dtype: torch.dtype, + quant_type: GGMLQuantizationType): + tensors = get_gguf_sample_tensors(hidden_size, quant_type) + for tensor in tensors: + shape_str = tensor.name.split("_")[-1] + shape = map(int, shape_str.split("x")) + + ref_output = torch.tensor(dequantize(tensor.data, quant_type), + device="cuda").to(dtype) + output = ops.ggml_dequantize(torch.tensor(tensor.data, device="cuda"), + quant_type, *list(shape)).to(dtype) + + torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=4e-2) + + +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("quant_type", QUANT_TYPES) +@torch.inference_mode() +def test_mmvq(hidden_size: int, dtype: torch.dtype, + quant_type: GGMLQuantizationType): + torch.cuda.manual_seed_all(0) + + tensors = get_gguf_sample_tensors(hidden_size, quant_type) + x = torch.rand((1, hidden_size), dtype=dtype, device="cuda") + for tensor in tensors: + weight = torch.tensor(dequantize(tensor.data, quant_type), + device="cuda").to(dtype) + ref_output = x @ weight.T + + qweight = torch.tensor(tensor.data, device="cuda") + output = ops.ggml_mul_mat_vec_a8(qweight, x, quant_type, + qweight.shape[0]).to(dtype) + + torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize( + "quant_type", + [ + # k-quants + GGMLQuantizationType.Q2_K, + GGMLQuantizationType.Q3_K, + GGMLQuantizationType.Q4_K, + GGMLQuantizationType.Q5_K, + GGMLQuantizationType.Q6_K, + # standard quants + GGMLQuantizationType.Q4_0, + GGMLQuantizationType.Q5_0, + GGMLQuantizationType.Q8_0, + ]) +@torch.inference_mode() +def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype, + quant_type: GGMLQuantizationType): + torch.cuda.manual_seed_all(0) + + tensors = get_gguf_sample_tensors(hidden_size, quant_type) + x = torch.rand((num_tokens, hidden_size), dtype=dtype, device="cuda") + for tensor in tensors: + weight = torch.tensor(dequantize(tensor.data, quant_type), + device="cuda").to(dtype) + ref_output = x @ weight.T + + qweight = torch.tensor(tensor.data, device="cuda") + output = ops.ggml_mul_mat_a8(qweight, x, quant_type, + qweight.shape[0]).to(dtype) + + torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1) diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index a6a1ed5b0dee5..dc83017bcc7f9 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -55,7 +55,10 @@ def get_scaled_act_names(self) -> List[str]: def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor: # use dequantize mulmat for IQmatrix, mmq for k-quants - if qweight_type >= 16: + if x.shape[0] == 1: + # enable mmvq in contiguous batching + y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0]) + elif qweight_type >= 16: block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size) weight = ops.ggml_dequantize(qweight, qweight_type, *shape) From a091e2da3e3fcb4c63c8206839d7240a2a2a176a Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 16 Sep 2024 17:47:19 +0200 Subject: [PATCH 48/98] [Kernel] Enable 8-bit weights in Fused Marlin MoE (#8032) Co-authored-by: Dipika --- csrc/moe/marlin_moe_ops.cu | 537 +++++++++++++----- csrc/moe/marlin_moe_ops.h | 7 +- csrc/moe/torch_bindings.cpp | 8 +- tests/kernels/test_moe.py | 18 +- tests/weight_loading/models-large.txt | 3 +- .../run_model_weight_loading_test.sh | 0 vllm/_custom_ops.py | 2 +- .../layers/fused_moe/fused_marlin_moe.py | 44 +- .../layers/fused_moe/fused_moe.py | 2 +- .../compressed_tensors_moe.py | 8 +- .../layers/quantization/gptq_marlin.py | 1 + vllm/model_executor/model_loader/utils.py | 8 +- 12 files changed, 453 insertions(+), 185 deletions(-) mode change 100644 => 100755 tests/weight_loading/run_model_weight_loading_test.sh diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 92184f43c9eb0..666d87eb92595 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -25,6 +25,8 @@ #include +#include "core/scalar_type.hpp" + template inline std::string str(T x) { return std::to_string(x); @@ -131,11 +133,26 @@ __device__ inline int lop3(int a, int b, int c) { return res; } -// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 -// values. We mostly follow the strategy in the link below, with some small -// changes: -// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -__device__ inline FragB dequant(int q) { +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +template +__device__ inline FragB dequant(int q); + +// Efficiently dequantize 4bit values packed in an int32 value into a full +// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, +// with some small changes: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +template <> +__device__ inline FragB dequant(int q) { const int LO = 0x000f000f; const int HI = 0x00f000f0; const int EX = 0x64006400; @@ -156,6 +173,28 @@ __device__ inline FragB dequant(int q) { return frag_b; } +// Fast Int8ToFp16: Efficiently dequantize 8bit int values to fp16 +// Reference: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +template <> +__device__ inline FragB dequant(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { @@ -296,7 +335,8 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, __syncthreads(); } -template ( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); + + #pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } }; bool is_same_group[stages]; int same_group_id[stages]; auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + is_same_group[pipe] = false; + same_group_id[pipe] = 0; + return; + } + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); @@ -840,10 +902,19 @@ __device__ inline void MarlinMoESingle( // dequantization and matmul operations. #pragma unroll for (int j = 0; j < 4; j++) { - int b_quant = frag_b_quant[k % 2][j]; - int b_quant_shift = b_quant >> 8; + int b_quant_0, b_quant_1; + if constexpr (w_type.size_bits() == 4) { + b_quant_0 = frag_b_quant[k % 2][0][j]; + b_quant_1 = b_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + } - FragB frag_b0 = dequant(b_quant); + FragB frag_b0 = dequant(b_quant_0); + FragB frag_b1 = dequant(b_quant_1); // Apply scale to frag_b0 if constexpr (has_act_order) { @@ -855,8 +926,6 @@ __device__ inline void MarlinMoESingle( } } - FragB frag_b1 = dequant(b_quant_shift); - // Apply scale to frag_b1 if constexpr (has_act_order) { scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], @@ -881,13 +950,13 @@ __device__ inline void MarlinMoESingle( // multiple warps that accumulate their partial sums of the same output // location; which we have to reduce over in the end. We do in shared memory. auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride / 2; + constexpr int red_off = threads / b_sh_stride_threads / 2; if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride; - constexpr int red_sh_stride = b_sh_stride * 4 * 2; - constexpr int red_sh_delta = b_sh_stride; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + - (threadIdx.x % b_sh_stride); + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); // Parallel logarithmic shared memory reduction. We make sure to avoid any // unnecessary read or write iterations, e.g., for two warps we write only @@ -1035,8 +1104,10 @@ __device__ inline void MarlinMoESingle( auto write = [&](int idx, float c0, float c1, FragS& s) { half2 res = __halves2half2(__float2half(c0), __float2half(c1)); - // For per-column quantization we finally apply the scale here - if constexpr (!has_act_order && group_blocks == -1) { + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 4) { res = __hmul2(res, s[0]); } @@ -1088,9 +1159,9 @@ __device__ inline void MarlinMoESingle( // Start global fetch and register load pipelines. auto start_pipes = [&]() { - // TODO re-enable after fixing this function - // fetch_sorted_ids_to_shared(); - __syncthreads(); + // TODO re-enable after fixing this function + // fetch_sorted_ids_to_shared(); + // __syncthreads(); #pragma unroll for (int i = 0; i < stages - 1; i++) { @@ -1166,28 +1237,70 @@ __device__ inline void MarlinMoESingle( if (slice_iters == 0) { cp_async_wait<0>(); bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out if constexpr (!has_act_order && group_blocks == -1) { - if (last) { + if constexpr (w_type.size_bits() == 8) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); } cp_async_fence(); + } else { + // For 4-bit per-column scales, we only fetch them here in the + // final step before write-out + if (last) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } } } thread_block_reduce(); if constexpr (!has_act_order && group_blocks == -1) { - if (last) { + if constexpr (w_type.size_bits() == 8) { cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } + + } else { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } } } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 8) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float(reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + 0]); + + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + } + if (slice_count > 1) { // only globally reduce if there is more than one // block in a slice barrier_acquire(&locks[slice_col], slice_idx); @@ -1227,7 +1340,8 @@ __device__ inline void MarlinMoESingle( } } -template 4) { + if (max_block > cfg_max_m_blocks) { // Note that parallel > 1 currently only works for inputs without any // padding - par = (16 * max_block - pad) / 64; - par = min((16 * max_block - pad) / 64, max_par); - prob_m = 64 * par; - m_block_ctr += 4 * (par - 1); - max_block = 4; + par = (16 * max_block - pad) / (16 * cfg_max_m_blocks); + if (par > max_par) par = max_par; + prob_m = (16 * cfg_max_m_blocks) * par; + m_block_ctr += cfg_max_m_blocks * (par - 1); + max_block = cfg_max_m_blocks; } if (max_block == 1) { - MarlinMoESingle( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 2) { - MarlinMoESingle( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 3) { - MarlinMoESingle( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else { - MarlinMoESingle( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, @@ -1342,7 +1457,8 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, return; } -template , \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - MarlinMoE \ + MarlinMoE \ <<>>( \ A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ - replicate_input, apply_weights, m_block, max_par); \ + replicate_input, apply_weights, m_block, max_par, \ + exec_cfg.max_m_blocks); \ } typedef struct { @@ -1423,6 +1543,11 @@ typedef struct { int num_threads; } thread_config_t; +typedef struct { + int max_m_blocks; + thread_config_t tb_cfg; +} exec_config_t; + thread_config_t small_batch_thread_configs[] = { // Ordered by priority @@ -1443,8 +1568,77 @@ thread_config_t large_batch_thread_configs[] = { {128, 64, 128}, // Reduce N 4X, increase K 2X }; -bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, - int prob_k) { +int get_scales_cache_size(thread_config_t const& th_config, int prob_m, + int prob_n, int prob_k, int num_bits, int group_size, + bool has_act_order, bool is_k_full) { + bool cache_scales_chunk = has_act_order && !is_k_full; + + int tb_n = th_config.thread_n; + int tb_k = th_config.thread_k; + + // Get max scale groups per thread-block + int tb_groups; + if (group_size == -1) { + tb_groups = 1; + } else if (group_size == 0) { + tb_groups = ceildiv(tb_k, 32); // Worst case is 32 group size + } else { + tb_groups = ceildiv(tb_k, group_size); + } + + if (cache_scales_chunk) { + int load_groups = + tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups + return load_groups * tb_n * 2; + + } else { + int tb_scales = tb_groups * tb_n * 2; + + return tb_scales * STAGES; + } +} + +bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int scales_cache_size, int max_shared_mem) { + int pack_factor = 32 / num_bits; + + // Get B size + int tb_k = th_config.thread_k; + int tb_n = th_config.thread_n; + + int b_size = (tb_k * tb_n / pack_factor) * 4; + + // Get A size + int m_blocks = ceildiv(prob_m, 16); + int tb_max_m = 16; + + while (true) { + if (m_blocks >= max_m_blocks) { + tb_max_m *= max_m_blocks; + break; + } + + max_m_blocks--; + if (max_m_blocks == 0) { + TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); + } + } + + int a_size = (tb_max_m * tb_k) * 2; + + float pipe_size = (a_size + b_size) * STAGES; + + TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity + + return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); +} + +bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int group_size, bool has_act_order, bool is_k_full, + int max_shared_mem) { // Sanity if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { @@ -1472,64 +1666,88 @@ bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, return false; } + // Determine cache for scales + int scales_cache_size = + get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, + group_size, has_act_order, is_k_full); + + // Check that pipeline fits into cache + if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, scales_cache_size, max_shared_mem)) { + return false; + } + return true; } -thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { - if (prob_m <= 16) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; +exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, + int num_bits, int group_size, + bool has_act_order, bool is_k_full, + int max_shared_mem) { + int max_m_blocks = 4; + while (max_m_blocks > 0) { + if (prob_m <= 16) { + for (auto th_config : small_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } } - } - - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; + } else { + for (auto th_config : large_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } } } + + max_m_blocks--; // Process less M blocks per invocation to reduce cache + // usage } - return thread_config_t{-1, -1, -1}; + return exec_config_t{0, {-1, -1, -1}}; } -#define CALL_IF_MOE(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) +#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, const void* sorted_ids, const void* topk_weights, const void* topk_ids, const void* s, const void* g_idx, const void* perm, void* a_tmp, void* expert_offsets, int prob_m, int prob_n, int prob_k, void* workspace, - bool has_act_order, bool is_k_full, int num_groups, - int group_size, int num_experts, int topk, - int moe_block_size, int dev, cudaStream_t stream, - int thread_k, int thread_n, int sms, int max_par, - bool replicate_input, bool apply_weights) { + vllm::ScalarType const& q_type, bool has_act_order, + bool is_k_full, int num_groups, int group_size, + int num_experts, int topk, int moe_block_size, int dev, + cudaStream_t stream, int thread_k, int thread_n, + int sms, int max_par, bool replicate_input, + bool apply_weights) { TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); @@ -1537,26 +1755,42 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); } + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + int num_bits = q_type.size_bits(); + // Set thread config - thread_config_t th_config; + exec_config_t exec_cfg; if (thread_k != -1 && thread_n != -1) { // User-defined config - th_config = thread_config_t{thread_k, thread_n, USER_THREADS}; + exec_cfg = + exec_config_t{4, thread_config_t{thread_k, thread_n, USER_THREADS}}; } else { // Auto config - th_config = determine_thread_config(prob_m, prob_n, prob_k); + exec_cfg = + determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, max_shared_mem); } - TORCH_CHECK(is_valid_config(th_config, prob_m, prob_n, prob_k), - "Invalid thread config: thread_k = " + str(th_config.thread_k) + - ", thread_n = " + str(th_config.thread_n) + - ", num_threads = " + str(th_config.num_threads) + - " for MKN = [" + str(prob_m) + ", " + str(prob_k) + ", " + - str(prob_n) + "]"); - - int num_threads = th_config.num_threads; - thread_k = th_config.thread_k; - thread_n = th_config.thread_n; + TORCH_CHECK(exec_cfg.max_m_blocks > 0 && + is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, + prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, max_shared_mem), + "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, + ", thread_k = ", exec_cfg.tb_cfg.thread_k, + ", thread_n = ", exec_cfg.tb_cfg.thread_n, + ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", + prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, + ", group_size = ", group_size, + ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, + ", max_shared_mem = ", max_shared_mem); + + int num_threads = exec_cfg.tb_cfg.num_threads; + thread_k = exec_cfg.tb_cfg.thread_k; + thread_n = exec_cfg.tb_cfg.thread_n; int thread_k_blocks = thread_k / 16; int thread_n_blocks = thread_n / 16; @@ -1590,11 +1824,6 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, } } - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - TORCH_CHECK(max_shared_mem > 0); - int tot_m = prob_m; const int* topk_ids_ptr = (const int*)topk_ids; @@ -1611,10 +1840,13 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, has_act_order = false; } + int pack_factor = 32 / q_type.size_bits(); + for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) { const int4* A_ptr = (const int4*)A; int4* a_tmp_ptr = (int4*)a_tmp; - const int4* B_ptr = (const int4*)B + (prob_n * prob_k / 32) * expert_idx; + const int4* B_ptr = + (const int4*)B + (prob_n * prob_k / (pack_factor * 4)) * expert_idx; int4* C_ptr = (int4*)C; const float* topk_weights_ptr = (const float*)topk_weights; const int* sorted_ids_ptr = (const int*)sorted_ids; @@ -1636,19 +1868,22 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, A_ptr = a_tmp_ptr; } - int max_m_blocks = ceildiv(tot_m, 16); - for (int m_block = 0; m_block < max_m_blocks; m_block += 16) { - // Define kernel configurations - + int tot_m_blocks = ceildiv(tot_m, 16); + for (int m_block = 0; m_block < tot_m_blocks; + m_block += 4 * exec_cfg.max_m_blocks) { // make it max possible value - int thread_m_blocks = 4; + int thread_m_blocks = exec_cfg.max_m_blocks; if (false) { } - CALL_IF_MOE(16, 4, 256) - CALL_IF_MOE(8, 8, 256) - CALL_IF_MOE(8, 4, 128) - CALL_IF_MOE(4, 8, 128) + CALL_IF_MOE(vllm::kU4B8, 16, 4, 256) + CALL_IF_MOE(vllm::kU4B8, 8, 8, 256) + CALL_IF_MOE(vllm::kU4B8, 8, 4, 128) + CALL_IF_MOE(vllm::kU4B8, 4, 8, 128) + CALL_IF_MOE(vllm::kU8B128, 16, 4, 256) + CALL_IF_MOE(vllm::kU8B128, 8, 8, 256) + CALL_IF_MOE(vllm::kU8B128, 8, 4, 128) + CALL_IF_MOE(vllm::kU8B128, 4, 8, 128) else { TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + str(prob_n) + ", " + str(prob_k) + "]" + @@ -1670,9 +1905,15 @@ torch::Tensor marlin_gemm_moe( const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, const torch::Tensor& topk_ids, const torch::Tensor& b_scales, const torch::Tensor& g_idx, const torch::Tensor& perm, - torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, + torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, + int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, + int64_t num_experts, int64_t topk, int64_t moe_block_size, bool replicate_input, bool apply_weights) { + TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, + "b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str()); + + int pack_factor = 32 / b_q_type->size_bits(); + int max_par = 4; int dev = a.get_device(); @@ -1733,8 +1974,8 @@ torch::Tensor marlin_gemm_moe( topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), - has_act_order, is_k_full, num_groups, group_size, num_experts, topk, - moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, + *b_q_type, has_act_order, is_k_full, num_groups, group_size, num_experts, + topk, moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par, replicate_input, apply_weights); return c; } diff --git a/csrc/moe/marlin_moe_ops.h b/csrc/moe/marlin_moe_ops.h index 43d264e0770d6..adee8399a4d6f 100644 --- a/csrc/moe/marlin_moe_ops.h +++ b/csrc/moe/marlin_moe_ops.h @@ -2,11 +2,14 @@ #include +#include "core/scalar_type.hpp" + torch::Tensor marlin_gemm_moe( const torch::Tensor& a, const torch::Tensor& b_q_weights, const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, const torch::Tensor& topk_ids, const torch::Tensor& b_scales, const torch::Tensor& g_idx, const torch::Tensor& perm, - torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, + torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, + int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, + int64_t num_experts, int64_t topk, int64_t moe_block_size, bool replicate_input, bool apply_weights); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 8a0e625b43fa1..cd65a8ee92b94 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -13,9 +13,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.def( "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " - "g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int " - "size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, " - "bool replicate_input, bool apply_weights) -> Tensor"); + "g_idx, Tensor! perm, Tensor! workspace, " + "__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, " + "int size_n, int size_k, bool is_k_full, int num_experts, int topk, " + "int moe_block_size, bool replicate_input, bool apply_weights)" + " -> Tensor"); m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); #endif } diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 2250cf1598b8b..8072cf09e5b65 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -140,6 +140,7 @@ def compute_max_diff(output, output_ref): @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("act_order", [True, False]) +@pytest.mark.parametrize("num_bits", [4, 8]) def test_fused_marlin_moe( m: int, n: int, @@ -148,6 +149,7 @@ def test_fused_marlin_moe( topk: int, group_size: int, act_order: bool, + num_bits: int, ): torch.manual_seed(7) @@ -161,13 +163,12 @@ def test_fused_marlin_moe( if group_size in (k, n): return - quant_type = scalar_types.uint4b8 + quant_type = (scalar_types.uint4b8 + if num_bits == 4 else scalar_types.uint8b128) dtype = torch.float16 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - for i in range(w2.shape[0]): - w2[0] = torch.eye(k, n, device="cuda", dtype=dtype) w_ref1_l = [] qweight1_l = [] @@ -240,6 +241,7 @@ def test_fused_marlin_moe( topk_ids, w1_scale=scales1, w2_scale=scales2, + num_bits=num_bits, ) assert compute_max_diff(marlin_output, triton_output) < 4e-2 @@ -254,7 +256,8 @@ def test_fused_marlin_moe( @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("act_order", [True, False]) -def test_marlin_moe_mmm( +@pytest.mark.parametrize("num_bits", [4, 8]) +def test_single_marlin_moe_multiply( m: int, n: int, k: int, @@ -262,6 +265,7 @@ def test_marlin_moe_mmm( topk: int, group_size: int, act_order: bool, + num_bits: int, ): if topk > e: return @@ -273,7 +277,8 @@ def test_marlin_moe_mmm( if group_size == k: return - quant_type = scalar_types.uint4b8 + quant_type = (scalar_types.uint4b8 + if num_bits == 4 else scalar_types.uint8b128) dtype = torch.float16 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10 @@ -308,7 +313,8 @@ def test_marlin_moe_mmm( g_idx, sort_indices, topk, - renormalize=False) + renormalize=False, + num_bits=num_bits) torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) assert compute_max_diff(marlin_output, torch_output) < 1e-2 diff --git a/tests/weight_loading/models-large.txt b/tests/weight_loading/models-large.txt index fe76705746766..2f5c6c5a117f3 100644 --- a/tests/weight_loading/models-large.txt +++ b/tests/weight_loading/models-large.txt @@ -1,3 +1,4 @@ compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main -gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main \ No newline at end of file +compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main +gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main diff --git a/tests/weight_loading/run_model_weight_loading_test.sh b/tests/weight_loading/run_model_weight_loading_test.sh old mode 100644 new mode 100755 diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ed08878f14875..74b3b69606c67 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -559,7 +559,7 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, num_bits: int) -> torch.Tensor: num_experts = b_q_weight.shape[0] assert size_k % 16 == 0 - output = torch.empty((num_experts, size_k // 16, size_n * 2), + output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)), device=b_q_weight.device, dtype=b_q_weight.dtype) for e in range(num_experts): diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 200a6148978aa..866b18d725a8c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -7,18 +7,21 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, moe_align_block_size, try_get_optimal_moe_config) +from vllm.scalar_type import scalar_types def single_marlin_moe( - hidden_states: torch.Tensor, - w: torch.Tensor, - scales: torch.Tensor, - gating_output: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - topk: int, - renormalize: bool, - override_config: Optional[Dict[str, Any]] = None) -> torch.Tensor: + hidden_states: torch.Tensor, + w: torch.Tensor, + scales: torch.Tensor, + gating_output: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + topk: int, + renormalize: bool, + override_config: Optional[Dict[str, Any]] = None, + num_bits: int = 8, +) -> torch.Tensor: """ This function computes the multiplication of hidden_states with expert weights used in Marlin MoE, using weights w and top-k gating mechanism. @@ -36,6 +39,7 @@ def single_marlin_moe( - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration. + - num_bits (bool): The number of bits in expert weights quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -48,10 +52,11 @@ def single_marlin_moe( assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w.is_contiguous(), "Expert weights must be contiguous" assert hidden_states.dtype == torch.float16 + assert num_bits in [4, 8] M, K = hidden_states.shape E = w.shape[0] - N = w.shape[2] // 2 + N = w.shape[2] // (num_bits // 2) topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, renormalize) @@ -76,10 +81,13 @@ def single_marlin_moe( device="cuda", requires_grad=False) + scalar_type = (scalar_types.uint4b8 + if num_bits == 4 else scalar_types.uint8b128) + intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, - g_idx, perm, workspace, M, N, K, True, E, topk, block_size_m, True, - False) + g_idx, perm, workspace, scalar_type, M, N, K, True, E, topk, + block_size_m, True, False) return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) @@ -98,6 +106,7 @@ def fused_marlin_moe( override_config: Optional[Dict[str, Any]] = None, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + num_bits: int = 8, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -122,6 +131,7 @@ def fused_marlin_moe( w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. + - num_bits (bool): The number of bits in expert weights quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -131,13 +141,14 @@ def fused_marlin_moe( 0], "Number of tokens mismatch" assert hidden_states.shape[ 1] == w1.shape[1] * 16, "Hidden size mismatch w1" - assert hidden_states.shape[ - 1] == w2.shape[2] // 2, "Hidden size mismatch w2" + assert hidden_states.shape[1] == w2.shape[2] // ( + num_bits // 2), "Hidden size mismatch w2" assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert hidden_states.dtype == torch.float16 + assert num_bits in [4, 8] M, K = hidden_states.shape E = w1.shape[0] @@ -165,6 +176,9 @@ def fused_marlin_moe( device="cuda", requires_grad=False) + scalar_type = (scalar_types.uint4b8 + if num_bits == 4 else scalar_types.uint8b128) + intermediate_cache2 = torch.empty( (M * topk_ids.shape[1], N), device=hidden_states.device, @@ -181,6 +195,7 @@ def fused_marlin_moe( g_idx1, perm1, workspace, + scalar_type, M, 2 * N, K, @@ -204,6 +219,7 @@ def fused_marlin_moe( g_idx2, perm2, workspace, + scalar_type, M, K, N, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index a0cb4337f9dee..3e01112eaa14d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -445,7 +445,7 @@ def grouped_topk(hidden_states: torch.Tensor, if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids.to(torch.int32) + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) def get_config_dtype_str(dtype: torch.dtype, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 49c29c2775cb6..7dee2fca81153 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -6,6 +6,8 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + WNA16_SUPPORTED_BITS) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( CompressionFormat) from vllm.model_executor.utils import set_weight_attrs @@ -38,10 +40,11 @@ def __init__( if not (self.quant_config.quant_format == CompressionFormat.pack_quantized.value - and self.num_bits == 4): + and self.num_bits in WNA16_SUPPORTED_BITS): raise ValueError("For Fused MoE layers, only ", f"{CompressionFormat.pack_quantized.value} ", - "is supported for 4 bits") + "is supported for the following bits: ", + f"{WNA16_SUPPORTED_BITS}") def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size: int, @@ -292,4 +295,5 @@ def apply( topk_ids, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, + num_bits=self.num_bits, ) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 3617a32f80fc1..cc699f5b4554f 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -611,4 +611,5 @@ def apply( topk_ids, w1_scale=layer.w13_scales, w2_scale=layer.w2_scales, + num_bits=self.quant_config.quant_type.size_bits, ).to(orig_dtype) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 0052489d99dc4..2bfe6ea09bd62 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -23,13 +23,7 @@ def get_model_architecture( architectures = getattr(model_config.hf_config, "architectures", []) # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. - mixtral_supported = ["fp8", "compressed-tensors"] - # for gptq_marlin, only run fused MoE for int4 - if model_config.quantization == "gptq_marlin": - hf_quant_config = getattr(model_config.hf_config, - "quantization_config", None) - if hf_quant_config and hf_quant_config.get("bits") == 4: - mixtral_supported.append("gptq_marlin") + mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin"] if (model_config.quantization is not None and model_config.quantization not in mixtral_supported From 837c1968f9f1fdd9d894b2071d605ca1bdc97942 Mon Sep 17 00:00:00 2001 From: lewtun Date: Mon, 16 Sep 2024 17:55:26 +0200 Subject: [PATCH 49/98] [Frontend] Expose revision arg in OpenAI server (#8501) --- vllm/entrypoints/openai/api_server.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index d8704d5e24964..7c1f307e06619 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -69,8 +69,10 @@ def model_is_embedding(model_name: str, trust_remote_code: bool, - quantization: Optional[str]) -> bool: + quantization: Optional[str], + revision: Optional[str]) -> bool: return ModelConfig(model=model_name, + revision=revision, tokenizer=model_name, tokenizer_mode="auto", trust_remote_code=trust_remote_code, @@ -130,7 +132,7 @@ async def build_async_engine_client_from_engine_args( # If manually triggered or embedding model, use AsyncLLMEngine in process. # TODO: support embedding model via RPC. if (model_is_embedding(engine_args.model, engine_args.trust_remote_code, - engine_args.quantization) + engine_args.quantization, engine_args.revision) or disable_frontend_multiprocessing): engine_client = AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_API_SERVER) From acd5511b6d0e196b273b6250201115b5c5cfd1ca Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 16 Sep 2024 17:33:46 +0100 Subject: [PATCH 50/98] [BugFix] Fix clean shutdown issues (#8492) --- tests/async_engine/test_async_llm_engine.py | 10 +- vllm/engine/async_llm_engine.py | 70 +++++--- vllm/engine/llm_engine.py | 21 ++- vllm/entrypoints/launcher.py | 21 ++- vllm/entrypoints/openai/api_server.py | 181 ++++++++++++-------- vllm/entrypoints/openai/rpc/server.py | 8 +- vllm/executor/multiproc_gpu_executor.py | 14 -- vllm/executor/multiproc_worker_utils.py | 5 +- vllm/executor/ray_tpu_executor.py | 2 + vllm/scripts.py | 4 +- vllm/utils.py | 15 ++ 11 files changed, 215 insertions(+), 136 deletions(-) diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index a093a2b29278a..6cae76f74603d 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -26,6 +26,11 @@ class RequestOutput: finished: bool = False +@dataclass +class MockModelConfig: + use_async_output_proc = True + + class MockEngine: def __init__(self): @@ -35,6 +40,7 @@ def __init__(self): self.request_id = None # Ugly, remove dependency when possible self.parallel_config = ParallelConfig(1, 1, False) + self.model_config = MockModelConfig() async def step_async(self, virtual_engine): # PP size is 1, ignore virtual engine @@ -80,7 +86,7 @@ class MockAsyncLLMEngine(AsyncLLMEngine): @pytest.mark.asyncio async def test_new_requests_event(): - engine = MockAsyncLLMEngine(worker_use_ray=False) + engine = MockAsyncLLMEngine() engine.start_background_loop() await asyncio.sleep(0.01) assert engine.engine.step_calls == 0 @@ -113,7 +119,7 @@ async def test_new_requests_event(): assert engine.engine.add_request_calls == 3 assert engine.engine.step_calls == old_step_calls + 1 - engine = MockAsyncLLMEngine(worker_use_ray=True) + engine = MockAsyncLLMEngine() assert engine.get_model_config() is not None assert engine.get_tokenizer() is not None assert engine.get_decoding_config() is not None diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 8a07ce1c965e1..410e6ffaa2d50 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,8 +1,10 @@ import asyncio import time +import weakref from functools import partial from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List, Mapping, Optional, Set, Tuple, Type, Union) +from weakref import ReferenceType import vllm.envs as envs from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, @@ -26,6 +28,7 @@ from vllm.sequence import ExecuteModelRequest from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext +from vllm.utils import weak_bind logger = init_logger(__name__) ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S @@ -450,9 +453,6 @@ class AsyncLLMEngine: method yields the outputs from the :class:`LLMEngine` to the caller. Args: - worker_use_ray: Whether to use Ray for model workers. Required for - distributed execution. Should be the same as - `parallel_config.worker_use_ray`. log_requests: Whether to log the requests. start_engine_loop: If True, the background task to run the engine will be automatically started in the generate call. @@ -463,23 +463,22 @@ class AsyncLLMEngine: _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine def __init__(self, - worker_use_ray: bool, *args, log_requests: bool = True, start_engine_loop: bool = True, **kwargs) -> None: - self.worker_use_ray = worker_use_ray self.log_requests = log_requests self.engine = self._engine_class(*args, **kwargs) # This ensures quick processing of request outputs # so the append to asyncio queues is not delayed, # especially for multi-step. - # - self.use_process_request_outputs_callback = True + self.use_process_request_outputs_callback = ( + self.engine.model_config.use_async_output_proc) + if self.use_process_request_outputs_callback: self.engine.process_request_outputs_callback = \ - self.process_request_outputs + weak_bind(self.process_request_outputs) self.background_loop: Optional[asyncio.Future] = None # We need to keep a reference to unshielded @@ -492,6 +491,11 @@ def __init__(self, # Lazy initialized fields self._request_tracker: RequestTracker + def __del__(self): + if rt := getattr(self, "request_tracker", None): + # Wake up engine loop so that it will exit cleanly + rt.new_requests_event.set() + @classmethod def _get_executor_cls( cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]: @@ -502,15 +506,12 @@ def _get_executor_cls( raise TypeError( "distributed_executor_backend must be a subclass of " f"ExecutorAsyncBase. Got {distributed_executor_backend}.") - if distributed_executor_backend.uses_ray: # type: ignore - initialize_ray_cluster(engine_config.parallel_config) executor_class = distributed_executor_backend elif engine_config.device_config.device_type == "neuron": from vllm.executor.neuron_executor import NeuronExecutorAsync executor_class = NeuronExecutorAsync elif engine_config.device_config.device_type == "tpu": if distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync executor_class = RayTPUExecutorAsync else: @@ -531,11 +532,9 @@ def _get_executor_cls( from vllm.executor.xpu_executor import XPUExecutorAsync executor_class = XPUExecutorAsync elif distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync executor_class = RayXPUExecutorAsync elif distributed_executor_backend == "mp": - initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.multiproc_xpu_executor import ( MultiprocessingXPUExecutorAsync) executor_class = MultiprocessingXPUExecutorAsync @@ -543,7 +542,6 @@ def _get_executor_cls( raise RuntimeError( "Not supported distributed execution model on XPU device.") elif distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync executor_class = RayGPUExecutorAsync elif distributed_executor_backend == "mp": @@ -559,19 +557,23 @@ def _get_executor_cls( def from_engine_args( cls, engine_args: AsyncEngineArgs, + engine_config: Optional[EngineConfig] = None, start_engine_loop: bool = True, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, ) -> "AsyncLLMEngine": """Creates an async LLM engine from the engine arguments.""" # Create the engine configs. - engine_config = engine_args.create_engine_config() + if engine_config is None: + engine_config = engine_args.create_engine_config() executor_class = cls._get_executor_cls(engine_config) + if executor_class.uses_ray: + initialize_ray_cluster(engine_config.parallel_config) + # Create the async LLM engine. engine = cls( - executor_class.uses_ray, **engine_config.to_dict(), executor_class=executor_class, log_requests=not engine_args.disable_log_requests, @@ -628,7 +630,7 @@ def start_background_loop(self) -> None: self._request_tracker = RequestTracker() self._background_loop_unshielded = asyncio.get_event_loop( - ).create_task(self.run_engine_loop()) + ).create_task(self.run_engine_loop(weakref.ref(self))) self._background_loop_unshielded.add_done_callback( partial(_log_task_completion, error_callback=self._error_callback)) self.background_loop = asyncio.shield(self._background_loop_unshielded) @@ -698,9 +700,16 @@ def process_request_outputs(self, request_outputs) -> bool: async def _engine_abort(self, request_ids: Iterable[str]): self.engine.abort_request(request_ids) - async def run_engine_loop(self): + @staticmethod + async def run_engine_loop(engine_ref: ReferenceType): + """We use a weakref to the engine so that the running loop + doesn't prevent the engine being garbage collected.""" + engine: Optional["AsyncLLMEngine"] = engine_ref() + if not engine: + return + pipeline_parallel_size = \ - self.engine.parallel_config.pipeline_parallel_size + engine.engine.parallel_config.pipeline_parallel_size has_requests_in_progress = [False] * pipeline_parallel_size while True: if not any(has_requests_in_progress): @@ -711,11 +720,21 @@ async def run_engine_loop(self): # timeout, and unblocks the RPC thread in the workers so that # they can process any other queued control plane messages, # such as add/remove lora adapters. - await self.engine.stop_remote_worker_execution_loop_async() - await self._request_tracker.wait_for_new_requests() + await engine.engine.stop_remote_worker_execution_loop_async() + request_tracker = engine._request_tracker + # Allow engine to be garbage collected while + # waiting for new requests + del engine + await asyncio.sleep(0) + if engine_ref() is None: + return + await request_tracker.wait_for_new_requests() + engine = engine_ref() + if not engine: + return logger.debug("Got new requests!") requests_in_progress = [ - asyncio.create_task(self.engine_step(ve)) + asyncio.create_task(engine.engine_step(ve)) for ve in range(pipeline_parallel_size) ] has_requests_in_progress = [True] * pipeline_parallel_size @@ -733,19 +752,20 @@ async def run_engine_loop(self): result = task.result() virtual_engine = requests_in_progress.index(task) has_unfinished_requests = ( - self.engine.has_unfinished_requests_for_virtual_engine( + engine.engine. + has_unfinished_requests_for_virtual_engine( virtual_engine)) if result or has_unfinished_requests: requests_in_progress[virtual_engine] = ( asyncio.create_task( - self.engine_step(virtual_engine))) + engine.engine_step(virtual_engine))) has_requests_in_progress[virtual_engine] = True else: has_requests_in_progress[virtual_engine] = False except asyncio.TimeoutError as exc: logger.error( "Engine iteration timed out. This should never happen!") - self.set_errored(exc) + engine.set_errored(exc) raise await asyncio.sleep(0) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index dfdbc22ef00e1..8b5009b2c6668 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,8 +1,8 @@ -import functools import time from collections import deque from contextlib import contextmanager from dataclasses import dataclass +from functools import partial from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, Iterable, List, Mapping, NamedTuple, Optional) from typing import Sequence as GenericSequence @@ -51,7 +51,7 @@ BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) -from vllm.utils import Counter, Device +from vllm.utils import Counter, Device, weak_bind from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -382,11 +382,16 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: for _ in range(self.parallel_config.pipeline_parallel_size) ] - self.async_callbacks = [ - functools.partial(self._process_model_outputs, - ctx=self.scheduler_contexts[v_id]) - for v_id in range(self.parallel_config.pipeline_parallel_size) - ] + if model_config.use_async_output_proc: + process_model_outputs = weak_bind(self._process_model_outputs) + + self.async_callbacks = [ + partial(process_model_outputs, + ctx=self.scheduler_contexts[v_id]) + for v_id in range(self.parallel_config.pipeline_parallel_size) + ] + else: + self.async_callbacks = [] # Currently used by AsyncLLMEngine to ensure quick append # of request outputs to asyncio queues @@ -869,8 +874,8 @@ def has_unfinished_requests_for_virtual_engine( """ return self.scheduler[virtual_engine].has_unfinished_seqs() + @staticmethod def _process_sequence_group_outputs( - self, seq_group: SequenceGroup, outputs: List[EmbeddingSequenceGroupOutput], ) -> None: diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 3598872b65bb0..47d227010c075 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -1,21 +1,20 @@ import asyncio import signal from http import HTTPStatus -from typing import Any +from typing import Any, Optional import uvicorn -from fastapi import FastAPI, Response +from fastapi import FastAPI, Request, Response from vllm import envs from vllm.engine.async_llm_engine import AsyncEngineDeadError -from vllm.engine.protocol import AsyncEngineClient from vllm.logger import init_logger from vllm.utils import find_process_using_port logger = init_logger(__name__) -async def serve_http(app: FastAPI, engine: AsyncEngineClient, +async def serve_http(app: FastAPI, limit_concurrency: Optional[int], **uvicorn_kwargs: Any): logger.info("Available routes are:") for route in app.routes: @@ -29,16 +28,16 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient, # Set concurrency limits in uvicorn if running in multiprocessing mode # since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536). - if engine.limit_concurrency is not None: + if limit_concurrency is not None: logger.info( "Launching Uvicorn with --limit_concurrency %s. To avoid this " "limit at the expense of performance run with " - "--disable-frontend-multiprocessing", engine.limit_concurrency) - uvicorn_kwargs["limit_concurrency"] = engine.limit_concurrency + "--disable-frontend-multiprocessing", limit_concurrency) + uvicorn_kwargs["limit_concurrency"] = limit_concurrency config = uvicorn.Config(app, **uvicorn_kwargs) server = uvicorn.Server(config) - _add_shutdown_handlers(app, server, engine) + _add_shutdown_handlers(app, server) loop = asyncio.get_running_loop() @@ -68,15 +67,15 @@ async def dummy_shutdown() -> None: return server.shutdown() -def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server, - engine: AsyncEngineClient) -> None: +def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None: """Adds handlers for fatal errors that should crash the server""" @app.exception_handler(RuntimeError) - async def runtime_error_handler(_, __): + async def runtime_error_handler(request: Request, __): """On generic runtime error, check to see if the engine has died. It probably has, in which case the server will no longer be able to handle requests. Trigger a graceful shutdown with a SIGTERM.""" + engine = request.app.state.engine_client if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored and not engine.is_running): logger.fatal("AsyncLLMEngine has failed, terminating server " diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 7c1f307e06619..b50fc6a265f8d 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -4,16 +4,20 @@ import multiprocessing import os import re +import signal import tempfile from argparse import Namespace from contextlib import asynccontextmanager +from functools import partial from http import HTTPStatus from typing import AsyncIterator, Optional, Set +import uvloop from fastapi import APIRouter, FastAPI, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse +from starlette.datastructures import State from starlette.routing import Mount from typing_extensions import assert_never @@ -54,12 +58,6 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds -async_engine_client: AsyncEngineClient -engine_args: AsyncEngineArgs -openai_serving_chat: OpenAIServingChat -openai_serving_completion: OpenAIServingCompletion -openai_serving_embedding: OpenAIServingEmbedding -openai_serving_tokenization: OpenAIServingTokenization prometheus_multiproc_dir: tempfile.TemporaryDirectory # Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) @@ -83,18 +81,28 @@ def model_is_embedding(model_name: str, trust_remote_code: bool, @asynccontextmanager async def lifespan(app: FastAPI): - - async def _force_log(): - while True: - await asyncio.sleep(10) - await async_engine_client.do_log_stats() - - if not engine_args.disable_log_stats: - task = asyncio.create_task(_force_log()) - _running_tasks.add(task) - task.add_done_callback(_running_tasks.remove) - - yield + try: + if app.state.log_stats: + async_engine_client = app.state.engine_client + + async def _force_log(): + while True: + await asyncio.sleep(10) + await async_engine_client.do_log_stats() + + task = asyncio.create_task(_force_log()) + _running_tasks.add(task) + task.add_done_callback(_running_tasks.remove) + else: + task = None + try: + yield + finally: + if task is not None: + task.cancel() + finally: + # Ensure app state including engine ref is gc'd + del app.state @asynccontextmanager @@ -103,16 +111,10 @@ async def build_async_engine_client( # Context manager to handle async_engine_client lifecycle # Ensures everything is shutdown and cleaned up on error/exit - global engine_args engine_args = AsyncEngineArgs.from_cli_args(args) - # Backend itself still global for the silly lil' health handler - global async_engine_client - async with build_async_engine_client_from_engine_args( engine_args, args.disable_frontend_multiprocessing) as engine: - - async_engine_client = engine # type: ignore[assignment] yield engine @@ -134,12 +136,22 @@ async def build_async_engine_client_from_engine_args( if (model_is_embedding(engine_args.model, engine_args.trust_remote_code, engine_args.quantization, engine_args.revision) or disable_frontend_multiprocessing): - engine_client = AsyncLLMEngine.from_engine_args( - engine_args, usage_context=UsageContext.OPENAI_API_SERVER) - try: - yield engine_client - finally: - engine_client.shutdown_background_loop() + engine_config = engine_args.create_engine_config() + uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config), + "uses_ray", False) + + build_engine = partial(AsyncLLMEngine.from_engine_args, + engine_args=engine_args, + engine_config=engine_config, + usage_context=UsageContext.OPENAI_API_SERVER) + if uses_ray: + # Must run in main thread with ray for its signal handlers to work + engine_client = build_engine() + else: + engine_client = await asyncio.get_running_loop().run_in_executor( + None, build_engine) + + yield engine_client return # Otherwise, use the multiprocessing AsyncLLMEngine. @@ -241,16 +253,36 @@ def mount_metrics(app: FastAPI): app.routes.append(metrics_route) +def chat(request: Request) -> OpenAIServingChat: + return request.app.state.openai_serving_chat + + +def completion(request: Request) -> OpenAIServingCompletion: + return request.app.state.openai_serving_completion + + +def tokenization(request: Request) -> OpenAIServingTokenization: + return request.app.state.openai_serving_tokenization + + +def embedding(request: Request) -> OpenAIServingEmbedding: + return request.app.state.openai_serving_embedding + + +def engine_client(request: Request) -> AsyncEngineClient: + return request.app.state.engine_client + + @router.get("/health") -async def health() -> Response: +async def health(raw_request: Request) -> Response: """Health check.""" - await async_engine_client.check_health() + await engine_client(raw_request).check_health() return Response(status_code=200) @router.post("/tokenize") -async def tokenize(request: TokenizeRequest): - generator = await openai_serving_tokenization.create_tokenize(request) +async def tokenize(request: TokenizeRequest, raw_request: Request): + generator = await tokenization(raw_request).create_tokenize(request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) @@ -261,8 +293,8 @@ async def tokenize(request: TokenizeRequest): @router.post("/detokenize") -async def detokenize(request: DetokenizeRequest): - generator = await openai_serving_tokenization.create_detokenize(request) +async def detokenize(request: DetokenizeRequest, raw_request: Request): + generator = await tokenization(raw_request).create_detokenize(request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) @@ -273,8 +305,8 @@ async def detokenize(request: DetokenizeRequest): @router.get("/v1/models") -async def show_available_models(): - models = await openai_serving_completion.show_available_models() +async def show_available_models(raw_request: Request): + models = await completion(raw_request).show_available_models() return JSONResponse(content=models.model_dump()) @@ -288,7 +320,7 @@ async def show_version(): async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): - generator = await openai_serving_chat.create_chat_completion( + generator = await chat(raw_request).create_chat_completion( request, raw_request) if isinstance(generator, ErrorResponse): @@ -303,7 +335,7 @@ async def create_chat_completion(request: ChatCompletionRequest, @router.post("/v1/completions") async def create_completion(request: CompletionRequest, raw_request: Request): - generator = await openai_serving_completion.create_completion( + generator = await completion(raw_request).create_completion( request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), @@ -316,7 +348,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): @router.post("/v1/embeddings") async def create_embedding(request: EmbeddingRequest, raw_request: Request): - generator = await openai_serving_embedding.create_embedding( + generator = await embedding(raw_request).create_embedding( request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), @@ -333,16 +365,16 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): "used for local development!") @router.post("/start_profile") - async def start_profile(): + async def start_profile(raw_request: Request): logger.info("Starting profiler...") - await async_engine_client.start_profile() + await engine_client(raw_request).start_profile() logger.info("Profiler started.") return Response(status_code=200) @router.post("/stop_profile") - async def stop_profile(): + async def stop_profile(raw_request: Request): logger.info("Stopping profiler...") - await async_engine_client.stop_profile() + await engine_client(raw_request).stop_profile() logger.info("Profiler stopped.") return Response(status_code=200) @@ -353,13 +385,14 @@ async def stop_profile(): "This should ONLY be used for local development!") @router.post("/v1/load_lora_adapter") - async def load_lora_adapter(request: LoadLoraAdapterRequest): - response = await openai_serving_chat.load_lora_adapter(request) + async def load_lora_adapter(request: LoadLoraAdapterRequest, + raw_request: Request): + response = await chat(raw_request).load_lora_adapter(request) if isinstance(response, ErrorResponse): return JSONResponse(content=response.model_dump(), status_code=response.code) - response = await openai_serving_completion.load_lora_adapter(request) + response = await completion(raw_request).load_lora_adapter(request) if isinstance(response, ErrorResponse): return JSONResponse(content=response.model_dump(), status_code=response.code) @@ -367,13 +400,14 @@ async def load_lora_adapter(request: LoadLoraAdapterRequest): return Response(status_code=200, content=response) @router.post("/v1/unload_lora_adapter") - async def unload_lora_adapter(request: UnloadLoraAdapterRequest): - response = await openai_serving_chat.unload_lora_adapter(request) + async def unload_lora_adapter(request: UnloadLoraAdapterRequest, + raw_request: Request): + response = await chat(raw_request).unload_lora_adapter(request) if isinstance(response, ErrorResponse): return JSONResponse(content=response.model_dump(), status_code=response.code) - response = await openai_serving_completion.unload_lora_adapter(request) + response = await completion(raw_request).unload_lora_adapter(request) if isinstance(response, ErrorResponse): return JSONResponse(content=response.model_dump(), status_code=response.code) @@ -398,7 +432,8 @@ def build_app(args: Namespace) -> FastAPI: @app.exception_handler(RequestValidationError) async def validation_exception_handler(_, exc): - err = openai_serving_chat.create_error_response(message=str(exc)) + chat = app.state.openai_serving_chat + err = chat.create_error_response(message=str(exc)) return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST) @@ -430,30 +465,26 @@ async def authentication(request: Request, call_next): return app -async def init_app( +def init_app_state( async_engine_client: AsyncEngineClient, + model_config: ModelConfig, + state: State, args: Namespace, -) -> FastAPI: - app = build_app(args) - +) -> None: if args.served_model_name is not None: served_model_names = args.served_model_name else: served_model_names = [args.model] - model_config = await async_engine_client.get_model_config() - if args.disable_log_requests: request_logger = None else: request_logger = RequestLogger(max_log_len=args.max_log_len) - global openai_serving_chat - global openai_serving_completion - global openai_serving_embedding - global openai_serving_tokenization + state.engine_client = async_engine_client + state.log_stats = not args.disable_log_stats - openai_serving_chat = OpenAIServingChat( + state.openai_serving_chat = OpenAIServingChat( async_engine_client, model_config, served_model_names, @@ -465,7 +496,7 @@ async def init_app( return_tokens_as_token_ids=args.return_tokens_as_token_ids, enable_auto_tools=args.enable_auto_tool_choice, tool_parser=args.tool_call_parser) - openai_serving_completion = OpenAIServingCompletion( + state.openai_serving_completion = OpenAIServingCompletion( async_engine_client, model_config, served_model_names, @@ -474,13 +505,13 @@ async def init_app( request_logger=request_logger, return_tokens_as_token_ids=args.return_tokens_as_token_ids, ) - openai_serving_embedding = OpenAIServingEmbedding( + state.openai_serving_embedding = OpenAIServingEmbedding( async_engine_client, model_config, served_model_names, request_logger=request_logger, ) - openai_serving_tokenization = OpenAIServingTokenization( + state.openai_serving_tokenization = OpenAIServingTokenization( async_engine_client, model_config, served_model_names, @@ -488,25 +519,31 @@ async def init_app( request_logger=request_logger, chat_template=args.chat_template, ) - app.root_path = args.root_path - - return app async def run_server(args, **uvicorn_kwargs) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) + def signal_handler(*_) -> None: + # Interrupt server on sigterm while initializing + raise KeyboardInterrupt("terminated") + + signal.signal(signal.SIGTERM, signal_handler) + async with build_async_engine_client(args) as async_engine_client: # If None, creation of the client failed and we exit. if async_engine_client is None: return - app = await init_app(async_engine_client, args) + app = build_app(args) + + model_config = await async_engine_client.get_model_config() + init_app_state(async_engine_client, model_config, app.state, args) shutdown_task = await serve_http( app, - engine=async_engine_client, + limit_concurrency=async_engine_client.limit_concurrency, host=args.host, port=args.port, log_level=args.uvicorn_log_level, @@ -530,4 +567,4 @@ async def run_server(args, **uvicorn_kwargs) -> None: parser = make_arg_parser(parser) args = parser.parse_args() - asyncio.run(run_server(args)) + uvloop.run(run_server(args)) diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index bebc2faedb680..460ff0636b6e9 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -46,7 +46,6 @@ def cleanup(self): """Cleanup all resources.""" self.socket.close() self.context.destroy() - self.engine.shutdown_background_loop() # Clear the engine reference so that it can be GC'ed. del self.engine @@ -233,5 +232,12 @@ def signal_handler() -> None: def run_rpc_server(async_engine_args: AsyncEngineArgs, usage_context: UsageContext, rpc_path: str): + + def signal_handler(*_) -> None: + # Interrupt server on sigterm while initializing + raise KeyboardInterrupt("AsyncEngineRPCServer terminated") + + signal.signal(signal.SIGTERM, signal_handler) + server = AsyncEngineRPCServer(async_engine_args, usage_context, rpc_path) uvloop.run(run_server(server)) diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 9c6d4051eb3f8..cc535e99a06ef 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -1,8 +1,5 @@ import asyncio import os -import signal -import threading -import weakref from functools import partial from typing import Any, List, Optional @@ -108,17 +105,6 @@ def _init_executor(self) -> None: # Set up signal handlers to shutdown the executor cleanly # sometimes gc does not work well - # Use weakref to avoid holding a reference to self - ref = weakref.ref(self) - - def shutdown(signum, frame): - if executor := ref(): - executor.shutdown() - - if threading.current_thread() is threading.main_thread(): - signal.signal(signal.SIGINT, shutdown) - signal.signal(signal.SIGTERM, shutdown) - self.driver_worker = self._create_worker( distributed_init_method=distributed_init_method) self._run_workers("init_device") diff --git a/vllm/executor/multiproc_worker_utils.py b/vllm/executor/multiproc_worker_utils.py index 28c8e8699f083..aa2a16c04d08d 100644 --- a/vllm/executor/multiproc_worker_utils.py +++ b/vllm/executor/multiproc_worker_utils.py @@ -120,7 +120,8 @@ def run(self) -> None: logger.error("Worker %s pid %s died, exit code: %s", process.name, process.pid, process.exitcode) # Cleanup any remaining workers - logger.info("Killing local vLLM worker processes") + if logger: + logger.info("Killing local vLLM worker processes") for worker in self.workers: worker.kill_worker() # Must be done after worker task queues are all closed @@ -221,6 +222,8 @@ def _run_worker_process( try: executor = getattr(worker, method) output = executor(*args, **kwargs) + except KeyboardInterrupt: + break except BaseException as e: tb = traceback.format_exc() logger.error( diff --git a/vllm/executor/ray_tpu_executor.py b/vllm/executor/ray_tpu_executor.py index 732b69d6e5954..d02fecb46f007 100644 --- a/vllm/executor/ray_tpu_executor.py +++ b/vllm/executor/ray_tpu_executor.py @@ -26,6 +26,8 @@ class RayTPUExecutor(TPUExecutor): + uses_ray: bool = True + def __init__(self, *args, **kwargs): # This is non-None when the execute model loop is running # in the parallel workers. It's a coroutine in the AsyncLLMEngine case. diff --git a/vllm/scripts.py b/vllm/scripts.py index e557961a335bf..231a18e99f3d7 100644 --- a/vllm/scripts.py +++ b/vllm/scripts.py @@ -1,11 +1,11 @@ # The CLI entrypoint to vLLM. import argparse -import asyncio import os import signal import sys from typing import List, Optional +import uvloop from openai import OpenAI from openai.types.chat import ChatCompletionMessageParam @@ -34,7 +34,7 @@ def serve(args: argparse.Namespace) -> None: # EngineArgs expects the model name to be passed as --model. args.model = args.model_tag - asyncio.run(run_server(args)) + uvloop.run(run_server(args)) def interactive_cli(args: argparse.Namespace) -> None: diff --git a/vllm/utils.py b/vllm/utils.py index aba243071b69a..014fc16a17c1f 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -12,6 +12,7 @@ import threading import uuid import warnings +import weakref from asyncio import FIRST_COMPLETED, ensure_future from functools import lru_cache, partial, wraps from platform import uname @@ -1079,6 +1080,20 @@ def cuda_device_count_stateless() -> int: return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES) +def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]: + """Make an instance method that weakly references + its associated instance and no-ops once that + instance is collected.""" + ref = weakref.ref(bound_method.__self__) # type: ignore[attr-defined] + unbound = bound_method.__func__ # type: ignore[attr-defined] + + def weak_bound(*args, **kwargs) -> None: + if inst := ref(): + unbound(inst, *args, **kwargs) + + return weak_bound + + #From: https://stackoverflow.com/a/4104188/2749989 def run_once(f: Callable[P, None]) -> Callable[P, None]: From 781e3b9a4281babf5576f8dc22445af69814d8f6 Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Mon, 16 Sep 2024 18:15:57 +0000 Subject: [PATCH 51/98] [Bugfix][Kernel] Fix build for sm_60 in GGUF kernel (#8506) --- csrc/quantization/gguf/vecdotq.cuh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/csrc/quantization/gguf/vecdotq.cuh b/csrc/quantization/gguf/vecdotq.cuh index ff339753bcbb5..d5af345a6b26f 100644 --- a/csrc/quantization/gguf/vecdotq.cuh +++ b/csrc/quantization/gguf/vecdotq.cuh @@ -1671,6 +1671,7 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( static __device__ __forceinline__ float vec_dot_iq1_s_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 const block_iq1_s * bq1 = (const block_iq1_s *) vbq; const int qs_packed = get_int_b2(bq1->qs, iqs); @@ -1697,10 +1698,12 @@ static __device__ __forceinline__ float vec_dot_iq1_s_q8_1( const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000); const float2 ds = __half22float2(bq8_1[iqs].ds); return d1q * (ds.x*sumi + ds.y*delta); +#endif } static __device__ __forceinline__ float vec_dot_iq1_m_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 const block_iq1_m * bq1 = (const block_iq1_m *) vbq; @@ -1741,6 +1744,7 @@ static __device__ __forceinline__ float vec_dot_iq1_m_q8_1( const int sc0 = 2*((tmp >> 0) & 0x07) + 1; const int sc1 = 2*((tmp >> 3) & 0x07) + 1; return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1); +#endif } static __device__ __forceinline__ void get_int_from_table_16(const uint32_t & q4, const uint8_t * values, From 5d73ae49d65394f8dbe46accd921fb21e8247b82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Mon, 16 Sep 2024 14:52:40 -0400 Subject: [PATCH 52/98] [Kernel] AQ AZP 3/4: Asymmetric quantization kernels (#7270) --- csrc/cpu/quant.cpp | 9 +- csrc/cpu/torch_bindings.cpp | 9 +- csrc/ops.h | 6 +- .../compressed_tensors/int8_quant_kernels.cu | 173 ++++++++++++++++-- csrc/torch_bindings.cpp | 8 +- tests/kernels/test_int8_quant.py | 158 ++++++++++++++-- vllm/_custom_ops.py | 29 ++- .../model_executor/layers/quantization/qqq.py | 2 +- .../layers/quantization/utils/w8a8_utils.py | 2 +- 9 files changed, 339 insertions(+), 57 deletions(-) diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp index 0cfc19097fded..2d7abe6145fee 100644 --- a/csrc/cpu/quant.cpp +++ b/csrc/cpu/quant.cpp @@ -257,11 +257,13 @@ void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major // static-per-tensor quantization. void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] const torch::Tensor& input, // [..., hidden_size] - const torch::Tensor& scale) { + const torch::Tensor& scale, + c10::optional const& azp) { CPU_KERNEL_GUARD_IN(static_scaled_int8_quant) TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(scale.numel() == 1); + TORCH_CHECK(!azp.has_value(), "Zero point is not supported on CPU."); const int hidden_size = input.size(-1); const int num_tokens = input.numel() / hidden_size; @@ -277,11 +279,12 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] void dynamic_scaled_int8_quant( torch::Tensor& out, // [..., hidden_size] const torch::Tensor& input, // [..., hidden_size] - torch::Tensor& scale // [..., 1] -) { + torch::Tensor& scale, // [..., 1] + c10::optional const& azp) { CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant) TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(!azp.has_value(), "Zero point is not supported on CPU."); int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index b45da1b386b5b..ab697e3e6aef7 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -94,13 +94,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { #ifdef __AVX512F__ // Compute int8 quantized tensor for given scaling factor. ops.def( - "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> " - "()"); + "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," + "Tensor? azp) -> ()"); ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant); + // Compute int8 quantized tensor and scaling factor ops.def( - "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> " - "()"); + "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " + "Tensor!? azp) -> ()"); ops.impl("dynamic_scaled_int8_quant", torch::kCPU, &dynamic_scaled_int8_quant); // W8A8 GEMM, supporting symmetric per-tensor or per-row/column diff --git a/csrc/ops.h b/csrc/ops.h index 5333b22c536d6..681ab4b898ca3 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -184,10 +184,12 @@ torch::Tensor marlin_qqq_gemm(torch::Tensor const& a, #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, - torch::Tensor const& scale); + torch::Tensor const& scale, + c10::optional const& azp); void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, - torch::Tensor& scales); + torch::Tensor& scales, + c10::optional const& azp); torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor b_gptq_qzeros, diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index 616fc149760e5..aec9fa002f96e 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -14,12 +14,17 @@ static inline __device__ int8_t float_to_int8_rn(float x) { #ifdef USE_ROCM - static const float i8_min = + static constexpr auto i8_min = static_cast(std::numeric_limits::min()); - static const float i8_max = + static constexpr auto i8_max = static_cast(std::numeric_limits::max()); - // round + + // To match the rounding mode of CUDA, we use nearbyint. + // It uses the current rounding mode, which is always FE_TONEAREST on HIP. + // If that changes in the future, we may need to set the rounding mode + // explicitly, either at runtime or compile time. float dst = std::nearbyint(x); + // saturate dst = std::clamp(dst, i8_min, i8_max); return static_cast(dst); @@ -31,6 +36,59 @@ static inline __device__ int8_t float_to_int8_rn(float x) { #endif } +static inline __device__ int32_t float_to_int32_rn(float x) { +#ifdef USE_ROCM + // int32_max is not exactly representable as float. + // Therefore, we need to be careful and manually return int32_max on overflow. + // For symmetry, we also do the same for int32_min, even though it is exactly + // representable as float and the conversion should be exact. + static constexpr auto i32_min = std::numeric_limits::min(); + static constexpr auto i32_min_f = static_cast(i32_min); + static constexpr auto i32_max = std::numeric_limits::max(); + static constexpr auto i32_max_f = static_cast(i32_max); + + // To match the rounding mode of CUDA, we use nearbyint. + // It uses the current rounding mode, which is always FE_TONEAREST on HIP. + // If that changes in the future, we may need to set the rounding mode + // explicitly, either at runtime or compile time. + float dst = std::nearbyint(x); + + // saturate on the higher end. + if (dst >= i32_max_f) { + return i32_max; + } + // saturate on the lower end. + if (dst <= i32_min_f) { + return i32_min; + } + + return static_cast(dst); +#else + // CUDA path + uint32_t dst; + asm volatile("cvt.rni.sat.s32.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast(dst); +#endif +} + +static inline __device__ int8_t int32_to_int8(int32_t x) { +#ifdef USE_ROCM + static constexpr auto i8_min = + static_cast(std::numeric_limits::min()); + static constexpr auto i8_max = + static_cast(std::numeric_limits::max()); + + // saturate + int32_t dst = std::clamp(x, i8_min, i8_max); + return static_cast(dst); +#else + // CUDA path + uint32_t dst; + asm volatile("cvt.sat.s8.s32 %0, %1;" : "=r"(dst) : "r"(x)); + return reinterpret_cast(dst); +#endif +} + namespace vllm { template @@ -47,6 +105,23 @@ __global__ void static_scaled_int8_quant_kernel( } } +template +__global__ void static_scaled_int8_azp_quant_kernel( + scalar_t const* __restrict__ input, int8_t* __restrict__ out, + scale_type const* scale_ptr, azp_type const* azp_ptr, + const int hidden_size) { + int const tid = threadIdx.x; + int const token_idx = blockIdx.x; + scale_type const scale = *scale_ptr; + azp_type const azp = *azp_ptr; + + for (int i = tid; i < hidden_size; i += blockDim.x) { + auto const val = static_cast(input[token_idx * hidden_size + i]); + auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp); + out[token_idx * hidden_size + i] = quant_val; + } +} + template __global__ void dynamic_scaled_int8_quant_kernel( scalar_t const* __restrict__ input, int8_t* __restrict__ out, @@ -80,14 +155,68 @@ __global__ void dynamic_scaled_int8_quant_kernel( } } +template +__global__ void dynamic_scaled_int8_azp_quant_kernel( + scalar_t const* __restrict__ input, int8_t* __restrict__ out, + scale_type* scale, azp_type* azp, const int hidden_size) { + int const token_idx = blockIdx.x; + + // Scan for the min and max value for this token + float max_val = std::numeric_limits::min(); + float min_val = std::numeric_limits::max(); + for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + auto val = static_cast(input[token_idx * hidden_size + i]); + max_val = std::max(max_val, val); + min_val = std::min(min_val, val); + } + + // Reduce the max and min values across the block + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStorage; + max_val = BlockReduce(reduceStorage).Reduce(max_val, cub::Max{}, blockDim.x); + __syncthreads(); // Make sure min doesn't mess with max shared memory + min_val = BlockReduce(reduceStorage).Reduce(min_val, cub::Min{}, blockDim.x); + + __shared__ scale_type scale_sh; + __shared__ azp_type azp_sh; + + // Compute the scale and zero point and store them, only on the first thread + if (threadIdx.x == 0) { + float const scale_val = (max_val - min_val) / 255.0f; + // Use rounding to even (same as torch.round) + auto const azp_float = std::nearbyint(-128.0f - min_val / scale_val); + auto const azp_val = static_cast(azp_float); + + // Store the scale and azp into shared and global + scale[token_idx] = scale_sh = scale_val; + azp[token_idx] = azp_sh = azp_val; + } + + // Wait for the scale and azp to be computed + __syncthreads(); + + float const scale_val = scale_sh; + azp_type const azp_val = azp_sh; + + // Quantize the values + for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + auto const val = static_cast(input[token_idx * hidden_size + i]); + auto const quant_val = + int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val); + out[token_idx * hidden_size + i] = quant_val; + } +} + } // namespace vllm void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] torch::Tensor const& input, // [..., hidden_size] - torch::Tensor const& scale) { + torch::Tensor const& scale, + c10::optional const& azp) { TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(scale.numel() == 1); + TORCH_CHECK(!azp || azp->numel() == 1); int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; @@ -96,19 +225,29 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { - vllm::static_scaled_int8_quant_kernel - <<>>(input.data_ptr(), - out.data_ptr(), - scale.data_ptr(), hidden_size); + if (!azp) { + vllm::static_scaled_int8_quant_kernel + <<>>( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), hidden_size); + } else { + vllm::static_scaled_int8_azp_quant_kernel + <<>>( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), azp->data_ptr(), + hidden_size); + } }); } void dynamic_scaled_int8_quant( torch::Tensor& out, // [..., hidden_size] torch::Tensor const& input, // [..., hidden_size] - torch::Tensor& scales) { + torch::Tensor& scales, c10::optional const& azp) { TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(scales.is_contiguous()); + TORCH_CHECK(!azp || azp->is_contiguous()); int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; @@ -117,9 +256,17 @@ void dynamic_scaled_int8_quant( const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] { - vllm::dynamic_scaled_int8_quant_kernel - <<>>(input.data_ptr(), - out.data_ptr(), - scales.data_ptr(), hidden_size); + if (!azp) { + vllm::dynamic_scaled_int8_quant_kernel + <<>>( + input.data_ptr(), out.data_ptr(), + scales.data_ptr(), hidden_size); + } else { + vllm::dynamic_scaled_int8_azp_quant_kernel + <<>>( + input.data_ptr(), out.data_ptr(), + scales.data_ptr(), azp->data_ptr(), + hidden_size); + } }); } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 51afeacfdc0ad..d7f7547fbef55 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -336,14 +336,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Compute int8 quantized tensor for given scaling factor. ops.def( - "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> " - "()"); + "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," + "Tensor? azp) -> ()"); ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant); // Compute int8 quantized tensor and scaling factor ops.def( - "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> " - "()"); + "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " + "Tensor!? azp) -> ()"); ops.impl("dynamic_scaled_int8_quant", torch::kCUDA, &dynamic_scaled_int8_quant); } diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index a82ecb026482e..e93cb535d715a 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -13,14 +13,28 @@ SCALE = [0.1, 0.5, 0.8, 1.2, 2.1] -def opcheck_int8_quant(output, input, scale=None): - if scale is not None: - opcheck(torch.ops._C.static_scaled_int8_quant, (output, input, scale)) +def opcheck_int8_quant_static(output, input, scale, azp=None): + if azp is None: + opcheck(torch.ops._C.static_scaled_int8_quant, + (output, input, scale, None)) else: - scale = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) - opcheck(torch.ops._C.dynamic_scaled_int8_quant, (output, input, scale)) + opcheck(torch.ops._C.static_scaled_int8_quant, + (output, input, scale, azp)) + + +def opcheck_int8_quant_dynamic(output, input, symmetric=True): + scale = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + if symmetric: + opcheck(torch.ops._C.dynamic_scaled_int8_quant, + (output, input, scale, None)) + else: + azp = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.int32) + opcheck(torch.ops._C.dynamic_scaled_int8_quant, + (output, input, scale, azp)) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -38,14 +52,56 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, # reference ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.int8) # kernel - ops_out, ops_scales = scaled_int8_quant(x) + ops_out, ops_scales, _ = scaled_int8_quant(x) torch.testing.assert_close(ops_scales, ref_scales) - torch.testing.assert_close( - ops_out, ref_out, atol=1, - rtol=0.0) # big atol to account for rounding errors + # big atol to account for rounding errors + torch.testing.assert_close(ops_out, ref_out, atol=1, rtol=0.0) - opcheck_int8_quant(ops_out, x) + opcheck_int8_quant_dynamic(ops_out, x) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, + dtype: torch.dtype, seed: int) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + int8_traits = torch.iinfo(torch.int8) + + x = torch.rand(num_tokens, hidden_size, dtype=dtype, + device="cuda") * 1000 - 300 + + x_token_max, _ = x.to(dtype=torch.float32).max(dim=1, keepdim=True) + x_token_min, _ = x.to(dtype=torch.float32).min(dim=1, keepdim=True) + + # calculate scale and azp, and adjust the range + scales = (x_token_max - x_token_min) / torch.tensor(255.0) + azps = torch.round(torch.tensor(-128.0) - x_token_min / scales).to( + torch.int32) + + torch_out = ((x / scales).round() + azps).clamp( + int8_traits.min, int8_traits.max).to(torch.int8) + assert torch_out.min() >= int8_traits.min and torch_out.max( + ) <= int8_traits.max + + ops_out = torch.empty_like(x, dtype=torch.int8) + scales_out = torch.empty_like(scales, dtype=torch.float32) + azp_out = torch.empty_like(azps, dtype=torch.int32) + torch.ops._C.dynamic_scaled_int8_quant(ops_out, x, scales_out, azp_out) + + if (not torch.allclose(scales_out, scales)): + print(torch.argmax(torch.abs(scales_out - scales))) + torch.testing.assert_close(scales_out, scales) + # big atol to account for rounding errors + torch.testing.assert_close(azp_out, azps, atol=1, rtol=0.0) + # if AZP is off by 1, after rounding-to-even, the output may be off by 2 + torch.testing.assert_close(ops_out, torch_out, atol=2, rtol=0.0) + + opcheck_int8_quant_dynamic(ops_out, x, False) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -62,14 +118,76 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, int8_traits = torch.iinfo(torch.int8) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 - scale = torch.tensor([scale], dtype=torch.float32, device="cuda") + scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda") + + out1 = (x / scale_arg).round().clamp(int8_traits.min, + int8_traits.max).to(torch.int8) + out2, _, _ = scaled_int8_quant(x, scale_arg) + + # big atol to account for rounding errors + torch.testing.assert_close(out1, out2, atol=1, rtol=0.0) + + opcheck_int8_quant_static(out2, x, scale_arg) - out1 = (x / scale).round().clamp(int8_traits.min, - int8_traits.max).to(torch.int8) - out2, _ = scaled_int8_quant(x, scale) - torch.testing.assert_close( - out1, out2, atol=1, - rtol=0.0) # big atol to account for rounding errors +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("scale", SCALE[2:]) # Reduce test time +@pytest.mark.parametrize("azp", [-255, 54]) +@torch.inference_mode() +def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, + dtype: torch.dtype, seed: int, + scale: float, azp: int) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + int8_traits = torch.iinfo(torch.int8) + + x = torch.rand(num_tokens, hidden_size, dtype=dtype, + device="cuda") * 1000 - 300 + + out1 = ((x / scale).round() + azp).clamp(int8_traits.min, + int8_traits.max).to(torch.int8) + out2 = torch.empty_like(x, dtype=torch.int8) + scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda") + azp_arg = torch.tensor([azp], dtype=torch.int32, device="cuda") + + torch.ops._C.static_scaled_int8_quant(out2, x, scale_arg, azp_arg) + + # big atol to account for rounding errors + torch.testing.assert_close(out1, out2, atol=1, rtol=0.0) + + opcheck_int8_quant_static(out2, x, scale_arg, azp_arg) + + +@pytest.mark.parametrize("is_max", [True, False]) +@torch.inference_mode() +def test_static_scaled_int8_azp_quant_saturating_cast(is_max: bool) -> None: + # Test that the saturating cast works correctly for values near i32 max/min + + from numpy import inf, nextafter + + int32_traits = torch.iinfo(torch.int32) + val = float(int32_traits.max if is_max else int32_traits.min) + + x_vals = [[ + nextafter(val, inf), val + 1, val, val - 1, + nextafter(val, -inf) + ]] + x = torch.tensor(x_vals, dtype=torch.float32, device="cuda") + + # The calculation in the kernel is: cast(cast(x / scale) + azp) + # where cast is a saturating cast to type T. + # Scale is set to 1.0 so that the input values are the ones that are cast. + # AZP is set to 0 to make sure the int8 saturating cast is tested as well. + scale = torch.scalar_tensor(1.0, dtype=torch.float32, device="cuda") + azp = torch.scalar_tensor(0, dtype=torch.int32, device="cuda") + + int8_traits = torch.iinfo(torch.int8) + val_i8 = int8_traits.max if is_max else int8_traits.min + expected = torch.full((1, 5), val_i8, dtype=torch.int8, device="cuda") - opcheck_int8_quant(out2, x, scale) + out = torch.empty_like(expected) + torch.ops._C.static_scaled_int8_quant(out, x, scale, azp) + torch.testing.assert_close(expected, out, atol=0, rtol=0) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 74b3b69606c67..d5b3d7bc6dd5a 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -684,32 +684,43 @@ def scaled_fp8_quant( # int8 def scaled_int8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None -) -> Tuple[torch.Tensor, torch.Tensor]: + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + azp: Optional[torch.Tensor] = None, + symmetric: bool = True +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ - Quantize the input tensor to int8 and return the quantized tensor and scale. + Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. Args: input: The input tensor to be quantized to int8. scale: Optional scaling factor for the int8 quantization. When not provided, we invoke dynamic-per-token quantization. + azp: Optional zero-point for the int8 quantization. + Must be provided for asymmetric quantization if `scale` is provided. + symmetric: Whether to use symmetric quantization (scale only, azp ignored). Returns: - Tuple[Torch.Tensor, Torch.Tensor] : Output int8 tensor and scales. + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. """ output = torch.empty_like(input, dtype=torch.int8) if scale is not None: # static-per-tensor quantization. - torch.ops._C.static_scaled_int8_quant(output, input, scale) - return output, scale + assert symmetric == ( + azp is + None), "azp must only be provided for asymmetric quantization." + torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) + return output, scale, None # dynamic-per-token quantization. input_scales = torch.empty((input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales) - return output, input_scales + input_azp = None if symmetric else torch.empty_like(input_scales, + dtype=torch.int32) + torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales, + input_azp) + return output, input_scales, input_azp # qqq ops diff --git a/vllm/model_executor/layers/quantization/qqq.py b/vllm/model_executor/layers/quantization/qqq.py index c3434214a1cde..5bc3737520865 100644 --- a/vllm/model_executor/layers/quantization/qqq.py +++ b/vllm/model_executor/layers/quantization/qqq.py @@ -260,7 +260,7 @@ def apply( size_k = x_2d.shape[1] size_n = s_ch.shape[1] - x_int8, s_tok = ops.scaled_int8_quant(x_2d) + x_int8, s_tok, _ = ops.scaled_int8_quant(x_2d) output_2d = ops.marlin_qqq_gemm(x_int8, qweight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index a54e3cae73b14..887ee6605560c 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -188,7 +188,7 @@ def apply_int8_linear( # ops.scaled_int8_quant supports both dynamic and static quant. # * dynamic, layer.input_scale is None and x_scale computed from x. # * static, layer.input_scale is scalar and x_scale is input_scale. - x_q, x_scale = ops.scaled_int8_quant(input, input_scale) + x_q, x_scale, _ = ops.scaled_int8_quant(input, input_scale) return ops.cutlass_scaled_mm(x_q, weight, From 2759a43a26e4eecb7ff7d741c2b6da0d544462ad Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 16 Sep 2024 12:10:23 -0700 Subject: [PATCH 53/98] [doc] update doc on testing and debugging (#8514) --- docs/source/getting_started/debugging.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/source/getting_started/debugging.rst b/docs/source/getting_started/debugging.rst index 31ecca1332e5d..81287762d3c0a 100644 --- a/docs/source/getting_started/debugging.rst +++ b/docs/source/getting_started/debugging.rst @@ -98,6 +98,13 @@ Here are some common issues that can cause hangs: If the script runs successfully, you should see the message ``sanity check is successful!``. + Note that multi-node environment is more complicated than single-node. If you see errors such as ``torch.distributed.DistNetworkError``, it is likely that the network/DNS setup is incorrect. In that case, you can manually assign node rank and specify the IP via command line arguments: + + - In the first node, run ``NCCL_DEBUG=TRACE torchrun --nnodes 2 --nproc-per-node=2 --node-rank 0 --master_addr $MASTER_ADDR test.py``. + - In the second node, run ``NCCL_DEBUG=TRACE torchrun --nnodes 2 --nproc-per-node=2 --node-rank 1 --master_addr $MASTER_ADDR test.py``. + + Adjust ``--nproc-per-node``, ``--nnodes``, and ``--node-rank`` according to your setup. The difference is that you need to execute different commands (with different ``--node-rank``) on different nodes. + If the problem persists, feel free to `open an issue on GitHub `_, with a detailed description of the issue, your environment, and the logs. Some known issues: From 47f5e03b5b9fc719b7e5ee00cbd6d1e79627f105 Mon Sep 17 00:00:00 2001 From: Kevin Lin <42618777+kevin314@users.noreply.github.com> Date: Mon, 16 Sep 2024 15:56:28 -0500 Subject: [PATCH 54/98] [Bugfix] Bind api server port before starting engine (#8491) --- vllm/entrypoints/openai/api_server.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index b50fc6a265f8d..3d1d832986c1e 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -5,6 +5,7 @@ import os import re import signal +import socket import tempfile from argparse import Namespace from contextlib import asynccontextmanager @@ -525,6 +526,9 @@ async def run_server(args, **uvicorn_kwargs) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) + temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + temp_socket.bind(("", args.port)) + def signal_handler(*_) -> None: # Interrupt server on sigterm while initializing raise KeyboardInterrupt("terminated") @@ -541,6 +545,8 @@ def signal_handler(*_) -> None: model_config = await async_engine_client.get_model_config() init_app_state(async_engine_client, model_config, app.state, args) + temp_socket.close() + shutdown_task = await serve_http( app, limit_concurrency=async_engine_client.limit_concurrency, From 5478c4b41f60995b92b9997306b2e0702055341f Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Mon, 16 Sep 2024 14:30:02 -0700 Subject: [PATCH 55/98] [perf bench] set timeout to debug hanging (#8516) --- .buildkite/nightly-benchmarks/benchmark-pipeline.yaml | 3 +-- .buildkite/nightly-benchmarks/scripts/wait-for-image.sh | 4 +++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml index 2b70e2da5d87c..eec2a51e2f8fd 100644 --- a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml +++ b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml @@ -8,8 +8,7 @@ steps: containers: - image: badouralix/curl-jq command: - - sh - - .buildkite/nightly-benchmarks/scripts/wait-for-image.sh + - sh .buildkite/nightly-benchmarks/scripts/wait-for-image.sh - wait - label: "A100" agents: diff --git a/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh b/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh index c785e6a0da628..f16862907def1 100644 --- a/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh +++ b/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh @@ -2,9 +2,11 @@ TOKEN=$(curl -s -L "https://public.ecr.aws/token?service=public.ecr.aws&scope=repository:q9t5s3a7/vllm-ci-test-repo:pull" | jq -r .token) URL="https://public.ecr.aws/v2/q9t5s3a7/vllm-ci-test-repo/manifests/$BUILDKITE_COMMIT" +TIMEOUT_SECONDS=10 + retries=0 while [ $retries -lt 1000 ]; do - if [ $(curl -s -L -H "Authorization: Bearer $TOKEN" -o /dev/null -w "%{http_code}" $URL) -eq 200 ]; then + if [ $(curl -s --max-time $TIMEOUT_SECONDS -L -H "Authorization: Bearer $TOKEN" -o /dev/null -w "%{http_code}" $URL) -eq 200 ]; then exit 0 fi From 5ce45eb54d3fb870f1fb6865c67aac05ec9bf555 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Mon, 16 Sep 2024 15:11:27 -0700 Subject: [PATCH 56/98] [misc] small qol fixes for release process (#8517) --- Dockerfile | 2 ++ setup.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 5484be5bc5785..620f549cf3955 100644 --- a/Dockerfile +++ b/Dockerfile @@ -82,6 +82,7 @@ ENV BUILDKITE_COMMIT=${buildkite_commit} ARG USE_SCCACHE ARG SCCACHE_BUCKET_NAME=vllm-build-sccache ARG SCCACHE_REGION_NAME=us-west-2 +ARG SCCACHE_S3_NO_CREDENTIALS=0 # if USE_SCCACHE is set, use sccache to speed up compilation RUN --mount=type=cache,target=/root/.cache/pip \ if [ "$USE_SCCACHE" = "1" ]; then \ @@ -92,6 +93,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \ && rm -rf sccache.tar.gz sccache-v0.8.1-x86_64-unknown-linux-musl \ && export SCCACHE_BUCKET=${SCCACHE_BUCKET_NAME} \ && export SCCACHE_REGION=${SCCACHE_REGION_NAME} \ + && export SCCACHE_S3_NO_CREDENTIALS=${SCCACHE_S3_NO_CREDENTIALS} \ && export SCCACHE_IDLE_TIMEOUT=0 \ && export CMAKE_BUILD_TYPE=Release \ && sccache --show-stats \ diff --git a/setup.py b/setup.py index 8930ea7239dc9..7da9115440433 100644 --- a/setup.py +++ b/setup.py @@ -371,7 +371,9 @@ def get_vllm_version() -> str: cuda_version = str(get_nvcc_cuda_version()) if cuda_version != MAIN_CUDA_VERSION: cuda_version_str = cuda_version.replace(".", "")[:3] - version += f"+cu{cuda_version_str}" + # skip this for source tarball, required for pypi + if "sdist" not in sys.argv: + version += f"+cu{cuda_version_str}" elif _is_hip(): # Get the HIP version hipcc_version = get_hipcc_rocm_version() From cca61642e0484212e6cd78b35b4789afed8d19c6 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Mon, 16 Sep 2024 18:01:45 -0600 Subject: [PATCH 57/98] [Bugfix] Fix 3.12 builds on main (#8510) Signed-off-by: Joe Runde --- Dockerfile | 4 ---- requirements-common.txt | 1 + 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/Dockerfile b/Dockerfile index 620f549cf3955..001068b4b36ca 100644 --- a/Dockerfile +++ b/Dockerfile @@ -182,10 +182,6 @@ FROM vllm-base AS test ADD . /vllm-workspace/ # install development dependencies (for testing) -# A newer setuptools is required for installing some test dependencies from source that do not publish python 3.12 wheels -# This installation must complete before the test dependencies are collected and installed. -RUN --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install "setuptools>=74.1.1" RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-dev.txt diff --git a/requirements-common.txt b/requirements-common.txt index ad950d0313454..ad53395307ec5 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -29,4 +29,5 @@ importlib_metadata mistral_common >= 1.4.0 pyyaml six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 +setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12 einops # Required for Qwen2-VL. From 546034b466bf11f12936791312981b9982850eb0 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Mon, 16 Sep 2024 20:04:48 -0700 Subject: [PATCH 58/98] [refactor] remove triton based sampler (#8524) --- tests/kernels/test_rand.py | 52 --- tests/kernels/test_sampler.py | 209 ----------- vllm/model_executor/layers/ops/__init__.py | 0 vllm/model_executor/layers/ops/rand.py | 157 -------- vllm/model_executor/layers/ops/sample.py | 394 --------------------- vllm/model_executor/layers/sampler.py | 97 +---- vllm/model_executor/sampling_metadata.py | 211 +++-------- vllm/triton_utils/sample.py | 13 - vllm/utils.py | 37 +- 9 files changed, 75 insertions(+), 1095 deletions(-) delete mode 100644 tests/kernels/test_rand.py delete mode 100644 tests/kernels/test_sampler.py delete mode 100644 vllm/model_executor/layers/ops/__init__.py delete mode 100644 vllm/model_executor/layers/ops/rand.py delete mode 100644 vllm/model_executor/layers/ops/sample.py delete mode 100644 vllm/triton_utils/sample.py diff --git a/tests/kernels/test_rand.py b/tests/kernels/test_rand.py deleted file mode 100644 index a4242d22eb489..0000000000000 --- a/tests/kernels/test_rand.py +++ /dev/null @@ -1,52 +0,0 @@ -import random - -import pytest -import torch - -from vllm.model_executor.layers.ops.rand import seeded_uniform -from vllm.model_executor.utils import set_random_seed - - -@pytest.mark.parametrize("dtype", - [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("use_3d", [True, False]) -def test_seeded_uniform(dtype: torch.dtype, use_3d: bool): - device = "cuda" - for seed in range(512): - set_random_seed(seed) - rows = random.randint(1, 512) - cols = random.randint(1, 64000) - if use_3d: - third_dim = random.randint(2, 10) - dims = [rows, third_dim, cols] - else: - dims = [rows, cols] - seeds = torch.randint(torch.iinfo(torch.long).min, - torch.iinfo(torch.long).max, (rows, ), - device=device) - - # Test that the same seed produces the same output - out = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device) - out2 = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device) - torch.testing.assert_close(out, out2) - # del to save memory - del out2 - - out3 = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device) - torch.testing.assert_close(out, out3) - # del to save memory - del out3 - - # Initialize out tensor with garbage to ensure that it is overwritten - out_with_tensor = seeded_uniform( - *dims, - out=torch.full( - (*dims, ), - -1, - dtype=dtype, - device=device, - ), - seeds=seeds, - dtype=dtype, - ) - torch.testing.assert_close(out, out_with_tensor) diff --git a/tests/kernels/test_sampler.py b/tests/kernels/test_sampler.py deleted file mode 100644 index 03844aba20f8a..0000000000000 --- a/tests/kernels/test_sampler.py +++ /dev/null @@ -1,209 +0,0 @@ -import gc -from unittest.mock import patch - -import pytest -import torch -import triton -import triton.language as tl - -from vllm.model_executor.layers.ops.sample import (_sample_triton, - _uniform_to_exponential, - sample) -from vllm.model_executor.sampling_metadata import SamplingTensors -from vllm.model_executor.utils import set_random_seed -from vllm.triton_utils.libentry import LibEntry -from vllm.triton_utils.sample import (MAX_TRITON_N_COLS, - get_num_triton_sampler_splits) - -SINGLE_SPLIT_VOCAB_SIZE = 32000 # llama/mistral/mixtral vocab size -MULTI_SPLIT_VOCAB_SIZE = MAX_TRITON_N_COLS + 100 - - -@pytest.fixture(autouse=True) -def _cleanup(): - yield - gc.collect() - torch.cuda.empty_cache() - - -@triton.jit -def _uniform_to_exponential_kernel(input, output, n: tl.constexpr): - idx = tl.arange(0, n) - x = tl.load(input + idx) - y = _uniform_to_exponential(x) - tl.store(output + idx, y) - - -def test_uniform_to_exponential(): - """Test that we can convert uniform to exponential without div by 0.""" - input = torch.tensor([0.0, 1.0 - torch.finfo(torch.float32).eps], - dtype=torch.float32, - device="cuda") - output = torch.zeros(input.shape, dtype=torch.float32, device="cuda") - _uniform_to_exponential_kernel[(1, )](input, output, 2) - assert torch.all(torch.isfinite(output)) - assert torch.all(output > 0) - assert torch.all(torch.isfinite(torch.full_like(output, 1.0) / output)) - - -@pytest.mark.parametrize("random_sampling", [True, False, "mixed"]) -@pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5]) -@pytest.mark.parametrize("modify_greedy_probs", [True, False]) -@pytest.mark.parametrize("seed", [1337]) -@pytest.mark.parametrize("vocab_size", - [SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE]) -@pytest.mark.parametrize("save_logprobs", [True, False]) -def test_sample_decoding_only(random_sampling, max_best_of, - modify_greedy_probs, seed, vocab_size, - save_logprobs): - set_random_seed(seed) - bs = 8 - probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda") - for i in range(bs): - probs[i, i * (vocab_size // bs)] = 1.0 - logprobs = torch.rand_like(probs) - sample_indices = torch.arange(bs, dtype=torch.long, device="cuda") - n_splits = get_num_triton_sampler_splits(probs.shape[1]) - if random_sampling == "mixed": - random_sampling_mask = (torch.rand( - (1, bs), device="cuda") < 0.5).expand(n_splits, bs) - elif random_sampling: - random_sampling_mask = torch.ones((n_splits, bs), - dtype=torch.bool, - device="cuda") - else: - random_sampling_mask = torch.zeros((n_splits, bs), - dtype=torch.bool, - device="cuda") - - seeds = torch.randint(1, - torch.iinfo(torch.long).max, (n_splits, bs), - device="cuda").mul_(random_sampling_mask) - #The current _sample_triton does not utilize the - # libentry decoration. The purpose of adding this patch is to test - # the correctness of libentry. - with patch("vllm.model_executor.layers.ops.sample._sample_triton", - LibEntry(_sample_triton)): - sampled_tokens, sampled_logprobs, sampled_modified_probs = sample( - probs=probs, - logprobs=logprobs, - sample_indices=sample_indices, - seeds=seeds, - max_best_of=max_best_of, - modify_greedy_probs=modify_greedy_probs, - save_logprobs=save_logprobs, - _save_modified_probs=True) - assert sampled_tokens.shape == (bs, max_best_of) - for i in range(bs): - assert torch.all(sampled_tokens[i] == i * (vocab_size // bs)) - request_uses_random_sampling = random_sampling_mask[0, i] - if modify_greedy_probs and not request_uses_random_sampling: - # If we are modifying greedy probs and the request is greedy, - # we want to make sure the probs tensor is modified in place - torch.testing.assert_close( - probs[i][sampled_tokens[i]], - torch.full_like(probs[i][sampled_tokens[i]], 1.0)) - assert torch.sum(probs[i]) == 1.0 - torch.testing.assert_close( - sampled_modified_probs[i][0], - torch.full_like(sampled_modified_probs[i][0], 1.0)) - elif request_uses_random_sampling: - # If the request is random, we want to make sure - # sampled_modified_probs tensor has noise added - # (and thus is different from probs tensor) - assert not torch.allclose(sampled_modified_probs[i][0], - probs[i][sampled_tokens[i]]) - elif not request_uses_random_sampling: - # If the request is greedy and we are not modifying greedy probs, - # we want to make sure sampled_modified_probs tensor is the same as - # the probs tensor. - torch.testing.assert_close(sampled_modified_probs[i], - probs[i][sampled_tokens[i]]) - - if save_logprobs: - assert sampled_logprobs.shape == (bs, max_best_of) - for i in range(bs): - for best_of in range(max_best_of): - assert torch.all(sampled_logprobs[i] == logprobs[i][ - sampled_tokens[i, best_of]]) - else: - assert sampled_logprobs is None - - -@pytest.mark.parametrize("random_sampling", [True, False, "mixed"]) -@pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5]) -@pytest.mark.parametrize("modify_greedy_probs", [True, False]) -@pytest.mark.parametrize("seed", [1337]) -@pytest.mark.parametrize("vocab_size", - [SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE]) -def test_sample_prompt_logprobs(random_sampling, max_best_of, - modify_greedy_probs, seed, vocab_size): - - set_random_seed(seed) - prompt_sizes = [16, 32, 64, 128] * 2 - samples = 8 - bs = samples + sum(prompt_sizes) - probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda") - for i in range(bs): - probs[i, i * (vocab_size // bs)] = 1.0 - logprobs = torch.rand_like(probs) - sample_indices = torch.tensor(prompt_sizes, - dtype=torch.long, - device="cuda").cumsum_(0) - n_splits = get_num_triton_sampler_splits(probs.shape[1]) - if random_sampling == "mixed": - random_sampling_mask = torch.rand( - (n_splits, samples), device="cuda") < 0.5 - elif random_sampling: - random_sampling_mask = torch.ones((n_splits, samples), - dtype=torch.bool, - device="cuda") - else: - random_sampling_mask = torch.zeros((n_splits, samples), - dtype=torch.bool, - device="cuda") - - seeds = torch.randint(1, - torch.iinfo(torch.long).max, (n_splits, samples), - device="cuda").mul_(random_sampling_mask) - #ditto - with patch("vllm.model_executor.layers.ops.sample._sample_triton", - LibEntry(_sample_triton)): - sampled_tokens, sampled_logprobs, _ = sample( - probs=probs, - logprobs=logprobs, - sample_indices=sample_indices, - seeds=seeds, - max_best_of=max_best_of, - modify_greedy_probs=modify_greedy_probs, - save_logprobs=True) - assert sampled_tokens.shape == (samples, max_best_of) - assert sampled_logprobs.shape == (samples, max_best_of) - for i, t in enumerate(sample_indices): - assert torch.all(sampled_tokens[i] == t * (vocab_size // bs)) - for best_of in range(max_best_of): - assert torch.all(sampled_logprobs[i] == logprobs[sample_indices[i]] - [sampled_tokens[i, best_of]]) - - -@pytest.mark.parametrize("seed", list(range(16))) -def test_get_sequence_seeds(seed): - """Ensure that we get a different child seed from base - seed + extra entropy""" - starting_seed = seed - seq_seed = None - extra_entropy = 1 - for i in range(512): - new_seq_seed = SamplingTensors._get_sequence_seeds(starting_seed, - i, - seeds_to_generate=1, - is_greedy=False)[0] - new_seq_seed_extra_entropy = SamplingTensors._get_sequence_seeds( - starting_seed, - i, - extra_entropy, - seeds_to_generate=1, - is_greedy=False)[0] - assert new_seq_seed_extra_entropy != new_seq_seed - assert seq_seed != new_seq_seed - seq_seed = new_seq_seed diff --git a/vllm/model_executor/layers/ops/__init__.py b/vllm/model_executor/layers/ops/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/vllm/model_executor/layers/ops/rand.py b/vllm/model_executor/layers/ops/rand.py deleted file mode 100644 index 4a429e329567d..0000000000000 --- a/vllm/model_executor/layers/ops/rand.py +++ /dev/null @@ -1,157 +0,0 @@ -from typing import Optional, Union - -import torch -import triton -import triton.language as tl - - -def seeded_uniform( - *size, - seeds: torch.Tensor, - out: Optional[torch.Tensor] = None, - dtype: Optional[torch.dtype] = None, - device: Optional[Union[torch.device, str]] = None, - pin_memory: Optional[bool] = False, -) -> torch.Tensor: - """Similar to torch.rand, but allows for seeds to be set per row. - - seeds must be a 1d tensor. The output tensor may be 1d, 2d, or 3d. - If it is 3d, the additional seeds needed will be derived automatically - in a deterministic fashion: - [ - row 0: [columns_with_seed_0], [columns_with_seed0^1], ... - ] - """ - n_dims = len(size) - - if n_dims > 3: - raise ValueError("seeded_uniform only supports up to 3D tensors") - - if out is None: - out = torch.empty(*size, - dtype=dtype, - device=device, - pin_memory=pin_memory) - elif out.shape != size: - raise ValueError("shape of out and size must be the same") - - if n_dims == 3: - n_rows, n_3d, n_cols = out.shape - stride_row = out.stride(0) - stride_3d = out.stride(1) - elif n_dims == 2: - n_rows, n_cols = out.shape - n_3d = 1 - stride_row = out.stride(0) - stride_3d = 1 - else: - n_cols = out.shape[0] - n_rows = 1 - n_3d = 1 - stride_row = 1 - stride_3d = 1 - - if seeds.ndim != 1: - raise ValueError("seeds must be a 1D tensor") - - if seeds.numel() != n_rows: - raise ValueError( - "seeds must have the same number of elements as out has rows") - - # The philox PRNG Triton uses generates 4 random numbers at once. - # Therefore, the most efficient use of it is to divide the - # block size by 4, and then save the generated random numbers to - # each of the 4 slices of the tensor. - full_block_size = triton.next_power_of_2(n_cols) - philox_block_size = max(full_block_size // 4, 1) - n_slices = full_block_size // philox_block_size - num_warps = 4 - # Manual tuning. This seems to give best performance on A100 for - # simple kernels like this. - if philox_block_size >= 8192: - num_warps = 32 - elif philox_block_size >= 4096: - num_warps = 16 - elif philox_block_size >= 2048: - num_warps = 8 - - _seeded_uniform_triton[(n_rows, n_3d)]( - out, - seeds, - stride_row, - stride_3d, - seeds.stride(0), - n_rows, - n_3d, - n_cols, - n_slices=n_slices, - num_warps=num_warps, - block_size=philox_block_size, - ) - return out - - -@triton.jit -def _seeded_uniform_triton( - out_ptr: torch.Tensor, - seed_ptr: torch.Tensor, - out_row_stride: int, - out_3d_stride: int, - seed_row_stride: int, - n_rows: int, - n_3d: int, - n_cols: int, - n_slices: tl.constexpr, - block_size: tl.constexpr, -): - """ - Generate a random float32 number in [0, 1) for each element in the output - tensor. The random numbers in a row generated using the seed for that row. - - Args: - out_ptr: The output tensor. - seed_ptr: The per-row seeds to use for random number generation. - out_row_stride: The stride between rows of the output tensor. - out_3d_stride: The stride between 3D slices of the output tensor. - seed_row_stride: The stride between rows of the seed tensor. - n_rows: The number of rows in the output tensor. - n_3d: The size of second dimension of the output tensor, - if output tensor is 3D. - n_cols: The number of columns in the output tensor. - n_slices: The number of philox outputs to use. - """ - tl.static_assert(n_slices > 0 and n_slices <= 4, "0 < n_slices <= 4") - - # Get the row index. - row_idx = tl.program_id(axis=0) - three_d_idx = tl.program_id(axis=1) - - philox_offsets = tl.arange(0, block_size) - # Get the seed for the current element. - seed = tl.load(seed_ptr + row_idx * seed_row_stride) - if three_d_idx > 0: - seed ^= three_d_idx - # Generate random numbers in [0, 1). - out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets) - - output_row_start_ptr = (out_ptr + row_idx * out_row_stride + - three_d_idx * out_3d_stride) - out1_offsets = philox_offsets - tl.store(output_row_start_ptr + out1_offsets, - out1, - mask=out1_offsets < n_cols) - if n_slices > 1: - out2_offsets = tl.arange(block_size, block_size * 2) - tl.store(output_row_start_ptr + out2_offsets, - out2, - mask=out2_offsets < n_cols) - if n_slices > 2: - out3_offsets = tl.arange(block_size * 2, block_size * 3) - tl.store(output_row_start_ptr + out3_offsets, - out3, - mask=out3_offsets < n_cols) - if n_slices > 3: - out4_offsets = tl.arange(block_size * 3, block_size * 4) - tl.store(output_row_start_ptr + out4_offsets, - out4, - mask=out4_offsets < n_cols) diff --git a/vllm/model_executor/layers/ops/sample.py b/vllm/model_executor/layers/ops/sample.py deleted file mode 100644 index fb88a05daf482..0000000000000 --- a/vllm/model_executor/layers/ops/sample.py +++ /dev/null @@ -1,394 +0,0 @@ -from typing import Optional, Tuple - -import torch -import triton -import triton.language as tl - -from vllm.model_executor.layers.ops.rand import seeded_uniform -from vllm.triton_utils.sample import get_num_triton_sampler_splits - -_EPS: tl.constexpr = 1e-6 - - -def _multi_split_sample( - probs: torch.Tensor, - seeds: torch.Tensor, - n_splits: int, - sampled_tokens_size: Tuple[int, int], - sampled_logprobs_size: Tuple[int, int], - sample_indices: torch.Tensor, - logprobs: torch.Tensor, - *, - modify_greedy_probs: bool = False, - save_logprobs: bool = False, -): - """Sample tokens where vocab size is split into multiple parts - (too large for Triton otherwise).""" - assert seeds.ndim == 2 and seeds.shape[0] == n_splits - split_probs = probs.tensor_split(n_splits, 1) - split_logprobs = logprobs.tensor_split(n_splits, 1) - sampled_tokens_tmp = [ - torch.empty(sampled_tokens_size, dtype=torch.long, device=probs.device) - for _ in range(n_splits) - ] - sampled_logprobs_tmp = [ - torch.empty(sampled_logprobs_size, - dtype=probs.dtype, - device=probs.device) for _ in range(n_splits) - ] - # We are purposefuly using sampled_tokens_size as we need to always - # save modified probs in this case. - sampled_modified_probs_tmp = [ - torch.empty(sampled_tokens_size, - dtype=probs.dtype, - device=probs.device) for _ in range(n_splits) - ] - for i in range(n_splits): - n_samples = sample_indices.shape[0] - n_cols = split_probs[i].shape[1] - n_best = sampled_tokens_tmp[i].shape[1] - uniform_noise = seeded_uniform(n_samples, - n_best, - n_cols, - seeds=seeds[i].flatten(), - device=split_probs[i].device, - dtype=split_probs[i].dtype) - # TODO(yard1): See if we can remove the contiguous() calls. - # Will need kernel support. - _sample( - split_probs[i].contiguous(), - split_logprobs[i].contiguous(), - sample_indices, - sampled_tokens_tmp[i], - sampled_logprobs_tmp[i], - sampled_modified_probs_tmp[i], - seeds[i], - uniform_noise, - modify_greedy_probs=False, - save_logprobs=save_logprobs, - save_modified_probs=True, - ) - if i > 0: - # Add offset to sampled tokens - sampled_tokens_tmp[i].add_(i * split_probs[i - 1].shape[1]) - sampled_tokens = torch.stack(sampled_tokens_tmp) - sampled_modified_probs = torch.stack(sampled_modified_probs_tmp) - # Reduce the results from the splits. - sampled_modified_probs, indices = torch.max(sampled_modified_probs, - dim=0, - keepdim=True) - sampled_tokens = sampled_tokens.gather(0, indices).squeeze(0) - if save_logprobs: - sampled_logprobs = torch.stack(sampled_logprobs_tmp) - sampled_logprobs = sampled_logprobs.gather(0, indices).squeeze(0) - else: - sampled_logprobs = None - sampled_modified_probs = sampled_modified_probs.squeeze(0) - - if modify_greedy_probs: - # We need to modify the greedy probs for the sampled tokens. - # We can't do this in the kernel as we need to know the - # sampled tokens. - probs.fill_(0.0) - probs.scatter_(1, sampled_tokens, 1.0) - - return (sampled_tokens, sampled_logprobs, sampled_modified_probs) - - -def sample( - probs: torch.Tensor, - seeds: torch.Tensor, - *, - max_best_of: int = 1, - sample_indices: Optional[torch.Tensor] = None, - logprobs: Optional[torch.Tensor] = None, - modify_greedy_probs: bool = False, - save_logprobs: bool = False, - _save_modified_probs: bool = False, # pylint: disable=invalid-name -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - """Sample tokens from probs. with per-sequence seeds. - - Can sample from a subset of sequences through sample_indices. - - Args: - probs: Probabilities to sample from. - shape = [batch_size, vocab_size] - seeds: Per-sequence seed values. - shape = [n, math.ceil(vocab_size / MAX_TRITON_N_COLS)] - max_best_of: Number of samples to generate per sequence. - Sequence seed will be incremented by 1 each time. - sample_indices: Indices of sequences to sample from. - If not provided, will sample from all sequences. - shape = [n] - logprobs: Log-probabilities of the sampled tokens. - Only used for saving the logprobs if save_logprobs is True. - shape = [batch_size, vocab_size] - modify_greedy_probs: Whether to modify the greedy probabilities - for speculative sampling (sampled token = 1.0, - everything else = 0.0). - save_logprobs: Whether to save the log-probabilities of the - sampled tokens to a tensor. - _save_modified_probs: Whether to save the modified probabilities - (including gumbel noise) of the sampled tokens to a tensor. - DOES NOT include the modification done by modify_greedy_probs - (because we want to use the unmodified probs to pick the best - split in case of multi-split sampling). - This is exposed only for testing. - - Returns: - sampled_tokens: shape = [n, max_best_of] - sampled_logprobs: shape = [n, max_best_of] if save_logprobs else None - sampled_modified_probs: shape = [n, max_best_of] - if save_modified_probs else None - """ - if sample_indices is None: - sample_indices = torch.arange(0, probs.shape[0], device=probs.device) - - sampled_tokens_size = (sample_indices.size(0), max_best_of) - if save_logprobs: - if logprobs is None: - raise ValueError( - "logprobs tensor must be provided if save_logprobs is True") - sampled_logprobs_size = sampled_tokens_size - else: - # Empty tensors to invoke the kernel - sampled_logprobs_size = (0, 0) - logprobs = probs - - assert logprobs is not None - if _save_modified_probs: - sampled_modified_probs_size = sampled_tokens_size - else: - # Empty tensors to invoke the kernel - sampled_modified_probs_size = (0, 0) - - # If the number of columns in probs is too large for Triton to handle, - # we split the tensor and sample from each split separately, and then - # do an argmax+gather to combine the results. - n_splits = get_num_triton_sampler_splits(probs.shape[1]) - if n_splits > 1: - (sampled_tokens, sampled_logprobs, - sampled_modified_probs) = _multi_split_sample( - probs, - seeds, - n_splits, - sampled_tokens_size, - sampled_logprobs_size, - sample_indices, - logprobs=logprobs, - modify_greedy_probs=modify_greedy_probs, - save_logprobs=save_logprobs) - else: - sampled_tokens = torch.empty(sampled_tokens_size, - dtype=torch.long, - device=probs.device) - sampled_logprobs = torch.empty(sampled_logprobs_size, - dtype=probs.dtype, - device=probs.device) - sampled_modified_probs = torch.empty(sampled_modified_probs_size, - dtype=probs.dtype, - device=probs.device) - n_samples = sample_indices.shape[0] - n_cols = probs.shape[1] - uniform_noise = seeded_uniform(n_samples, - max_best_of, - n_cols, - seeds=seeds.flatten(), - device=probs.device, - dtype=probs.dtype) - - _sample( - probs, - logprobs, - sample_indices, - sampled_tokens, - sampled_logprobs, - sampled_modified_probs, - seeds, - uniform_noise, - modify_greedy_probs=modify_greedy_probs, - save_logprobs=save_logprobs, - save_modified_probs=_save_modified_probs, - ) - return (sampled_tokens, sampled_logprobs if save_logprobs else None, - sampled_modified_probs if _save_modified_probs else None) - - -def _sample(probs: torch.Tensor, - logprobs: torch.Tensor, - sample_indices: torch.Tensor, - output_samples: torch.Tensor, - output_logprobs: torch.Tensor, - output_modified_probs: torch.Tensor, - seeds: torch.Tensor, - uniform_noise: torch.Tensor, - *, - modify_greedy_probs: bool = False, - save_logprobs: bool = True, - save_modified_probs: bool = False) -> torch.Tensor: - """Sample tokens from probs. - - Args: - probs [batch_size, vocab_size]: probs to sample from. - logprobs [batch_size, vocab_size]: logprobs (used when - save_logprobsis True). - sample_indices [n]: Indices of the samples to use for each row of probs. - output_samples [n, n_best]: Output tensor to store samples in. - output_logprobs [n, n_best]: Output tensor to store logprobs in. - output_modified_probs [n, n_best]: Output tensor to store - probs of chosen tokens in (modified with noise). - seeds [n]: Seeds to use for sampling. If the seed is 0, we use - greedy sampling. Note this is ONLY used for determining - whether to use random sampling or not. The actual random - noise should be passed as uniform_noise. - uniform_noise [batch_size, n_best, vocab_size]: Uniform - noise to use for random sampling (will be converted - to exponential gumbel noise by the kernel). - modify_greedy_probs: If True, we modify the probs tensor in-place - to encode the sampling method used for each row. This is used - in speculative decoding. Only applies in greedy decoding. - save_logprobs: If True, we save the logprobs of the sampled tokens - in the output_logprobs tensor. - save_modified_probs: If True, we save the modified probs (with noise) - of the sampled tokens in the output_modified_probs tensor. - DOES NOT include the modification done by modify_greedy_probs - (because we want to use the unmodified probs to pick the best - split in case of multi-split sampling). - """ - n_samples = sample_indices.shape[0] - n_cols = probs.shape[1] - n_best = output_samples.shape[1] if len(output_samples.shape) > 1 else 1 - - # The block size is the smallest power of two greater than the number of - # columns in probs - block_size = triton.next_power_of_2(n_cols) - num_warps = 4 - # Manual tuning. This seems to give best performance on A100 for - # simple kernels like this. - if block_size >= 8192: - num_warps = 32 - elif block_size >= 4096: - num_warps = 16 - elif block_size >= 2048: - num_warps = 8 - - # Enqueue kernel. The 1D launch grid is simple: we have one kernel - # instance per row of the probs matrix - _sample_triton[(n_samples, n_best)]( - sample_indices, - output_samples, - output_logprobs, - output_modified_probs, - probs, - logprobs, - seeds, - uniform_noise, - output_samples.stride(0), - probs.stride(0), - uniform_noise.stride(0), - uniform_noise.stride(1) if n_best > 1 else 1, - n_samples, - n_cols, - n_best, - num_warps=num_warps, - block_size=block_size, - modify_greedy_probs=modify_greedy_probs, - save_logprobs=save_logprobs, - save_modified_probs=save_modified_probs, - ) - return output_samples, output_logprobs, output_modified_probs - - -@triton.jit -def _uniform_to_exponential(uniform_noise): - """Convert uniform samples to exponential samples.""" - # tl.rand returns values in [0, 1), so we clamp lower bound - # to _EPS to avoid log(0) and thus division by 0 later - lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype) - uniform_noise = tl.maximum(uniform_noise, lb) - # Use the inversion method to turn uniform samples - # into exponential samples - exponential_noise = -tl.log(uniform_noise) - return exponential_noise - - -@triton.jit -def _sample_triton( - sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor, - output_logprobs_ptr: torch.Tensor, - output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor, - logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor, - uniform_noise_ptr: torch.Tensor, output_row_stride: int, - probs_row_stride: int, uniform_noise_row_stride: int, - uniform_noise_best_stride: int, n_samples: int, n_cols: int, - n_best: int, block_size: tl.constexpr, - modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr, - save_modified_probs: tl.constexpr): - # The rows are independent, so we parallelize across those - sample_idx = tl.program_id(0) - best_idx = tl.program_id(1) - - # Load the row index from DRAM - row_idx = tl.load(sample_indices_ptr + sample_idx) - seed = tl.load(seeds_ptr + sample_idx) - uses_random_sampling = seed != 0 - - # The stride represents how much we need to increase the - # pointer to advance 1 row - row_start_ptr = probs_ptr + row_idx * probs_row_stride - - # The block size is the next power of two greater than n_cols, - # so we can fit each row in a single block - col_offsets = tl.arange(0, block_size) - - # Load the row into SRAM, using a mask since block_size may be > than n_cols - row = tl.load(row_start_ptr + col_offsets, - mask=col_offsets < n_cols, - other=float("-inf")) - - if uses_random_sampling: - uniform_noise_start_ptr = (uniform_noise_ptr + - sample_idx * uniform_noise_row_stride + - best_idx * uniform_noise_best_stride) - uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets, - mask=col_offsets < n_cols, - other=0.5) - exponential_noise = _uniform_to_exponential(uniform_noise) - row /= exponential_noise - - sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True) - # clamp sampled token to n_cols - 1 - # this should not be necessary, but we do it - # just in case - if sampled_token >= n_cols: - sampled_token = n_cols - 1 - # Write back output to DRAM - output_row_start_ptr = (output_ptr + sample_idx * output_row_stride + - best_idx) - tl.store(output_row_start_ptr, sampled_token) - - if modify_greedy_probs: # noqa - if not uses_random_sampling: - # Set the probability of the sampled token to 1, all other - # tokens to zero. This is used in speculative decoding where - # the sampling method must be encoded within the sampled - # probability distributions. - row = tl.where(col_offsets == sampled_token, 1.0, 0.0) - tl.store(row_start_ptr + col_offsets, - row, - mask=col_offsets < n_cols) - - if save_modified_probs: - output_row_start_ptr = (output_modified_probs_ptr + - sample_idx * output_row_stride + best_idx) - tl.store(output_row_start_ptr, sampled_value) - - if save_logprobs: - # Load the row into SRAM, using a mask since block_size - # may be > than n_cols - sampled_logprob = tl.load(logprobs_ptr + row_idx * probs_row_stride + - sampled_token) - # Write back output to DRAM - output_row_start_ptr = (output_logprobs_ptr + - sample_idx * output_row_stride + best_idx) - tl.store(output_row_start_ptr, sampled_logprob) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index c00da106734ae..487f5a3d2a441 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -10,12 +10,6 @@ import torch import torch.nn as nn -from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics -from vllm.triton_utils import HAS_TRITON - -if HAS_TRITON: - from vllm.model_executor.layers.ops.sample import sample as sample_triton - import vllm.envs as envs from vllm.model_executor.sampling_metadata import (SamplingMetadata, SamplingTensors, @@ -23,6 +17,7 @@ from vllm.sampling_params import SamplingType from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, PromptLogprobs, SampleLogprobs, SequenceOutput) +from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): import flashinfer.sampling @@ -740,7 +735,7 @@ def _sample_with_torch( ) -> SampleReturnType: '''Torch-oriented _sample() implementation. - Single-step scheduling: + Single-step scheduling: * Perform GPU-side sampling computation * Immediately Pythonize sampling result @@ -777,7 +772,7 @@ def _sample_with_torch( # Counterintiutively, having two loops here is actually faster. # The first loop can run without waiting on GPU<->CPU sync. for sampling_type in SamplingType: - sample_indices = categorized_sample_indices[sampling_type][:, 0] + sample_indices = categorized_sample_indices[sampling_type] num_tokens = len(sample_indices) if num_tokens == 0: continue @@ -863,88 +858,6 @@ def _sample_with_torch( ) -def _sample_with_triton_kernel( - probs: torch.Tensor, - logprobs: torch.Tensor, - sampling_metadata: SamplingMetadata, - sampling_tensors: SamplingTensors, -) -> SampleResultType: - categorized_seq_group_ids: Dict[SamplingType, - List[int]] = {t: [] - for t in SamplingType} - categorized_sample_indices = sampling_metadata.categorized_sample_indices - for i, seq_group in enumerate(sampling_metadata.seq_groups): - sampling_params = seq_group.sampling_params - sampling_type = sampling_params.sampling_type - categorized_seq_group_ids[sampling_type].append(i) - - sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} - sample_metadata: Dict[SamplingType, - Tuple[List[int], List[SequenceGroupToSample], - torch.Tensor, torch.Tensor]] = {} - max_best_of_in_batch = 1 - - # Counterintiutively, having two loops here is actually faster. - # The first loop can run without waiting on GPU<->CPU sync. - for sampling_type in SamplingType: - sample_indices = categorized_sample_indices[sampling_type][:, 0] - sampled_token_indices = categorized_sample_indices[sampling_type][:, 1] - num_tokens = len(sample_indices) - if num_tokens == 0: - continue - seq_group_id = categorized_seq_group_ids[sampling_type] - seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id] - sample_metadata[sampling_type] = (seq_group_id, seq_groups, - sample_indices, - sampled_token_indices) - if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM, - SamplingType.RANDOM_SEED): - for seq_group in seq_groups: - if seq_group.is_prompt: - sampling_params = seq_group.sampling_params - max_best_of_in_batch = max(max_best_of_in_batch, - sampling_params.best_of) - elif sampling_type == SamplingType.BEAM: - beam_search_logprobs = logprobs[sample_indices] - else: - raise ValueError(f"Unsupported sampling type: {sampling_type}") - - sampled_tokens, _, _ = sample_triton( - probs=probs, - seeds=sampling_tensors.sampling_seeds, - max_best_of=max_best_of_in_batch, - sample_indices=sampling_tensors.sample_indices, - logprobs=logprobs, - # don't save logprobs because we have logic for that below - # TODO: use this instead of the CPU-based logic below - save_logprobs=False, - ) - - # GPU<->CPU sync happens in the loop below. - - for sampling_type in SamplingType: - if sampling_type not in sample_metadata: - continue - (seq_group_id, seq_groups, sample_indices, - sampled_token_indices) = sample_metadata[sampling_type] - if sampling_type == SamplingType.GREEDY: - sample_results = _greedy_sample( - seq_groups, sampled_tokens[sampled_token_indices][:, 0]) - elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): - sample_results = _random_sample( - seq_groups, sampled_tokens[sampled_token_indices]) - elif sampling_type == SamplingType.BEAM: - sample_results = _beam_search_sample(seq_groups, - beam_search_logprobs) - sample_results_dict.update(zip(seq_group_id, sample_results)) - - sample_results = [ - sample_results_dict.get(i, ([], [])) - for i in range(len(sampling_metadata.seq_groups)) - ] - return sample_results - - def _sample( probs: torch.Tensor, logprobs: torch.Tensor, @@ -974,10 +887,6 @@ def _sample( modify_greedy_probs=modify_greedy_probs, ) - # TODO: Enable once Triton kernel & associated code is faster. - # return _sample_with_triton_kernel(probs, logprobs, sampling_metadata, - # sampling_tensors) - def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: """ diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index a085779bc61a7..97d36d31f2b11 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -1,4 +1,3 @@ -import random from array import array from dataclasses import dataclass from typing import Dict, List, Optional, Tuple @@ -8,15 +7,10 @@ from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData, SequenceGroupMetadata) -from vllm.triton_utils.sample import get_num_triton_sampler_splits from vllm.utils import (PyObjectCache, async_tensor_h2d, - is_pin_memory_available, make_tensor_with_pad, - maybe_expand_dim) + is_pin_memory_available, make_tensor_with_pad) _SAMPLING_EPS = 1e-5 -_SEED_0_REPLACEMENT = 3403598558 -# Some triton sampler related code is guarded before it is ready. -_USE_TRITON_SAMPLER = False @dataclass @@ -74,12 +68,12 @@ def gen_seq_group_to_sample_builder(num_seqs: int): generator=None, is_prompt=True, prompt_logprob_indices=[], - sample_indices=[]) + sample_indices=[], + ) class SamplingMetadataCache: - """Used to cache SamplingMetadata objects between scheduler iterations - """ + """Used to cache SamplingMetadata objects between scheduler iterations""" def __init__(self): self._seq_group_to_sample_cache: Dict[int, PyObjectCache] = {} @@ -124,12 +118,12 @@ def sample(logits): The first tuple is [1, 2] (sampled index within original logit), and the second tuple is [0, 1] (sampled index within pruned logit). num_prompts: Number of prompt sequence groups in seq_groups. - skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU + skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU serialization of token outputs. - reuse_sampling_tensors: Indicates if we want to reuse sampling + reuse_sampling_tensors: Indicates if we want to reuse sampling tensors that are part of the sampler forward pass. Currently, it is mainly used for multi-step decode. - + """ def __init__( @@ -165,16 +159,19 @@ def prepare( num_prompts, ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens, device, generators, cache) - selected_token_indices = async_tensor_h2d(selected_token_indices, - dtype=torch.long, - target_device=device, - pin_memory=pin_memory) + selected_token_indices = async_tensor_h2d( + selected_token_indices, + dtype=torch.long, + target_device=device, + pin_memory=pin_memory, + ) categorized_sample_indices = { - t: maybe_expand_dim( - async_tensor_h2d(seq_ids, - dtype=torch.int, - target_device=device, - pin_memory=pin_memory), 2, 2) + t: async_tensor_h2d( + seq_ids, + dtype=torch.int, + target_device=device, + pin_memory=pin_memory, + ) for t, seq_ids in categorized_sample_indices.items() } @@ -201,8 +198,8 @@ def _prepare_seq_groups( device: str, generators: Optional[Dict[str, torch.Generator]] = None, cache: Optional[SamplingMetadataCache] = None, -) -> Tuple[List[SequenceGroupToSample], List[int], Dict[ - SamplingType, List[Tuple[int, int]]], int]: +) -> Tuple[List[SequenceGroupToSample], List[int], Dict[SamplingType, + List[int]], int, ]: """Prepare sequence groups and indices for sampling. Args: @@ -233,16 +230,13 @@ def _prepare_seq_groups( # Sampling type -> ( # indices to sample/prompt logprob within pruned output logits, # indices to sample within pruned logits) - categorized_sample_indices: Dict[SamplingType, List[Tuple[int, int]]] = { + categorized_sample_indices: Dict[SamplingType, List[int]] = { t: [] for t in SamplingType } # Index of logits to compute logprob. Logits include both prompt logprob # and sample logprob indices. logit_idx = 0 - # Index to sample from a sample tensor. It is used by triton sample kernel. - # See `_sample_with_triton_kernel` for more details. - sample_idx = 0 # Total number of prompts from given sequence groups. num_prompts = 0 @@ -264,10 +258,10 @@ def _prepare_seq_groups( # If the current seq group is in decode stage, it is None. seq_len: Optional[int] = None query_len: Optional[int] = None - prompt_logprob_indices: List[int] = \ - sample_obj.prompt_logprob_indices if cache is not None else [] - sample_indices: List[int] = \ - sample_obj.sample_indices if cache is not None else [] + prompt_logprob_indices: List[int] = (sample_obj.prompt_logprob_indices + if cache is not None else []) + sample_indices: List[int] = (sample_obj.sample_indices + if cache is not None else []) do_sample = seq_group_metadata.do_sample if seq_group_metadata.is_prompt: @@ -333,11 +327,8 @@ def sample(logits): if do_sample: sample_indices.extend(range(logit_idx, logit_idx + sample_len)) categorized_sample_indices[sampling_params.sampling_type].extend( - list( - zip(range(logit_idx, logit_idx + sample_len), - range(sample_idx, sample_idx + sample_len)))) + list(range(logit_idx, logit_idx + sample_len))) logit_idx += sample_len - sample_idx += sample_len if cache is not None: sample_obj.sampling_params = sampling_params @@ -356,7 +347,8 @@ def sample(logits): generator=generator, is_prompt=is_prompt, prompt_logprob_indices=list(prompt_logprob_indices), - sample_indices=list(sample_indices)) + sample_indices=list(sample_indices), + ) seq_groups.append(sample_obj) @@ -378,9 +370,6 @@ class SamplingTensors: presence_penalties: torch.Tensor frequency_penalties: torch.Tensor repetition_penalties: torch.Tensor - sampling_seeds: torch.Tensor - sample_indices: torch.Tensor - extra_seeds: Optional[torch.Tensor] prompt_tokens: torch.Tensor output_tokens: torch.Tensor @@ -391,15 +380,7 @@ def from_sampling_metadata( vocab_size: int, device: torch.device, dtype: torch.dtype, - *, - extra_seeds_to_generate: int = 0, - extra_entropy: Optional[Tuple[int, ...]] = None ) -> Tuple["SamplingTensors", bool, bool, bool]: - """ - extra_seeds_to_generate: extra seeds to generate using the - user-defined seed for each sequence. - extra_entropy: extra entropy to use when generating seeds. - """ prompt_tokens: List[array] = [] output_tokens: List[array] = [] top_ks: List[int] = [] @@ -409,19 +390,10 @@ def from_sampling_metadata( presence_penalties: List[float] = [] frequency_penalties: List[float] = [] repetition_penalties: List[float] = [] - sampling_seeds: List[int] = [] - sample_indices: List[int] = [] do_penalties = False do_top_p_top_k = False do_min_p = False - if _USE_TRITON_SAMPLER: - prompt_best_of: List[int] = [] - - # We need one base seed per Triton slice. - seeds_to_generate = (extra_seeds_to_generate + - get_num_triton_sampler_splits(vocab_size)) - assert sampling_metadata.seq_groups is not None for seq_group in sampling_metadata.seq_groups: seq_ids = seq_group.seq_ids @@ -452,7 +424,7 @@ def from_sampling_metadata( do_penalties = True is_prompt = seq_group.is_prompt - if (is_prompt and sampling_params.prompt_logprobs is not None): + if is_prompt and sampling_params.prompt_logprobs is not None: # For tokens in the prompt that we only need to get # their logprobs query_len = seq_group.query_len @@ -477,28 +449,6 @@ def from_sampling_metadata( frequency_penalties += [f] * len(seq_ids) repetition_penalties += [r] * len(seq_ids) - if _USE_TRITON_SAMPLER: - if is_prompt: - prompt_best_of.append(sampling_params.best_of) - query_len = seq_group.query_len - assert query_len is not None - - seed = sampling_params.seed - is_greedy = sampling_params.sampling_type == SamplingType.GREEDY - - for seq_id in seq_ids: - seq_data = seq_group.seq_data[seq_id] - extra_entropy = extra_entropy or () - seq_seeds = cls._get_sequence_seeds( - seed, - seq_data.get_len(), - *extra_entropy, - seq_id, - seeds_to_generate=seeds_to_generate, - is_greedy=is_greedy) - sampling_seeds.append(seq_seeds) - sample_indices.extend(seq_group.sample_indices) - if do_penalties: for seq_group in sampling_metadata.seq_groups: seq_ids = seq_group.seq_ids @@ -518,23 +468,37 @@ def from_sampling_metadata( output_tokens.append(seq_data.output_token_ids_array) sampling_tensors = SamplingTensors.from_lists( - temperatures, top_ps, top_ks, min_ps, presence_penalties, - frequency_penalties, repetition_penalties, sampling_seeds, - sample_indices, prompt_tokens, output_tokens, vocab_size, - extra_seeds_to_generate, device, dtype) + temperatures, + top_ps, + top_ks, + min_ps, + presence_penalties, + frequency_penalties, + repetition_penalties, + prompt_tokens, + output_tokens, + vocab_size, + device, + dtype, + ) return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) @classmethod - def from_lists(cls, temperatures: List[float], top_ps: List[float], - top_ks: List[int], min_ps: List[float], - presence_penalties: List[float], - frequency_penalties: List[float], - repetition_penalties: List[float], - sampling_seeds: List[int], sample_indices: List[int], - prompt_tokens: List[array], output_tokens: List[array], - vocab_size: int, extra_seeds_to_generate: int, - device: torch.device, - dtype: torch.dtype) -> "SamplingTensors": + def from_lists( + cls, + temperatures: List[float], + top_ps: List[float], + top_ks: List[int], + min_ps: List[float], + presence_penalties: List[float], + frequency_penalties: List[float], + repetition_penalties: List[float], + prompt_tokens: List[array], + output_tokens: List[array], + vocab_size: int, + device: torch.device, + dtype: torch.dtype, + ) -> "SamplingTensors": # Note that the performance will be very bad without # pinned memory. pin_memory = is_pin_memory_available() @@ -603,34 +567,9 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], dtype=torch.int, pin_memory=pin_memory, ) - sample_indices_t = torch.tensor( - sample_indices, - device="cpu", - dtype=torch.long, - pin_memory=pin_memory, - ) - # need to transpose and make contiguous to - # copy the tensor correctly. - # [batch_size, n_seeds] -> [n_seeds, batch_size] - sampling_seeds_t = torch.tensor( - sampling_seeds, - device="cpu", - dtype=torch.long, - pin_memory=pin_memory, - ).t().contiguous() - # Because the memory is pinned, we can do non-blocking # transfer to device. - # How many seeds the sample operation itself will need. - num_base_seeds = sampling_seeds_t.shape[0] - extra_seeds_to_generate - sampling_seeds_gpu = sampling_seeds_t.to(device=device, - non_blocking=True) - extra_seeds_gpu = sampling_seeds_gpu[num_base_seeds:] - if not extra_seeds_gpu.numel(): - extra_seeds_gpu = None - sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds] - return cls( temperatures=temperatures_t.to(device=device, non_blocking=True), top_ps=top_ps_t.to(device=device, non_blocking=True), @@ -644,38 +583,4 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], non_blocking=True), prompt_tokens=prompt_t.to(device=device, non_blocking=True), output_tokens=output_t.to(device=device, non_blocking=True), - sampling_seeds=sampling_seeds_gpu, - sample_indices=sample_indices_t.to(device=device, - non_blocking=True), - extra_seeds=extra_seeds_gpu, ) - - @staticmethod - def _get_sequence_seeds( - seed: int, - *extra_entropy: int, - seeds_to_generate: int, - is_greedy: bool, - ): - """Get `seeds_to_generate` child seeds from `seed` and extra entropy.""" - if not is_greedy: - if seed is None: - randint_fn = random.randint - else: - generator = random.Random(str((seed, ) + extra_entropy)) - randint_fn = generator.randint - lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max - # If the user/random sets seed = 0 but request should - # have sampling, we need to change it to something - # else. We use a constant in that case. - # This way we don't need to create and load a bool - # matrix in the sampling kernel, which reduces CPU - # overhead and latency. - seq_seeds = [ - randint_fn(lo, hi) or _SEED_0_REPLACEMENT - for _ in range(seeds_to_generate) - ] - else: - # For the kernel, seed == 0 means greedy decoding. - seq_seeds = [0] * seeds_to_generate - return seq_seeds diff --git a/vllm/triton_utils/sample.py b/vllm/triton_utils/sample.py deleted file mode 100644 index 401e4d28a3c99..0000000000000 --- a/vllm/triton_utils/sample.py +++ /dev/null @@ -1,13 +0,0 @@ -import math - -# This is a hardcoded limit in Triton (max block size). -MAX_TRITON_N_COLS = 131072 - - -def get_num_triton_sampler_splits(n_cols: int) -> int: - """Get the number of splits to use for Triton sampling. - - Triton has a limit on the number of columns it can handle, so we need to - split the tensor and call the kernel multiple times if it's too large. - """ - return math.ceil(n_cols / MAX_TRITON_N_COLS) diff --git a/vllm/utils.py b/vllm/utils.py index 014fc16a17c1f..1cbd9d55c68b3 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -270,7 +270,7 @@ def clear(self): class PyObjectCache: - """Used to cache python objects to avoid object allocations + """Used to cache python objects to avoid object allocations across scheduler iterations. """ @@ -289,7 +289,7 @@ def _grow_cache(self): self._obj_cache.append(self._obj_builder()) def get_object(self): - """Returns a pre-allocated cached object. If there is not enough + """Returns a pre-allocated cached object. If there is not enough objects, then the cache size will double. """ if self._index >= len(self._obj_cache): @@ -837,15 +837,6 @@ def async_tensor_h2d( return t.to(device=target_device, non_blocking=True) -def maybe_expand_dim(tensor: torch.Tensor, - target_dims: int, - size: int = 1) -> torch.Tensor: - """Expand the tensor to the target_dims.""" - if tensor.ndim < target_dims: - tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim))) - return tensor - - def get_dtype_size(dtype: torch.dtype) -> int: """Get the size of the data type in bytes.""" return torch.tensor([], dtype=dtype).element_size() @@ -1070,7 +1061,7 @@ def _cuda_device_count_stateless( def cuda_device_count_stateless() -> int: """Get number of CUDA devices, caching based on the value of CUDA_VISIBLE_DEVICES at the time of call. - + This should be used instead of torch.cuda.device_count() unless CUDA_VISIBLE_DEVICES has already been set to the desired value.""" @@ -1136,10 +1127,10 @@ def parse_args(self, args=None, namespace=None): def _pull_args_from_config(args: List[str]) -> List[str]: """Method to pull arguments specified in the config file into the command-line args variable. - - The arguments in config file will be inserted between + + The arguments in config file will be inserted between the argument list. - + example: ```yaml port: 12323 @@ -1150,21 +1141,21 @@ def _pull_args_from_config(args: List[str]) -> List[str]: --config config.yaml -tp 2 $: args = [ "serve,chat,complete", - "facebook/opt-12B", - '--config', 'config.yaml', + "facebook/opt-12B", + '--config', 'config.yaml', '-tp', '2' ] $: args = [ "serve,chat,complete", - "facebook/opt-12B", - '--port', '12323', - '--tensor-parallel-size', '4', + "facebook/opt-12B", + '--port', '12323', + '--tensor-parallel-size', '4', '-tp', '2' ] ``` Please note how the config args are inserted after the sub command. - this way the order of priorities is maintained when these are args + this way the order of priorities is maintained when these are args parsed by super(). """ assert args.count( @@ -1190,7 +1181,7 @@ def _pull_args_from_config(args: List[str]) -> List[str]: @staticmethod def _load_config_file(file_path: str) -> List[str]: - """Loads a yaml file and returns the key value pairs as a + """Loads a yaml file and returns the key value pairs as a flattened list with argparse like pattern ```yaml port: 12323 @@ -1201,7 +1192,7 @@ def _load_config_file(file_path: str) -> List[str]: '--port': '12323', '--tensor-parallel-size': '4' ] - + """ extension: str = file_path.split('.')[-1] From 1c1bb388e0d35a2d10da5c5cda2edac57bf62591 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Mon, 16 Sep 2024 22:17:32 -0600 Subject: [PATCH 59/98] [Frontend] Improve Nullable kv Arg Parsing (#8525) Signed-off-by: Alex-Brooks --- tests/engine/test_arg_utils.py | 20 +++++++++++++++++++- vllm/engine/arg_utils.py | 28 +++++++++++++++++++++------- 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 3208d6bb48bdc..8dd200b35d0f3 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -1,6 +1,8 @@ +from argparse import ArgumentTypeError + import pytest -from vllm.engine.arg_utils import EngineArgs +from vllm.engine.arg_utils import EngineArgs, nullable_kvs from vllm.utils import FlexibleArgumentParser @@ -13,6 +15,10 @@ "image": 16, "video": 2 }), + ("Image=16, Video=2", { + "image": 16, + "video": 2 + }), ]) def test_limit_mm_per_prompt_parser(arg, expected): parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) @@ -22,3 +28,15 @@ def test_limit_mm_per_prompt_parser(arg, expected): args = parser.parse_args(["--limit-mm-per-prompt", arg]) assert args.limit_mm_per_prompt == expected + + +@pytest.mark.parametrize( + ("arg"), + [ + "image", # Missing = + "image=4,image=5", # Conflicting values + "image=video=4" # Too many = in tokenized arg + ]) +def test_bad_nullable_kvs(arg): + with pytest.raises(ArgumentTypeError): + nullable_kvs(arg) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b5eba9ca3727a..35013eedea9c6 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -44,22 +44,36 @@ def nullable_str(val: str): def nullable_kvs(val: str) -> Optional[Mapping[str, int]]: + """Parses a string containing comma separate key [str] to value [int] + pairs into a dictionary. + + Args: + val: String value to be parsed. + + Returns: + Dictionary with parsed values. + """ if len(val) == 0: return None out_dict: Dict[str, int] = {} for item in val.split(","): - try: - key, value = item.split("=") - except TypeError as exc: - msg = "Each item should be in the form KEY=VALUE" - raise ValueError(msg) from exc + kv_parts = [part.lower().strip() for part in item.split("=")] + if len(kv_parts) != 2: + raise argparse.ArgumentTypeError( + "Each item should be in the form KEY=VALUE") + key, value = kv_parts try: - out_dict[key] = int(value) + parsed_value = int(value) except ValueError as exc: msg = f"Failed to parse value of item {key}={value}" - raise ValueError(msg) from exc + raise argparse.ArgumentTypeError(msg) from exc + + if key in out_dict and out_dict[key] != parsed_value: + raise argparse.ArgumentTypeError( + f"Conflicting values specified for key: {key}") + out_dict[key] = parsed_value return out_dict From ee2bceaaa67bd2f420f62a924da5834a7c1c862b Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Mon, 16 Sep 2024 22:22:45 -0700 Subject: [PATCH 60/98] [Misc][Bugfix] Disable guided decoding for mistral tokenizer (#8521) --- .../guided_decoding/__init__.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 7161e83952a3d..f4fe8a7307c04 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -6,6 +6,7 @@ from vllm.model_executor.guided_decoding.guided_fields import ( GuidedDecodingRequest) from vllm.sampling_params import LogitsProcessor +from vllm.transformers_utils.tokenizer import MistralTokenizer async def get_guided_decoding_logits_processor( @@ -15,12 +16,23 @@ async def get_guided_decoding_logits_processor( request = _adapt_request_for_tool_use(request) if guided_decoding_backend == 'outlines': + if isinstance(tokenizer, MistralTokenizer): + raise NotImplementedError( + "Guided decoding with 'outlines' is currently not supported " + "for Mistral tokenizer. Please consider contributing to the " + "'outlines' project if you are interested in this feature.") # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_outlines_guided_decoding_logits_processor) return await get_outlines_guided_decoding_logits_processor( request, tokenizer) if guided_decoding_backend == 'lm-format-enforcer': + if isinstance(tokenizer, MistralTokenizer): + raise NotImplementedError( + "Guided decoding with 'lm-format-enforcer' is currently not " + "supported for Mistral tokenizer. Please consider contributing " + "to the 'lm-format-enforcer' project if you are interested " + "in this feature.") from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa get_lm_format_enforcer_guided_decoding_logits_processor) return await get_lm_format_enforcer_guided_decoding_logits_processor( @@ -37,12 +49,23 @@ def get_local_guided_decoding_logits_processor( # request = _adapt_request_for_tool_use(request) if guided_decoding_backend == 'outlines': + if isinstance(tokenizer, MistralTokenizer): + raise NotImplementedError( + "Guided decoding with 'outlines' is currently not supported " + "for Mistral tokenizer. Please consider contributing to the " + "'outlines' project if you are interested in this feature.") # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_local_outlines_guided_decoding_logits_processor) return get_local_outlines_guided_decoding_logits_processor( guided_options, tokenizer) if guided_decoding_backend == 'lm-format-enforcer': + if isinstance(tokenizer, MistralTokenizer): + raise NotImplementedError( + "Guided decoding with 'lm-format-enforcer' is currently not " + "supported for Mistral tokenizer. Please consider contributing " + "to the 'lm-format-enforcer' project if you are interested " + "in this feature.") from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa get_local_lm_format_enforcer_guided_decoding_logits_processor) return get_local_lm_format_enforcer_guided_decoding_logits_processor( From 99aa4eddaf929f57dac405b00db3f5286624ee8b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 16 Sep 2024 22:57:57 -0700 Subject: [PATCH 61/98] [torch.compile] register allreduce operations as custom ops (#8526) --- .buildkite/test-pipeline.yaml | 10 +- csrc/custom_all_reduce.cu | 12 -- csrc/ops.h | 2 - csrc/torch_bindings.cpp | 5 - tests/compile/__init__.py | 0 tests/compile/test_full_graph.py | 15 ++- vllm/_custom_ops.py | 6 - .../device_communicators/custom_all_reduce.py | 21 +++- vllm/distributed/parallel_state.py | 116 +++++++++++++++--- 9 files changed, 137 insertions(+), 50 deletions(-) create mode 100644 tests/compile/__init__.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 9b0cb6663a55b..9483adcc5d587 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -163,13 +163,6 @@ steps: - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - python3 offline_inference_encoder_decoder.py -- label: torch compile integration test - source_file_dependencies: - - vllm/ - commands: - - pytest -v -s ./compile/test_full_graph.py - - pytest -v -s ./compile/test_wrapper.py - - label: Prefix Caching Test # 7min #mirror_hardwares: [amd] source_file_dependencies: @@ -348,7 +341,10 @@ steps: - vllm/executor/ - vllm/model_executor/models/ - tests/distributed/ + - vllm/compilation commands: + - pytest -v -s ./compile/test_full_graph.py + - pytest -v -s ./compile/test_wrapper.py - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed' - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus # Avoid importing model tests that cause CUDA reinitialization error diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index 82a3563979f16..9b82bec44c3c6 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -55,18 +55,6 @@ bool _is_weak_contiguous(torch::Tensor& t) { t.numel() * t.element_size()); } -bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size, - bool full_nvlink) { - auto inp_size = inp.numel() * inp.element_size(); - // custom allreduce requires input byte size to be multiples of 16 - if (inp_size % 16 != 0) return false; - if (!_is_weak_contiguous(inp)) return false; - if (world_size == 2 || full_nvlink) return inp_size <= max_size; - // for 4 or more non NVLink-capable GPUs, custom allreduce provides little - // performance improvement over NCCL. - return false; -} - void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, cudaStream_t stream) { auto fa = reinterpret_cast(_fa); diff --git a/csrc/ops.h b/csrc/ops.h index 681ab4b898ca3..ee89ad32cb025 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -241,8 +241,6 @@ fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, const std::vector& handles, const std::vector& offsets, int64_t rank, bool full_nvlink); -bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size, - bool full_nvlink); void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, torch::Tensor& out); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index d7f7547fbef55..7009180a8687c 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -411,11 +411,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { "bool full_nvlink) -> int"); custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar); - custom_ar.def( - "should_custom_ar(Tensor inp, int max_size, int world_size, " - "bool full_nvlink) -> bool"); - custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar); - custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()"); custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg); diff --git a/tests/compile/__init__.py b/tests/compile/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 5452ce6be8110..6fc445539bbbe 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -2,9 +2,20 @@ import pytest +from vllm.utils import cuda_device_count_stateless + +from ..utils import fork_new_process_for_each_test + @pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) -def test_full_graph(model): +@pytest.mark.parametrize("tp_size", [1, 2]) +@fork_new_process_for_each_test +def test_full_graph(model, tp_size): + + # Skip the test if there are not enough CUDA devices. + if cuda_device_count_stateless() < tp_size: + pytest.skip("Not enough CUDA devices for the test.") + # make sure these models can be captured in full graph mode if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ: os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1" @@ -17,7 +28,7 @@ def test_full_graph(model): "The future of AI is", ] sampling_params = SamplingParams(temperature=0) - llm = LLM(model=model, enforce_eager=True) + llm = LLM(model=model, enforce_eager=True, tensor_parallel_size=tp_size) outputs = llm.generate(prompts, sampling_params) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index d5b3d7bc6dd5a..ac90895b11c37 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -870,12 +870,6 @@ def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor, offsets, rank, full_nvlink) -def should_custom_ar(inp: torch.Tensor, max_size: int, world_size: int, - full_nvlink: bool) -> bool: - return torch.ops._C_custom_ar.should_custom_ar(inp, max_size, world_size, - full_nvlink) - - def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 6229f1d6ec788..d239d645edc14 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -33,6 +33,12 @@ def _can_p2p(rank: int, world_size: int) -> bool: return True +def is_weak_contiguous(inp: torch.Tensor): + return inp.is_contiguous() or (inp.storage().nbytes() - + inp.storage_offset() * inp.element_size() + == inp.numel() * inp.element_size()) + + class CustomAllreduce: _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] @@ -224,8 +230,19 @@ def register_graph_buffers(self): ops.register_graph_buffers(self._ptr, handles, offsets) def should_custom_ar(self, inp: torch.Tensor): - return ops.should_custom_ar(inp, self.max_size, self.world_size, - self.full_nvlink) + if self.disabled: + return False + inp_size = inp.numel() * inp.element_size() + # custom allreduce requires input byte size to be multiples of 16 + if inp_size % 16 != 0: + return False + if not is_weak_contiguous(inp): + return False + # for 4 or more non NVLink-capable GPUs, custom allreduce provides + # little performance improvement over NCCL. + if self.world_size == 2 or self.full_nvlink: + return inp_size < self.max_size + return False # all reduce, assuming inp tensor is IPC registered with register_buffer, # or, in the context of cuda graphs, register_graph_buffers diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 6755b20eec9bb..1c864bcd5d708 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -21,11 +21,12 @@ """ import contextlib import pickle +import weakref from collections import namedtuple from contextlib import contextmanager, nullcontext from dataclasses import dataclass from multiprocessing import shared_memory -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from unittest.mock import patch import torch @@ -69,6 +70,58 @@ def _split_tensor_dict( return metadata_list, tensor_list +_group_name_counter: Dict[str, int] = {} + + +def _get_unique_name(name: str) -> str: + """Get a unique name for the group. + Example: + _get_unique_name("tp") -> "tp:0" + _get_unique_name("tp") -> "tp:1" + """ + if name not in _group_name_counter: + _group_name_counter[name] = 0 + newname = f"{name}:{_group_name_counter[name]}" + _group_name_counter[name] += 1 + return newname + + +_groups: Dict[str, Callable[[], "GroupCoordinator"]] = {} + + +def _register_group(group: "GroupCoordinator") -> None: + # looks like Python 3.8 does not understand `ReferenceType` + _groups[group.unique_name] = weakref.ref(group) # type: ignore + + +@torch.library.custom_op("vllm::inplace_all_reduce", mutates_args=["tensor"]) +def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + group._all_reduce(tensor) + + +@inplace_all_reduce.register_fake +def _(tensor: torch.Tensor, group_name: str) -> None: + return + + +@torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[]) +def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + return group._all_reduce(tensor) + + +@outplace_all_reduce.register_fake +def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + return torch.empty_like(tensor) + + class GroupCoordinator: """ PyTorch ProcessGroup wrapper for a group of processes. @@ -111,7 +164,11 @@ def __init__( use_custom_allreduce: bool, use_tpu_communicator: bool, use_message_queue_broadcaster: bool = False, + group_name: Optional[str] = None, ): + group_name = group_name or "anonymous" + self.unique_name = _get_unique_name(group_name) + _register_group(self) self.rank = torch.distributed.get_rank() self.local_rank = local_rank @@ -149,28 +206,24 @@ def __init__( from vllm.distributed.device_communicators.pynccl import ( PyNcclCommunicator) - self.pynccl_comm: Optional[PyNcclCommunicator] + self.pynccl_comm: Optional[PyNcclCommunicator] = None if use_pynccl and self.world_size > 1: self.pynccl_comm = PyNcclCommunicator( group=self.cpu_group, device=self.device, ) - else: - self.pynccl_comm = None - self.ca_comm: Optional[CustomAllreduce] + self.ca_comm: Optional[CustomAllreduce] = None if use_custom_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. self.ca_comm = CustomAllreduce( group=self.cpu_group, device=self.device, ) - else: - self.ca_comm = None from vllm.distributed.device_communicators.tpu_communicator import ( TpuCommunicator) - self.tpu_communicator: Optional[TpuCommunicator] + self.tpu_communicator: Optional[TpuCommunicator] = None if use_tpu_communicator and self.world_size > 1: self.tpu_communicator = TpuCommunicator(group=self.cpu_group) @@ -264,16 +317,46 @@ def graph_capture( def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: """ + User-facing all-reduce function before we actually call the + all-reduce operation. + + We need this because Dynamo does not support passing an arbitrary + object (`self` in this case) to a custom op. We need to pass the + group name as a string, and then look up the group coordinator from + the group name, dispatch the all-reduce operation to the group + coordinator. + + In addition, PyTorch custom ops do not support mutation or returning + a new tensor in the same op. So we need to figure out if the op is + in-place or out-of-place ahead of time. + """ + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + + if self.tpu_communicator is not None and \ + not self.tpu_communicator.disabled: + # TPU handles Dynamo with its own logic. + return self._all_reduce(input_) + + if self.ca_comm is not None and self.ca_comm.should_custom_ar(input_): + return torch.ops.vllm.outplace_all_reduce( + input_, group_name=self.unique_name) + else: + torch.ops.vllm.inplace_all_reduce(input_, + group_name=self.unique_name) + return input_ + + def _all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + """ + The actual all-reduce implementation. + NOTE: This operation will be applied in-place or out-of-place. Always assume this function modifies its input, but use the return value as the output. """ ca_comm = self.ca_comm - # Bypass the function if we are using only 1 GPU. - if self.world_size == 1: - return input_ - # For TPUs, use TPU communicator. tpu_comm = self.tpu_communicator if tpu_comm is not None and not tpu_comm.disabled: @@ -758,6 +841,7 @@ def init_world_group(ranks: List[int], local_rank: int, use_pynccl=False, use_custom_allreduce=False, use_tpu_communicator=False, + group_name="world", ) @@ -767,6 +851,7 @@ def init_model_parallel_group( backend: str, use_custom_allreduce: Optional[bool] = None, use_message_queue_broadcaster: bool = False, + group_name: Optional[str] = None, ) -> GroupCoordinator: if use_custom_allreduce is None: use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE @@ -778,6 +863,7 @@ def init_model_parallel_group( use_custom_allreduce=use_custom_allreduce, use_tpu_communicator=True, use_message_queue_broadcaster=use_message_queue_broadcaster, + group_name=group_name, ) @@ -931,7 +1017,8 @@ def initialize_model_parallel( _TP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, - use_message_queue_broadcaster=True) + use_message_queue_broadcaster=True, + group_name="tp") # Build the pipeline model-parallel groups. num_pipeline_model_parallel_groups: int = (world_size // @@ -947,7 +1034,8 @@ def initialize_model_parallel( _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, - use_custom_allreduce=False) + use_custom_allreduce=False, + group_name="pp") def ensure_model_parallel_initialized( From cbdb25225914a04d94e8830f4e739faca8ff3b9d Mon Sep 17 00:00:00 2001 From: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> Date: Tue, 17 Sep 2024 00:06:26 -0700 Subject: [PATCH 62/98] [Misc] Limit to ray[adag] 2.35 to avoid backward incompatible change (#8509) Signed-off-by: Rui Qiao --- requirements-test.txt | 2 +- vllm/executor/ray_gpu_executor.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/requirements-test.txt b/requirements-test.txt index 16a883b81ce50..10d463de27be5 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -14,7 +14,7 @@ librosa # required for audio test opencv-python # required for video test peft requests -ray[adag]>=2.35 +ray[adag]==2.35 sentence-transformers # required for embedding soundfile # required for audio test compressed-tensors==0.4.0 # required for compressed-tensors diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index b124fe2e08ea6..9433dce842b09 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -437,8 +437,10 @@ def _check_ray_adag_installation(self): required_version = version.parse("2.35") current_version = version.parse( pkg_resources.get_distribution("ray").version) - if current_version < required_version: - raise ValueError(f"Ray version {required_version} or greater is " + # TODO: update the constraint once we adapt to the backward + # incompatible API change from ray 2.36 + if current_version != required_version: + raise ValueError(f"Ray version {required_version} is " f"required, but found {current_version}") import importlib.util From 1b6de8352b878348974b3f117cbb68ed18daa609 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 17 Sep 2024 15:34:27 +0800 Subject: [PATCH 63/98] [Benchmark] Support sample from HF datasets and image input for benchmark_serving (#8495) --- benchmarks/backend_request_func.py | 6 +- benchmarks/benchmark_serving.py | 239 +++++++++++++++++++++-------- 2 files changed, 177 insertions(+), 68 deletions(-) diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 3243bb94f787c..3def4a6d67acf 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -25,6 +25,7 @@ class RequestFuncInput: best_of: int = 1 use_beam_search: bool = False logprobs: Optional[int] = None + multi_modal_content: Optional[dict] = None @dataclass @@ -312,12 +313,15 @@ async def async_request_openai_chat_completions( async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: assert not request_func_input.use_beam_search + content = [{"type": "text", "text": request_func_input.prompt}] + if request_func_input.multi_modal_content: + content.append(request_func_input.multi_modal_content) payload = { "model": request_func_input.model, "messages": [ { "role": "user", - "content": request_func_input.prompt, + "content": content }, ], "temperature": 0.0, diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 9ba3f649810b7..3ace910a6cac6 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -24,6 +24,8 @@ """ import argparse import asyncio +import base64 +import io import json import os import random @@ -31,11 +33,13 @@ import warnings from dataclasses import dataclass from datetime import datetime -from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple +from typing import Any, AsyncGenerator, Collection, Dict, List, Optional, Tuple import numpy as np from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput, RequestFuncOutput) +from datasets import load_dataset +from PIL.Image import Image from tqdm.asyncio import tqdm from transformers import PreTrainedTokenizerBase @@ -84,7 +88,7 @@ def sample_sharegpt_requests( num_requests: int, tokenizer: PreTrainedTokenizerBase, fixed_output_len: Optional[int] = None, -) -> List[Tuple[str, int, int]]: +) -> List[Tuple[str, int, int, None]]: if fixed_output_len is not None and fixed_output_len < 4: raise ValueError("output_len too small") # Load the dataset. @@ -119,7 +123,7 @@ def sample_sharegpt_requests( if prompt_len > 1024 or prompt_len + output_len > 2048: # Prune too long sequences. continue - filtered_dataset.append((prompt, prompt_len, output_len)) + filtered_dataset.append((prompt, prompt_len, output_len, None)) return filtered_dataset @@ -131,7 +135,7 @@ def sample_sonnet_requests( output_len: int, prefix_len: int, tokenizer: PreTrainedTokenizerBase, -) -> List[Tuple[str, str, int, int]]: +) -> List[Tuple[str, str, int, int, None]]: assert ( input_len > prefix_len ), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'." @@ -189,7 +193,65 @@ def sample_sonnet_requests( message, add_generation_prompt=True, tokenize=False) prompt_len = len(tokenizer(prompt_formatted).input_ids) sampled_requests.append( - (prompt, prompt_formatted, prompt_len, output_len)) + (prompt, prompt_formatted, prompt_len, output_len, None)) + + return sampled_requests + + +def sample_hf_requests( + dataset_path: str, + dataset_subset: str, + dataset_split: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int] = None, +) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]: + dataset = load_dataset(dataset_path, + name=dataset_subset, + split=dataset_split, + streaming=True) + assert "conversations" in dataset.features, ( + "HF Dataset must have 'conversations' column.") + filtered_dataset = dataset.shuffle().filter( + lambda x: len(x["conversations"]) >= 2) + sampled_requests: List[Tuple[str, int, int, Dict[str, + Collection[str]]]] = [] + for data in filtered_dataset: + if len(sampled_requests) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = data["conversations"][0]["value"] + prompt_token_ids = tokenizer(prompt).input_ids + completion = data["conversations"][1]["value"] + completion_token_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_token_ids) + output_len = len(completion_token_ids + ) if fixed_output_len is None else fixed_output_len + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. + continue + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue + + if "image" in data and isinstance(data["image"], Image): + image: Image = data["image"] + image = image.convert("RGB") + image_data = io.BytesIO() + image.save(image_data, format='JPEG') + image_base64 = base64.b64encode( + image_data.getvalue()).decode("utf-8") + mm_content = { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_base64}" + }, + } + else: + mm_content = None + + sampled_requests.append((prompt, prompt_len, output_len, mm_content)) return sampled_requests @@ -223,8 +285,8 @@ def sample_random_requests( [(offsets[i] + i + j) % tokenizer.vocab_size for j in range(input_lens[i])]) - input_requests.append( - (prompt, int(prefix_len + input_lens[i]), int(output_lens[i]))) + input_requests.append((prompt, int(prefix_len + input_lens[i]), + int(output_lens[i]), None)) return input_requests @@ -343,7 +405,12 @@ async def benchmark( raise ValueError(f"Unknown backend: {backend}") print("Starting initial single prompt test run...") - test_prompt, test_prompt_len, test_output_len = input_requests[0] + test_prompt, test_prompt_len, test_output_len, test_mm_content = ( + input_requests[0]) + if backend != "openai-chat" and test_mm_content is not None: + # multi-modal benchmark is only available on OpenAI Chat backend. + raise ValueError( + "Multi-modal content is only supported on 'openai-chat' backend.") test_input = RequestFuncInput( model=model_id, prompt=test_prompt, @@ -353,6 +420,7 @@ async def benchmark( logprobs=logprobs, best_of=best_of, use_beam_search=use_beam_search, + multi_modal_content=test_mm_content, ) test_output = await request_func(request_func_input=test_input) if not test_output.success: @@ -373,6 +441,7 @@ async def benchmark( logprobs=logprobs, best_of=best_of, use_beam_search=use_beam_search, + multi_modal_content=test_mm_content, ) profile_output = await request_func(request_func_input=profile_input) if profile_output.success: @@ -385,7 +454,7 @@ async def benchmark( benchmark_start_time = time.perf_counter() tasks: List[asyncio.Task] = [] async for request in get_request(input_requests, request_rate): - prompt, prompt_len, output_len = request + prompt, prompt_len, output_len, mm_content = request request_func_input = RequestFuncInput( model=model_id, prompt=prompt, @@ -395,6 +464,7 @@ async def benchmark( logprobs=logprobs, best_of=best_of, use_beam_search=use_beam_search, + multi_modal_content=mm_content, ) tasks.append( asyncio.create_task( @@ -575,6 +645,16 @@ def main(args: argparse.Namespace): for prompt, prompt_formatted, prompt_len, output_len in input_requests] + elif args.dataset_name == "hf": + input_requests = sample_hf_requests( + dataset_path=args.dataset_path, + dataset_subset=args.hf_subset, + dataset_split=args.hf_split, + num_requests=args.num_prompts, + tokenizer=tokenizer, + fixed_output_len=args.hf_output_len, + ) + elif args.dataset_name == "random": input_requests = sample_random_requests( prefix_len=args.random_prefix_len, @@ -685,13 +765,14 @@ def main(args: argparse.Namespace): "--dataset-name", type=str, default="sharegpt", - choices=["sharegpt", "sonnet", "random"], + choices=["sharegpt", "sonnet", "random", "hf"], help="Name of the dataset to benchmark on.", ) parser.add_argument("--dataset-path", type=str, default=None, - help="Path to the dataset.") + help="Path to the sharegpt/sonnet dataset. " + "Or the huggingface dataset ID if using HF dataset.") parser.add_argument( "--model", type=str, @@ -718,26 +799,6 @@ def main(args: argparse.Namespace): default=1000, help="Number of prompts to process.", ) - parser.add_argument( - "--sharegpt-output-len", - type=int, - default=None, - help="Output length for each request. Overrides the output length " - "from the ShareGPT dataset.") - parser.add_argument( - "--sonnet-input-len", - type=int, - default=550, - help= - "Number of input tokens per request, used only for sonnet dataset.", - ) - parser.add_argument( - "--sonnet-output-len", - type=int, - default=150, - help= - "Number of output tokens per request, used only for sonnet dataset.", - ) parser.add_argument( "--logprobs", type=int, @@ -748,42 +809,6 @@ def main(args: argparse.Namespace): "logprob is returned for each token; or (2) if beam search " "is enabled 1 logprob per token is computed"), ) - parser.add_argument( - "--sonnet-prefix-len", - type=int, - default=200, - help= - "Number of prefix tokens per request, used only for sonnet dataset.", - ) - parser.add_argument( - "--random-input-len", - type=int, - default=1024, - help= - "Number of input tokens per request, used only for random sampling.", - ) - parser.add_argument( - "--random-output-len", - type=int, - default=128, - help= - "Number of output tokens per request, used only for random sampling.", - ) - parser.add_argument( - "--random-range-ratio", - type=float, - default=1.0, - help="Range of sampled ratio of input/output length, " - "used only for random sampling.", - ) - parser.add_argument( - "--random-prefix-len", - type=int, - default=0, - help="Number of fixed prefix tokens before random " - " context. The length range of context in a random " - " request is [random-prefix-len, " - " random-prefix-len + random-prefix-len * random-range-ratio).") parser.add_argument( "--request-rate", type=float, @@ -857,5 +882,85 @@ def main(args: argparse.Namespace): "Use \"--percentile-metrics\" to select metrics.", ) + # group for dataset specific arguments + sonnet_group = parser.add_argument_group("sonnet dataset options") + sonnet_group.add_argument( + "--sonnet-input-len", + type=int, + default=550, + help= + "Number of input tokens per request, used only for sonnet dataset.", + ) + sonnet_group.add_argument( + "--sonnet-output-len", + type=int, + default=150, + help= + "Number of output tokens per request, used only for sonnet dataset.", + ) + sonnet_group.add_argument( + "--sonnet-prefix-len", + type=int, + default=200, + help= + "Number of prefix tokens per request, used only for sonnet dataset.", + ) + + sharegpt_group = parser.add_argument_group("sharegpt dataset options") + sharegpt_group.add_argument( + "--sharegpt-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output length " + "from the ShareGPT dataset.") + + random_group = parser.add_argument_group("random dataset options") + random_group.add_argument( + "--random-input-len", + type=int, + default=1024, + help= + "Number of input tokens per request, used only for random sampling.", + ) + random_group.add_argument( + "--random-output-len", + type=int, + default=128, + help= + "Number of output tokens per request, used only for random sampling.", + ) + random_group.add_argument( + "--random-range-ratio", + type=float, + default=1.0, + help="Range of sampled ratio of input/output length, " + "used only for random sampling.", + ) + random_group.add_argument( + "--random-prefix-len", + type=int, + default=0, + help="Number of fixed prefix tokens before random " + " context. The length range of context in a random " + " request is [random-prefix-len, " + " random-prefix-len + random-prefix-len * random-range-ratio).") + + hf_group = parser.add_argument_group("hf dataset options") + hf_group.add_argument("--hf-subset", + type=str, + default=None, + help="Subset of the HF dataset.") + hf_group.add_argument("--hf-split", + type=str, + default=None, + help="Split of the HF dataset.") + hf_group.add_argument( + "--hf-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output lengths " + "from the sampled HF dataset.", + ) + args = parser.parse_args() main(args) From 1009e93c5d634c724eeff3d4e453369337f502d4 Mon Sep 17 00:00:00 2001 From: sroy745 <142070531+sroy745@users.noreply.github.com> Date: Tue, 17 Sep 2024 07:35:01 -0700 Subject: [PATCH 64/98] [Encoder decoder] Add cuda graph support during decoding for encoder-decoder models (#7631) --- .buildkite/test-pipeline.yaml | 7 + tests/encoder_decoder/__init__.py | 0 tests/encoder_decoder/test_e2e_correctness.py | 98 ++++++++++ .../test_encoder_decoder_model_runner.py | 182 +++++++++++++++--- vllm/attention/backends/abstract.py | 17 +- vllm/attention/backends/flashinfer.py | 12 +- vllm/attention/backends/utils.py | 113 ++++++++++- vllm/config.py | 41 +--- vllm/engine/arg_utils.py | 5 +- vllm/entrypoints/llm.py | 8 +- vllm/model_executor/models/bart.py | 6 +- vllm/utils.py | 5 - vllm/worker/enc_dec_model_runner.py | 43 ++++- vllm/worker/model_runner.py | 97 ++++++++-- vllm/worker/utils.py | 4 - 15 files changed, 526 insertions(+), 112 deletions(-) create mode 100644 tests/encoder_decoder/__init__.py create mode 100644 tests/encoder_decoder/test_e2e_correctness.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 9483adcc5d587..63ce9bff7d4c1 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -252,6 +252,13 @@ steps: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - bash ./run-tests.sh -c configs/models-small.txt -t 1 +- label: Encoder Decoder tests # 5min + source_file_dependencies: + - vllm/ + - tests/encoder_decoder + commands: + - pytest -v -s encoder_decoder + - label: OpenAI-Compatible Tool Use # 20 min fast_check: false mirror_hardwares: [ amd ] diff --git a/tests/encoder_decoder/__init__.py b/tests/encoder_decoder/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/encoder_decoder/test_e2e_correctness.py b/tests/encoder_decoder/test_e2e_correctness.py new file mode 100644 index 0000000000000..9324a737a779c --- /dev/null +++ b/tests/encoder_decoder/test_e2e_correctness.py @@ -0,0 +1,98 @@ +"""E2E tests to verify the correctness of the encoder-decoder framework + +Run `pytest tests/encoder_decoder/test_e2e_correctness.py`. +""" +from typing import List, Optional, Tuple + +import pytest +from transformers import AutoModelForSeq2SeqLM + +from vllm.sequence import SampleLogprobs +from vllm.utils import is_cpu + +from ..conftest import DecoderPromptType +from ..models.utils import check_logprobs_close + + +def vllm_to_hf_output( + vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], + decoder_prompt_type: DecoderPromptType, +): + """Sanitize vllm output to be comparable with hf output.""" + output_ids, output_str, out_logprobs = vllm_output + + hf_output_str = output_str + "" + if decoder_prompt_type == DecoderPromptType.NONE: + hf_output_str = "" + hf_output_str + + return output_ids, hf_output_str, out_logprobs + + +@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) +@pytest.mark.parametrize("enforce_eager", [True, False]) +@pytest.mark.skipif( + is_cpu(), + reason="CPU backend is not currently supported with encoder/decoder models" +) +def test_encoder_decoder_e2e( + hf_runner, + vllm_runner, + example_encoder_decoder_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, + decoder_prompt_type: DecoderPromptType, + enforce_eager: bool, +) -> None: + ''' + End-to-End (E2E) test for the encoder-decoder framework. + This test evaluates the encoder-decoder functionality using the BART + model. We compare the outputs of the Hugging Face and vLLM + implementations to ensure that both implementations produce consistent + and correct results. + ''' + test_case_prompts = example_encoder_decoder_prompts[decoder_prompt_type] + + # Configuration settings for HF baseline + hf_kwargs = { + "top_k": None, + "num_beams": 1, + "repetition_penalty": 1.0, + "top_p": 1.0, + "length_penalty": 1.0, + "early_stopping": False, + "no_repeat_ngram_size": None, + "min_length": 0 + } + + with hf_runner(model, dtype=dtype, + auto_cls=AutoModelForSeq2SeqLM) as hf_model: + hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit( + test_case_prompts, + max_tokens, + num_logprobs, + **hf_kwargs, + )) + with vllm_runner(model, dtype=dtype, + enforce_eager=enforce_eager) as vllm_model: + vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( + test_case_prompts, max_tokens, num_logprobs) + + hf_skip_tokens = (1 + if decoder_prompt_type == DecoderPromptType.NONE else 0) + + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=[ + vllm_to_hf_output(vllm_output, decoder_prompt_type) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + num_outputs_0_skip_tokens=hf_skip_tokens, + ) diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index 32bff22f66a8b..a00d46ddeb007 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -1,3 +1,4 @@ +import itertools from array import array from typing import List @@ -7,13 +8,9 @@ from vllm.engine.arg_utils import EngineArgs from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams, SequenceData, SequenceGroupMetadata) -from vllm.utils import is_cpu +from vllm.utils import is_cpu, make_tensor_with_pad from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner - -# CUDA graph scenarios to test -# -# Currently CUDA graph is not supported -ENFORCE_EAGER = [True] +from vllm.worker.model_runner import _get_graph_batch_size BATCH_SIZES = [1, 4, 16, 64, 256] @@ -40,8 +37,7 @@ def _create_model_runner(model: str, *args, reason="CPU backend is currently " "unsupported for encoder/ " "decoder models") -@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER) -def test_empty_seq_group(enforce_eager, ): +def test_empty_seq_group(): """Verify prepare prompt and decode returns empty output for empty seq group list""" @@ -52,7 +48,7 @@ def test_empty_seq_group(enforce_eager, ): max_num_batched_tokens=100000, max_num_seqs=100000, enable_chunked_prefill=False, - enforce_eager=enforce_eager, + enforce_eager=True, ) seq_group_metadata_list: List[SequenceGroupMetadata] = [] model_input = model_runner._prepare_model_input_tensors( @@ -85,11 +81,7 @@ def test_empty_seq_group(enforce_eager, ): "unsupported for encoder/ " "decoder models") @pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER) -def test_prepare_prompt( - batch_size, - enforce_eager, -): +def test_prepare_prompt(batch_size): ''' Test the ability of the encoder/decoder model runner subclass to produce prefill-phase model inputs & attention metadata. @@ -115,7 +107,7 @@ def test_prepare_prompt( max_num_batched_tokens=100000, max_num_seqs=100000, enable_chunked_prefill=False, - enforce_eager=enforce_eager, + enforce_eager=True, ) seq_lens: List[int] = [] @@ -281,11 +273,7 @@ def test_prepare_prompt( "unsupported for encoder/ " "decoder models") @pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER) -def test_prepare_decode( - batch_size, - enforce_eager, -): +def test_prepare_decode(batch_size): ''' Test the ability of the encoder/decoder model runner subclass to produce decode-phase model inputs & attention metadata. @@ -311,7 +299,7 @@ def test_prepare_decode( max_num_batched_tokens=100000, max_num_seqs=100000, enable_chunked_prefill=False, - enforce_eager=enforce_eager, + enforce_eager=True, ) seq_lens: List[int] = [] @@ -428,7 +416,8 @@ def test_prepare_decode( expected, ) - # Cuda graph should is currently not supported for encoder/decoer. + # Model runner's CUDAGraph setting should be propagated to attention + # metadata. assert attn_metadata.use_cuda_graph is False # Verify the lengths of input tokens & positions @@ -484,3 +473,152 @@ def test_prepare_decode( dtype=actual.dtype, ) assert torch.equal(actual, expected) + + +@pytest.mark.parametrize("batch_size", list(range(1, 257))) +def test_prepare_decode_cuda_graph(batch_size): + """ + Tests that for encoder-decoder models with CUDA Graph capture and replay + enabled, the tensors used during the decode phase are correctly padded + for varying input batch sizes. + """ + model_runner = _create_model_runner( + "facebook/bart-base", + seed=0, + dtype="float16", + max_num_batched_tokens=100000, + max_num_seqs=100000, + enable_chunked_prefill=False, + enforce_eager=False, + ) + + seq_lens: List[int] = [] + encoder_seq_lens: List[int] = [] + seq_group_metadata_list: List[SequenceGroupMetadata] = [] + block_tables = {0: [1]} + cross_block_table = [2] + for i in range(batch_size): + # make sure all tokens fit into one block + seq_len = i % (model_runner.block_size - 1) + 1 + seq_lens.append(seq_len) + seq_data = SequenceData( + array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len)))) + encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 + encoder_seq_lens.append(encoder_seq_len) + encoder_seq_data = SequenceData( + array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len)))) + seq_group_metadata = SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=False, + seq_data={0: seq_data}, + sampling_params=SamplingParams(temperature=0), + block_tables=block_tables, + encoder_seq_data=encoder_seq_data, + cross_block_table=cross_block_table, + ) + assert seq_group_metadata.token_chunk_size == 1 + seq_group_metadata_list.append(seq_group_metadata) + + model_input = model_runner.prepare_model_input(seq_group_metadata_list) + input_tokens = model_input.input_tokens + input_positions = model_input.input_positions + attn_metadata = model_input.attn_metadata + return_seq_lens = model_input.seq_lens + slot_mapping = attn_metadata.slot_mapping + encoder_input_tokens = model_input.encoder_input_tokens + encoder_input_positions = model_input.encoder_input_positions + cross_slot_mapping = attn_metadata.cross_slot_mapping + + # With CUDA Graph capture and replay enabled, the decoder and encoder + # input sequences will be padded. Create the expected padded tensors + # accordingly. + graph_batch_size = _get_graph_batch_size(batch_size) + cuda_graph_pad_size = graph_batch_size - batch_size + padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size)) + padded_encoder_seq_lens = encoder_seq_lens + list( + itertools.repeat(1, cuda_graph_pad_size)) + + assert return_seq_lens == padded_seq_lens + assert len(slot_mapping) == len(input_tokens) + assert len(cross_slot_mapping) == len(encoder_input_tokens) + + # Verify attention metadata + device = model_runner.device + assert attn_metadata.num_prefills == 0 + assert attn_metadata.num_decode_tokens > 0 + assert torch.equal( + attn_metadata.seq_lens_tensor, + torch.tensor(padded_seq_lens, device=device, dtype=torch.int)) + assert attn_metadata.seq_lens == padded_seq_lens + assert attn_metadata.max_prefill_seq_len == 0 + assert attn_metadata.max_decode_seq_len == max(seq_lens) + # - Encoder attention metadata + assert attn_metadata.encoder_seq_lens == padded_encoder_seq_lens + assert torch.equal( + attn_metadata.encoder_seq_lens_tensor, + torch.tensor(padded_encoder_seq_lens, device=device, dtype=torch.int)) + assert attn_metadata.max_encoder_seq_len == max(padded_encoder_seq_lens) + assert attn_metadata.num_encoder_tokens == sum(padded_encoder_seq_lens) + + # Verify block tables are correct for prompts + # - Decoder self-attention. Pad the block tables as expected. + expected = [block_tables[0] for _ in range(batch_size)] + expected.extend([[] for _ in range(cuda_graph_pad_size)]) + expected = make_tensor_with_pad( + expected, + max_len=64, + pad=0, + dtype=torch.int32, + device=model_runner.device, + ) + assert torch.equal( + attn_metadata.block_tables, + expected, + ) + # - Encoder/decoder cross-attention. Pad the cross-attention block tables + # as expected. + expected = [cross_block_table for _ in range(len(seq_group_metadata_list))] + expected.extend([[] for _ in range(cuda_graph_pad_size)]) + expected = make_tensor_with_pad( + expected, + max_len=64, + pad=0, + dtype=torch.int32, + device=model_runner.device, + ) + assert torch.equal( + attn_metadata.cross_block_tables, + expected, + ) + + # Model runner's CUDAGraph setting should be propagated to attention + # metadata. + assert attn_metadata.use_cuda_graph is True + + # Verify the lengths of input tokens & positions + # - Decoder + assert len(input_tokens) == len(padded_seq_lens) + assert len(input_positions) == len(padded_seq_lens) + # -- An indirect check that model_input.input_tokens + # and model_input.input_positions are correct - + # by design of the test, the input tokens are + # equal to the input position values, so if + # the model_input data structure has the correct + # values then these two should be equal + assert torch.equal( + input_tokens, + input_positions, + ) + # - Encoder + assert len(encoder_input_tokens) == 0 + assert len(encoder_input_tokens) == 0 + # -- An indirect check that model_input.encoder_input_tokens + # and model_input.encoder_input_positions are correct - + # by design of the test, the input tokens are + # equal to the input position values, so if + # the model_input data structure has the correct + # values then these two should be equal + assert torch.equal( + encoder_input_tokens, + encoder_input_positions, + ) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index adc8390e6f9ec..2bc36ff18a96b 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -156,18 +156,27 @@ def graph_clone(self, batch_size: int) -> "AttentionState[T]": ... @abstractmethod - def graph_capture_get_metadata_for_batch(self, batch_size: int) -> T: + def graph_capture_get_metadata_for_batch( + self, + batch_size: int, + is_encoder_decoder_model: bool = False) -> T: """Get attention metadata for CUDA graph capture of batch_size.""" ... @abstractmethod - def get_graph_input_buffers(self, attn_metadata: T) -> Dict[str, Any]: + def get_graph_input_buffers( + self, + attn_metadata: T, + is_encoder_decoder_model: bool = False) -> Dict[str, Any]: """Get attention-specific input buffers for CUDA graph capture.""" ... @abstractmethod - def prepare_graph_input_buffers(self, input_buffers: Dict[str, Any], - attn_metadata: T) -> None: + def prepare_graph_input_buffers( + self, + input_buffers: Dict[str, Any], + attn_metadata: T, + is_encoder_decoder_model: bool = False) -> None: """In-place modify input buffers dict for CUDA graph replay.""" ... diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 4054d337316fe..3a602fbfbbc04 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -172,7 +172,8 @@ def graph_clone(self, batch_size: int): state._prefill_wrapper = self._get_prefill_wrapper() return state - def graph_capture_get_metadata_for_batch(self, batch_size: int): + def graph_capture_get_metadata_for_batch( + self, batch_size: int, is_encoder_decoder_model: bool = False): assert self._is_graph_capturing _indptr_buffer = self._graph_indptr_buffer[:batch_size + 1] _last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size] @@ -232,12 +233,17 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int): attn_metadata.begin_forward() return attn_metadata - def get_graph_input_buffers(self, attn_metadata): + def get_graph_input_buffers(self, + attn_metadata, + is_encoder_decoder_model: bool = False): return { "slot_mapping": attn_metadata.slot_mapping, } - def prepare_graph_input_buffers(self, input_buffers, attn_metadata): + def prepare_graph_input_buffers(self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False): return def begin_forward(self, model_input): diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 0375d3488eb15..089008967a244 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -304,7 +304,8 @@ def graph_clone(self, batch_size: int) -> "CommonAttentionState": assert self._is_graph_capturing return self.__class__(self.runner) - def graph_capture_get_metadata_for_batch(self, batch_size: int): + def graph_capture_get_metadata_for_batch( + self, batch_size: int, is_encoder_decoder_model: bool = False): assert self._is_graph_capturing attn_metadata = self.runner.attn_backend.make_metadata( num_prefills=0, @@ -322,21 +323,121 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int): block_tables=self._graph_block_tables[:batch_size], use_cuda_graph=True, ) + if is_encoder_decoder_model: + # The encoder decoder model works only with XFormers backend. + # Assert the same. + assert self.runner.attn_backend.get_name() == "xformers", \ + f"Expected attn_backend name to be 'xformers', but "\ + f" got '{self.runner.attn_backend.get_name()}'" + self._update_captured_metadata_for_enc_dec_model( + batch_size=batch_size, attn_metadata=attn_metadata) + return attn_metadata - def get_graph_input_buffers(self, attn_metadata) -> Dict[str, Any]: - return { + def get_graph_input_buffers( + self, + attn_metadata, + is_encoder_decoder_model: bool = False) -> Dict[str, Any]: + input_buffers = { "slot_mapping": attn_metadata.slot_mapping, "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, "block_tables": attn_metadata.decode_metadata.block_tables, } - - def prepare_graph_input_buffers(self, input_buffers, - attn_metadata) -> None: + if is_encoder_decoder_model: + # The encoder decoder model works only with XFormers backend. + # Assert the same. + assert self.runner.attn_backend.get_name() == "xformers", \ + f"Expected attn_backend name to be 'xformers', but "\ + f" got '{self.runner.attn_backend.get_name()}'" + self._add_additonal_input_buffers_for_enc_dec_model( + attn_metadata=attn_metadata, input_buffers=input_buffers) + return input_buffers + + def prepare_graph_input_buffers( + self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False) -> None: input_buffers["seq_lens_tensor"].copy_( attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.block_tables, non_blocking=True) + if is_encoder_decoder_model: + # The encoder decoder model works only with XFormers backend. + # Assert the same. + assert self.runner.attn_backend.get_name() == "xformers", \ + f"Expected attn_backend name to be 'xformers', but "\ + f" got '{self.runner.attn_backend.get_name()}'" + self._prepare_input_buffers_for_enc_dec_model( + attn_metadata, input_buffers) def begin_forward(self, model_input) -> None: return + + def _update_captured_metadata_for_enc_dec_model(self, batch_size: int, + attn_metadata): + """ + Updates the attention metadata parameters for CUDA graph capture in an + encoder-decoder model. + + This method modifies attention-related tensors and metadata required + for CUDA graph capture in encoder-decoder models. Specifically, it + updates the cross-attention and encoder sequence tensors in the + AttentionMetadata object. + """ + # During decode phase the cross_slot_mapping will be empty. Hence set + # an empty tensor for CUDA Graph capture. + attn_metadata.cross_slot_mapping = torch.tensor( + [], dtype=torch.int).cuda() + attn_metadata.cross_block_tables = torch.full( + (batch_size, self.runner.get_max_block_per_batch()), + 1, + dtype=torch.int).cuda() + attn_metadata.encoder_seq_lens = torch.full((batch_size, ), + 1, + dtype=torch.int).cuda() + attn_metadata.encoder_seq_lens_tensor = torch.full( + (batch_size, ), 1, dtype=torch.int).cuda() + attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture + + def _add_additonal_input_buffers_for_enc_dec_model( + self, attn_metadata, input_buffers: Dict[str, Any]): + """ + Saves additional input buffers specific to the encoder-decoder model + from the attention metadata. + + This method extracts and stores encoder-decoder related input buffers + from the `attn_metadata` into the `input_buffers` dictionary. The + buffers include encoder sequence lengths, cross-slot mappings, and + cross-block tables, which are essential for the encoder-decoder model + during CUDA graph replay. + """ + input_buffers["encoder_seq_lens_tensor"] = ( + attn_metadata.decode_metadata.encoder_seq_lens_tensor) + input_buffers["cross_slot_mapping"] = ( + attn_metadata.decode_metadata.cross_slot_mapping) + input_buffers["cross_block_tables"] = ( + attn_metadata.decode_metadata.cross_block_tables) + + def _prepare_input_buffers_for_enc_dec_model(self, attn_metadata, + input_buffers: Dict[str, + Any]): + """ + Populates input buffers with data from the encoder-decoder model's + attention metadata. + + This method fills the input buffers with encoder-decoder specific + tensors. It copies data from the `attn_metadata` and keyword arguments + (`kwargs`) into corresponding buffers in the `input_buffers` dictionary. + The copied data includes attention-related metadata as well as input + IDs and positional information for the encoder. + """ + input_buffers["encoder_seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.encoder_seq_lens_tensor, + non_blocking=True) + input_buffers["cross_slot_mapping"].copy_( + attn_metadata.decode_metadata.cross_slot_mapping, + non_blocking=True) + input_buffers["cross_block_tables"].copy_( + attn_metadata.decode_metadata.cross_block_tables, + non_blocking=True) diff --git a/vllm/config.py b/vllm/config.py index 89cffc8b306b2..a0991597d0673 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -16,9 +16,8 @@ from vllm.transformers_utils.config import (ConfigFormat, get_config, get_hf_image_processor_config, get_hf_text_config) -from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes, - cuda_device_count_stateless, get_cpu_memory, is_cpu, - is_hip, is_neuron, is_openvino, is_xpu, +from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, + is_cpu, is_hip, is_neuron, is_openvino, is_xpu, print_warning_once) if TYPE_CHECKING: @@ -96,15 +95,15 @@ class ModelConfig: enforce_eager: Whether to enforce eager execution. If True, we will disable CUDA graph and always execute the model in eager mode. If False, we will use CUDA graph and eager execution in hybrid. - If None, the user did not specify, so default to False - - except for encoder/decoder models, which currently require - eager mode. + If None, the user did not specify, so default to False. max_context_len_to_capture: Maximum context len covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode (DEPRECATED. Use max_seq_len_to_capture instead). max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. When a sequence has context length larger than this, we fall back - to eager mode + to eager mode. Additionally for encoder-decoder models, if the + sequence length of the encoder input is larger than this, we fall + back to the eager mode. disable_sliding_window: Whether to disable sliding window. If True, we will disable the sliding window functionality of the model. If the model does not support sliding window, this argument is @@ -186,32 +185,8 @@ def __init__(self, self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.use_async_output_proc = use_async_output_proc - # Choose a default enforce_eager value if the user did not specify - # a value (enforce_eager is None) - if getattr(self.hf_config, 'is_encoder_decoder', False): - if self.enforce_eager is None: - # *Only for encoder/decoder models* and - # *only if enforce_eager is unset*, override - # to enforce_eager=True - # - # Add a logger message since it is *somewhat* non-intuitive that - # enforce_eager is True when the user has not specified its - # value. - logger.info("Forcing enforce_eager == True because " - "enforce_eager setting was unspecified and " - "CUDAGraph is not supported with encoder/ " - "decoder models.") - self.enforce_eager = True - - if not self.enforce_eager: - # Eager mode explicitly disabled by user for an encoder/ - # decoder model; however CUDAGRAPH + encoder/decoder is - # not currently supported - raise ValueError(STR_NOT_IMPL_ENC_DEC_CUDAGRAPH) - elif self.enforce_eager is None: - # *Only for decoder-only models*, enforce_eager - # defaults to False if unset. This is intuitive - # so no logging message needed. + # Set enforce_eager to False if the value is unset. + if self.enforce_eager is None: self.enforce_eager = False if (not self.disable_sliding_window diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 35013eedea9c6..4139eca9c1832 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -472,7 +472,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.max_seq_len_to_capture, help='Maximum sequence length covered by CUDA ' 'graphs. When a sequence has context length ' - 'larger than this, we fall back to eager mode.') + 'larger than this, we fall back to eager mode. ' + 'Additionally for encoder-decoder models, if the ' + 'sequence length of the encoder input is larger ' + 'than this, we fall back to the eager mode.') parser.add_argument('--disable-custom-all-reduce', action='store_true', default=EngineArgs.disable_custom_all_reduce, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index c01bffeb4289d..a26b721093521 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -88,7 +88,9 @@ class LLM: to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead). max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. When a sequence has context length larger than this, we fall back - to eager mode. + to eager mode. Additionally for encoder-decoder models, if the + sequence length of the encoder input is larger than this, we fall + back to the eager mode. disable_custom_all_reduce: See ParallelConfig **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See :ref:`engine_args`) @@ -137,9 +139,7 @@ def __init__( LLM constructor. Note: if enforce_eager is unset (enforce_eager is None) - it defaults to False for decoder-only models and True - for encoder/decoder models, since encoder/decoder models - do not currently support CUDAGraph. + it defaults to False. ''' if "disable_log_stats" not in kwargs: diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 9b4c4be7fcb09..cbdacf779b089 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -848,11 +848,13 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - encoder_input_ids: torch.Tensor, - encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + *, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, + **kwargs, ) -> torch.Tensor: r""" Args: diff --git a/vllm/utils.py b/vllm/utils.py index 1cbd9d55c68b3..29b8a8c2907eb 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -71,10 +71,6 @@ "currently supported with encoder/" "decoder models.") -STR_NOT_IMPL_ENC_DEC_CUDAGRAPH = ("CUDAGraph is not " - "currently supported with encoder/" - "decoder models.") - STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers is the only backend " "currently supported with encoder/" "decoder models.") @@ -98,7 +94,6 @@ "STR_NOT_IMPL_ENC_DEC_PP": STR_NOT_IMPL_ENC_DEC_PP, "STR_NOT_IMPL_ENC_DEC_MM": STR_NOT_IMPL_ENC_DEC_MM, "STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC, - "STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH": STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, "STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND, "STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER, "STR_NOT_IMPL_ENC_DEC_CPU": STR_NOT_IMPL_ENC_DEC_CPU diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index d6189d82d51d9..09dab0135f390 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -1,4 +1,5 @@ import dataclasses +import itertools from typing import Any, Dict, List, Optional, Tuple, Type, cast import torch @@ -24,7 +25,8 @@ from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPUBuilder, - ModelInputForGPUWithSamplingMetadata) + ModelInputForGPUWithSamplingMetadata, + _get_graph_batch_size) from vllm.worker.model_runner_base import ( _add_attn_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict) @@ -178,7 +180,15 @@ def execute_model( raise ValueError("num_steps > 1 is not supported in " "EncoderDecoderModelRunner") - model_executable = self.model + if (model_input.attn_metadata is not None + and model_input.attn_metadata.prefill_metadata is None + and model_input.attn_metadata.decode_metadata.use_cuda_graph): + assert model_input.input_tokens is not None + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = self.graph_runners[ + model_input.virtual_engine][graph_batch_size] + else: + model_executable = self.model seqlen_agnostic_kwargs = { "finished_requests_ids": model_input.finished_requests_ids, @@ -200,6 +210,9 @@ def execute_model( if not self.is_driver_worker: return [] + if model_input.async_callback is not None: + model_input.async_callback() + # Sample the next token. output: SamplerOutput = self.model.sample( logits=logits, @@ -231,14 +244,12 @@ def prepare_model_input( """ model_input = self._prepare_model_input_tensors( seq_group_metadata_list, finished_requests_ids) - ( attn_metadata, encoder_input_tokens_tensor, encoder_input_positions_tensor, ) = (self._prepare_encoder_model_input_tensors(seq_group_metadata_list, model_input)) - # Inject attn_metadata encoder/cross-attention fields & # encoder input tokens/positions into model_input. # Frozen dataclass fields cannot be modified, so use @@ -437,11 +448,29 @@ def _prepare_encoder_model_input_tensors( cross_block_tables.append([] if ( cross_block_table is None) else cross_block_table) - # Convert cross-attention block tables to encoder input tensor + if (model_input.attn_metadata is not None + and model_input.attn_metadata.use_cuda_graph): + # We will be using CUDA graph replay for this decode. + max_len_of_block_table = self.get_max_block_per_batch() + batch_size = len(encoder_seq_lens) + graph_batch_size = _get_graph_batch_size(batch_size) + assert graph_batch_size >= batch_size + cuda_graph_pad_size = graph_batch_size - batch_size + # extend the cross_block_tables and encoder_seq_lens to match + # the graph_batch_size. + cross_block_tables.extend([[] + for _ in range(cuda_graph_pad_size) + ]) + encoder_seq_lens.extend( + itertools.repeat(1, cuda_graph_pad_size)) + + else: + max_len_of_block_table = max( + len(block_table) for block_table in cross_block_tables) + cross_block_tables = make_tensor_with_pad( cross_block_tables, - max_len=max( - len(block_table) for block_table in cross_block_tables), + max_len=max_len_of_block_table, pad=0, dtype=torch.int32, device=self.device, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9df9ae783b9fa..e8c472df8b5fc 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -243,6 +243,7 @@ def __init__( prefix_cache_hit: bool = False, reinit: bool = False, reinit_use_defaults: bool = False, + encoder_seq_len: int = 0, ): if reinit: assert len(self.seq_ids) == len(seq_ids) # type: ignore @@ -256,6 +257,7 @@ def __init__( self.block_tables = block_tables self.computed_block_nums = computed_block_nums self.n_seqs = n_seqs + self.encoder_seq_len = encoder_seq_len if reinit: if len(self.seq_ids) == 1 and reinit_use_defaults: @@ -702,6 +704,11 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): assert n_seqs == 1 self.decode_only = False + encoder_seq_len = 0 + + if self.runner.model_config.is_encoder_decoder_model: + encoder_seq_len = seq_group_metadata.encoder_seq_data.get_len() + inter_data = self.init_cached_inter_data( request_id=seq_group_metadata.request_id, seq_ids=seq_ids, @@ -709,7 +716,8 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): block_tables=seq_group_metadata.block_tables, computed_block_nums=seq_group_metadata.computed_block_nums, reinit=True, - reinit_use_defaults=True) + reinit_use_defaults=True, + encoder_seq_len=encoder_seq_len) self.inter_data_list.append(inter_data) @@ -719,11 +727,15 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): for per_seq_group_fn in self.per_seq_group_compute_fns: per_seq_group_fn(inter_data, seq_group_metadata) - def _use_captured_graph(self, batch_size: int, - max_decode_seq_len: int) -> bool: + def _use_captured_graph(self, + batch_size: int, + max_decode_seq_len: int, + max_encoder_seq_len: int = 0) -> bool: return (self.decode_only and not self.runner.model_config.enforce_eager - and batch_size <= self.runner.max_batchsize_to_capture - and max_decode_seq_len <= self.runner.max_seq_len_to_capture) + and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] + and max_decode_seq_len <= self.runner.max_seq_len_to_capture + and max_encoder_seq_len <= self.runner.max_seq_len_to_capture + and batch_size <= self.runner.max_batchsize_to_capture) def build(self) -> ModelInputForGPU: """Finalize the builder intermediate data and @@ -763,15 +775,18 @@ def build(self) -> ModelInputForGPU: input_positions.extend(cur_input_positions) seq_lens = [] + query_lens = [] max_decode_seq_len = 0 + max_encoder_seq_len = 0 for inter_data in self.inter_data_list: seq_lens.extend(inter_data.seq_lens) + query_lens.extend(inter_data.query_lens) if not inter_data.is_prompt: max_decode_seq_len = max(max_decode_seq_len, max(inter_data.seq_lens)) - query_lens = [] - for inter_data in self.inter_data_list: - query_lens.extend(inter_data.query_lens) + if self.runner.model_config.is_encoder_decoder_model: + max_encoder_seq_len = max(max_encoder_seq_len, + inter_data.encoder_seq_len) # Mapping from request IDs to sequence IDs. Used for Jamba models # that manages the cache by itself. @@ -781,8 +796,10 @@ def build(self) -> ModelInputForGPU: } batch_size = len(input_tokens) - use_captured_graph = self._use_captured_graph(batch_size, - max_decode_seq_len) + use_captured_graph = self._use_captured_graph( + batch_size, + max_decode_seq_len, + max_encoder_seq_len=max_encoder_seq_len) # If cuda graph can be used, pad tensors accordingly. # See `capture_model` API for more details. @@ -1364,7 +1381,9 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: for batch_size in reversed(batch_size_capture_list): attn_metadata = ( self.attn_state.graph_capture_get_metadata_for_batch( - batch_size)) + batch_size, + is_encoder_decoder_model=self.model_config. + is_encoder_decoder_model)) if self.lora_config: lora_mapping = LoRAMapping( @@ -1380,10 +1399,10 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: ) self.set_active_prompt_adapters( set(), prompt_adapter_mapping) - graph_runner = CUDAGraphRunner( self.model, self.attn_backend.get_name(), - self.attn_state.graph_clone(batch_size)) + self.attn_state.graph_clone(batch_size), + self.model_config.is_encoder_decoder_model) capture_inputs = { "input_ids": @@ -1420,6 +1439,12 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: self.model.get_seqlen_agnostic_capture_inputs( batch_size) }) + if self.model_config.is_encoder_decoder_model: + # add the additional inputs to capture for + # encoder-decoder models. + self._update_inputs_to_capture_for_enc_dec_model( + capture_inputs) + graph_runner.capture(**capture_inputs) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[virtual_engine][batch_size] = ( @@ -1430,6 +1455,24 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: # This usually takes < 10 seconds. logger.info("Graph capturing finished in %.0f secs.", elapsed_time) + def _update_inputs_to_capture_for_enc_dec_model(self, + capture_inputs: Dict[str, + Any]): + """ + Updates the set of input tensors needed for CUDA graph capture in an + encoder-decoder model. + + This method modifies the provided `capture_inputs` dictionary by + adding tensors specific to encoder-decoder specific models that + need to be captured for CUDA Graph replay. + """ + # During the decode phase encoder_input_ids and encoder_positions are + # unset. Do the same thing for graph capture. + capture_inputs["encoder_input_ids"] = torch.tensor( + [], dtype=torch.long).cuda() + capture_inputs["encoder_positions"] = torch.tensor( + [], dtype=torch.long).cuda() + @property def vocab_size(self) -> int: return self.model_config.get_vocab_size() @@ -1629,7 +1672,7 @@ def execute_model( class CUDAGraphRunner: def __init__(self, model: nn.Module, backend_name: str, - attn_state: AttentionState): + attn_state: AttentionState, is_encoder_decoder_model: bool): self.model = model self.backend_name = backend_name self.attn_state = attn_state @@ -1638,6 +1681,7 @@ def __init__(self, model: nn.Module, backend_name: str, self.output_buffers: Dict[str, torch.Tensor] = {} self._graph: Optional[torch.cuda.CUDAGraph] = None + self._is_encoder_decoder_model = is_encoder_decoder_model @property def graph(self): @@ -1671,8 +1715,9 @@ def capture( intermediate_tensors=intermediate_inputs, **kwargs, ) + # Wait for the warm up operations to finish before proceeding with + # Graph Capture. torch.cuda.synchronize() - # Capture the graph. self._graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): @@ -1704,10 +1749,14 @@ def capture( # Save the input and output buffers. self.input_buffers = { - "input_ids": input_ids, - "positions": positions, - "kv_caches": kv_caches, - **self.attn_state.get_graph_input_buffers(attn_metadata), + "input_ids": + input_ids, + "positions": + positions, + "kv_caches": + kv_caches, + **self.attn_state.get_graph_input_buffers( + attn_metadata, self._is_encoder_decoder_model), **kwargs, } if intermediate_inputs is not None: @@ -1737,8 +1786,8 @@ def forward( self.input_buffers["positions"].copy_(positions, non_blocking=True) self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, non_blocking=True) - self.attn_state.prepare_graph_input_buffers(self.input_buffers, - attn_metadata) + self.attn_state.prepare_graph_input_buffers( + self.input_buffers, attn_metadata, self._is_encoder_decoder_model) if "seqlen_agnostic_capture_inputs" in self.input_buffers: self.model.copy_inputs_before_cuda_graphs(self.input_buffers, **kwargs) @@ -1752,6 +1801,12 @@ def forward( if key != "model_execute_time" and key != "model_forward_time": self.input_buffers[key].copy_(intermediate_tensors[key], non_blocking=True) + if self._is_encoder_decoder_model: + self.input_buffers["encoder_input_ids"].copy_( + kwargs['encoder_input_ids'], non_blocking=True) + self.input_buffers["encoder_positions"].copy_( + kwargs['encoder_positions'], non_blocking=True) + # Run the graph. self.graph.replay() # Return the output tensor. diff --git a/vllm/worker/utils.py b/vllm/worker/utils.py index d73023e8e1724..a58b80e4f2adb 100644 --- a/vllm/worker/utils.py +++ b/vllm/worker/utils.py @@ -47,10 +47,6 @@ def assert_enc_dec_mr_supported_scenario( raise NotImplementedError( STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SPEC_DEC']) - if not enc_dec_mr.model_config.enforce_eager: - raise NotImplementedError( - STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH']) - if enc_dec_mr.prompt_adapter_config is not None: raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[ 'STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER']) From 9855b99502c7537db5ef018129e603650800ac46 Mon Sep 17 00:00:00 2001 From: chenqianfzh <51831990+chenqianfzh@users.noreply.github.com> Date: Tue, 17 Sep 2024 08:09:12 -0700 Subject: [PATCH 65/98] [Feature][kernel] tensor parallelism with bitsandbytes quantization (#8434) --- tests/quantization/test_bitsandbytes.py | 26 ++++++++++--- vllm/config.py | 6 --- vllm/model_executor/layers/linear.py | 21 ++++++++--- vllm/model_executor/model_loader/loader.py | 44 +++++++++++++++++++++- 4 files changed, 80 insertions(+), 17 deletions(-) diff --git a/tests/quantization/test_bitsandbytes.py b/tests/quantization/test_bitsandbytes.py index 87200b1dcc534..36167cf95f589 100644 --- a/tests/quantization/test_bitsandbytes.py +++ b/tests/quantization/test_bitsandbytes.py @@ -64,6 +64,24 @@ def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts, model_name) +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason='Test requires at least 2 GPUs.') +@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), + reason='bitsandbytes is not supported on this GPU type.') +@pytest.mark.parametrize("model_name, description", models_4bit_to_test) +@fork_new_process_for_each_test +def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, + model_name, description) -> None: + + hf_model_kwargs = {"load_in_4bit": True} + validate_generated_texts(hf_runner, + vllm_runner, + example_prompts[:1], + model_name, + hf_model_kwargs, + vllm_tp_size=2) + + def log_generated_texts(prompts, outputs, runner_name): logged_texts = [] for i, (_, generated_text) in enumerate(outputs): @@ -80,22 +98,21 @@ def validate_generated_texts(hf_runner, vllm_runner, prompts, model_name, - hf_model_kwargs=None): + hf_model_kwargs=None, + vllm_tp_size=1): # NOTE: run vLLM first, as it requires a clean process # when using distributed inference - - #Run with vLLM runner with vllm_runner(model_name, quantization='bitsandbytes', load_format='bitsandbytes', + tensor_parallel_size=vllm_tp_size, enforce_eager=True, gpu_memory_utilization=0.8) as llm: vllm_outputs = llm.generate_greedy(prompts, 8) vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner") # Clean up the GPU memory for the next test - torch.cuda.synchronize() gc.collect() torch.cuda.empty_cache() @@ -108,7 +125,6 @@ def validate_generated_texts(hf_runner, hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner") # Clean up the GPU memory for the next test - torch.cuda.synchronize() gc.collect() torch.cuda.empty_cache() diff --git a/vllm/config.py b/vllm/config.py index a0991597d0673..6c24d15640e99 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -393,12 +393,6 @@ def verify_with_parallel_config( "Pipeline parallelism is only supported for the following " f" architectures: {_PP_SUPPORTED_MODELS}.") - if self.quantization == "bitsandbytes" and ( - parallel_config.tensor_parallel_size > 1 - or parallel_config.pipeline_parallel_size > 1): - raise ValueError( - "BitAndBytes quantization with TP or PP is not supported yet.") - # Remove the constraint after the bitsandbytes issue is fixed: # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1308 if self.quantization == "bitsandbytes" and self.enforce_eager is False: diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index cea768469aeb8..568892778abe2 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -530,8 +530,11 @@ def weight_loader(self, param_data = param_data.narrow(output_dim, shard_offset, shard_size) start_idx = tp_rank * shard_size - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) + # bitsandbytes loads the weights of the specific portion + # no need to narrow here + if not use_bitsandbytes_4bit: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) # Special case for AQLM codebooks. elif is_metadata: # metadata indicates fixed size concatenated along dim 0 @@ -899,8 +902,13 @@ def weight_loader(self, else: shard_id = tp_rank // self.num_kv_head_replicas start_idx = shard_id * shard_size - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) + + # bitsandbytes loads the weights of the specific portion + # no need to narrow here + if not use_bitsandbytes_4bit: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + # Special case for for AQLM codebooks. elif is_metadata: # metadata indicates fixed size concatenated along dim 0 @@ -1000,6 +1008,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() tp_size = get_tensor_model_parallel_world_size() input_dim = getattr(param, "input_dim", None) + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) # Special case for GGUF is_gguf_weight = getattr(param, "is_gguf_weight", False) @@ -1015,7 +1024,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype) param_data = param.data - if input_dim is not None: + # bitsandbytes loads the weights of the specific portion + # no need to narrow here + if input_dim is not None and not use_bitsandbytes_4bit: shard_size = param_data.shape[input_dim] start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(input_dim, start_idx, diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index ac869e56ce198..fd9533ab156a5 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -22,6 +22,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, SchedulerConfig) +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( @@ -689,6 +691,8 @@ def save_model( class BitsAndBytesModelLoader(BaseModelLoader): """Model loader to load model weights with BitAndBytes quantization.""" + # TODO: these module names are for Llama only, + # change so that it works with other models as well default_target_modules = [ "gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj", "o_proj" @@ -911,13 +915,44 @@ def _parse_quant_state(param_name: str, def _unquantized_generator(self, hf_weights_files, use_safetensors, quant_state_dict) -> Generator: from bitsandbytes.functional import quantize_4bit + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + for weight_name, weight_tensor in self._hf_weight_iter( hf_weights_files, use_safetensors): if any(target_module in weight_name for target_module in self.target_modules): weight_name = weight_name.replace(".weight", ".qweight") + + # weight partitions of different modules occur at + # different dimensions + # TODO: these module names are for Llama only, + # change so that it works with other models as well + if 'down_proj' in weight_name or 'o_proj' in weight_name: + total_size = weight_tensor.size(-1) + start_index = total_size // tp_size * tp_rank + end_index = total_size // tp_size * (tp_rank + 1) + weight_sub_tensor = weight_tensor[..., + start_index:end_index] + + else: + total_size = weight_tensor.size(0) + start_index = total_size // tp_size * tp_rank + end_index = total_size // tp_size * (tp_rank + 1) + weight_sub_tensor = weight_tensor[start_index:end_index, + ...] + # bitsandbytes requires data in GPU - loaded_weight = weight_tensor.cuda().data + if weight_sub_tensor.is_cuda: + loaded_weight = weight_sub_tensor + else: + loaded_weight = weight_sub_tensor.cuda() + + # remove the following after the issue is fixed: + # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342 + if loaded_weight.is_contiguous() is False: + loaded_weight = loaded_weight.contiguous() + with set_default_torch_dtype(torch.float32): processed_weight, quant_state = quantize_4bit( loaded_weight, @@ -958,6 +993,13 @@ def _load_weights(self, model_config: ModelConfig, f"BitsAndBytes loader does not support {quant_method} " "quantization") + # The quant_states in pre_quantized models cannot work with a split + # weight tensor. So TP does not work with pre_quantized bnb models. + if pre_quant and get_tensor_model_parallel_world_size() > 1: + raise ValueError( + "Prequant BitsAndBytes models with TP is not supported." + "Please try with PP.") + load_8bit = False if pre_quant: load_8bit = quant_config.get('load_in_8bit', False) From a54ed8024953dc6b59906072a7a89cd4791ec4f0 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 17 Sep 2024 19:50:37 +0200 Subject: [PATCH 66/98] [Model] Add mistral function calling format to all models loaded with "mistral" format (#8515) Co-authored-by: Cyrus Leung --- examples/offline_chat_with_tools.py | 138 ++++++++++++++++++ .../decoder_only/language/test_mistral.py | 67 +++++++++ vllm/entrypoints/llm.py | 6 +- vllm/entrypoints/openai/serving_chat.py | 9 +- vllm/transformers_utils/tokenizers/mistral.py | 8 +- 5 files changed, 219 insertions(+), 9 deletions(-) create mode 100644 examples/offline_chat_with_tools.py diff --git a/examples/offline_chat_with_tools.py b/examples/offline_chat_with_tools.py new file mode 100644 index 0000000000000..e69a6c067e4da --- /dev/null +++ b/examples/offline_chat_with_tools.py @@ -0,0 +1,138 @@ +# ruff: noqa +import json +import random +import string + +from vllm import LLM +from vllm.sampling_params import SamplingParams + +# This script is an offline demo for function calling +# +# If you want to run a server/client setup, please follow this code: +# +# - Server: +# +# ```bash +# vllm serve mistralai/Mistral-7B-Instruct-v0.3 --tokenizer-mode mistral --load-format mistral --config-format mistral +# ``` +# +# - Client: +# +# ```bash +# curl --location 'http://:8000/v1/chat/completions' \ +# --header 'Content-Type: application/json' \ +# --header 'Authorization: Bearer token' \ +# --data '{ +# "model": "mistralai/Mistral-7B-Instruct-v0.3" +# "messages": [ +# { +# "role": "user", +# "content": [ +# {"type" : "text", "text": "Describe this image in detail please."}, +# {"type": "image_url", "image_url": {"url": "https://s3.amazonaws.com/cms.ipressroom.com/338/files/201808/5b894ee1a138352221103195_A680%7Ejogging-edit/A680%7Ejogging-edit_hero.jpg"}}, +# {"type" : "text", "text": "and this one as well. Answer in French."}, +# {"type": "image_url", "image_url": {"url": "https://www.wolframcloud.com/obj/resourcesystem/images/a0e/a0ee3983-46c6-4c92-b85d-059044639928/6af8cfb971db031b.png"}} +# ] +# } +# ] +# }' +# ``` +# +# Usage: +# python demo.py simple +# python demo.py advanced + +model_name = "mistralai/Mistral-7B-Instruct-v0.3" +# or switch to "mistralai/Mistral-Nemo-Instruct-2407" +# or "mistralai/Mistral-Large-Instruct-2407" +# or any other mistral model with function calling ability + +sampling_params = SamplingParams(max_tokens=8192, temperature=0.0) +llm = LLM(model=model_name, + tokenizer_mode="mistral", + config_format="mistral", + load_format="mistral") + + +def generate_random_id(length=9): + characters = string.ascii_letters + string.digits + random_id = ''.join(random.choice(characters) for _ in range(length)) + return random_id + + +# simulate an API that can be called +def get_current_weather(city: str, state: str, unit: 'str'): + return (f"The weather in {city}, {state} is 85 degrees {unit}. It is " + "partly cloudly, with highs in the 90's.") + + +tool_funtions = {"get_current_weather": get_current_weather} + +tools = [{ + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": + "string", + "description": + "The city to find the weather for, e.g. 'San Francisco'" + }, + "state": { + "type": + "string", + "description": + "the two-letter abbreviation for the state that the city is" + " in, e.g. 'CA' which would mean 'California'" + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["city", "state", "unit"] + } + } +}] + +messages = [{ + "role": + "user", + "content": + "Can you tell me what the temperate will be in Dallas, in fahrenheit?" +}] + +outputs = llm.chat(messages, sampling_params=sampling_params, tools=tools) +output = outputs[0].outputs[0].text.strip() + +# append the assistant message +messages.append({ + "role": "assistant", + "content": output, +}) + +# let's now actually parse and execute the model's output simulating an API call by using the +# above defined function +tool_calls = json.loads(output) +tool_answers = [ + tool_funtions[call['name']](**call['arguments']) for call in tool_calls +] + +# append the answer as a tool message and let the LLM give you an answer +messages.append({ + "role": "tool", + "content": "\n\n".join(tool_answers), + "tool_call_id": generate_random_id(), +}) + +outputs = llm.chat(messages, sampling_params, tools=tools) + +print(outputs[0].outputs[0].text.strip()) +# yields +# 'The weather in Dallas, TX is 85 degrees fahrenheit. ' +# 'It is partly cloudly, with highs in the 90's.' diff --git a/tests/models/decoder_only/language/test_mistral.py b/tests/models/decoder_only/language/test_mistral.py index 687ba6a03a691..26f90456849f1 100644 --- a/tests/models/decoder_only/language/test_mistral.py +++ b/tests/models/decoder_only/language/test_mistral.py @@ -4,13 +4,61 @@ """ import pytest +from vllm import SamplingParams + from ...utils import check_logprobs_close MODELS = [ "mistralai/Mistral-7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.3", + # Mistral-Nemo is to big for CI, but passes locally + # "mistralai/Mistral-Nemo-Instruct-2407" ] +SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5) + +# for function calling +TOOLS = [{ + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": + "string", + "description": + "The city to find the weather for, e.g. 'San Francisco'" + }, + "state": { + "type": + "string", + "description": + "the two-letter abbreviation for the state that the city is" + " in, e.g. 'CA' which would mean 'California'" + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["city", "state", "unit"] + } + } +}] +MSGS = [{ + "role": + "user", + "content": ("Can you tell me what the temperate" + " will be in Dallas, in fahrenheit?") +}] +EXPECTED_FUNC_CALL = ( + '[{"name": "get_current_weather", "arguments": ' + '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]') + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) @@ -81,3 +129,22 @@ def test_mistral_format( name_0="hf", name_1="mistral", ) + + +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("model", MODELS[1:]) # v1 can't do func calling +def test_mistral_function_calling( + vllm_runner, + model: str, + dtype: str, +) -> None: + with vllm_runner(model, + dtype=dtype, + tokenizer_mode="mistral", + config_format="mistral", + load_format="mistral") as vllm_model: + outputs = vllm_model.model.chat(MSGS, + tools=TOOLS, + sampling_params=SAMPLING_PARAMS) + + assert outputs[0].outputs[0].text.strip() == EXPECTED_FUNC_CALL diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index a26b721093521..248b070611cd2 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,5 +1,6 @@ from contextlib import contextmanager -from typing import ClassVar, List, Optional, Sequence, Union, cast, overload +from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Union, cast, + overload) from tqdm import tqdm @@ -357,6 +358,7 @@ def chat( lora_request: Optional[LoRARequest] = None, chat_template: Optional[str] = None, add_generation_prompt: bool = True, + tools: Optional[List[Dict[str, Any]]] = None, ) -> List[RequestOutput]: """ Generate responses for a chat conversation. @@ -401,6 +403,7 @@ def chat( messages=messages, chat_template=chat_template, add_generation_prompt=add_generation_prompt, + tools=tools, ) else: prompt = apply_hf_chat_template( @@ -408,6 +411,7 @@ def chat( conversation=conversation, chat_template=chat_template, add_generation_prompt=add_generation_prompt, + tools=tools, ) inputs: PromptInputs diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 58e42fb5363fb..d28362a12abdb 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -123,7 +123,8 @@ async def create_chat_completion( ] prompt: Union[str, List[int]] - if isinstance(tokenizer, MistralTokenizer): + is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer) + if is_mistral_tokenizer: prompt = apply_mistral_chat_template( tokenizer, messages=request.messages, @@ -159,10 +160,10 @@ async def create_chat_completion( return self.create_error_response( "tool_choice = \"required\" is not supported!") - # "auto" tools requires --enable-auto-tool-choice - # and --tool-call-parser - if request.tool_choice == "auto" and not ( + if not is_mistral_tokenizer and request.tool_choice == "auto" and not ( self.enable_auto_tools and self.tool_parser is not None): + # for hf tokenizers, "auto" tools requires + # --enable-auto-tool-choice and --tool-call-parser return self.create_error_response( "\"auto\" tool choice requires " "--enable-auto-tool-choice and --tool-call-parser to be set") diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index ea1910ed20ec3..7a228a3efa6e8 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -165,10 +165,9 @@ def apply_chat_template(self, messages: List["ChatCompletionMessageParam"], tools: Optional[Dict[str, Any]] = None, **kwargs) -> List[int]: - assert tools is None, "`tools` are not yet supported." - request = ChatCompletionRequest( - messages=messages) # type: ignore[type-var] + request = ChatCompletionRequest(messages=messages, + tools=tools) # type: ignore[type-var] encoded = self.mistral.encode_chat_completion(request) # encode-decode to get clean prompt @@ -176,7 +175,8 @@ def apply_chat_template(self, def convert_tokens_to_string(self, tokens: List[str]) -> str: if isinstance(self.tokenizer, Tekkenizer): - return "".join(tokens) + return "".join(t for t in tokens + if t not in self.tokenizer._all_special_tokens) else: return self.tokenizer.decode(tokens) # type: ignore[arg-type] From 56c3de018c35580fd088655c2f9951cd4da5335d Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 17 Sep 2024 20:24:29 +0100 Subject: [PATCH 67/98] [Misc] Don't dump contents of kvcache tensors on errors (#8527) --- vllm/worker/model_runner_base.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 94d2507968382..975b88c0e79a2 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -3,11 +3,13 @@ from abc import ABC, abstractmethod from datetime import datetime from functools import wraps -from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type, - TypeVar) +from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, + Optional, Type, TypeVar) import torch +from torch import is_tensor +from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors, SequenceGroupMetadata @@ -17,6 +19,8 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.model_executor import SamplingMetadata +logger = init_logger(__name__) + T = TypeVar('T', bound="BroadcastableModelInput") @@ -113,6 +117,8 @@ def _wrapper(*args, **kwargs): except Exception as err: timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") filename = f"/tmp/err_{func.__name__}_input_{timestamp}.pkl" + logger.info("Writing input of failed execution to %s...", + filename) with open(filename, "wb") as filep: dumped_inputs = { k: v @@ -122,7 +128,19 @@ def _wrapper(*args, **kwargs): for i, arg in enumerate(args): if i not in (exclude_args or []): dumped_inputs[f"arg_{i}"] = arg + + # Only persist dtype and shape for kvcache tensors + # (can be way to big otherwise) + if (kv_caches := dumped_inputs.get("kv_caches")) \ + and isinstance(kv_caches, Iterable): + dumped_inputs["kv_caches"] = [(t.dtype, t.shape) + for t in kv_caches + if is_tensor(t)] + pickle.dump(dumped_inputs, filep) + logger.info( + "Completed writing input of failed execution to %s.", + filename) raise type(err)( f"Error in model execution (input dumped to {filename}): " f"{str(err)}") from err From 98f9713399bd602ff954a83e6e6abcb4cf8b8864 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 17 Sep 2024 17:17:08 -0600 Subject: [PATCH 68/98] [Bugfix] Fix TP > 1 for new granite (#8544) Signed-off-by: Joe Runde --- vllm/model_executor/models/granite.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index b0325e8b616c8..5f365bbc30670 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -428,7 +428,8 @@ def compute_logits( sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) - logits /= self.config.logits_scaling + if logits is not None: + logits /= self.config.logits_scaling return logits def sample( From fa0c114fad4e2b807503e78d5110558cfee92ba4 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 17 Sep 2024 16:24:06 -0700 Subject: [PATCH 69/98] [doc] improve installation doc (#8550) Co-authored-by: Andy Dai <76841985+Imss27@users.noreply.github.com> --- docs/source/getting_started/installation.rst | 2 ++ tests/compile/test_full_graph.py | 5 ++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/docs/source/getting_started/installation.rst b/docs/source/getting_started/installation.rst index 50a761b49490c..0322503a89a56 100644 --- a/docs/source/getting_started/installation.rst +++ b/docs/source/getting_started/installation.rst @@ -95,6 +95,8 @@ You can also build and install vLLM from source: $ export MAX_JOBS=6 $ pip install -e . + This is especially useful when you are building on less powerful machines. For example, when you use WSL, it only `gives you half of the memory by default `_, and you'd better use ``export MAX_JOBS=1`` to avoid compiling multiple files simultaneously and running out of memory. The side effect is that the build process will be much slower. If you only touch the Python code, slow compilation is okay, as you are building in an editable mode: you can just change the code and run the Python script without any re-compilation or re-installation. + .. tip:: If you have trouble building vLLM, we recommend using the NVIDIA PyTorch Docker image. diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 6fc445539bbbe..2e309aaa58d48 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -28,7 +28,10 @@ def test_full_graph(model, tp_size): "The future of AI is", ] sampling_params = SamplingParams(temperature=0) - llm = LLM(model=model, enforce_eager=True, tensor_parallel_size=tp_size) + llm = LLM(model=model, + enforce_eager=True, + tensor_parallel_size=tp_size, + disable_custom_all_reduce=True) outputs = llm.generate(prompts, sampling_params) From 09deb4721f830602d0417604c7e18b7e384f9594 Mon Sep 17 00:00:00 2001 From: "Alexey Kondratiev(AMD)" <143633163+alexeykondrat@users.noreply.github.com> Date: Tue, 17 Sep 2024 19:40:29 -0400 Subject: [PATCH 70/98] [CI/Build] Excluding kernels/test_gguf.py from ROCm (#8520) --- .buildkite/run-amd-test.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index 6659440135ff4..9274a30e04325 100755 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -83,6 +83,7 @@ if [[ $commands == *" kernels "* ]]; then --ignore=kernels/test_encoder_decoder_attn.py \ --ignore=kernels/test_flash_attn.py \ --ignore=kernels/test_flashinfer.py \ + --ignore=kernels/test_gguf.py \ --ignore=kernels/test_int8_quant.py \ --ignore=kernels/test_machete_gemm.py \ --ignore=kernels/test_mamba_ssm.py \ From 8110e44529f431d54b02060528601c0d3e3f7d02 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 17 Sep 2024 19:44:27 -0400 Subject: [PATCH 71/98] [Kernel] Change interface to Mamba causal_conv1d_update for continuous batching (#8012) --- csrc/mamba/causal_conv1d/causal_conv1d.cu | 30 +++++++++- csrc/mamba/causal_conv1d/causal_conv1d.h | 4 ++ csrc/ops.h | 9 ++- csrc/torch_bindings.cpp | 5 +- tests/kernels/test_causal_conv1d.py | 58 +++++++++++++++++++ vllm/_custom_ops.py | 14 +++-- .../layers/mamba/ops/causal_conv1d.py | 10 +++- 7 files changed, 114 insertions(+), 16 deletions(-) diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index 88a64a8ece585..32261ec17d897 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -198,7 +198,8 @@ causal_conv1d_update(const at::Tensor &x, const at::Tensor &conv_state, const at::Tensor &weight, const c10::optional &bias_, - bool silu_activation) { + bool silu_activation, + const c10::optional &conv_state_indices_) { auto input_type = x.scalar_type(); auto weight_type = weight.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -216,7 +217,6 @@ causal_conv1d_update(const at::Tensor &x, const int width = weight.size(-1); CHECK_SHAPE(x, batch_size, dim); - CHECK_SHAPE(conv_state, batch_size, dim, width); CHECK_SHAPE(weight, dim, width); TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); @@ -241,6 +241,22 @@ causal_conv1d_update(const at::Tensor &x, params.conv_state_c_stride = conv_state.stride(1); params.conv_state_l_stride = conv_state.stride(2); + if (conv_state_indices_.has_value()) { + auto conv_state_indices = conv_state_indices_.value(); + TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32) + TORCH_CHECK(conv_state_indices.is_cuda()); + TORCH_CHECK(conv_state_indices.stride(0) == 1) + CHECK_SHAPE(conv_state_indices, batch_size); + + int conv_state_entries = conv_state.size(0); + CHECK_SHAPE(conv_state, conv_state_entries, dim, width); + + params.conv_state_indices_ptr = conv_state_indices.data_ptr(); + } else { + CHECK_SHAPE(conv_state, batch_size, dim, width); + params.conv_state_indices_ptr = nullptr; + } + // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing at::cuda::CUDAGuard device_guard{(char)x.get_device()}; @@ -646,8 +662,16 @@ void causal_conv1d_update_kernel(ConvParamsBase params) { const int channel_id = blockIdx.y * kNThreads + tidx; input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride + channel_id * params.x_c_stride; - input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride + + // If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor + // along the batch axis. Otherwise, the conv state coordinate is the same as the batch id. + const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr + ? batch_id + : params.conv_state_indices_ptr[batch_id]; + input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + + conv_state_batch_coord * params.conv_state_batch_stride + channel_id * params.conv_state_c_stride; + weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + channel_id * params.out_c_stride; diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.h b/csrc/mamba/causal_conv1d/causal_conv1d.h index bb25314c8bbbd..32a7d83c09b8d 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.h +++ b/csrc/mamba/causal_conv1d/causal_conv1d.h @@ -36,6 +36,10 @@ struct ConvParamsBase { void *__restrict__ conv_state_ptr; + // For the continuous batching case. Makes it so that the mamba state for + // the current batch doesn't need to be a contiguous tensor. + int32_t *__restrict__ conv_state_indices_ptr; + void *__restrict__ seq_idx_ptr; // No __restrict__ since initial_states could be the same as final_states. diff --git a/csrc/ops.h b/csrc/ops.h index ee89ad32cb025..15e9ebe87408a 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -222,11 +222,10 @@ std::vector selective_scan_fwd( const c10::optional& index_, const c10::optional& x); -at::Tensor causal_conv1d_update(const at::Tensor& x, - const at::Tensor& conv_state, - const at::Tensor& weight, - const c10::optional& bias_, - bool silu_activation); +at::Tensor causal_conv1d_update( + const at::Tensor& x, const at::Tensor& conv_state, const at::Tensor& weight, + const c10::optional& bias, bool silu_activation, + const c10::optional& conv_state_indices); at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, const c10::optional& bias_, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 7009180a8687c..045203c3de8a8 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -279,8 +279,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "causal_conv1d_update(Tensor! x," "Tensor! conv_state," "Tensor! weight," - "Tensor? bias_," - "bool silu_activation) -> Tensor"); + "Tensor? bias," + "bool silu_activation," + "Tensor? conv_state_indices) -> Tensor"); ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update); ops.def( diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index 7bf338b36953a..344e07e739454 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -203,3 +203,61 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation, assert torch.equal(conv_state, conv_state_ref) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("itype", + [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("silu_activation", [False, True]) +@pytest.mark.parametrize("has_bias", [False, True]) +@pytest.mark.parametrize("seqlen", [1, 4, 5]) +@pytest.mark.parametrize("width", [2, 3, 4]) +@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) +def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias, + silu_activation, itype): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 1e-2, 5e-2 + + # set seed + torch.random.manual_seed(0) + batch = 64 + + x = torch.randn(batch, dim, device=device, dtype=itype) + + total_entries = 10 * batch + conv_state = torch.randn(total_entries, + dim, + width, + device=device, + dtype=itype) + conv_state_indices = torch.randperm(total_entries)[:batch].to( + dtype=torch.int32, device=device) + + weight = torch.randn(dim, + width, + device=device, + dtype=itype, + requires_grad=True) + if has_bias: + bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True) + else: + bias = None + conv_state_ref = conv_state[conv_state_indices, :].detach().clone() + activation = None if not silu_activation else "silu" + out = causal_conv1d_update(x, + conv_state, + weight, + bias, + activation=activation, + conv_state_indices=conv_state_indices) + out_ref = causal_conv1d_update_ref(x, + conv_state_ref, + weight, + bias, + activation=activation) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref) + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ac90895b11c37..ff5aa8bee3c27 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -768,11 +768,17 @@ def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor, silu_activation) -def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor, - weight: torch.Tensor, bias_: Optional[torch.Tensor], - silu_activation: bool) -> torch.Tensor: +def causal_conv1d_update( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias_: Optional[torch.Tensor], + silu_activation: bool, + conv_state_indices: Optional[torch.Tensor], +) -> torch.Tensor: return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_, - silu_activation) + silu_activation, + conv_state_indices) def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index 413c8bc227ae8..196d81267f32f 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -1,4 +1,5 @@ # Copyright (c) 2024, Tri Dao. +# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py from typing import Optional @@ -70,12 +71,17 @@ def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, - activation: Optional[str] = None): + activation: Optional[str] = None, + conv_state_indices: Optional[torch.Tensor] = None): """ x: (batch, dim) conv_state: (batch, dim, width) weight: (dim, width) bias: (dim,) + conv_state_indices: (batch,), dtype int32 + If not None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. out: (batch, dim) """ @@ -83,4 +89,4 @@ def causal_conv1d_update(x: torch.Tensor, raise NotImplementedError("activation must be None, silu, or swish") activation_bool = activation in ["silu", "swish"] return ops.causal_conv1d_update(x, conv_state, weight, bias, - activation_bool) + activation_bool, conv_state_indices) From 95965d31b6ac2c9557816a6ffabe4a3117a5ccb2 Mon Sep 17 00:00:00 2001 From: Daniele <36171005+dtrifiro@users.noreply.github.com> Date: Wed, 18 Sep 2024 04:49:53 +0200 Subject: [PATCH 72/98] [CI/Build] fix Dockerfile.cpu on podman (#8540) --- Dockerfile.cpu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Dockerfile.cpu b/Dockerfile.cpu index 34b4c95e34ffc..4d7289366296b 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -24,6 +24,8 @@ RUN echo 'ulimit -c 0' >> ~/.bashrc RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.4.0%2Bgitfbaa4bc-cp310-cp310-linux_x86_64.whl +WORKDIR /workspace + ENV PIP_EXTRA_INDEX_URL=https://download.pytorch.org/whl/cpu RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,src=requirements-build.txt,target=requirements-build.txt \ From e351572900f7d87e14fe203ea3a49c1c7ddae0d6 Mon Sep 17 00:00:00 2001 From: Jiaxin Shan Date: Wed, 18 Sep 2024 02:51:59 -0700 Subject: [PATCH 73/98] [Misc] Add argument to disable FastAPI docs (#8554) --- vllm/entrypoints/openai/api_server.py | 8 +++++++- vllm/entrypoints/openai/cli_args.py | 7 +++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 3d1d832986c1e..b891debfd2b91 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -417,7 +417,13 @@ async def unload_lora_adapter(request: UnloadLoraAdapterRequest, def build_app(args: Namespace) -> FastAPI: - app = FastAPI(lifespan=lifespan) + if args.disable_fastapi_docs: + app = FastAPI(openapi_url=None, + docs_url=None, + redoc_url=None, + lifespan=lifespan) + else: + app = FastAPI(lifespan=lifespan) app.include_router(router) app.root_path = args.root_path diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 7ccee0b6b55b7..bbb0823de9a51 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -190,6 +190,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'ID numbers being printed in log.' '\n\nDefault: Unlimited') + parser.add_argument( + "--disable-fastapi-docs", + action='store_true', + default=False, + help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint" + ) + return parser From 6ffa3f314c59e42238f1c5f923ff2839e0af9698 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 18 Sep 2024 18:38:11 +0800 Subject: [PATCH 74/98] [CI/Build] Avoid CUDA initialization (#8534) --- benchmarks/kernels/benchmark_layernorm.py | 9 +-- benchmarks/kernels/benchmark_moe.py | 6 +- .../kernels/benchmark_paged_attention.py | 7 +-- benchmarks/kernels/benchmark_quant.py | 9 +-- benchmarks/kernels/benchmark_rope.py | 6 +- tests/kernels/test_activation.py | 9 +-- tests/kernels/test_attention.py | 18 ++---- tests/kernels/test_attention_selector.py | 2 +- tests/kernels/test_awq_triton.py | 5 +- tests/kernels/test_blocksparse_attention.py | 12 +--- tests/kernels/test_cache.py | 25 +++----- tests/kernels/test_causal_conv1d.py | 5 +- tests/kernels/test_cutlass.py | 11 ++-- tests/kernels/test_flash_attn.py | 5 +- tests/kernels/test_flashinfer.py | 10 +-- tests/kernels/test_fp8_quant.py | 10 ++- tests/kernels/test_gguf.py | 5 +- tests/kernels/test_int8_quant.py | 13 ++-- tests/kernels/test_layernorm.py | 5 +- tests/kernels/test_machete_gemm.py | 2 +- tests/kernels/test_mamba_ssm.py | 5 +- tests/kernels/test_moe.py | 3 +- tests/kernels/test_pos_encoding.py | 14 ++--- tests/kernels/test_prefix_prefill.py | 12 +--- tests/lora/test_layers.py | 5 +- tests/lora/test_punica_sizes.py | 18 ++---- tests/lora/test_punica_variation.py | 18 ++---- .../decoder_only/language/test_granite.py | 9 +-- tests/quantization/test_fp8.py | 4 +- tests/quantization/utils.py | 8 ++- vllm/attention/backends/rocm_flash_attn.py | 3 +- .../ops/blocksparse_attention/interface.py | 5 +- vllm/attention/ops/prefix_prefill.py | 3 +- vllm/attention/selector.py | 4 +- vllm/config.py | 12 ++-- vllm/distributed/parallel_state.py | 3 +- vllm/envs.py | 1 + .../compressed_tensors/compressed_tensors.py | 6 +- .../layers/quantization/fbgemm_fp8.py | 4 +- .../model_executor/layers/quantization/fp8.py | 5 +- .../layers/quantization/utils/marlin_utils.py | 10 +-- .../quantization/utils/marlin_utils_fp8.py | 3 +- .../layers/quantization/utils/w8a8_utils.py | 5 +- vllm/model_executor/model_loader/loader.py | 6 +- vllm/model_executor/models/qwen2_vl.py | 2 +- vllm/model_executor/utils.py | 10 +-- vllm/platforms/cpu.py | 8 +-- vllm/platforms/cuda.py | 17 ++--- vllm/platforms/interface.py | 62 ++++++++++++++++--- vllm/platforms/rocm.py | 14 ++--- vllm/platforms/tpu.py | 8 ++- vllm/prompt_adapter/utils.py | 4 +- vllm/usage/usage_lib.py | 3 +- vllm/utils.py | 28 ++++++--- vllm/worker/worker.py | 16 +++-- 55 files changed, 256 insertions(+), 256 deletions(-) diff --git a/benchmarks/kernels/benchmark_layernorm.py b/benchmarks/kernels/benchmark_layernorm.py index 4947fda02e1cc..92f6053cc6d7e 100644 --- a/benchmarks/kernels/benchmark_layernorm.py +++ b/benchmarks/kernels/benchmark_layernorm.py @@ -1,10 +1,10 @@ -import random import time import torch from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, + seed_everything) @torch.inference_mode() @@ -16,10 +16,7 @@ def main(num_tokens: int, do_profile: bool = False, num_warmup_iters: int = 5, num_iters: int = 100) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device("cuda") layer = RMSNorm(hidden_size).to(dtype=dtype) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index fd233c71b10a6..c2ad98b7e2656 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -10,7 +10,7 @@ from transformers import AutoConfig from vllm.model_executor.layers.fused_moe.fused_moe import * -from vllm.utils import FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser, seed_everything class BenchmarkConfig(TypedDict): @@ -166,7 +166,7 @@ class BenchmarkWorker: def __init__(self, seed: int) -> None: torch.set_default_device("cuda") - torch.cuda.manual_seed_all(seed) + seed_everything(seed) self.seed = seed def benchmark( @@ -180,7 +180,7 @@ def benchmark( use_fp8_w8a8: bool, use_int8_w8a16: bool, ) -> Tuple[Dict[str, int], float]: - torch.cuda.manual_seed_all(self.seed) + seed_everything(self.seed) dtype_str = get_config_dtype_str(dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index a04433142da42..87864d038d593 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -6,7 +6,7 @@ from vllm import _custom_ops as ops from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, - create_kv_caches_with_random) + create_kv_caches_with_random, seed_everything) NUM_BLOCKS = 1024 PARTITION_SIZE = 512 @@ -28,10 +28,7 @@ def main( device: str = "cuda", kv_cache_dtype: Optional[str] = None, ) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) scale = float(1.0 / (head_size**0.5)) query = torch.empty(num_seqs, diff --git a/benchmarks/kernels/benchmark_quant.py b/benchmarks/kernels/benchmark_quant.py index 4c1a7b26213a5..743a5744e8614 100644 --- a/benchmarks/kernels/benchmark_quant.py +++ b/benchmarks/kernels/benchmark_quant.py @@ -1,10 +1,10 @@ -import random import time import torch from vllm import _custom_ops as ops -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, + seed_everything) @torch.inference_mode() @@ -17,10 +17,7 @@ def main(num_tokens: int, do_profile: bool = False, num_warmup_iters: int = 5, num_iters: int = 100) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device("cuda") x = torch.randn(num_tokens, hidden_size, dtype=dtype) diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index f542684a9a2a9..73fc9e9dbf461 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -6,7 +6,7 @@ from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding, get_rope) -from vllm.utils import FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser, seed_everything def benchmark_rope_kernels_multi_lora( @@ -22,9 +22,7 @@ def benchmark_rope_kernels_multi_lora( max_position: int = 8192, base: int = 10000, ) -> None: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) if rotary_dim is None: rotary_dim = head_size diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py index ed050ce851535..9b476585fa19e 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/test_activation.py @@ -7,6 +7,7 @@ from vllm.model_executor.layers.activation import (FastGELU, GeluAndMul, NewGELU, QuickGELU, SiluAndMul) +from vllm.utils import seed_everything from .allclose_default import get_default_atol, get_default_rtol @@ -34,9 +35,7 @@ def test_act_and_mul( seed: int, device: str, ) -> None: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) x = torch.randn(num_tokens, 2 * d, dtype=dtype) if activation == "silu": @@ -77,9 +76,7 @@ def test_activation( seed: int, device: str, ) -> None: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) x = torch.randn(num_tokens, d, dtype=dtype) layer = activation[0]() diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 46831b506aff3..4bd6f7863a658 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -6,7 +6,7 @@ from tests.kernels.utils import opcheck from vllm import _custom_ops as ops -from vllm.utils import get_max_shared_memory_bytes, is_hip +from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything from .allclose_default import get_default_atol, get_default_rtol @@ -139,10 +139,8 @@ def test_paged_attention( ) -> None: if kv_cache_dtype == "fp8" and head_size % 16: pytest.skip() - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + + seed_everything(seed) torch.set_default_device(device) scale = float(1.0 / (head_size**0.5)) num_query_heads, num_kv_heads = num_heads @@ -354,10 +352,7 @@ def test_paged_attention_rocm( seed: int, device: str, ) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) scale = float(1.0 / (head_size**0.5)) num_query_heads, num_kv_heads = num_heads @@ -506,10 +501,7 @@ def test_multi_query_kv_attention( seed: int, device: str, ) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. # As the xformers library is already tested with its own tests, we can use diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index a20a741c27f74..c1fb45955a0e5 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -45,7 +45,7 @@ def test_flash_attn(monkeypatch): override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL) # Unsupported CUDA arch - with patch("torch.cuda.get_device_capability", return_value=[7, 5]): + with patch("torch.cuda.get_device_capability", return_value=(7, 5)): backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) assert backend.name != STR_FLASH_ATTN_VAL diff --git a/tests/kernels/test_awq_triton.py b/tests/kernels/test_awq_triton.py index 198d40a155ccb..e95e5bd948212 100644 --- a/tests/kernels/test_awq_triton.py +++ b/tests/kernels/test_awq_triton.py @@ -7,6 +7,7 @@ from vllm.model_executor.layers.quantization.awq_triton import ( AWQ_TRITON_SUPPORTED_GROUP_SIZES, awq_dequantize_triton, awq_gemm_triton) +from vllm.utils import seed_everything device = "cuda" @@ -79,7 +80,7 @@ def test_dequantize(qweight_rows, qweight_cols, group_size): zeros_cols = qweight_cols zeros_dtype = torch.int32 - torch.manual_seed(0) + seed_everything(0) qweight = torch.randint(0, torch.iinfo(torch.int32).max, @@ -133,7 +134,7 @@ def test_gemm(N, K, M, splitK, group_size): qzeros_rows = scales_rows qzeros_cols = qweight_cols - torch.manual_seed(0) + seed_everything(0) input = torch.rand((input_rows, input_cols), dtype=input_dtype, diff --git a/tests/kernels/test_blocksparse_attention.py b/tests/kernels/test_blocksparse_attention.py index 7357508751ae1..f3bd8f0524264 100644 --- a/tests/kernels/test_blocksparse_attention.py +++ b/tests/kernels/test_blocksparse_attention.py @@ -7,7 +7,7 @@ from vllm import _custom_ops as ops from vllm.attention.ops.blocksparse_attention.interface import ( LocalStridedBlockSparseAttn) -from vllm.utils import get_max_shared_memory_bytes, is_hip +from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything from .allclose_default import get_default_atol, get_default_rtol @@ -172,10 +172,7 @@ def test_paged_attention( blocksparse_block_size: int, blocksparse_head_sliding_step: int, ) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) scale = float(1.0 / (head_size**0.5)) num_query_heads, num_kv_heads = num_heads @@ -386,10 +383,7 @@ def test_varlen_blocksparse_attention_prefill( seed: int, device: str, ) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. # As the xformers library is already tested with its own tests, we can use diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 19402a337b8d6..b0e7097fdfbd4 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -6,6 +6,7 @@ from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck from vllm import _custom_ops as ops +from vllm.utils import seed_everything COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -55,10 +56,7 @@ def test_copy_blocks( ) -> None: if kv_cache_dtype == "fp8" and head_size % 16: pytest.skip() - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) # Generate random block mappings where each source block is mapped to two # destination blocks. @@ -134,10 +132,7 @@ def test_reshape_and_cache( ) -> None: if kv_cache_dtype == "fp8" and head_size % 16: pytest.skip() - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) # Create a random slot mapping. num_slots = block_size * num_blocks @@ -229,9 +224,7 @@ def test_reshape_and_cache_flash( device: str, kv_cache_dtype: str, ) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) # Create a random slot mapping. @@ -345,10 +338,8 @@ def test_swap_blocks( pytest.skip() if kv_cache_dtype == "fp8" and head_size % 16: pytest.skip() - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + + seed_everything(seed) src_device = device if direction[0] == "cuda" else 'cpu' dst_device = device if direction[1] == "cuda" else 'cpu' @@ -417,9 +408,7 @@ def test_fp8_e4m3_conversion( seed: int, device: str, ) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) low = -224.0 high = 224.0 diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index 344e07e739454..043c4923bd660 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -7,6 +7,7 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) +from vllm.utils import seed_everything def causal_conv1d_ref( @@ -104,7 +105,7 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 # set seed - torch.random.manual_seed(0) + seed_everything(0) if not channel_last: x = torch.randn(batch, 4096 + dim + 64, @@ -175,7 +176,7 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation, if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 # set seed - torch.random.manual_seed(0) + seed_everything(0) batch = 2 x = torch.randn(batch, dim, device=device, dtype=itype) conv_state = torch.randn(batch, dim, width, device=device, dtype=itype) diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index d1f0524f83c4c..cc4ca2e91e76f 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -15,9 +15,6 @@ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] -capability = current_platform.get_device_capability() -capability = capability[0] * 10 + capability[1] - def to_fp8(tensor: torch.Tensor): finfo = torch.finfo(torch.float8_e4m3fn) @@ -119,7 +116,7 @@ def cutlass_int8_gemm_helper(m: int, @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("use_bias", [True, False]) -@pytest.mark.skipif(capability < 89, +@pytest.mark.skipif(not current_platform.has_device_capability(89), reason="FP8 is not supported on this GPU type.") def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool, per_out_ch: bool, use_bias: bool): @@ -157,7 +154,7 @@ def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool, @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("use_bias", [True, False]) -@pytest.mark.skipif(capability < 89, +@pytest.mark.skipif(not current_platform.has_device_capability(89), reason="FP8 is not supported on this GPU type.") def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool, out_dtype: Type[torch.dtype], @@ -175,7 +172,7 @@ def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool, @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.skipif(capability < 89, +@pytest.mark.skipif(not current_platform.has_device_capability(89), reason="FP8 is not supported on this GPU type.") def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool, use_bias: bool, device: str): @@ -207,7 +204,7 @@ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool, @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("use_bias", [True, False]) -@pytest.mark.skipif(capability < 89, +@pytest.mark.skipif(not current_platform.has_device_capability(89), reason="FP8 is not supported on this GPU type.") def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool, use_bias: bool): diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index 870a8bf65eb92..8e960d098c408 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -4,6 +4,7 @@ import torch import vllm.attention.backends.flash_attn # noqa: F401 +from vllm.utils import seed_everything NUM_HEADS = [(4, 4), (8, 2), (16, 2)] HEAD_SIZES = [128, 256] @@ -87,7 +88,7 @@ def test_flash_attn_with_paged_kv( num_blocks: int, ) -> None: torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) + seed_everything(0) num_seqs = len(kv_lens) num_query_heads = num_heads[0] num_kv_heads = num_heads[1] @@ -174,7 +175,7 @@ def test_varlen_with_paged_kv( num_blocks: int, ) -> None: torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) + seed_everything(0) num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] kv_lens = [x[1] for x in seq_lens] diff --git a/tests/kernels/test_flashinfer.py b/tests/kernels/test_flashinfer.py index 696cc0c6cdf10..80a388db6530e 100644 --- a/tests/kernels/test_flashinfer.py +++ b/tests/kernels/test_flashinfer.py @@ -4,6 +4,8 @@ import pytest import torch +from vllm.utils import seed_everything + NUM_HEADS = [(16, 16), (32, 8), (64, 8), (6, 1)] HEAD_SIZES = [128, 256] BLOCK_SIZES = [16, 32] @@ -82,7 +84,7 @@ def test_flashinfer_decode_with_paged_kv( soft_cap: Optional[float], ) -> None: torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) + seed_everything(0) num_seqs = len(kv_lens) num_query_heads = num_heads[0] num_kv_heads = num_heads[1] @@ -168,7 +170,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]], block_size: int, soft_cap: Optional[float]) -> None: torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) + seed_everything(0) num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] kv_lens = [x[1] for x in seq_lens] @@ -266,7 +268,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv( head_size: int, dtype: torch.dtype, block_size: int, soft_cap: Optional[float]) -> None: torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) + seed_everything(0) num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] kv_lens = [x[1] for x in seq_lens] @@ -379,7 +381,7 @@ def test_flashinfer_decode_with_paged_fp8_kv( ) -> None: # test doesn't work for num_heads = (16,16) torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) + seed_everything(0) num_seqs = len(kv_lens) num_query_heads = num_heads[0] num_kv_heads = num_heads[1] diff --git a/tests/kernels/test_fp8_quant.py b/tests/kernels/test_fp8_quant.py index bae9b39203ff9..49f5ce53aab54 100644 --- a/tests/kernels/test_fp8_quant.py +++ b/tests/kernels/test_fp8_quant.py @@ -5,6 +5,7 @@ from tests.kernels.quant_utils import (FP8_DTYPE, ref_dynamic_per_tensor_fp8_quant, ref_dynamic_per_token_quant) +from vllm.utils import seed_everything DTYPES = [torch.half, torch.bfloat16, torch.float] HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192, @@ -24,8 +25,7 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, scale_ub: bool, seed: int) -> None: - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") + 1e-6 # avoid nans @@ -49,8 +49,7 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, @torch.inference_mode() def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int) -> None: - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") @@ -67,8 +66,7 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int, @torch.inference_mode() @pytest.mark.parametrize("seed", SEEDS) def test_fp8_quant_large(seed: int) -> None: - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) num_tokens = 1024000 # Mistral-Nemo's max_position_embeddings hidden_size = 1152 # Smallest hidden_size to reproduce the error diff --git a/tests/kernels/test_gguf.py b/tests/kernels/test_gguf.py index ee29ed93b61fc..1513fc196153c 100644 --- a/tests/kernels/test_gguf.py +++ b/tests/kernels/test_gguf.py @@ -7,6 +7,7 @@ from huggingface_hub import snapshot_download import vllm._custom_ops as ops +from vllm.utils import seed_everything GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample") @@ -74,7 +75,7 @@ def test_dequantize(hidden_size: int, dtype: torch.dtype, @torch.inference_mode() def test_mmvq(hidden_size: int, dtype: torch.dtype, quant_type: GGMLQuantizationType): - torch.cuda.manual_seed_all(0) + seed_everything(0) tensors = get_gguf_sample_tensors(hidden_size, quant_type) x = torch.rand((1, hidden_size), dtype=dtype, device="cuda") @@ -110,7 +111,7 @@ def test_mmvq(hidden_size: int, dtype: torch.dtype, @torch.inference_mode() def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype, quant_type: GGMLQuantizationType): - torch.cuda.manual_seed_all(0) + seed_everything(0) tensors = get_gguf_sample_tensors(hidden_size, quant_type) x = torch.rand((num_tokens, hidden_size), dtype=dtype, device="cuda") diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index e93cb535d715a..41e103e1d09f9 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -4,6 +4,7 @@ from tests.kernels.quant_utils import ref_dynamic_per_token_quant from tests.kernels.utils import opcheck from vllm._custom_ops import scaled_int8_quant +from vllm.utils import seed_everything DTYPES = [torch.half, torch.bfloat16, torch.float] HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192, @@ -44,8 +45,7 @@ def opcheck_int8_quant_dynamic(output, input, symmetric=True): @torch.inference_mode() def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int) -> None: - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 @@ -68,8 +68,7 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, @torch.inference_mode() def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int) -> None: - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) int8_traits = torch.iinfo(torch.int8) x = torch.rand(num_tokens, hidden_size, dtype=dtype, @@ -113,8 +112,7 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int, scale: float) -> None: - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) int8_traits = torch.iinfo(torch.int8) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 @@ -140,8 +138,7 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int, scale: float, azp: int) -> None: - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) int8_traits = torch.iinfo(torch.int8) x = torch.rand(num_tokens, hidden_size, dtype=dtype, diff --git a/tests/kernels/test_layernorm.py b/tests/kernels/test_layernorm.py index 6eaf67ec75f41..382079d472ee9 100644 --- a/tests/kernels/test_layernorm.py +++ b/tests/kernels/test_layernorm.py @@ -3,6 +3,7 @@ from tests.kernels.utils import opcheck from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.utils import seed_everything DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing @@ -30,9 +31,7 @@ def test_rms_norm( seed: int, device: str, ) -> None: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) layer = RMSNorm(hidden_size).to(dtype=dtype) layer.weight.data.normal_(mean=1.0, std=0.1) diff --git a/tests/kernels/test_machete_gemm.py b/tests/kernels/test_machete_gemm.py index ce65aaef60ac6..0a90882223077 100644 --- a/tests/kernels/test_machete_gemm.py +++ b/tests/kernels/test_machete_gemm.py @@ -48,7 +48,7 @@ # `is_quant_method_supported` conflates kernels with quantization methods # an assumption which is breaking down as quantizations methods can have # have kernels and some kernels support multiple quantization methods. -IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9 +IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90) def rand_data(shape, dtype=torch.float16): diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py index d3cb0a8656a02..f582445692344 100644 --- a/tests/kernels/test_mamba_ssm.py +++ b/tests/kernels/test_mamba_ssm.py @@ -5,6 +5,7 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_scan_fn, selective_state_update) +from vllm.utils import seed_everything def selective_state_update_ref(state, @@ -186,7 +187,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, rtolw = max(rtolw, rtol) atolw = max(atolw, atol) # set seed - torch.random.manual_seed(0) + seed_everything(0) batch_size = 2 dim = 4 dstate = 8 @@ -287,7 +288,7 @@ def test_selective_state_update(dim, dstate, has_z, itype): if torch.version.hip: atol *= 2 # set seed - torch.random.manual_seed(0) + seed_everything(0) batch_size = 1 state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device) x = torch.randn(batch_size, dim, device=device, dtype=itype) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 8072cf09e5b65..b1f0516dfa0b3 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -18,6 +18,7 @@ marlin_quantize) from vllm.model_executor.models.mixtral import MixtralMoE from vllm.scalar_type import scalar_types +from vllm.utils import seed_everything def torch_moe(a, w1, w2, score, topk): @@ -151,7 +152,7 @@ def test_fused_marlin_moe( act_order: bool, num_bits: int, ): - torch.manual_seed(7) + seed_everything(7) if topk > e: return diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index 65242e275650c..ba9d2d4389b21 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -5,6 +5,7 @@ import torch from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.utils import seed_everything from .allclose_default import get_default_atol, get_default_rtol @@ -46,9 +47,8 @@ def test_rotary_embedding( ) -> None: if rotary_dim is None: rotary_dim = head_size - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + + seed_everything(seed) torch.set_default_device(device) if rotary_dim is None: rotary_dim = head_size @@ -100,9 +100,7 @@ def test_batched_rotary_embedding( max_position: int = 8192, base: int = 10000, ) -> None: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) if rotary_dim is None: rotary_dim = head_size @@ -162,9 +160,7 @@ def test_batched_rotary_embedding_multi_lora( max_position: int = 8192, base: int = 10000, ) -> None: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) if rotary_dim is None: rotary_dim = head_size diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index 60f9a4dc9f90f..3181d92562399 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -9,7 +9,7 @@ from vllm.attention.backends.xformers import _make_alibi_bias from vllm.attention.ops.prefix_prefill import context_attention_fwd -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, seed_everything NUM_HEADS = [64] NUM_QUERIES_PER_KV = [1, 8, 64] @@ -39,10 +39,7 @@ def test_contexted_kv_attention( kv_cache_dtype: str, device: str, ) -> None: - random.seed(0) - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed(0) + seed_everything(0) torch.set_default_device(device) # Need this, otherwise when we capture the graph the process @@ -237,10 +234,7 @@ def test_contexted_kv_attention_alibi( kv_cache_dtype: str, device: str, ) -> None: - random.seed(0) - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed(0) + seed_everything(0) torch.set_default_device(device) # Need this, otherwise when we capture the graph the process diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index effcffc5c174e..e3233c6b60696 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -39,6 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask) from vllm.model_executor.utils import set_random_seed +from vllm.utils import seed_everything from .utils import DummyLoRAManager @@ -922,9 +923,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, seq_len) -> None: dtype = torch.float16 seed = 0 - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) punica_wrapper = PunicaWrapper(8192, 256, device) max_loras = 8 diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_sizes.py index c36fb3afb0cc3..314d6215cbd9c 100644 --- a/tests/lora/test_punica_sizes.py +++ b/tests/lora/test_punica_sizes.py @@ -4,7 +4,6 @@ whether the corresponding Triton kernel can run normally when tensor parallelism is set to [1, 2, 4, 8, 16, 32, 64]. """ -import random from unittest.mock import patch import pytest @@ -17,6 +16,7 @@ from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.sgmv_shrink import sgmv_shrink from vllm.triton_utils.libentry import LibEntry +from vllm.utils import seed_everything from .utils import (generate_data, generate_data_for_expand_nslices, ref_torch_groupgemm) @@ -145,11 +145,8 @@ def test_punica_sgmv( seed: int, device: str, ): - random.seed(seed) torch.set_default_device(device) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) seq_length = 128 ( @@ -238,11 +235,8 @@ def test_punica_bgmv( from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel - random.seed(seed) torch.set_default_device(device) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) seq_length = 1 ( @@ -329,11 +323,9 @@ def test_punica_expand_nslices( ): from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel - random.seed(seed) torch.set_default_device(device) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) + seq_length = 128 if op_type == "sgmv" else 1 ( inputs_tensor, diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_variation.py index d026e34878e04..28a395af19e6d 100644 --- a/tests/lora/test_punica_variation.py +++ b/tests/lora/test_punica_variation.py @@ -3,7 +3,6 @@ under different conditions, including various batches, numbers of LoRA , and maximum ranks. """ -import random from unittest.mock import patch import pytest @@ -16,6 +15,7 @@ from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.sgmv_shrink import sgmv_shrink from vllm.triton_utils.libentry import LibEntry +from vllm.utils import seed_everything from .utils import (generate_data, generate_data_for_expand_nslices, ref_torch_groupgemm) @@ -60,11 +60,8 @@ def test_punica_sgmv( seed: int, device: str, ): - random.seed(seed) torch.set_default_device(device) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) seq_length = 128 ( @@ -153,11 +150,8 @@ def test_punica_bgmv( from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel - random.seed(seed) torch.set_default_device(device) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) seq_length = 1 ( @@ -244,11 +238,9 @@ def test_punica_expand_nslices( ): from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel - random.seed(seed) torch.set_default_device(device) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) + seq_length = 128 if op_type == "sgmv" else 1 ( inputs_tensor, diff --git a/tests/models/decoder_only/language/test_granite.py b/tests/models/decoder_only/language/test_granite.py index 82c753855e714..e5c5ce4a8f745 100644 --- a/tests/models/decoder_only/language/test_granite.py +++ b/tests/models/decoder_only/language/test_granite.py @@ -2,23 +2,18 @@ Run `pytest tests/models/test_granite.py`. """ -import importlib.metadata - import pytest +import transformers from ...utils import check_logprobs_close -TRANSFORMERS_VERSION = tuple( - map(int, - importlib.metadata.version("transformers").split("."))) - MODELS = [ "ibm/PowerLM-3b", ] # GraniteForCausalLM will be in transformers >= 4.45 -@pytest.mark.skipif(TRANSFORMERS_VERSION < (4, 45), +@pytest.mark.skipif(transformers.__version__ < "4.45", reason="granite model test requires transformers >= 4.45") @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index 58864e83173f9..a0c1d7e24c503 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -86,9 +86,7 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool, assert attn._k_scale == 1.0 assert attn._v_scale == 1.0 - capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] - if capability >= 89 and not force_marlin: + if current_platform.has_device_capability(89) and not force_marlin: # For GPUs with hardware support, we keep weights in fp8 assert fc1.weight.dtype == torch.float8_e4m3fn else: diff --git a/tests/quantization/utils.py b/tests/quantization/utils.py index 5fad06878f4a3..061a077592e80 100644 --- a/tests/quantization/utils.py +++ b/tests/quantization/utils.py @@ -8,6 +8,8 @@ def is_quant_method_supported(quant_method: str) -> bool: return False capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] - return (capability >= - QUANTIZATION_METHODS[quant_method].get_min_capability()) + assert capability is not None + + min_capability = QUANTIZATION_METHODS[quant_method].get_min_capability() + + return capability.to_int() >= min_capability diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index f1404b8b6bfe7..6bd276ade1d41 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -13,6 +13,7 @@ from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger +from vllm.platforms import current_platform logger = init_logger(__name__) @@ -299,7 +300,7 @@ def __init__( else: # if not using triton, navi3x/navi21/navi10 do not use flash-attn # either - if torch.cuda.get_device_capability()[0] != 9: + if not current_platform.has_device_capability(90): self.use_naive_attn = True else: try: diff --git a/vllm/attention/ops/blocksparse_attention/interface.py b/vllm/attention/ops/blocksparse_attention/interface.py index e870a8e614d12..1ead541f391b5 100644 --- a/vllm/attention/ops/blocksparse_attention/interface.py +++ b/vllm/attention/ops/blocksparse_attention/interface.py @@ -8,8 +8,7 @@ from .utils import (dense_to_crow_col, get_head_sliding_step, get_sparse_attn_mask) -IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available() - and current_platform.get_device_capability()[0] >= 8) +IS_COMPUTE_8_OR_ABOVE = current_platform.has_device_capability(80) if IS_COMPUTE_8_OR_ABOVE: from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd @@ -36,7 +35,7 @@ def __init__( use_spda = is_hip() or is_cpu() or not \ IS_COMPUTE_8_OR_ABOVE device = device or (torch.cuda.current_device() - if torch.cuda.is_available() else "cpu") + if current_platform.is_cuda_alike() else "cpu") device = torch.device(device) # NOTE: vllm CPU backend support BF16 instead of FP16. dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 558b2f3eeac7e..a2a649c8ebcfd 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -709,8 +709,7 @@ def context_attention_fwd(q, alibi_slopes=None, sliding_window=None): - cap = current_platform.get_device_capability() - BLOCK = 128 if cap[0] >= 8 else 64 + BLOCK = 128 if current_platform.has_device_capability(80) else 64 NUM_WARPS = 8 # need to reduce num. blocks when using fp32 diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 855586d4e5961..fbda263ba8e08 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -203,7 +203,7 @@ def which_attn_to_use( selected_backend = (_Backend.ROCM_FLASH if selected_backend == _Backend.FLASH_ATTN else selected_backend) if selected_backend == _Backend.ROCM_FLASH: - if current_platform.get_device_capability()[0] != 9: + if not current_platform.has_device_capability(90): # not Instinct series GPUs. logger.info("flash_attn is not supported on NAVI GPUs.") else: @@ -212,7 +212,7 @@ def which_attn_to_use( # FlashAttn in NVIDIA GPUs. if selected_backend == _Backend.FLASH_ATTN: - if current_platform.get_device_capability()[0] < 8: + if not current_platform.has_device_capability(80): # Volta and Turing NVIDIA GPUs. logger.info( "Cannot use FlashAttention-2 backend for Volta and Turing " diff --git a/vllm/config.py b/vllm/config.py index 6c24d15640e99..9d42b75c1c462 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -17,7 +17,7 @@ get_hf_image_processor_config, get_hf_text_config) from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, - is_cpu, is_hip, is_neuron, is_openvino, is_xpu, + is_hip, is_neuron, is_openvino, is_xpu, print_warning_once) if TYPE_CHECKING: @@ -1035,20 +1035,20 @@ class DeviceConfig: def __init__(self, device: str = "auto") -> None: if device == "auto": # Automated device type detection - if is_neuron(): + if current_platform.is_cuda_alike(): + self.device_type = "cuda" + elif is_neuron(): self.device_type = "neuron" elif is_openvino(): self.device_type = "openvino" elif current_platform.is_tpu(): self.device_type = "tpu" - elif is_cpu(): + elif current_platform.is_cpu(): self.device_type = "cpu" elif is_xpu(): self.device_type = "xpu" else: - # We don't call torch.cuda.is_available() here to - # avoid initializing CUDA before workers are forked - self.device_type = "cuda" + raise RuntimeError("Failed to infer device type") else: # Device type is assigned explicitly self.device_type = device diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 1c864bcd5d708..df07842edfa56 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -35,6 +35,7 @@ import vllm.envs as envs from vllm.logger import init_logger +from vllm.platforms import current_platform @dataclass @@ -191,7 +192,7 @@ def __init__( assert self.cpu_group is not None assert self.device_group is not None - if torch.cuda.is_available(): + if current_platform.is_cuda_alike(): self.device = torch.device(f"cuda:{local_rank}") else: self.device = torch.device("cpu") diff --git a/vllm/envs.py b/vllm/envs.py index 2003ede95d2d8..6edb06ecd2e20 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -60,6 +60,7 @@ VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000 VLLM_PLUGINS: Optional[List[str]] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None + VLLM_USE_TRITON_AWQ: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index b5b2570966600..ab8207f128348 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -116,10 +116,10 @@ def get_config_filenames(cls) -> List[str]: def _check_scheme_supported(self, min_capability: int, error: bool = True) -> bool: - capability = current_platform.get_device_capability() # type: ignore + capability_tuple = current_platform.get_device_capability() - if capability is not None: - capability = capability[0] * 10 + capability[1] + if capability_tuple is not None: + capability = capability_tuple.to_int() supported = capability >= min_capability if error and not supported: raise RuntimeError( diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 3ccf1af9eb898..eb59344f36d2e 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -32,9 +32,7 @@ def __init__(self, ignore_list: List[str], input_scale_ub: float): # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization - capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] - self.use_marlin = capability < 89 + self.use_marlin = not current_platform.has_device_capability(89) @classmethod def get_name(cls) -> str: diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 32affe06b89b7..b5feb55db0e74 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -120,9 +120,8 @@ def __init__(self, quant_config: Fp8Config): # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization - capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] - self.use_marlin = capability < 89 or envs.VLLM_TEST_FORCE_FP8_MARLIN + self.use_marlin = (not current_platform.has_device_capability(89) + or envs.VLLM_TEST_FORCE_FP8_MARLIN) # Disable marlin for rocm if is_hip(): self.use_marlin = False diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 699d5f1844146..fea94cf7322ad 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -29,8 +29,9 @@ def query_marlin_supported_quant_types(has_zp: bool, device_capability: Optional[int] = None ): if device_capability is None: - major, minor = current_platform.get_device_capability() - device_capability = major * 10 + minor + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) if device_capability < 80: return [] @@ -52,8 +53,9 @@ def _check_marlin_supported( device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]: if device_capability is None: - major, minor = current_platform.get_device_capability() - device_capability = major * 10 + minor + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) supported_types = query_marlin_supported_quant_types( has_zp, device_capability) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index 5f9d8658a342f..8b3dfaae971c3 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -10,8 +10,7 @@ def is_fp8_marlin_supported(): - capability = current_platform.get_device_capability() - return capability[0] >= 8 + return current_platform.has_device_capability(80) def apply_fp8_marlin_linear( diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 887ee6605560c..d86fea63d8a1b 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -17,8 +17,9 @@ def cutlass_fp8_supported() -> bool: # cutlass is not supported on Rocm if is_hip(): return False - capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] + + capability_tuple = current_platform.get_device_capability() + capability = -1 if capability_tuple is None else capability_tuple.to_int() return ops.cutlass_scaled_mm_supports_fp8(capability) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index fd9533ab156a5..f0d2a9e7f06be 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -97,10 +97,10 @@ def _get_quantization_config( """Get the quantization config.""" if model_config.quantization is not None: quant_config = get_quant_config(model_config, load_config) - capability = current_platform.get_device_capability() # type: ignore + capability_tuple = current_platform.get_device_capability() - if capability is not None: - capability = capability[0] * 10 + capability[1] + if capability_tuple is not None: + capability = capability_tuple.to_int() if capability < quant_config.get_min_capability(): raise ValueError( f"The quantization method {model_config.quantization} " diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 179399a12a3d5..a9a0329e99f08 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -207,7 +207,7 @@ def __init__( selected_backend = backend_name_to_enum(backend_by_env_var) if selected_backend is None: # For Volta and Turing GPUs, use xformers instead. - device_available = current_platform.get_device_capability()[0] >= 8 + device_available = current_platform.has_device_capability(80) if device_available: from transformers.utils import is_flash_attn_2_available diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index 336bc1cd005cf..d7eec818cbba4 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -1,17 +1,13 @@ """Utils for model executor.""" -import random from typing import Any, Dict, Optional -import numpy as np import torch +from vllm.utils import seed_everything + def set_random_seed(seed: int) -> None: - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) + seed_everything(seed) def set_weight_attrs( diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 4736e898b6a52..9b348f3e17a5f 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -6,10 +6,10 @@ class CpuPlatform(Platform): _enum = PlatformEnum.CPU - @staticmethod - def get_device_name(device_id: int = 0) -> str: + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: return "cpu" - @staticmethod - def inference_mode(): + @classmethod + def inference_mode(cls): return torch.no_grad() diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 8d18527e7c973..a9978d5d84d7c 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -11,7 +11,7 @@ from vllm.logger import init_logger -from .interface import Platform, PlatformEnum +from .interface import DeviceCapability, Platform, PlatformEnum logger = init_logger(__name__) @@ -96,19 +96,20 @@ def device_id_to_physical_device_id(device_id: int) -> int: class CudaPlatform(Platform): _enum = PlatformEnum.CUDA - @staticmethod - def get_device_capability(device_id: int = 0) -> Tuple[int, int]: + @classmethod + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: physical_device_id = device_id_to_physical_device_id(device_id) - return get_physical_device_capability(physical_device_id) + major, minor = get_physical_device_capability(physical_device_id) + return DeviceCapability(major=major, minor=minor) - @staticmethod - def get_device_name(device_id: int = 0) -> str: + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: physical_device_id = device_id_to_physical_device_id(device_id) return get_physical_device_name(physical_device_id) - @staticmethod + @classmethod @with_nvml_context - def is_full_nvlink(physical_device_ids: List[int]) -> bool: + def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool: """ query if the set of gpus are fully connected by nvlink (1 hop) """ diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 676f4c9fccf5a..360590d7d5eb6 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -1,5 +1,5 @@ import enum -from typing import Optional, Tuple +from typing import NamedTuple, Optional, Tuple, Union import torch @@ -12,6 +12,23 @@ class PlatformEnum(enum.Enum): UNSPECIFIED = enum.auto() +class DeviceCapability(NamedTuple): + major: int + minor: int + + def as_version_str(self) -> str: + return f"{self.major}.{self.minor}" + + def to_int(self) -> int: + """ + Express device capability as an integer ````. + + It is assumed that the minor version is always a single digit. + """ + assert 0 <= self.minor < 10 + return self.major * 10 + self.minor + + class Platform: _enum: PlatformEnum @@ -27,16 +44,47 @@ def is_tpu(self) -> bool: def is_cpu(self) -> bool: return self._enum == PlatformEnum.CPU - @staticmethod - def get_device_capability(device_id: int = 0) -> Optional[Tuple[int, int]]: + def is_cuda_alike(self) -> bool: + """Stateless version of :func:`torch.cuda.is_available`.""" + return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM) + + @classmethod + def get_device_capability( + cls, + device_id: int = 0, + ) -> Optional[DeviceCapability]: + """Stateless version of :func:`torch.cuda.get_device_capability`.""" return None - @staticmethod - def get_device_name(device_id: int = 0) -> str: + @classmethod + def has_device_capability( + cls, + capability: Union[Tuple[int, int], int], + device_id: int = 0, + ) -> bool: + """ + Test whether this platform is compatible with a device capability. + + The ``capability`` argument can either be: + + - A tuple ``(major, minor)``. + - An integer ````. (See :meth:`DeviceCapability.to_int`) + """ + current_capability = cls.get_device_capability(device_id=device_id) + if current_capability is None: + return False + + if isinstance(capability, tuple): + return current_capability >= capability + + return current_capability.to_int() >= capability + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: raise NotImplementedError - @staticmethod - def inference_mode(): + @classmethod + def inference_mode(cls): """A device-specific wrapper of `torch.inference_mode`. This wrapper is recommended because some hardware backends such as TPU diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 28525e8ff8811..b6a19eca01745 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -1,12 +1,11 @@ import os from functools import lru_cache -from typing import Tuple import torch from vllm.logger import init_logger -from .interface import Platform, PlatformEnum +from .interface import DeviceCapability, Platform, PlatformEnum logger = init_logger(__name__) @@ -20,12 +19,13 @@ class RocmPlatform(Platform): _enum = PlatformEnum.ROCM - @staticmethod + @classmethod @lru_cache(maxsize=8) - def get_device_capability(device_id: int = 0) -> Tuple[int, int]: - return torch.cuda.get_device_capability(device_id) + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) - @staticmethod + @classmethod @lru_cache(maxsize=8) - def get_device_name(device_id: int = 0) -> str: + def get_device_name(cls, device_id: int = 0) -> str: return torch.cuda.get_device_name(device_id) diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 393fc230da0b9..b30bccb103af3 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -6,6 +6,10 @@ class TpuPlatform(Platform): _enum = PlatformEnum.TPU - @staticmethod - def inference_mode(): + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + raise NotImplementedError + + @classmethod + def inference_mode(cls): return torch.no_grad() diff --git a/vllm/prompt_adapter/utils.py b/vllm/prompt_adapter/utils.py index 989cc5a0f87c8..4cde2a0254b90 100644 --- a/vllm/prompt_adapter/utils.py +++ b/vllm/prompt_adapter/utils.py @@ -8,13 +8,15 @@ from huggingface_hub.utils import EntryNotFoundError from safetensors.torch import load_file as safe_load_file +from vllm.platforms import current_platform + WEIGHTS_NAME = "adapter_model.bin" SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors" # Get current device name based on available devices def infer_device() -> str: - if torch.cuda.is_available(): + if current_platform.is_cuda_alike(): return "cuda" return "cpu" diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index 515e0a4d8abe7..7fadfd5dfffb4 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -17,6 +17,7 @@ import vllm.envs as envs from vllm.connections import global_http_connection +from vllm.platforms import current_platform from vllm.version import __version__ as VLLM_VERSION _config_home = envs.VLLM_CONFIG_ROOT @@ -151,7 +152,7 @@ def _report_usage_once(self, model_architecture: str, usage_context: UsageContext, extra_kvs: Dict[str, Any]) -> None: # Platform information - if torch.cuda.is_available(): + if current_platform.is_cuda_alike(): device_property = torch.cuda.get_device_properties(0) self.gpu_count = torch.cuda.device_count() self.gpu_type = device_property.name diff --git a/vllm/utils.py b/vllm/utils.py index 29b8a8c2907eb..060b387ec7834 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -5,6 +5,7 @@ import enum import gc import os +import random import socket import subprocess import sys @@ -32,6 +33,7 @@ import vllm.envs as envs from vllm.logger import enable_trace_function_call, init_logger +from vllm.platforms import current_platform logger = init_logger(__name__) @@ -373,6 +375,22 @@ def get_cpu_memory() -> int: return psutil.virtual_memory().total +def seed_everything(seed: int) -> None: + """ + Set the seed of each random module. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + + if current_platform.is_cuda_alike(): + torch.cuda.manual_seed_all(seed) + + if is_xpu(): + torch.xpu.manual_seed_all(seed) + + def random_uuid() -> str: return str(uuid.uuid4().hex) @@ -634,9 +652,7 @@ def create_kv_caches_with_random_flash( seed: int = 0, device: Optional[str] = "cuda", ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) @@ -678,9 +694,7 @@ def create_kv_caches_with_random( f"Does not support key cache of type fp8 with head_size {head_size}" ) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) @@ -750,7 +764,7 @@ def __init__(self, device: Optional[torch.types.Device] = None): def current_memory_usage(self) -> float: # Return the memory usage in bytes. - if torch.cuda.is_available(): + if current_platform.is_cuda_alike(): torch.cuda.reset_peak_memory_stats(self.device) mem = torch.cuda.max_memory_allocated(self.device) elif is_xpu(): diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 52092dc2dc291..3851843afc960 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -454,14 +454,20 @@ def init_worker_distributed_environment( def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): # Check if the GPU supports the dtype. - if torch_dtype == torch.bfloat16: - compute_capability = current_platform.get_device_capability() - if compute_capability[0] < 8: + if torch_dtype == torch.bfloat16: # noqa: SIM102 + if not current_platform.has_device_capability(80): + capability = current_platform.get_device_capability() gpu_name = current_platform.get_device_name() + + if capability is None: + compute_str = "does not have a compute capability" + else: + version_str = capability.as_version_str() + compute_str = f"has compute capability {version_str}" + raise ValueError( "Bfloat16 is only supported on GPUs with compute capability " - f"of at least 8.0. Your {gpu_name} GPU has compute capability " - f"{compute_capability[0]}.{compute_capability[1]}. " + f"of at least 8.0. Your {gpu_name} GPU {compute_str}. " "You can use float16 instead by explicitly setting the" "`dtype` flag in CLI, for example: --dtype=half.") From 9d104b5beb7bbb51c64b680e007f39169489ea86 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Wed, 18 Sep 2024 07:00:56 -0400 Subject: [PATCH 75/98] [CI/Build] Update Ruff version (#8469) Signed-off-by: Aaron Pham Co-authored-by: Cyrus Leung --- .github/workflows/ruff.yml | 4 ++-- benchmarks/kernels/graph_machete_bench.py | 4 +--- format.sh | 4 ++-- pyproject.toml | 2 ++ requirements-lint.txt | 2 +- tests/conftest.py | 5 +---- tests/lora/conftest.py | 5 +---- tests/multimodal/test_base.py | 2 +- tests/test_cache_block_hashing.py | 5 +---- tests/test_logger.py | 4 ++-- tests/worker/test_encoder_decoder_model_runner.py | 4 +--- tests/worker/test_model_runner.py | 4 +--- vllm/adapter_commons/utils.py | 2 +- vllm/attention/backends/utils.py | 6 ++---- vllm/core/block/prefix_caching_block.py | 4 +--- vllm/core/block_manager_v2.py | 4 +--- vllm/engine/async_llm_engine.py | 6 +++--- vllm/engine/llm_engine.py | 6 +++--- .../guided_decoding/outlines_logits_processors.py | 4 ++-- .../layers/quantization/awq_marlin.py | 6 +++--- .../compressed_tensors/compressed_tensors.py | 14 +++++++------- .../layers/quantization/gptq_marlin.py | 8 ++++---- vllm/model_executor/model_loader/tensorizer.py | 4 +--- vllm/model_executor/models/minicpmv.py | 2 +- vllm/spec_decode/draft_model_runner.py | 5 +---- vllm/spec_decode/metrics.py | 7 ++----- vllm/triton_utils/libentry.py | 4 ++-- 27 files changed, 50 insertions(+), 77 deletions(-) diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index 1a794af572fef..90735d6e2bbf9 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -25,10 +25,10 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install ruff==0.1.5 codespell==2.3.0 tomli==2.0.1 isort==5.13.2 + pip install -r requirements-lint.txt - name: Analysing the code with ruff run: | - ruff . + ruff check . - name: Spelling check with codespell run: | codespell --toml pyproject.toml diff --git a/benchmarks/kernels/graph_machete_bench.py b/benchmarks/kernels/graph_machete_bench.py index 1d076ed6d5c18..de608fd05af70 100644 --- a/benchmarks/kernels/graph_machete_bench.py +++ b/benchmarks/kernels/graph_machete_bench.py @@ -45,8 +45,7 @@ rows = int(math.ceil(len(results) / 2)) fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows)) axs = axs.flatten() - axs_idx = 0 - for shape, data in results.items(): + for axs_idx, (shape, data) in enumerate(results.items()): plt.sca(axs[axs_idx]) df = pd.DataFrame(data) sns.lineplot(data=df, @@ -59,6 +58,5 @@ palette="Dark2") plt.title(f"Shape: {shape}") plt.ylabel("time (median, s)") - axs_idx += 1 plt.tight_layout() plt.savefig("graph_machete_bench.pdf") diff --git a/format.sh b/format.sh index 2204b3ba59498..6563d89b192ea 100755 --- a/format.sh +++ b/format.sh @@ -159,7 +159,7 @@ echo 'vLLM codespell: Done' # Lint specified files lint() { - ruff "$@" + ruff check "$@" } # Lint files that differ from main branch. Ignores dirs that are not slated @@ -175,7 +175,7 @@ lint_changed() { if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \ - ruff + ruff check fi } diff --git a/pyproject.toml b/pyproject.toml index 6b682f5d4dd4d..14f0934499c46 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,8 @@ ignore = [ "E731", # Loop control variable not used within loop body "B007", + # f-string format + "UP032", ] [tool.mypy] diff --git a/requirements-lint.txt b/requirements-lint.txt index d0b2fef6deaef..07f738873e1a8 100644 --- a/requirements-lint.txt +++ b/requirements-lint.txt @@ -2,7 +2,7 @@ yapf==0.32.0 toml==0.10.2 tomli==2.0.1 -ruff==0.1.5 +ruff==0.6.5 codespell==2.3.0 isort==5.13.2 clang-format==18.1.5 diff --git a/tests/conftest.py b/tests/conftest.py index e4c7b96e82429..e9c7fc7bf9c67 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -158,10 +158,7 @@ def should_do_global_cleanup_after_test(request) -> bool: to initialize torch. """ - if request.node.get_closest_marker("skip_global_cleanup"): - return False - - return True + return not request.node.get_closest_marker("skip_global_cleanup") @pytest.fixture(autouse=True) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 0bcae5b0c96dc..4834a9d35a3ee 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -65,10 +65,7 @@ def should_do_global_cleanup_after_test(request) -> bool: to initialize torch. """ - if request.node.get_closest_marker("skip_global_cleanup"): - return False - - return True + return not request.node.get_closest_marker("skip_global_cleanup") @pytest.fixture(autouse=True) diff --git a/tests/multimodal/test_base.py b/tests/multimodal/test_base.py index e9562d2048f06..68d05de904ba8 100644 --- a/tests/multimodal/test_base.py +++ b/tests/multimodal/test_base.py @@ -5,7 +5,7 @@ def assert_nested_tensors_equal(expected: NestedTensors, actual: NestedTensors): - assert type(expected) == type(actual) + assert type(expected) == type(actual) # noqa: E721 if isinstance(expected, torch.Tensor): assert torch.equal(expected, actual) else: diff --git a/tests/test_cache_block_hashing.py b/tests/test_cache_block_hashing.py index fe413d1228021..3576a4834ebc3 100644 --- a/tests/test_cache_block_hashing.py +++ b/tests/test_cache_block_hashing.py @@ -66,8 +66,7 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int, hashes.append([]) prompts = [prefix + prompt for prompt in sample_prompts] - seq_id = 0 - for prompt in prompts: + for seq_id, prompt in enumerate(prompts): hashes[-1].append([]) prompt_token_ids = tokenizer.encode(prompt) seq = Sequence(seq_id, @@ -83,8 +82,6 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int, for idx in range(num_blocks): hashes[-1][-1].append(seq.hash_of_block(idx)) - seq_id += 1 - # Check that hashes made with two prefixes with different first blocks are # different everywhere. for hash0, hash1 in zip(flatten_2d(hashes[0]), flatten_2d(hashes[1])): diff --git a/tests/test_logger.py b/tests/test_logger.py index 8f3d218416870..fadf66f2b61d4 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -111,7 +111,7 @@ def test_an_error_is_raised_when_custom_logging_config_file_does_not_exist(): configuration occurs.""" with pytest.raises(RuntimeError) as ex_info: _configure_vllm_root_logger() - assert ex_info.type == RuntimeError + assert ex_info.type == RuntimeError # noqa: E721 assert "File does not exist" in str(ex_info) @@ -152,7 +152,7 @@ def test_an_error_is_raised_when_custom_logging_config_is_unexpected_json( logging_config_file.name): with pytest.raises(ValueError) as ex_info: _configure_vllm_root_logger() - assert ex_info.type == ValueError + assert ex_info.type == ValueError # noqa: E721 assert "Invalid logging config. Expected Dict, got" in str(ex_info) diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index a00d46ddeb007..c0654712b71b5 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -453,8 +453,7 @@ def test_prepare_decode(batch_size): # each sequence) in the decode phase expected_selected_token_indices = [] - selected_token_start_idx = 0 - for seq_len in seq_lens: + for selected_token_start_idx, seq_len in enumerate(seq_lens): # Compute the index offset of the final token in each # sequence's decoded outputs; since a single token is # decoded per iteration per sequence, then the length @@ -463,7 +462,6 @@ def test_prepare_decode(batch_size): # generated tokens is 0 (i.e. the expected sampling index # for a given sequence is just `selected_token_start_idx`) expected_selected_token_indices.append(selected_token_start_idx) - selected_token_start_idx += 1 sampling_metadata = model_input.sampling_metadata actual = sampling_metadata.selected_token_indices diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index a20aa37bcc1e2..42b2337f46914 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -241,10 +241,8 @@ def test_prepare_decode_cuda_graph(batch_size): # Verify Sampling expected_selected_token_indices = [] - selected_token_start_idx = 0 - for _ in context_lens: + for selected_token_start_idx, _ in enumerate(context_lens): expected_selected_token_indices.append(selected_token_start_idx) - selected_token_start_idx += 1 sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, diff --git a/vllm/adapter_commons/utils.py b/vllm/adapter_commons/utils.py index 6c5411f7d3d5c..1e9adca50093b 100644 --- a/vllm/adapter_commons/utils.py +++ b/vllm/adapter_commons/utils.py @@ -42,7 +42,7 @@ def list_adapters(registered_adapters: Dict[int, Any]) -> Dict[int, Any]: def get_adapter(adapter_id: int, registered_adapters: Dict[int, Any]) -> Optional[Any]: - return registered_adapters.get(adapter_id, None) + return registered_adapters.get(adapter_id) ## worker functions diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 089008967a244..49fbb25f4547b 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -33,10 +33,8 @@ def is_block_tables_empty(block_tables: Union[None, Dict]): """ if block_tables is None: return True - if isinstance(block_tables, dict) and all( - value is None for value in block_tables.values()): - return True - return False + return (isinstance(block_tables, dict) + and all(value is None for value in block_tables.values())) def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int, diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index a87e814cfb041..db67c95c32429 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -417,9 +417,7 @@ def get_prefix_cache_hit_rate(self) -> float: def is_block_cached(self, block: Block) -> bool: assert block.content_hash is not None - if block.content_hash in self._cached_blocks: - return True - return False + return block.content_hash in self._cached_blocks def promote_to_immutable_block(self, block: Block) -> BlockId: """Once a mutable block is full, it can be promoted to an immutable diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index b06385b062e83..54818c7e3e9a6 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -399,9 +399,7 @@ def can_swap_out(self, seq_group: SequenceGroup) -> bool: """ alloc_status = self._can_swap(seq_group, Device.CPU, SequenceStatus.RUNNING) - if alloc_status == AllocStatus.OK: - return True - return False + return alloc_status == AllocStatus.OK def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: """Returns the block id mapping (from GPU to CPU) generated by diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 410e6ffaa2d50..82cdd41ad497e 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -826,7 +826,7 @@ async def generate( request_id: The unique id of the request. lora_request: LoRA request to use for generation, if any. trace_headers: OpenTelemetry trace headers. - prompt_adapter_request: Prompt Adapter request to use + prompt_adapter_request: Prompt Adapter request to use for generation, if any. Yields: @@ -1042,7 +1042,7 @@ def remove_logger(self, logger_name: str) -> None: async def start_profile(self) -> None: # using type instead of isinstance to check to avoid capturing # inherited classes - if type(self.engine.model_executor) == GPUExecutorAsync: + if type(self.engine.model_executor) == GPUExecutorAsync: # noqa: E721 self.engine.model_executor.start_profile() else: self.engine.model_executor._run_workers("start_profile") @@ -1050,7 +1050,7 @@ async def start_profile(self) -> None: async def stop_profile(self) -> None: # using type instead of isinstance to check to avoid capturing # inherited classes - if type(self.engine.model_executor) == GPUExecutorAsync: + if type(self.engine.model_executor) == GPUExecutorAsync: # noqa: E721 self.engine.model_executor.stop_profile() else: self.engine.model_executor._run_workers("stop_profile") diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8b5009b2c6668..bdf1af014342a 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -144,7 +144,7 @@ class LLMEngine: decoding. executor_class: The model executor class for managing distributed execution. - prompt_adapter_config (Optional): The configuration related to serving + prompt_adapter_config (Optional): The configuration related to serving prompt adapters. log_stats: Whether to log statistics. usage_context: Specified entry point, used for usage info collection. @@ -1605,7 +1605,7 @@ def check_health(self) -> None: def start_profile(self) -> None: # using type instead of isinstance to check to avoid capturing # inherited classes (MultiprocessingGPUExecutor) - if type(self.model_executor) == GPUExecutor: + if type(self.model_executor) == GPUExecutor: # noqa: E721 self.model_executor.start_profile() else: self.model_executor._run_workers("start_profile") @@ -1613,7 +1613,7 @@ def start_profile(self) -> None: def stop_profile(self) -> None: # using type instead of isinstance to check to avoid capturing # inherited classes (MultiprocessingGPUExecutor) - if type(self.model_executor) == GPUExecutor: + if type(self.model_executor) == GPUExecutor: # noqa: E721 self.model_executor.stop_profile() else: self.model_executor._run_workers("stop_profile") diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index 554dcc0ed43ed..c28bd71c9f682 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -67,9 +67,9 @@ def __call__(self, input_ids: List[int], instruction = self._guide.get_next_instruction( state=self._fsm_state[seq_id]) - if type(instruction) == Generate: + if type(instruction) == Generate: # noqa: E721 allowed_tokens = instruction.tokens - elif type(instruction) == Write: + elif type(instruction) == Write: # noqa: E721 # TODO: support fast forward tokens allowed_tokens = [instruction.tokens[0]] else: diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index eee6a8f7cff49..eed01953fb4af 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -110,9 +110,9 @@ def get_scaled_act_names(self) -> List[str]: def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]): # Extract data from quant config. quant_method = quant_config.get("quant_method", "").lower() - num_bits = quant_config.get("bits", None) - group_size = quant_config.get("group_size", None) - has_zp = quant_config.get("zero_point", None) + num_bits = quant_config.get("bits") + group_size = quant_config.get("group_size") + has_zp = quant_config.get("zero_point") if quant_method != "awq": return False diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index ab8207f128348..e536fae45c845 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, cast import torch from pydantic import BaseModel @@ -79,8 +79,8 @@ def get_quant_method( @classmethod def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": target_scheme_map: Dict[str, Any] = dict() - ignore: List[str] = config.get("ignore", None) - quant_format: str = config.get("format", None) + ignore = cast(List[str], config.get("ignore")) + quant_format = cast(str, config.get("format")) # The quant_config has multiple config_groups, each containing # an input_activations key with details about how the activations are @@ -200,7 +200,7 @@ def _is_fp8_w8a16(self, weight_quant: BaseModel, is_per_tensor_or_channel_weight = (weight_quant.strategy in [ QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL ]) - if not (is_symmetric_weight and is_static_weight + if not (is_symmetric_weight and is_static_weight # noqa: SIM103 and is_per_tensor_or_channel_weight): return False @@ -333,7 +333,7 @@ def create_weights(self, layer: torch.nn.Module, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): """ - Use the CompressedTensorsScheme associated with each layer to create + Use the CompressedTensorsScheme associated with each layer to create the necessary parameters for the layer. See LinearMethodBase for param details """ @@ -352,8 +352,8 @@ def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None): """ - Use the output of create_weights and the CompressedTensorsScheme - associated with the layer to apply the forward pass with the + Use the output of create_weights and the CompressedTensorsScheme + associated with the layer to apply the forward pass with the layer input. See LinearMethodBase for param details """ diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index cc699f5b4554f..5a1b2d701ab0d 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -132,10 +132,10 @@ def get_scaled_act_names(self) -> List[str]: def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]): # Extract data from quant config. quant_method = quant_config.get("quant_method", "").lower() - num_bits = quant_config.get("bits", None) - group_size = quant_config.get("group_size", None) - sym = quant_config.get("sym", None) - desc_act = quant_config.get("desc_act", None) + num_bits = quant_config.get("bits") + group_size = quant_config.get("group_size") + sym = quant_config.get("sym") + desc_act = quant_config.get("desc_act") if quant_method != "gptq": return False diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 3aac5cd2b43a5..36f33d6d139ee 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -408,9 +408,7 @@ def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool: "inferred as vLLM models, so setting vllm_tensorized=True is " "only necessary for models serialized prior to this change.") return True - if (".vllm_tensorized_marker" in deserializer): - return True - return False + return ".vllm_tensorized_marker" in deserializer def serialize_vllm_model( diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index f8be9490ee55d..f0fc950defed7 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -884,7 +884,7 @@ def __new__( version = str(config.version).split(".") version = tuple([int(x) for x in version]) # Dispatch class based on version - instance_class = _SUPPORT_VERSION.get(version, None) + instance_class = _SUPPORT_VERSION.get(version) if instance_class is None: raise ValueError( "Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6") diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 1e403637d2388..cf64af72a14a5 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -183,10 +183,7 @@ def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest): return False # TODO: Add soft-tuning prompt adapter support - if self.prompt_adapter_config: - return False - - return True + return not self.prompt_adapter_config @torch.inference_mode() def execute_model( diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py index ad4e2dc879d7b..89ccaba70e93c 100644 --- a/vllm/spec_decode/metrics.py +++ b/vllm/spec_decode/metrics.py @@ -104,13 +104,10 @@ def _should_collect_rejsample_metrics(self, now: float) -> bool: if self._rank != 0: return False - if (now - self._last_metrics_collect_time < - self._rejsample_metrics_collect_interval_s): - return False - return True + return now - self._last_metrics_collect_time >= self._rejsample_metrics_collect_interval_s # noqa: E501 def _copy_rejsample_metrics_async(self) -> torch.cuda.Event: - """Copy rejection/typical-acceptance sampling metrics + """Copy rejection/typical-acceptance sampling metrics (number of accepted tokens, etc) to CPU asynchronously. Returns a CUDA event recording when the copy is complete. diff --git a/vllm/triton_utils/libentry.py b/vllm/triton_utils/libentry.py index ae00af44a048a..4335c7adfc13b 100644 --- a/vllm/triton_utils/libentry.py +++ b/vllm/triton_utils/libentry.py @@ -35,8 +35,8 @@ def key(self, spec_args, dns_args, const_args): dns_key = [ arg.dtype if hasattr( arg, "data_ptr") else type(arg) if not isinstance(arg, int) - else "i32" if -(2**31) <= arg and arg <= 2**31 - - 1 else "u64" if 2**63 <= arg and arg <= 2**64 - 1 else "i64" + else "i32" if arg >= -(2**31) and arg <= 2**31 - + 1 else "u64" if arg >= 2**63 and arg <= 2**64 - 1 else "i64" for arg in dns_args ] # const args passed by position From 7c7714d856eee6fa94aade729b67f00584f72a4c Mon Sep 17 00:00:00 2001 From: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com> Date: Wed, 18 Sep 2024 09:56:58 -0400 Subject: [PATCH 76/98] [Core][Bugfix][Perf] Introduce `MQLLMEngine` to avoid `asyncio` OH (#8157) Co-authored-by: Nick Hill Co-authored-by: rshaw@neuralmagic.com Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Co-authored-by: Simon Mo --- .buildkite/test-pipeline.yaml | 4 +- docs/source/dev/profiling/profiling_index.rst | 4 +- tests/async_engine/test_openapi_server.py | 106 ---- .../entrypoints/openai/rpc/test_zmq_client.py | 120 ----- tests/entrypoints/openai/test_accuracy.py | 56 +-- .../openai}/test_chat_template.py | 2 +- .../entrypoints/openai/test_mp_api_server.py | 40 -- tests/entrypoints/openai/test_serving_chat.py | 5 +- .../entrypoints/openai/test_serving_engine.py | 4 +- tests/entrypoints/openai/test_shutdown.py | 2 +- .../openai/rpc => mq_llm_engine}/__init__.py | 0 tests/mq_llm_engine/test_abort.py | 67 +++ tests/mq_llm_engine/test_error_handling.py | 244 ++++++++++ tests/mq_llm_engine/test_load.py | 57 +++ tests/mq_llm_engine/utils.py | 78 +++ tests/tpu/test_custom_dispatcher.py | 7 + tests/utils.py | 2 +- vllm/engine/async_llm_engine.py | 9 +- vllm/engine/llm_engine.py | 1 + vllm/engine/multiprocessing/__init__.py | 73 +++ vllm/engine/multiprocessing/client.py | 452 ++++++++++++++++++ vllm/engine/multiprocessing/engine.py | 321 +++++++++++++ vllm/engine/protocol.py | 8 +- vllm/entrypoints/launcher.py | 30 +- vllm/entrypoints/openai/api_server.py | 121 +++-- vllm/entrypoints/openai/rpc/__init__.py | 50 -- vllm/entrypoints/openai/rpc/client.py | 451 ----------------- vllm/entrypoints/openai/rpc/server.py | 243 ---------- vllm/entrypoints/openai/serving_chat.py | 21 +- vllm/entrypoints/openai/serving_completion.py | 21 +- vllm/entrypoints/openai/serving_embedding.py | 11 +- vllm/entrypoints/openai/serving_engine.py | 8 +- .../openai/serving_tokenization.py | 10 +- vllm/envs.py | 6 +- vllm/executor/cpu_executor.py | 1 + vllm/executor/multiproc_worker_utils.py | 4 + 36 files changed, 1467 insertions(+), 1172 deletions(-) delete mode 100644 tests/async_engine/test_openapi_server.py delete mode 100644 tests/entrypoints/openai/rpc/test_zmq_client.py rename tests/{async_engine => entrypoints/openai}/test_chat_template.py (99%) delete mode 100644 tests/entrypoints/openai/test_mp_api_server.py rename tests/{entrypoints/openai/rpc => mq_llm_engine}/__init__.py (100%) create mode 100644 tests/mq_llm_engine/test_abort.py create mode 100644 tests/mq_llm_engine/test_error_handling.py create mode 100644 tests/mq_llm_engine/test_load.py create mode 100644 tests/mq_llm_engine/utils.py create mode 100644 vllm/engine/multiprocessing/__init__.py create mode 100644 vllm/engine/multiprocessing/client.py create mode 100644 vllm/engine/multiprocessing/engine.py delete mode 100644 vllm/entrypoints/openai/rpc/__init__.py delete mode 100644 vllm/entrypoints/openai/rpc/client.py delete mode 100644 vllm/entrypoints/openai/rpc/server.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 63ce9bff7d4c1..37207b677a1ee 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -43,13 +43,15 @@ steps: fast_check: true source_file_dependencies: - vllm/ + - tests/mq_llm_engine - tests/async_engine - tests/test_inputs - tests/multimodal - tests/test_utils - tests/worker commands: - - pytest -v -s async_engine # Async Engine + - pytest -v -s mq_llm_engine # MQLLMEngine + - pytest -v -s async_engine # AsyncLLMEngine - NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py - pytest -v -s test_inputs.py - pytest -v -s multimodal diff --git a/docs/source/dev/profiling/profiling_index.rst b/docs/source/dev/profiling/profiling_index.rst index e22d547293445..9e8b2f1817567 100644 --- a/docs/source/dev/profiling/profiling_index.rst +++ b/docs/source/dev/profiling/profiling_index.rst @@ -21,8 +21,8 @@ Traces can be visualized using https://ui.perfetto.dev/. .. tip:: To stop the profiler - it flushes out all the profile trace files to the directory. This takes time, for example for about 100 requests worth of data for a llama 70b, it takes about 10 minutes to flush out on a H100. - Set the env variable VLLM_RPC_GET_DATA_TIMEOUT_MS to a big number before you start the server. Say something like 30 minutes. - ``export VLLM_RPC_GET_DATA_TIMEOUT_MS=1800000`` + Set the env variable VLLM_RPC_TIMEOUT to a big number before you start the server. Say something like 30 minutes. + ``export VLLM_RPC_TIMEOUT=1800000`` Example commands and usage: =========================== diff --git a/tests/async_engine/test_openapi_server.py b/tests/async_engine/test_openapi_server.py deleted file mode 100644 index 9e5c7c04287eb..0000000000000 --- a/tests/async_engine/test_openapi_server.py +++ /dev/null @@ -1,106 +0,0 @@ -import openai # use the official client for correctness check -import pytest -import pytest_asyncio - -from ..utils import VLLM_PATH, RemoteOpenAIServer - -# any model with a chat template should work here -MODEL_NAME = "facebook/opt-125m" -chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" -assert chatml_jinja_path.exists() - - -@pytest.fixture(scope="module") -def server(): - args = [ - # use half precision for speed and memory savings in CI environment - "--dtype", - "float16", - "--max-model-len", - "2048", - "--enforce-eager", - "--chat-template", - str(chatml_jinja_path), - ] - - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: - yield remote_server - - -@pytest_asyncio.fixture -async def client(server): - async with server.get_async_client() as async_client: - yield async_client - - -@pytest.mark.asyncio -async def test_check_models(client: openai.AsyncOpenAI): - models = await client.models.list() - models = models.data - served_model = models[0] - assert served_model.id == MODEL_NAME - assert all(model.root == MODEL_NAME for model in models) - - -@pytest.mark.asyncio -async def test_single_completion(client: openai.AsyncOpenAI): - completion = await client.completions.create(model=MODEL_NAME, - prompt="Hello, my name is", - max_tokens=5, - temperature=0.0) - - assert completion.id is not None - assert len(completion.choices) == 1 - assert len(completion.choices[0].text) >= 5 - assert completion.choices[0].finish_reason == "length" - assert completion.usage == openai.types.CompletionUsage( - completion_tokens=5, prompt_tokens=6, total_tokens=11) - - # test using token IDs - completion = await client.completions.create( - model=MODEL_NAME, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - ) - assert len(completion.choices[0].text) >= 5 - - -@pytest.mark.asyncio -async def test_single_chat_session(client: openai.AsyncOpenAI): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": "user", - "content": "what is 1+1?" - }] - - # test single completion - chat_completion = await client.chat.completions.create(model=MODEL_NAME, - messages=messages, - max_tokens=10, - logprobs=True, - top_logprobs=5) - assert chat_completion.id is not None - assert len(chat_completion.choices) == 1 - - choice = chat_completion.choices[0] - assert choice.finish_reason == "length" - assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=55, total_tokens=65) - - message = choice.message - assert message.content is not None and len(message.content) >= 10 - assert message.role == "assistant" - messages.append({"role": "assistant", "content": message.content}) - - # test multi-turn dialogue - messages.append({"role": "user", "content": "express your result in json"}) - chat_completion = await client.chat.completions.create( - model=MODEL_NAME, - messages=messages, - max_tokens=10, - ) - message = chat_completion.choices[0].message - assert message.content is not None and len(message.content) >= 0 diff --git a/tests/entrypoints/openai/rpc/test_zmq_client.py b/tests/entrypoints/openai/rpc/test_zmq_client.py deleted file mode 100644 index cafd125c5a598..0000000000000 --- a/tests/entrypoints/openai/rpc/test_zmq_client.py +++ /dev/null @@ -1,120 +0,0 @@ -import asyncio -import tempfile -import unittest -import unittest.mock -import uuid - -import pytest -import pytest_asyncio - -from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.entrypoints.openai.rpc.client import (AsyncEngineRPCClient, - RPCClientClosedError) -from vllm.entrypoints.openai.rpc.server import AsyncEngineRPCServer - - -@pytest.fixture(scope="function") -def tmp_socket(): - with tempfile.TemporaryDirectory() as td: - yield f"ipc://{td}/{uuid.uuid4()}" - - -@pytest_asyncio.fixture(scope="function") -async def dummy_server(tmp_socket, monkeypatch): - dummy_engine = unittest.mock.AsyncMock() - - def dummy_engine_builder(*args, **kwargs): - return dummy_engine - - with monkeypatch.context() as m: - m.setattr(AsyncLLMEngine, "from_engine_args", dummy_engine_builder) - server = AsyncEngineRPCServer(None, None, rpc_path=tmp_socket) - - loop = asyncio.get_running_loop() - server_task = loop.create_task(server.run_server_loop()) - - try: - yield server - finally: - server_task.cancel() - server.cleanup() - - -@pytest_asyncio.fixture(scope="function") -async def client(tmp_socket): - client = AsyncEngineRPCClient(rpc_path=tmp_socket) - # Sanity check: the server is connected - await client._wait_for_server_rpc() - - try: - yield client - finally: - client.close() - - -@pytest.mark.asyncio -async def test_client_data_methods_use_timeouts(monkeypatch, dummy_server, - client: AsyncEngineRPCClient): - with monkeypatch.context() as m: - # Make the server _not_ reply with a model config - m.setattr(dummy_server, "get_config", lambda x: None) - m.setattr(client, "_data_timeout", 10) - - # And ensure the task completes anyway - # (client.setup() invokes server.get_config()) - client_task = asyncio.get_running_loop().create_task(client.setup()) - with pytest.raises(TimeoutError, match="Server didn't reply within"): - await asyncio.wait_for(client_task, timeout=0.05) - - -@pytest.mark.asyncio -async def test_client_aborts_use_timeouts(monkeypatch, dummy_server, - client: AsyncEngineRPCClient): - with monkeypatch.context() as m: - # Hang all abort requests - m.setattr(dummy_server, "abort", lambda x: None) - m.setattr(client, "_data_timeout", 10) - - # The client should suppress timeouts on `abort`s - # and return normally, assuming the server will eventually - # abort the request. - client_task = asyncio.get_running_loop().create_task( - client.abort("test request id")) - await asyncio.wait_for(client_task, timeout=0.05) - - -@pytest.mark.asyncio -async def test_client_data_methods_reraise_exceptions( - monkeypatch, dummy_server, client: AsyncEngineRPCClient): - with monkeypatch.context() as m: - # Make the server raise some random exception - exception = RuntimeError("Client test exception") - - def raiser(): - raise exception - - m.setattr(dummy_server.engine, "get_model_config", raiser) - m.setattr(client, "_data_timeout", 10) - - client_task = asyncio.get_running_loop().create_task(client.setup()) - # And ensure the task completes, raising the exception - with pytest.raises(RuntimeError, match=str(exception)): - await asyncio.wait_for(client_task, timeout=0.05) - - -@pytest.mark.asyncio -async def test_client_errors_after_closing(monkeypatch, dummy_server, - client: AsyncEngineRPCClient): - - client.close() - - # Healthchecks and generate requests will fail with explicit errors - with pytest.raises(RPCClientClosedError): - await client.check_health() - with pytest.raises(RPCClientClosedError): - async for _ in client.generate(None, None, None): - pass - - # But no-ops like aborting will pass - await client.abort("test-request-id") - await client.do_log_stats() diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/test_accuracy.py index b442a903c33ae..2ad8460023c25 100644 --- a/tests/entrypoints/openai/test_accuracy.py +++ b/tests/entrypoints/openai/test_accuracy.py @@ -18,38 +18,32 @@ FILTER = "exact_match,strict-match" RTOL = 0.03 EXPECTED_VALUE = 0.58 +DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"] +MORE_ARGS_LIST = [["--enable-chunked-prefill"], ["--num-scheduler-steps", "8"]] -@pytest.fixture(scope="module") -def server(): - args = [ - "--max-model-len", "4096", "--enable-chunked-prefill", - "--disable-log-requests", "--enforce-eager" - ] - - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: - yield remote_server - - -@pytest.fixture(scope="module") -def server_data(server): - return { - "url": f"{server.url_for('v1')}/completions", - } +@pytest.mark.parametrize("more_args", MORE_ARGS_LIST) +def test_lm_eval_accuracy(more_args): + args = list(DEFAULT_ARGS) + args.extend(more_args) + print(f"Running with: {args}") -def test_lm_eval_accuracy(server_data): - model_args = (f"model={MODEL_NAME}," - f"base_url={server_data['url']}," - f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") - - results = lm_eval.simple_evaluate( - model="local-completions", - model_args=model_args, - tasks=TASK, - ) - - measured_value = results["results"][TASK][FILTER] - assert (measured_value - RTOL < EXPECTED_VALUE - and measured_value + RTOL > EXPECTED_VALUE - ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + url = f"{remote_server.url_for('v1')}/completions" + + model_args = ( + f"model={MODEL_NAME}," + f"base_url={url}," + f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") + + results = lm_eval.simple_evaluate( + model="local-completions", + model_args=model_args, + tasks=TASK, + ) + + measured_value = results["results"][TASK][FILTER] + assert (measured_value - RTOL < EXPECTED_VALUE + and measured_value + RTOL > EXPECTED_VALUE + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" diff --git a/tests/async_engine/test_chat_template.py b/tests/entrypoints/openai/test_chat_template.py similarity index 99% rename from tests/async_engine/test_chat_template.py rename to tests/entrypoints/openai/test_chat_template.py index 61a6d77cd8756..b98ab2e30d78d 100644 --- a/tests/async_engine/test_chat_template.py +++ b/tests/entrypoints/openai/test_chat_template.py @@ -5,7 +5,7 @@ from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.transformers_utils.tokenizer import get_tokenizer -from ..utils import VLLM_PATH +from ...utils import VLLM_PATH chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" assert chatml_jinja_path.exists() diff --git a/tests/entrypoints/openai/test_mp_api_server.py b/tests/entrypoints/openai/test_mp_api_server.py deleted file mode 100644 index fbfe0db19dd03..0000000000000 --- a/tests/entrypoints/openai/test_mp_api_server.py +++ /dev/null @@ -1,40 +0,0 @@ -import time - -import pytest - -from vllm.entrypoints.openai.api_server import build_async_engine_client -from vllm.entrypoints.openai.cli_args import make_arg_parser -from vllm.utils import FlexibleArgumentParser - - -@pytest.mark.asyncio -async def test_mp_crash_detection(): - - parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.") - parser = make_arg_parser(parser) - args = parser.parse_args([]) - # use an invalid tensor_parallel_size to trigger the - # error in the server - args.tensor_parallel_size = 65536 - - start = time.perf_counter() - async with build_async_engine_client(args): - pass - end = time.perf_counter() - - assert end - start < 60, ("Expected vLLM to gracefully shutdown in <60s " - "if there is an error in the startup.") - - -@pytest.mark.asyncio -async def test_mp_cuda_init(): - # it should not crash, when cuda is initialized - # in the API server process - import torch - torch.cuda.init() - parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.") - parser = make_arg_parser(parser) - args = parser.parse_args([]) - - async with build_async_engine_client(args): - pass diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index c3a6c65be1d90..de2a932199a01 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock from vllm.config import MultiModalConfig -from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.transformers_utils.tokenizer import get_tokenizer @@ -52,8 +52,9 @@ def test_async_serving_chat_init(): def test_serving_chat_should_set_correct_max_tokens(): - mock_engine = MagicMock(spec=AsyncLLMEngine) + mock_engine = MagicMock(spec=MQLLMEngineClient) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine.errored = False serving_chat = OpenAIServingChat(mock_engine, MockModelConfig(), diff --git a/tests/entrypoints/openai/test_serving_engine.py b/tests/entrypoints/openai/test_serving_engine.py index 325bc03434287..6d9e620b4af7d 100644 --- a/tests/entrypoints/openai/test_serving_engine.py +++ b/tests/entrypoints/openai/test_serving_engine.py @@ -4,7 +4,7 @@ import pytest from vllm.config import ModelConfig -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.openai.protocol import (ErrorResponse, LoadLoraAdapterRequest, UnloadLoraAdapterRequest) @@ -18,7 +18,7 @@ async def _async_serving_engine_init(): - mock_engine_client = MagicMock(spec=AsyncEngineClient) + mock_engine_client = MagicMock(spec=EngineClient) mock_model_config = MagicMock(spec=ModelConfig) # Set the max_model_len attribute to avoid missing attribute mock_model_config.max_model_len = 2048 diff --git a/tests/entrypoints/openai/test_shutdown.py b/tests/entrypoints/openai/test_shutdown.py index 73ecb74007272..25ab91ef69333 100644 --- a/tests/entrypoints/openai/test_shutdown.py +++ b/tests/entrypoints/openai/test_shutdown.py @@ -44,5 +44,5 @@ async def test_shutdown_on_engine_failure(tmp_path): prompt="Hello, my name is") # Now the server should shut down - return_code = remote_server.proc.wait(timeout=3) + return_code = remote_server.proc.wait(timeout=8) assert return_code is not None diff --git a/tests/entrypoints/openai/rpc/__init__.py b/tests/mq_llm_engine/__init__.py similarity index 100% rename from tests/entrypoints/openai/rpc/__init__.py rename to tests/mq_llm_engine/__init__.py diff --git a/tests/mq_llm_engine/test_abort.py b/tests/mq_llm_engine/test_abort.py new file mode 100644 index 0000000000000..782b508a57149 --- /dev/null +++ b/tests/mq_llm_engine/test_abort.py @@ -0,0 +1,67 @@ +"""Test that aborting is handled properly.""" + +import asyncio +import tempfile +import uuid + +import pytest + +from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate +from vllm.engine.arg_utils import AsyncEngineArgs + +MODEL = "google/gemma-1.1-2b-it" +ENGINE_ARGS = AsyncEngineArgs(model=MODEL) +RAISED_ERROR = KeyError +RAISED_VALUE = "foo" +EXPECTED_TOKENS = 250 + + +@pytest.fixture(scope="function") +def tmp_socket(): + with tempfile.TemporaryDirectory() as td: + yield f"ipc://{td}/{uuid.uuid4()}" + + +@pytest.mark.asyncio +async def test_abort(tmp_socket): + with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, + ipc_path=tmp_socket) as engine: + + client = await engine.make_client() + + request_id_to_be_aborted = "request-aborted" + request_ids_a = [f"request-a-{idx}" for idx in range(10)] + request_ids_b = [f"request-b-{idx}" for idx in range(10)] + + # Requests started before one to be aborted. + tasks = [] + for request_id in request_ids_a: + tasks.append( + asyncio.create_task( + generate(client, request_id, EXPECTED_TOKENS))) + + # Aborted. + task_aborted = asyncio.create_task( + generate(client, request_id_to_be_aborted, EXPECTED_TOKENS)) + + # Requests started after one to be aborted. + for request_id in request_ids_b: + tasks.append( + asyncio.create_task( + generate(client, request_id, EXPECTED_TOKENS))) + + # Actually abort. + await asyncio.sleep(0.5) + await client.abort(request_id_to_be_aborted) + + # Confirm that we got all the EXPECTED tokens from the requests. + for task in tasks: + count, request_id = await task + assert count == EXPECTED_TOKENS, ( + f"{request_id} generated only {count} tokens") + + # Cancel task (this will hang indefinitely if not). + task_aborted.cancel() + + # Shutdown. + client.close() diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py new file mode 100644 index 0000000000000..49cfc5aa04c36 --- /dev/null +++ b/tests/mq_llm_engine/test_error_handling.py @@ -0,0 +1,244 @@ +"""Test that various errors are handled properly.""" + +import asyncio +import tempfile +import time +import uuid +from unittest.mock import Mock + +import pytest + +from tests.mq_llm_engine.utils import RemoteMQLLMEngine +from vllm import SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.llm_engine import LLMEngine +from vllm.engine.multiprocessing import MQEngineDeadError +from vllm.engine.multiprocessing.engine import MQLLMEngine +from vllm.entrypoints.openai.api_server import build_async_engine_client +from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.lora.request import LoRARequest +from vllm.usage.usage_lib import UsageContext +from vllm.utils import FlexibleArgumentParser + +MODEL = "google/gemma-1.1-2b-it" +ENGINE_ARGS = AsyncEngineArgs(model=MODEL) +RAISED_ERROR = KeyError +RAISED_VALUE = "foo" + + +@pytest.fixture(scope="function") +def tmp_socket(): + with tempfile.TemporaryDirectory() as td: + yield f"ipc://{td}/{uuid.uuid4()}" + + +def run_with_evil_forward(engine_args: AsyncEngineArgs, ipc_path: str): + # Make engine. + engine = MQLLMEngine.from_engine_args( + engine_args=engine_args, + usage_context=UsageContext.UNKNOWN_CONTEXT, + ipc_path=ipc_path) + + # Raise error during first forward pass. + engine.engine.model_executor.execute_model = Mock( + side_effect=RAISED_ERROR(RAISED_VALUE)) + + # Run engine. + engine.start() + + +@pytest.mark.asyncio +async def test_evil_forward(tmp_socket): + with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, + ipc_path=tmp_socket, + run_fn=run_with_evil_forward) as engine: + + client = await engine.make_client() + + # Server should be healthy after initial probe. + await asyncio.sleep(2.0) + await client.check_health() + + # Throws an error in first forward pass. + with pytest.raises(RAISED_ERROR): + async for _ in client.generate(inputs="Hello my name is", + sampling_params=SamplingParams(), + request_id=uuid.uuid4()): + pass + assert client.errored + + # Engine is errored, should get ENGINE_DEAD_ERROR. + with pytest.raises(MQEngineDeadError): + async for _ in client.generate(inputs="Hello my name is", + sampling_params=SamplingParams(), + request_id=uuid.uuid4()): + pass + assert client.errored + + await asyncio.sleep(1.0) + with pytest.raises(RAISED_ERROR): + await client.check_health() + assert client.errored + + # Shutdown. + client.close() + + +def run_with_evil_model_executor_health(engine_args: AsyncEngineArgs, + ipc_path: str): + # Make engine. + engine = MQLLMEngine.from_engine_args( + engine_args=engine_args, + usage_context=UsageContext.UNKNOWN_CONTEXT, + ipc_path=ipc_path) + + # Raise error during first forward pass. + engine.engine.model_executor.check_health = Mock(side_effect=RAISED_ERROR) + + # Run engine. + engine.start() + + +@pytest.mark.asyncio +async def test_failed_health_check(tmp_socket): + with RemoteMQLLMEngine( + engine_args=ENGINE_ARGS, + ipc_path=tmp_socket, + run_fn=run_with_evil_model_executor_health) as engine: + + client = await engine.make_client() + assert client.is_running + + # Health probe should throw RAISED_ERROR. + await asyncio.sleep(15.) + + with pytest.raises(RAISED_ERROR): + await client.check_health() + assert client.errored + + # Generate call should throw ENGINE_DEAD_ERROR + with pytest.raises(MQEngineDeadError): + async for _ in client.generate(inputs="Hello my name is", + sampling_params=SamplingParams(), + request_id=uuid.uuid4()): + pass + + client.close() + + +def run_with_evil_abort(engine_args: AsyncEngineArgs, ipc_path: str): + # Make engine. + engine = MQLLMEngine.from_engine_args( + engine_args=engine_args, + usage_context=UsageContext.UNKNOWN_CONTEXT, + ipc_path=ipc_path) + + # Raise error during abort call. + engine.engine.abort_request = Mock(side_effect=RAISED_ERROR) + + # Run engine. + engine.start() + + +@pytest.mark.asyncio +async def test_failed_abort(tmp_socket): + with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, + ipc_path=tmp_socket, + run_fn=run_with_evil_abort) as engine: + + client = await engine.make_client() + assert client.is_running + + # Firsh check health should work. + await client.check_health() + + # Trigger an abort on the client side. + async def bad_abort_after_2s(): + await asyncio.sleep(2.0) + await client.abort(request_id="foo") + + # Trigger an abort in 2s from now. + abort_task = asyncio.create_task(bad_abort_after_2s()) + + # Exception in abort() will happen during this generation. + # This will kill the engine and should return ENGINE_DEAD_ERROR + # with reference to the original KeyError("foo") + with pytest.raises(MQEngineDeadError) as execinfo: + async for _ in client.generate( + inputs="Hello my name is", + sampling_params=SamplingParams(max_tokens=2000), + request_id=uuid.uuid4()): + pass + assert "KeyError" in repr(execinfo.value) + assert client.errored + + await abort_task + + # This should raise the original error. + with pytest.raises(RAISED_ERROR): + await client.check_health() + + client.close() + + +@pytest.mark.asyncio +async def test_bad_request(tmp_socket): + with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, + ipc_path=tmp_socket) as engine: + + client = await engine.make_client() + + # Invalid request should fail, but not crash the server. + with pytest.raises(ValueError): + async for _ in client.generate(inputs="Hello my name is", + sampling_params=SamplingParams(), + request_id="abcd-1", + lora_request=LoRARequest( + "invalid-lora", 1, + "invalid-path")): + pass + + # This request should be okay. + async for _ in client.generate(inputs="Hello my name is", + sampling_params=SamplingParams(), + request_id="abcd-2"): + pass + + # Shutdown. + client.close() + + +@pytest.mark.asyncio +async def test_mp_crash_detection(monkeypatch): + + parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.") + parser = make_arg_parser(parser) + args = parser.parse_args([]) + + # When LLMEngine is loaded, it will crash. + def mock_init(): + raise ValueError + + monkeypatch.setattr(LLMEngine, "__init__", mock_init) + + start = time.perf_counter() + async with build_async_engine_client(args): + pass + end = time.perf_counter() + + assert end - start < 60, ("Expected vLLM to gracefully shutdown in <60s " + "if there is an error in the startup.") + + +@pytest.mark.asyncio +async def test_mp_cuda_init(): + # it should not crash, when cuda is initialized + # in the API server process + import torch + torch.cuda.init() + parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.") + parser = make_arg_parser(parser) + args = parser.parse_args([]) + + async with build_async_engine_client(args): + pass diff --git a/tests/mq_llm_engine/test_load.py b/tests/mq_llm_engine/test_load.py new file mode 100644 index 0000000000000..630c112d0f0c9 --- /dev/null +++ b/tests/mq_llm_engine/test_load.py @@ -0,0 +1,57 @@ +"""Test that the MQLLMEngine is able to handle 10k concurrent requests.""" + +import asyncio +import tempfile +import uuid + +import pytest + +from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate +from vllm.engine.arg_utils import AsyncEngineArgs + +MODEL = "google/gemma-1.1-2b-it" +NUM_EXPECTED_TOKENS = 10 +NUM_REQUESTS = 10000 + +# Scenarios to test for num generated token. +ENGINE_ARGS = AsyncEngineArgs(model=MODEL, disable_log_requests=True) + + +@pytest.fixture(scope="function") +def tmp_socket(): + with tempfile.TemporaryDirectory() as td: + yield f"ipc://{td}/{uuid.uuid4()}" + + +@pytest.mark.asyncio +async def test_load(tmp_socket): + with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, + ipc_path=tmp_socket) as engine: + + client = await engine.make_client() + + request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)] + + # Create concurrent requests. + tasks = [] + for request_id in request_ids: + tasks.append( + asyncio.create_task( + generate(client, request_id, NUM_EXPECTED_TOKENS))) + + # Confirm that we got all the EXPECTED tokens from the requests. + failed_request_id = None + tokens = None + for task in tasks: + num_generated_tokens, request_id = await task + if (num_generated_tokens != NUM_EXPECTED_TOKENS + and failed_request_id is None): + failed_request_id = request_id + tokens = num_generated_tokens + + assert failed_request_id is None, ( + f"{failed_request_id} generated {tokens} but " + f"expected {NUM_EXPECTED_TOKENS}") + + # Shutdown. + client.close() diff --git a/tests/mq_llm_engine/utils.py b/tests/mq_llm_engine/utils.py new file mode 100644 index 0000000000000..e27fd77923412 --- /dev/null +++ b/tests/mq_llm_engine/utils.py @@ -0,0 +1,78 @@ +import asyncio +import multiprocessing +from typing import Callable, Tuple, Union + +from vllm import SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.multiprocessing.client import MQLLMEngineClient +from vllm.engine.multiprocessing.engine import MQLLMEngine +from vllm.outputs import RequestOutput +from vllm.usage.usage_lib import UsageContext + + +async def generate( + client: MQLLMEngineClient, + request_id: str, + num_tokens: int, + return_output: bool = False) -> Union[RequestOutput, Tuple[int, str]]: + + final_output = None + count = 0 + async for out in client.generate( + request_id=request_id, + inputs="Hello my name is Robert and", + sampling_params=SamplingParams(max_tokens=num_tokens, + temperature=0)): + + count += 1 + final_output = out + await asyncio.sleep(0.) + + if return_output: + return final_output + + # Confirm we generated all the tokens we expected. + return count, request_id + + +def run_normal(engine_args: AsyncEngineArgs, ipc_path: str): + # Make engine. + engine = MQLLMEngine.from_engine_args( + engine_args=engine_args, + usage_context=UsageContext.UNKNOWN_CONTEXT, + ipc_path=ipc_path) + + # Run engine. + engine.start() + + +class RemoteMQLLMEngine: + + def __init__(self, + engine_args: AsyncEngineArgs, + ipc_path: str, + run_fn: Callable = run_normal) -> None: + + self.engine_args = engine_args + self.ipc_path = ipc_path + context = multiprocessing.get_context("spawn") + self.proc = context.Process(target=run_fn, + args=(engine_args, ipc_path)) + self.proc.start() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.proc.kill() + + async def make_client(self) -> MQLLMEngineClient: + engine_config = self.engine_args.create_engine_config() + client = MQLLMEngineClient(self.ipc_path, engine_config) + while True: + try: + await client.setup() + break + except TimeoutError: + assert self.proc.is_alive() + return client diff --git a/tests/tpu/test_custom_dispatcher.py b/tests/tpu/test_custom_dispatcher.py index 7f3fb595321ad..69ab67abdd12b 100644 --- a/tests/tpu/test_custom_dispatcher.py +++ b/tests/tpu/test_custom_dispatcher.py @@ -1,5 +1,12 @@ +import os + from ..utils import compare_two_settings +# --enforce-eager on TPU causes graph compilation +# this times out default Health Check in the MQLLMEngine, +# so we set the timeout here to 30s +os.environ["VLLM_RPC_TIMEOUT"] = "30000" + def test_custom_dispatcher(): compare_two_settings("google/gemma-2b", diff --git a/tests/utils.py b/tests/utils.py index f6c2be17ebdcf..81442cad78da2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -119,7 +119,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): self.proc.terminate() try: - self.proc.wait(3) + self.proc.wait(8) except subprocess.TimeoutExpired: # force kill if needed self.proc.kill() diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 82cdd41ad497e..34e7e05341f02 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -601,9 +601,12 @@ def errored(self) -> bool: return self._errored_with is not None @property - def limit_concurrency(self) -> Optional[int]: - """Maximum number of concurrently running requests.""" - return None + def dead_error(self) -> BaseException: + return AsyncEngineDeadError( + "Background loop is not running. If it was running, " + "inspect the output to find the stacktrace of the " + "error that caused the background loop to stop " + "(AsyncEngineDeadError).") def set_errored(self, exc: Exception) -> None: self._errored_with = exc diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index bdf1af014342a..2743d5c7d2282 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1289,6 +1289,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: # torch.distributed ops which may otherwise timeout, and unblocks # the RPC thread in the workers so that they can process any other # queued control plane messages, such as add/remove lora adapters. + logger.debug("Stopping remote worker execution loop.") self.model_executor.stop_remote_worker_execution_loop() return ctx.request_outputs diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py new file mode 100644 index 0000000000000..ba5c6e15fc821 --- /dev/null +++ b/vllm/engine/multiprocessing/__init__.py @@ -0,0 +1,73 @@ +from dataclasses import dataclass +from enum import Enum +from typing import List, Mapping, Optional, Union + +from vllm.inputs import PromptInputs +from vllm.lora.request import LoRARequest +from vllm.outputs import RequestOutput +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams + +VLLM_RPC_SUCCESS_STR = "SUCCESS" + +IPC_INPUT_EXT = "_input_socket" +IPC_OUTPUT_EXT = "_output_socket" +IPC_HEALTH_EXT = "_health_socket" +IPC_DATA_EXT = "_data_socket" + + +class MQEngineDeadError(RuntimeError): + pass + + +@dataclass +class RPCGenerateRequest: + inputs: PromptInputs + sampling_params: SamplingParams + request_id: str + lora_request: Optional[LoRARequest] = None + trace_headers: Optional[Mapping[str, str]] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None + + +@dataclass +class RPCError: + request_id: Optional[str] + is_engine_errored: bool + exception: BaseException + + +@dataclass +class RPCAbortRequest: + request_id: str + + +class RPCHealthRequest: + pass + + +class RPCStartupRequest(Enum): + IS_SERVER_READY = 1 + + +@dataclass +class RPCStartupResponse: + tracing_enabled: bool + + +RPC_REQUEST_T = Union[RPCGenerateRequest, RPCAbortRequest, RPCHealthRequest, + RPCStartupRequest] + +REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError] + + +def ENGINE_DEAD_ERROR( + error: Optional[BaseException] = None) -> MQEngineDeadError: + if error is None: + return MQEngineDeadError( + "Engine loop is not running. Inspect the stacktrace to " + "find the original error") + + return MQEngineDeadError( + "Engine loop is not running. Inspect the stacktrace to " + f"find the original error: {repr(error)}.") diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py new file mode 100644 index 0000000000000..18b620c74ddf9 --- /dev/null +++ b/vllm/engine/multiprocessing/client.py @@ -0,0 +1,452 @@ +import asyncio +import copy +import pickle +from contextlib import contextmanager, suppress +from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional, + Union) + +import cloudpickle +import zmq +import zmq.asyncio +from zmq import Frame # type: ignore[attr-defined] +from zmq.asyncio import Socket + +from vllm.config import DecodingConfig, EngineConfig, ModelConfig +from vllm.engine.arg_utils import AsyncEngineArgs +# yapf conflicts with isort for this block +# yapf: disable +from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, + IPC_HEALTH_EXT, IPC_INPUT_EXT, + IPC_OUTPUT_EXT, RPC_REQUEST_T, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCError, RPCGenerateRequest, + RPCHealthRequest, RPCStartupRequest, + RPCStartupResponse) +# yapf: enable +from vllm.envs import VLLM_RPC_TIMEOUT +from vllm.inputs import PromptInputs +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs + +logger = init_logger(__name__) + + +class MQClientClosedError(Exception): + """Exception class raised when the client is used post-close. + + The client can be closed, which closes the ZMQ context. This normally + happens on server shutdown. In some cases, methods like abort and + do_log_stats will still be called and then try to open a socket, which + causes a ZMQError and creates a huge stack trace. + So, we throw this error such that we can suppress it. + """ + + +class MQLLMEngineClient: + """A client wrapper for MQLLMEngine that conforms to the + EngineClient protocol. + + MQLLMEngine and MQLLMEngineClient are intended to run in separate + processes communicating via zeromq ipc sockets. + + The entrypoint to MQLLMEngineClient is through the generate() + method. On generate() MQLLMEngine does three things: + - Creates an asyncio output queue + - Sends a RPCGenerateRequest to the MQLLMEngine via zmq + - Pulls RequestOutputs from its queue and yields them + + MQLLMEngine runs two background loops: + - output_loop: the output loop pulls List[RequestOutput] + from the MQLLMEngine via zmq (each list is the output + of one engine_step in the LLMEngine). It then parses + the list and pushes individual request_outputs into + the corresponding output_queue such that they can be + consumed by the .generate() method. + - health_loop: the health loop queries the health socket + every N seconds, confirming the engine is healthy + """ + + def __init__(self, ipc_path: str, engine_config: EngineConfig): + self.context = zmq.asyncio.Context() + self._errored_with: Optional[BaseException] = None + + # Get the configs. + self.model_config = engine_config.model_config + self.decoding_config = engine_config.decoding_config + + # Create the tokenizer group. + self.tokenizer = init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=engine_config.scheduler_config, + parallel_config=engine_config.parallel_config, + enable_lora=bool(engine_config.lora_config), + ) + + # Send RPCGenerateRequest to the MQLLMEngine. + self.input_socket: Socket = self.context.socket(zmq.constants.PUSH) + self.input_socket.connect(f"{ipc_path}{IPC_INPUT_EXT}") + + # Receive streams of RequestOutput from the MQLLMEngine. + self.output_socket: Socket = self.context.socket(zmq.constants.PULL) + self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}") + + # IPC path for ack of check_health requests. + self.health_socket: Socket = self.context.socket(zmq.constants.PULL) + self.health_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") + + # IPC path for the data socket. + self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" + + # Stream for each individual request. + self.output_queues: Dict[str, asyncio.Queue] = {} + self.output_loop = asyncio.create_task(self.run_output_handler_loop()) + + # Loop to check health of the LLMEngine periodically. + # Started after the MQLLMEngine is ready. + self.health_loop: Optional[asyncio.Task] = None + + @staticmethod + def is_unsupported_config(engine_args: AsyncEngineArgs): + if engine_args.pipeline_parallel_size > 1: + return True + + is_embedding = ModelConfig( + model=engine_args.model, + revision=engine_args.revision, + tokenizer=engine_args.model, + tokenizer_mode="auto", + trust_remote_code=engine_args.trust_remote_code, + quantization=engine_args.quantization, + seed=0, + dtype="auto").embedding_mode + + return is_embedding + + @contextmanager + def get_data_socket(self) -> Iterator[Socket]: + socket = self.context.socket(zmq.constants.DEALER) + try: + socket.connect(self.data_ipc_path) + yield socket + finally: + socket.close(linger=0) + + async def run_check_health_loop(self, timeout: int): + """Background loop that continually probes the RPCServer for health. + + The loop sends CHECK_HEALTH requests to the INPUT_SOCKET, which + the MQLLMEngine server is blocking on. + + The Server replies on the HEALTH_SOCKET (rather than on the + OUTPUT_SOCKET such that the messages are not intermingled with + output streaming). + """ + + try: + while True: + if await self.health_socket.poll(timeout=timeout) == 0: + # Wakeup every N seconds and do a health probe. + await self._send_one_way_rpc_request( + RPCHealthRequest(), self.input_socket) + + # Wait for ack from the health socket. + await self._await_ack(error_message="Health check failed.", + socket=self.health_socket) + else: + # Server sent a health status message unprompted. + await self._check_success( + error_message="Health check failed.", + socket=self.health_socket) + + logger.debug("Health probe successful.") + + except asyncio.CancelledError: + logger.debug("Shutting down MQLLMEngineClient check health loop.") + + except Exception as e: + self._set_errored(e) + + async def run_output_handler_loop(self): + """Get RequestOutputs from Engine and stream to Request Queues""" + + try: + while True: + # Poll, checking for ENGINE_DEAD + while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT + ) == 0: + logger.debug("Waiting for output from MQLLMEngine.") + + # If errored, alert all running requests. + if self.errored: + for queue_j in tuple(self.output_queues.values()): + queue_j.put_nowait( + ENGINE_DEAD_ERROR(self._errored_with)) + return + + message: Frame = await self.output_socket.recv(copy=False) + request_outputs = pickle.loads(message.buffer) + + is_error = isinstance(request_outputs, + (BaseException, RPCError)) + if is_error: + if isinstance(request_outputs, RPCError): + rpc_error: RPCError = request_outputs + request_id = rpc_error.request_id + exception = rpc_error.exception + is_engine_errored = rpc_error.is_engine_errored + else: + # MPLLMEngine should always return an RPCError to + # the output_socket when an issue arises. + # If we are here, we are in a bad state and + # should shut down the server. + error: BaseException = request_outputs + logger.error( + "Received Exception %s rather than RPCError from " + "MPLLMEngine. This should never happen.", error) + request_id = None + exception = error + is_engine_errored = True + + # Set to error state only on engine critical error + # (and record only the first one) + if is_engine_errored and not self._errored_with: + self._errored_with = exception + + if request_id is None: + for queue_i in tuple(self.output_queues.values()): + queue_i.put_nowait(exception) + else: + queue = self.output_queues.get(request_id) + if queue is not None: + queue.put_nowait(exception) + else: + # Put each output into the appropriate steam. + for request_output in request_outputs: + queue = self.output_queues.get( + request_output.request_id) + if queue is not None: + queue.put_nowait(request_output) + + except asyncio.CancelledError: + logger.debug("Shutting down MQLLMEngineClient output handler.") + + async def setup(self): + """Setup the client before it starts sending server requests.""" + + with self.get_data_socket() as socket: + # Wait until server is ready. + response = await self._wait_for_server_rpc(socket) + + self.tracing_flag = response.tracing_enabled + + # Start health_loop. + self.health_loop = asyncio.create_task( + self.run_check_health_loop(timeout=VLLM_RPC_TIMEOUT)) + + def close(self): + """Destroy the ZeroMQ Context.""" + # Close all sockets and terminate the context. + self.context.destroy(linger=0) + + # Cancel background tasks. + if self.health_loop is not None: + self.health_loop.cancel() + self.output_loop.cancel() + + def _set_errored(self, e: BaseException): + logger.exception(repr(e)) + if self._errored_with is None: + self._errored_with = e + + @staticmethod + async def _send_get_data_rpc_request(request: RPCStartupRequest, + expected_type: Any, + error_message: str, + socket: Socket) -> Any: + """Send an RPC request that is expecting data back.""" + + # Ping RPCServer with a request. + await socket.send_multipart((pickle.dumps(request), ), copy=False) + + # Make sure the server responds in time. + if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: + raise TimeoutError("RPCServer didn't reply within " + f"{VLLM_RPC_TIMEOUT} ms") + + # Await the data from the Server. + frame = await socket.recv(copy=False) + data = pickle.loads(frame.buffer) + + if isinstance(data, BaseException): + raise data + elif not isinstance(data, expected_type): + raise ValueError(error_message) + + return data + + @staticmethod + async def _send_one_way_rpc_request(request: RPC_REQUEST_T, + socket: Socket): + """Send one-way RPC request to trigger an action.""" + + if socket.closed: + raise MQClientClosedError() + + await socket.send_multipart((pickle.dumps(request), )) + + async def _await_ack(self, error_message: str, socket: Socket): + """Await acknowledgement that a request succeeded.""" + + if socket.closed: + raise MQClientClosedError() + + if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: + raise TimeoutError("MQLLMEngine didn't reply within " + f"{VLLM_RPC_TIMEOUT}ms") + + await self._check_success(error_message, socket) + + @staticmethod + async def _check_success(error_message: str, socket: Socket): + """Confirm that socket has a VLLM_RPC_SUCCESS_STR message""" + + if socket.closed: + raise MQClientClosedError() + + frame = await socket.recv(copy=False) + response = pickle.loads(frame.buffer) + + # Raise error if unsuccessful + if isinstance(response, BaseException): + raise response + elif (not isinstance(response, str) + or response != VLLM_RPC_SUCCESS_STR): + raise ValueError(error_message) + + async def get_tokenizer(self, lora_request: LoRARequest): + return await self.tokenizer.get_lora_tokenizer_async(lora_request) + + async def get_decoding_config(self) -> DecodingConfig: + return self.decoding_config + + async def get_model_config(self) -> ModelConfig: + return self.model_config + + async def is_tracing_enabled(self) -> bool: + return self.tracing_flag + + async def _wait_for_server_rpc(self, socket: Socket) -> RPCStartupResponse: + """Wait for the RPCServer to start up.""" + + return await self._send_get_data_rpc_request( + request=RPCStartupRequest.IS_SERVER_READY, + expected_type=RPCStartupResponse, + error_message="Unable to start RPC Server", + socket=socket) + + async def abort(self, request_id: str): + """Send an ABORT_REQUEST signal to the RPC Server""" + + with suppress(MQClientClosedError): + await self._send_one_way_rpc_request( + request=RPCAbortRequest(request_id), socket=self.input_socket) + + async def do_log_stats(self): + """Ignore do_log_stats (handled on MQLLMEngine polling)""" + pass + + async def check_health(self): + """ + The check health loop probes the health status of the + Engine's health every N seconds and sets _errored_with + if the engine is unhealthy. + """ + if self._errored_with is not None: + raise self._errored_with + + @property + def is_running(self) -> bool: + return not self.errored + + @property + def is_stopped(self) -> bool: + return self.errored + + @property + def errored(self) -> bool: + return self._errored_with is not None + + async def generate( + self, + inputs: PromptInputs, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None + ) -> AsyncGenerator[RequestOutput, None]: + """Send an RPCGenerateRequest to the RPCServer and stream responses.""" + + # If already dead, error out. + if self._errored_with is not None: + raise ENGINE_DEAD_ERROR(self._errored_with) + + # 1) Create output queue for this requests. + queue: asyncio.Queue[Union[RequestOutput, + BaseException]] = asyncio.Queue() + self.output_queues[request_id] = queue + + try: + # 2) Detach logits processors so that they can be pickled + # separately (may require cloudpickle which is slower) + if sampling_params.logits_processors: + # Defensive shallow copy + sampling_params = copy.copy(sampling_params) + logits_processors = sampling_params.logits_processors + sampling_params.logits_processors = None + lp_bytes = cloudpickle.dumps(logits_processors) + else: + lp_bytes = None + + request_bytes = pickle.dumps( + RPCGenerateRequest( + inputs=inputs, + sampling_params=sampling_params, + request_id=request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request)) + + # 3) Send the RPCGenerateRequest to the MQLLMEngine. + parts = (request_bytes, + lp_bytes) if lp_bytes else (request_bytes, ) + await self.input_socket.send_multipart(parts, copy=False) + + # 4) Stream the RequestOutputs from the output queue. Note + # that the output_loop pushes RequestOutput objects to this + # queue after pulling them from the zmq socket. + finished = False + try: + while not finished: + request_output = await queue.get() + + if isinstance(request_output, BaseException): + raise request_output + + finished = request_output.finished + yield request_output + finally: + # Request was canceled by the client. + if not finished and not self.errored: + await self.abort(request_id) + finally: + self.output_queues.pop(request_id) + + async def encode(self, *args, + **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: + raise NotImplementedError( + "Embeddings not supported with multiprocessing backend") diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py new file mode 100644 index 0000000000000..70cd6e5cb6000 --- /dev/null +++ b/vllm/engine/multiprocessing/engine.py @@ -0,0 +1,321 @@ +import pickle +import signal +from contextlib import contextmanager +from typing import Iterator, List, Optional, Union + +import cloudpickle +import zmq + +from vllm import AsyncEngineArgs, LLMEngine +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) +# yapf conflicts with isort for this block +# yapf: disable +from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, + IPC_HEALTH_EXT, IPC_INPUT_EXT, + IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCError, RPCGenerateRequest, + RPCHealthRequest, RPCStartupRequest, + RPCStartupResponse) +# yapf: enable +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.usage.usage_lib import UsageContext + +CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig, + SchedulerConfig, LoRAConfig] + +logger = init_logger(__name__) + +POLLING_TIMEOUT_MS = 10000 +HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), ) + + +class MQLLMEngine: + """A multiprocessing wrapper for :class:`LLMEngine`. + + This class is used to wrap the :class:`LLMEngine` class to enable use + in concurrnet manner. It runs a background loop and uses zeromq to + receive new requests and stream outputs incrementally via ipc. + + The :class:`LLMEngine.generate` is kicked off when a new + RPCGenerateRequest is received by the input_socket. + + The self.engine_loop checks the input_socket for new requests, + adds them to the LLMEngine if there are any, calls the internal + :class:`LLMEngine.step()`, and sends the RequestOutputs back over + the output_socket. + + If use_async_sockets is set, the logic associated with reading new + requests from the socket and sending data to the socket is passed + as a callback to the llm_engine, which calls the logic asynchronously + such that the IPC can be overlapped with the GPU. + + Args: + ipc_path: Base path for zeromq interprocess messaging + use_async_sockets: Whether to make send/recv async with GPU + log_requests: Whether to log the requests. + *args: Arguments for :class:`LLMEngine`. + **kwargs: Arguments for :class:`LLMEngine`. + """ + + def __init__(self, + ipc_path: str, + use_async_sockets: bool, + *args, + log_requests: bool = True, + **kwargs) -> None: + self.engine = LLMEngine(*args, **kwargs) + self.log_requests = log_requests + + self.use_async_sockets = use_async_sockets + if self.use_async_sockets: + self.engine.process_request_outputs_callback = \ + self._async_socket_engine_callback + + self.ctx = zmq.Context() # type: ignore[attr-defined] + + # Receive input from the client. + self.input_socket = self.ctx.socket(zmq.constants.PULL) + self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}") + + # Send output stream back to client. + self.output_socket = self.ctx.socket(zmq.constants.PUSH) + self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}") + + # Send health status back to client. + self.health_socket = self.ctx.socket(zmq.constants.PUSH) + self.health_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") + + # IPC path for the data socket. + self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" + + # Error state. + self._errored_with: Optional[BaseException] = None + + @property + def dead_error(self) -> BaseException: + if self._errored_with is not None: + return ENGINE_DEAD_ERROR(self._errored_with) + else: + return ENGINE_DEAD_ERROR() + + @classmethod + def from_engine_args(cls, engine_args: AsyncEngineArgs, + usage_context: UsageContext, ipc_path: str): + """Creates an MQLLMEngine from the engine arguments.""" + + engine_config = engine_args.create_engine_config() + + executor_class = LLMEngine._get_executor_cls(engine_config) + + return cls( + ipc_path=ipc_path, + use_async_sockets=engine_config.model_config.use_async_output_proc, + **engine_config.to_dict(), + executor_class=executor_class, + log_requests=not engine_args.disable_log_requests, + log_stats=not engine_args.disable_log_stats, + usage_context=usage_context) + + def start(self): + try: + try: + logger.debug("Starting Startup Loop.") + self.run_startup_loop() + logger.debug("Starting Engine Loop.") + self.run_engine_loop() + except Exception as e: + logger.exception(repr(e)) + except KeyboardInterrupt: + logger.debug("Shutting down MQLLMEngine.") + finally: + logger.debug("MQLLMEngine is shut down.") + self.cleanup() + + def cleanup(self): + """Cleanup zeromq state on shutdown.""" + # Closes all sockets and destroys context. + self.ctx.destroy(linger=0) + del self.engine + + @contextmanager + def make_data_socket( + self) -> Iterator[zmq.Socket]: # type: ignore[name-defined] + socket = self.ctx.socket(zmq.constants.ROUTER) + try: + socket.bind(self.data_ipc_path) + yield socket + finally: + socket.close(linger=0) + + def run_startup_loop(self) -> None: + """Startup loop for sending data from Engine -> Client.""" + + with self.make_data_socket() as socket: + response: Union[RPCStartupResponse, BaseException] + try: + identity, message = socket.recv_multipart(copy=False) + request: RPCStartupRequest = pickle.loads(message.buffer) + + # Handle the query from the Client. + if request == RPCStartupRequest.IS_SERVER_READY: + tracing_enabled = self.engine.is_tracing_enabled() + response = RPCStartupResponse( + tracing_enabled=tracing_enabled) + + except Exception as e: + response = e + + socket.send_multipart((identity, pickle.dumps(response)), + copy=False) + + def run_engine_loop(self): + """Core busy loop of the LLMEngine.""" + + while True: + if not self.engine.has_unfinished_requests(): + # Poll until there is work to do. + while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: + self.engine.do_log_stats() + logger.debug("Waiting for new requests in engine loop.") + + # Handle any input from the client. + self.handle_new_input() + + # Engine step. + request_outputs = self.engine_step() + + # Send request outputs (if async, done in engine_step callback). + if not self.use_async_sockets: + self._send_outputs(request_outputs) + + def engine_step(self) -> List[RequestOutput]: + """Engine step wrapper with error handling.""" + + try: + return self.engine.step() + except SystemExit: + raise + except BaseException as e: + self._set_errored(e) + rpc_err = RPCError(request_id=None, + is_engine_errored=True, + exception=e) + self._send_outputs(rpc_err) + raise e + + def handle_new_input(self): + """Handle new input from the socket""" + try: + while self.input_socket.poll(timeout=0) != 0: + frames = self.input_socket.recv_multipart(copy=False) + request = pickle.loads(frames[0].buffer) + + if isinstance(request, RPCGenerateRequest): + if len(frames) > 1: + # Use cloudpickle for logits processors + lprocs = cloudpickle.loads(frames[1].buffer) + request.sampling_params.logits_processors = lprocs + self._handle_generate_request(request) + elif isinstance(request, RPCAbortRequest): + self._handle_abort_request(request) + elif isinstance(request, RPCHealthRequest): + self._handle_health_request() + else: + raise ValueError("Unknown RPCRequest Type: {request}") + + except Exception as e: + self._set_errored(e) + self._send_unhealthy(e) + raise e + + def _handle_generate_request(self, request: RPCGenerateRequest): + """Handle RPCGenerateRequest by adding it to the LLMEngine.""" + request_id = request.request_id + + if self._errored_with is not None: + rpc_err = RPCError(request_id=request_id, + is_engine_errored=True, + exception=ENGINE_DEAD_ERROR(self._errored_with)) + self._send_outputs(rpc_err) + + try: + self.engine.add_request( + request_id=request_id, + inputs=request.inputs, + params=request.sampling_params, + lora_request=request.lora_request, + trace_headers=request.trace_headers, + prompt_adapter_request=request.prompt_adapter_request) + + if self.log_requests: + logger.info("Added request %s.", request.request_id) + + except Exception as e: + # We do not set self._errored = True here, since the error + # is due to an issue adding this request to the engine, + # rather than an issue with the engine itself. + is_errored = self._errored_with is not None + rpc_err = RPCError(request_id=request_id, + is_engine_errored=is_errored, + exception=e) + self._send_outputs(rpc_err) + + # Remove request from the engine. + self.engine.abort_request(request_id) + + def _handle_abort_request(self, request: RPCAbortRequest): + self.engine.abort_request(request.request_id) + if self.log_requests: + logger.info("Aborted request %s.", request.request_id) + + def _handle_health_request(self): + if self._errored_with is not None: + self._send_unhealthy(self._errored_with) + + # Raises error if unhealthy. + self.engine.check_health() + self._send_healthy() + + def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): + """Send List of RequestOutput to RPCClient.""" + if outputs: + output_bytes = pickle.dumps(outputs) + self.output_socket.send_multipart((output_bytes, ), copy=False) + + def _send_healthy(self): + """Send HEALTHY message to RPCClient.""" + self.health_socket.send_multipart(HEALTHY_RESPONSE, copy=False) + + def _send_unhealthy(self, error: BaseException): + """Send UNHEALTHY message to RPCClient.""" + error_bytes = pickle.dumps(error) + self.health_socket.send_multipart((error_bytes, ), copy=False) + + def _async_socket_engine_callback(self, + request_outputs: REQUEST_OUTPUTS_T): + """Callback used by engine to make socket handling async with GPU.""" + self._send_outputs(request_outputs) + self.handle_new_input() + + def _set_errored(self, e: BaseException): + """Log and set errored status if this is the first issue.""" + if self._errored_with is None: + self._errored_with = e + + +def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, + ipc_path: str): + + def signal_handler(*_) -> None: + # Interrupt server on sigterm + raise KeyboardInterrupt("MQLLMEngine terminated") + + signal.signal(signal.SIGTERM, signal_handler) + + engine = MQLLMEngine.from_engine_args(engine_args=engine_args, + usage_context=usage_context, + ipc_path=ipc_path) + engine.start() diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 34ae79f5fa8df..70444faa670a2 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -14,8 +14,8 @@ @runtime_checkable -class AsyncEngineClient(Protocol): - """Protocol class for Clients to AsyncLLMEngine""" +class EngineClient(Protocol): + """Protocol class for Clients to Engine""" @property def is_running(self) -> bool: @@ -30,8 +30,8 @@ def errored(self) -> bool: ... @property - def limit_concurrency(self) -> Optional[int]: - """Maximum number of concurrently running requests.""" + def dead_error(self) -> BaseException: + ... def generate( self, diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 47d227010c075..5dcf50bd1b0a1 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -1,21 +1,21 @@ import asyncio import signal from http import HTTPStatus -from typing import Any, Optional +from typing import Any import uvicorn from fastapi import FastAPI, Request, Response from vllm import envs from vllm.engine.async_llm_engine import AsyncEngineDeadError +from vllm.engine.multiprocessing import MQEngineDeadError from vllm.logger import init_logger from vllm.utils import find_process_using_port logger = init_logger(__name__) -async def serve_http(app: FastAPI, limit_concurrency: Optional[int], - **uvicorn_kwargs: Any): +async def serve_http(app: FastAPI, **uvicorn_kwargs: Any): logger.info("Available routes are:") for route in app.routes: methods = getattr(route, "methods", None) @@ -26,15 +26,6 @@ async def serve_http(app: FastAPI, limit_concurrency: Optional[int], logger.info("Route: %s, Methods: %s", path, ', '.join(methods)) - # Set concurrency limits in uvicorn if running in multiprocessing mode - # since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536). - if limit_concurrency is not None: - logger.info( - "Launching Uvicorn with --limit_concurrency %s. To avoid this " - "limit at the expense of performance run with " - "--disable-frontend-multiprocessing", limit_concurrency) - uvicorn_kwargs["limit_concurrency"] = limit_concurrency - config = uvicorn.Config(app, **uvicorn_kwargs) server = uvicorn.Server(config) _add_shutdown_handlers(app, server) @@ -63,7 +54,7 @@ async def dummy_shutdown() -> None: logger.debug( "port %s is used by process %s launched with command:\n%s", port, process, " ".join(process.cmdline())) - logger.info("Gracefully stopping http server") + logger.info("Shutting down FastAPI HTTP server.") return server.shutdown() @@ -90,7 +81,7 @@ async def runtime_error_handler(request: Request, __): return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) @app.exception_handler(AsyncEngineDeadError) - async def engine_dead_handler(_, __): + async def async_engine_dead_handler(_, __): """Kill the server if the async engine is already dead. It will not handle any further requests.""" if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: @@ -99,3 +90,14 @@ async def engine_dead_handler(_, __): server.should_exit = True return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) + + @app.exception_handler(MQEngineDeadError) + async def mq_engine_dead_handler(_, __): + """Kill the server if the mq engine is already dead. It will + not handle any further requests.""" + if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: + logger.fatal("MQLLMEngine is already dead, terminating server " + "process") + server.should_exit = True + + return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index b891debfd2b91..1b9eb30252417 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -26,7 +26,9 @@ from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.multiprocessing.client import MQLLMEngineClient +from vllm.engine.multiprocessing.engine import run_mp_engine +from vllm.engine.protocol import EngineClient from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import make_arg_parser @@ -44,8 +46,6 @@ TokenizeRequest, TokenizeResponse, UnloadLoraAdapterRequest) -from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient -from vllm.entrypoints.openai.rpc.server import run_rpc_server # yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion @@ -67,29 +67,16 @@ _running_tasks: Set[asyncio.Task] = set() -def model_is_embedding(model_name: str, trust_remote_code: bool, - quantization: Optional[str], - revision: Optional[str]) -> bool: - return ModelConfig(model=model_name, - revision=revision, - tokenizer=model_name, - tokenizer_mode="auto", - trust_remote_code=trust_remote_code, - quantization=quantization, - seed=0, - dtype="auto").embedding_mode - - @asynccontextmanager async def lifespan(app: FastAPI): try: if app.state.log_stats: - async_engine_client = app.state.engine_client + engine_client: EngineClient = app.state.engine_client async def _force_log(): while True: - await asyncio.sleep(10) - await async_engine_client.do_log_stats() + await asyncio.sleep(10.) + await engine_client.do_log_stats() task = asyncio.create_task(_force_log()) _running_tasks.add(task) @@ -108,9 +95,9 @@ async def _force_log(): @asynccontextmanager async def build_async_engine_client( - args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]: + args: Namespace) -> AsyncIterator[Optional[EngineClient]]: - # Context manager to handle async_engine_client lifecycle + # Context manager to handle engine_client lifecycle # Ensures everything is shutdown and cleaned up on error/exit engine_args = AsyncEngineArgs.from_cli_args(args) @@ -123,19 +110,18 @@ async def build_async_engine_client( async def build_async_engine_client_from_engine_args( engine_args: AsyncEngineArgs, disable_frontend_multiprocessing: bool = False, -) -> AsyncIterator[Optional[AsyncEngineClient]]: +) -> AsyncIterator[Optional[EngineClient]]: """ - Create AsyncEngineClient, either: + Create EngineClient, either: - in-process using the AsyncLLMEngine Directly - multiprocess using AsyncLLMEngine RPC Returns the Client or None if the creation failed. """ - # If manually triggered or embedding model, use AsyncLLMEngine in process. - # TODO: support embedding model via RPC. - if (model_is_embedding(engine_args.model, engine_args.trust_remote_code, - engine_args.quantization, engine_args.revision) + # Fall back + # TODO: fill out feature matrix. + if (MQLLMEngineClient.is_unsupported_config(engine_args) or disable_frontend_multiprocessing): engine_config = engine_args.create_engine_config() uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config), @@ -173,56 +159,60 @@ async def build_async_engine_client_from_engine_args( "and vLLM will properly handle cleanup.") # Select random path for IPC. - rpc_path = get_open_zmq_ipc_path() - logger.info("Multiprocessing frontend to use %s for RPC Path.", - rpc_path) - - # Build RPCClient, which conforms to AsyncEngineClient Protocol. - # NOTE: Actually, this is not true yet. We still need to support - # embedding models via RPC (see TODO above) - rpc_client = AsyncEngineRPCClient(rpc_path) + ipc_path = get_open_zmq_ipc_path() + logger.info("Multiprocessing frontend to use %s for IPC Path.", + ipc_path) - # Start RPCServer in separate process (holds the AsyncLLMEngine). - context = multiprocessing.get_context("spawn") + # Start RPCServer in separate process (holds the LLMEngine). # the current process might have CUDA context, # so we need to spawn a new process - rpc_server_process = context.Process( - target=run_rpc_server, - args=(engine_args, UsageContext.OPENAI_API_SERVER, rpc_path)) - rpc_server_process.start() - logger.info("Started engine process with PID %d", - rpc_server_process.pid) + context = multiprocessing.get_context("spawn") + + engine_process = context.Process(target=run_mp_engine, + args=(engine_args, + UsageContext.OPENAI_API_SERVER, + ipc_path)) + engine_process.start() + logger.info("Started engine process with PID %d", engine_process.pid) + + # Build RPCClient, which conforms to EngineClient Protocol. + # NOTE: Actually, this is not true yet. We still need to support + # embedding models via RPC (see TODO above) + engine_config = engine_args.create_engine_config() + mp_engine_client = MQLLMEngineClient(ipc_path, engine_config) try: while True: try: - await rpc_client.setup() + await mp_engine_client.setup() break except TimeoutError: - if not rpc_server_process.is_alive(): - logger.error( - "RPCServer process died before responding " - "to readiness probe") + if not engine_process.is_alive(): + logger.error("Engine process died before responding " + "to readiness probe") yield None return - yield rpc_client # type: ignore[misc] + yield mp_engine_client # type: ignore[misc] finally: # Ensure rpc server process was terminated - rpc_server_process.terminate() + engine_process.terminate() # Close all open connections to the backend - rpc_client.close() + mp_engine_client.close() - # Wait for server process to join - rpc_server_process.join() + # Wait for engine process to join + engine_process.join(4) + if engine_process.exitcode is None: + # Kill if taking longer than 5 seconds to stop + engine_process.kill() # Lazy import for prometheus multiprocessing. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable # before prometheus_client is imported. # See https://prometheus.github.io/client_python/multiprocess/ from prometheus_client import multiprocess - multiprocess.mark_process_dead(rpc_server_process.pid) + multiprocess.mark_process_dead(engine_process.pid) router = APIRouter() @@ -270,7 +260,7 @@ def embedding(request: Request) -> OpenAIServingEmbedding: return request.app.state.openai_serving_embedding -def engine_client(request: Request) -> AsyncEngineClient: +def engine_client(request: Request) -> EngineClient: return request.app.state.engine_client @@ -473,7 +463,7 @@ async def authentication(request: Request, call_next): def init_app_state( - async_engine_client: AsyncEngineClient, + engine_client: EngineClient, model_config: ModelConfig, state: State, args: Namespace, @@ -488,11 +478,11 @@ def init_app_state( else: request_logger = RequestLogger(max_log_len=args.max_log_len) - state.engine_client = async_engine_client + state.engine_client = engine_client state.log_stats = not args.disable_log_stats state.openai_serving_chat = OpenAIServingChat( - async_engine_client, + engine_client, model_config, served_model_names, args.response_role, @@ -504,7 +494,7 @@ def init_app_state( enable_auto_tools=args.enable_auto_tool_choice, tool_parser=args.tool_call_parser) state.openai_serving_completion = OpenAIServingCompletion( - async_engine_client, + engine_client, model_config, served_model_names, lora_modules=args.lora_modules, @@ -513,13 +503,13 @@ def init_app_state( return_tokens_as_token_ids=args.return_tokens_as_token_ids, ) state.openai_serving_embedding = OpenAIServingEmbedding( - async_engine_client, + engine_client, model_config, served_model_names, request_logger=request_logger, ) state.openai_serving_tokenization = OpenAIServingTokenization( - async_engine_client, + engine_client, model_config, served_model_names, lora_modules=args.lora_modules, @@ -541,21 +531,20 @@ def signal_handler(*_) -> None: signal.signal(signal.SIGTERM, signal_handler) - async with build_async_engine_client(args) as async_engine_client: + async with build_async_engine_client(args) as engine_client: # If None, creation of the client failed and we exit. - if async_engine_client is None: + if engine_client is None: return app = build_app(args) - model_config = await async_engine_client.get_model_config() - init_app_state(async_engine_client, model_config, app.state, args) + model_config = await engine_client.get_model_config() + init_app_state(engine_client, model_config, app.state, args) temp_socket.close() shutdown_task = await serve_http( app, - limit_concurrency=async_engine_client.limit_concurrency, host=args.host, port=args.port, log_level=args.uvicorn_log_level, diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/entrypoints/openai/rpc/__init__.py deleted file mode 100644 index efc7e43afdcc9..0000000000000 --- a/vllm/entrypoints/openai/rpc/__init__.py +++ /dev/null @@ -1,50 +0,0 @@ -from dataclasses import dataclass -from enum import Enum -from typing import Mapping, Optional, Union - -from vllm.inputs import PromptInputs -from vllm.lora.request import LoRARequest -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams - -# Success string used for RPC instructions. -VLLM_RPC_SUCCESS_STR = "SUCCESS" - -# Minimum value of ZMQ.SOCKET_LIMIT to run mp. -VLLM_RPC_SOCKET_LIMIT_CUTOFF = 2000 - -# HWM is set to Infinity. -VLLM_RPC_ZMQ_HWM = 0 - - -@dataclass -class RPCGenerateRequest: - inputs: PromptInputs - sampling_params: SamplingParams - request_id: str - lora_request: Optional[LoRARequest] = None - trace_headers: Optional[Mapping[str, str]] = None - prompt_adapter_request: Optional[PromptAdapterRequest] = None - - -@dataclass -class RPCAbortRequest: - request_id: str - - -class RPCUtilityRequest(Enum): - IS_SERVER_READY = 1 - GET_MODEL_CONFIG = 2 - GET_DECODING_CONFIG = 3 - GET_PARALLEL_CONFIG = 4 - GET_SCHEDULER_CONFIG = 5 - GET_LORA_CONFIG = 6 - DO_LOG_STATS = 7 - IS_SERVER_HEALTHY = 8 - IS_TRACING_ENABLED = 9 - START_PROFILE = 10 - STOP_PROFILE = 11 - - -RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest, - RPCUtilityRequest] diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py deleted file mode 100644 index 9b88db746be5c..0000000000000 --- a/vllm/entrypoints/openai/rpc/client.py +++ /dev/null @@ -1,451 +0,0 @@ -import asyncio -import pickle -from contextlib import contextmanager, suppress -from typing import Any, AsyncGenerator, Iterator, Mapping, Optional -from uuid import uuid4 - -import cloudpickle -import zmq -import zmq.asyncio -from zmq import Frame # type: ignore[attr-defined] -from zmq.asyncio import Socket - -from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) -# yapf: disable -from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE, - VLLM_RPC_SOCKET_LIMIT_CUTOFF, - VLLM_RPC_SUCCESS_STR, - VLLM_RPC_ZMQ_HWM, RPCAbortRequest, - RPCGenerateRequest, RPCUtilityRequest) -# yapf: enable -from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS -from vllm.inputs import PromptInputs -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.outputs import EmbeddingRequestOutput, RequestOutput -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs - -logger = init_logger(__name__) - -# Path used for inprocess proxy. -INPROC_PROXY_PATH = f"inproc://{uuid4()}" - - -class RPCClientClosedError(Exception): - """Exception class raised when the client is used post-close. - - The client can be closed, which closes the ZMQ context. This normally - happens on server shutdown. In some cases, methods like abort and - do_log_stats will still be called and then try to open a socket, which - causes a ZMQError and creates a huge stack trace. - So, we throw this error such that we can suppress it. - """ - - -class AsyncEngineRPCClient: - """ - RPCClient that connects to the RPCServer wrapping AsyncLLMEngine. - - The overall design mirrors the Asynchronous Client Server Pattern - https://zguide.zeromq.org/docs/chapter3/#The-Asynchronous-Client-Server-Pattern - - On startup, the RPCClient: - - makes DEALER socket (to_rpc_server) that connects to the RPCServer - via ipc, which uses unix sockets under the hood - (https://libzmq.readthedocs.io/en/zeromq4-1/zmq_ipc.html) - - makes ROUTER socket (from_api_server) that binds to a random - inproc address, which uses memory under the hood - (https://libzmq.readthedocs.io/en/zeromq3-x/zmq_inproc.html) - - runs a proxy in a background asyncio task between - from_api_server (ROUTER, inproc) and to_rpc_server (DEALER ipc, ) - - Each request handled by the asyncio api_server calls generate(): - - make a DEALER socket that connects to from_api_server via inproc - - send a RCPGenerateRequest to the inproc socket - - background proxy forwards the request from inproc -> ipc - - RPCServer responds to the request one token at a time over ipc - - background proxy forwards the response from ipc -> inproc - - The connection looks like this: - DEALER <- inproc -> [ ROUTER | DEALER ] <- ipc -> DEALER - - Message routing is performed via identities that are managed by the - ROUTER socket. ROUTER sockets track every connection it has and - tells the caller about these. The way it tells the caller is to stick - the connection identity in front of each message received. When we - send the message via a ROUTER, we first send an identity frame. - See https://zguide.zeromq.org/docs/chapter3/#The-Extended-Reply-Envelope - for more details on connection identities. - - This proxy design enables us to use a single unix socket, which - improves performance by avoiding syscalls (~5%) and avoids resource limits - such as ulimit, which defaults to 1024 on ubuntu. - - Note: we run set_hwm(0) on each socket, which sets the HWM to inf, - which is required to avoid dropping messages under high load. - This is generally not advisable. However, since we are in control - of both sides of the connection + failure on either side is - catastrophic to the overall system health and memory profiling - suggests limited memory overhead relative to asyncio, we will - proceed for now. - - See https://zguide.zeromq.org/docs/chapter2/#High-Water-Marks - for more details on high water marks. - """ - - def __init__(self, rpc_path: str): - self.context = zmq.asyncio.Context() - self._data_timeout = VLLM_RPC_GET_DATA_TIMEOUT_MS - self._errored = False - - # Maximum number of sockets that can be opened (typically 65536). - # ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get) - socket_limit = self.context.get(zmq.constants.SOCKET_LIMIT) - assert isinstance(socket_limit, int) - if socket_limit < VLLM_RPC_SOCKET_LIMIT_CUTOFF: - raise ValueError( - f"Found zmq.constants.SOCKET_LIMIT={socket_limit}, which caps " - "the number of concurrent requests vLLM can process. Launch " - "vLLM with --disable-frontend-multiprocessing and open a " - "GitHub issue so we can investigate.") - - # We only have 1 ipc connection that uses unix sockets, so - # safe to set MAX_SOCKETS to the zmq SOCKET_LIMIT (i.e. will - # not run into ulimit issues) - self.context.set(zmq.constants.MAX_SOCKETS, socket_limit) - - # IPC connection to RPC Server (uses unix sockets). - self.to_rpc_server: Socket = self.context.socket(zmq.constants.DEALER) - self.to_rpc_server.set_hwm(VLLM_RPC_ZMQ_HWM) - self.to_rpc_server.bind(rpc_path) - - # In process proxy to RPC Server (uses memory-based messaging). - self.from_api_server: Socket = self.context.socket( - zmq.constants.ROUTER) - self.from_api_server.set_hwm(VLLM_RPC_ZMQ_HWM) - self.from_api_server.bind(INPROC_PROXY_PATH) - - # Asyncio background task for the proxy. - self.proxy_in_task = asyncio.create_task( - self.run_proxy(self.from_api_server, self.to_rpc_server)) - self.proxy_out_task = asyncio.create_task( - self.run_proxy(self.to_rpc_server, self.from_api_server)) - - # Since we open 1 inproc socket per request, we have a hard cap on - # the number of requests that can run in vLLM w. frontend - # mulitprocessing. This value is used uvicorn to launch - # with --limit-concurrency to return 503 when server is overloaded. - # We need 2 sockets per request - 2: - # 1 for generate(), 1 for abort(), do_log_stats(), check_health() - self.limit_concurrency = socket_limit // 2 - 2 - - async def run_proxy(self, socket_from: Socket, socket_to: Socket): - """Background task that runs a proxy""" - while True: - frames = await socket_from.recv_multipart(copy=False) - await socket_to.send_multipart(frames, copy=False) - - async def setup(self): - """Setup the client before it starts sending server requests.""" - - # Wait until server is ready. - await self._wait_for_server_rpc() - - # Get the configs. - self.model_config = await self._get_model_config_rpc() - self.decoding_config = await self._get_decoding_config_rpc() - self.tracing_flag = await self._is_tracing_enabled_rpc() - - # Create the tokenizer group. - # TODO: refactor OAI server to avoid needing this info. - self.tokenizer = init_tokenizer_from_configs( - model_config=self.model_config, - scheduler_config=(await self._get_scheduler_config_rpc()), - parallel_config=(await self._get_parallel_config_rpc()), - enable_lora=bool(await self._get_lora_config_rpc()), - ) - - def close(self): - """Destroy the ZeroMQ Context.""" - # Close all sockets associated with this context and - # then terminate the context. - self.from_api_server.close() - self.to_rpc_server.close() - self.context.destroy() - - @contextmanager - def to_proxy_socket(self) -> Iterator[Socket]: - # Connect to the RPCServer via the proxy. - - # Raise a sensible error if the client was already closed. - # This can happen if a server shutdown is triggered but some coroutines - # are still running requests. - # There should not be a race condition with this check because we don't - # yield to the event loop between here and opening the socket. - if self.context.closed: - raise RPCClientClosedError("The ZMQ client has already shut down") - - # Note that we use DEALER to enable asynchronous communication - # to enable streaming. - socket = self.context.socket(zmq.constants.DEALER) - socket.set_hwm(VLLM_RPC_ZMQ_HWM) - try: - socket.connect(INPROC_PROXY_PATH) - yield socket - finally: - socket.close(linger=0) - - async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, - expected_type: Any, - error_message: str) -> Any: - """Send an RPC request that is expecting data back.""" - - with self.to_proxy_socket() as socket: - # Ping RPCServer with a request. - await socket.send_multipart((cloudpickle.dumps(request), ), - copy=False) - - # Make sure the server responds - if await socket.poll(timeout=self._data_timeout) == 0: - raise TimeoutError("Server didn't reply within " - f"{self._data_timeout} ms") - - # Await the data from the Server. - frame = await socket.recv(copy=False) - assert isinstance(frame, Frame) - data = pickle.loads(frame.buffer) - - if isinstance(data, Exception): - # Re-raise exceptions returned by the server - raise data - - if not isinstance(data, expected_type): - # LoRAConfig can be None. - if expected_type == LoRAConfig and data is None: - pass - elif isinstance(data, Exception): - logger.error(error_message) - raise data - else: - raise ValueError(error_message) - - return data - - async def _send_one_way_rpc_request(self, - request: RPC_REQUEST_TYPE, - error_message: str, - socket: Optional[Socket] = None): - """Send one-way RPC request to trigger an action.""" - - async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE): - - await socket.send_multipart((cloudpickle.dumps(request), )) - - if await socket.poll(timeout=self._data_timeout) == 0: - raise TimeoutError("Server didn't reply within " - f"{self._data_timeout} ms") - - frame = await socket.recv(copy=False) - assert isinstance(frame, Frame) - return pickle.loads(frame.buffer) - - # Make a new socket connection. - if socket is None: - with self.to_proxy_socket() as socket: - response = await do_rpc_call(socket, request) - - # Use existing socket connection. - else: - response = await do_rpc_call(socket, request) - - if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR: - if isinstance(response, Exception): - logger.error(error_message) - raise response - raise ValueError(error_message) - - async def get_tokenizer(self, lora_request: LoRARequest): - return await self.tokenizer.get_lora_tokenizer_async(lora_request) - - async def get_decoding_config(self) -> DecodingConfig: - return self.decoding_config - - async def get_model_config(self) -> ModelConfig: - return self.model_config - - async def is_tracing_enabled(self) -> bool: - return self.tracing_flag - - async def _wait_for_server_rpc(self): - """Wait for the RPCServer to start up.""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.IS_SERVER_READY, - error_message="Unable to start RPC Server") - - async def _get_model_config_rpc(self) -> ModelConfig: - """Get the ModelConfig object from the RPC Server""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_MODEL_CONFIG, - expected_type=ModelConfig, - error_message="Could not get ModelConfig from RPC Server") - - async def _get_decoding_config_rpc(self) -> DecodingConfig: - """Get DecodingConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_DECODING_CONFIG, - expected_type=DecodingConfig, - error_message="Could not get DecodingConfig from RPC Server") - - async def _get_parallel_config_rpc(self) -> ParallelConfig: - """Get ParallelConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_PARALLEL_CONFIG, - expected_type=ParallelConfig, - error_message="Could not get ParallelConfig from RPC Server") - - async def _get_scheduler_config_rpc(self) -> SchedulerConfig: - """Get SchedulerConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_SCHEDULER_CONFIG, - expected_type=SchedulerConfig, - error_message="Could not get SchedulerConfig from RPC Server") - - async def _get_lora_config_rpc(self) -> LoRAConfig: - """Get LoRAConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_LORA_CONFIG, - expected_type=LoRAConfig, - error_message="Could not get LoRAConfig from RPC Server") - - async def _is_tracing_enabled_rpc(self) -> bool: - """Get is_tracing_enabled flag from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.IS_TRACING_ENABLED, - expected_type=bool, - error_message="Could not get is_tracing_enabled from RPC Server") - - async def abort(self, request_id: str): - """Send an ABORT_REQUEST signal to the RPC Server""" - - # Suppress timeouts as well. - # In cases where the server is busy processing requests and a very - # large volume of abort requests arrive, it is likely that the server - # will not be able to ack all of them in time. We have seen this when - # we abort 20k requests at once while another 2k are processing- many - # of them time out, but we see the server successfully abort all of the - # requests. - # In this case we assume that the server has received or will receive - # these abort requests, and ignore the timeout. This prevents a massive - # wall of `TimeoutError` stack traces. - with suppress(RPCClientClosedError, TimeoutError): - await self._send_one_way_rpc_request( - request=RPCAbortRequest(request_id), - error_message=f"RPCAbortRequest {request_id} failed") - - async def do_log_stats(self): - """Send a DO_LOG_STATS signal to the RPC Server""" - with suppress(RPCClientClosedError): - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.DO_LOG_STATS, - error_message="RPCRequest DO_LOG_STATS failed.") - - @property - def is_running(self) -> bool: - return not self._errored - - @property - def is_stopped(self) -> bool: - return self._errored - - @property - def errored(self) -> bool: - return self._errored - - async def generate( - self, - inputs: PromptInputs, - sampling_params: SamplingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None - ) -> AsyncGenerator[RequestOutput, None]: - """Send an RPCGenerateRequest to the RPCServer and stream responses.""" - - finished = False - try: - with self.to_proxy_socket() as socket: - # Send RPCGenerateRequest to the RPCServer. - await socket.send_multipart((cloudpickle.dumps( - RPCGenerateRequest( - inputs=inputs, - sampling_params=sampling_params, - request_id=request_id, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request)), )) - - # Stream back the results from the RPC Server. - while not finished: - message = await socket.recv(copy=False) - assert isinstance(message, Frame) - request_output = pickle.loads(message.buffer) - - if isinstance(request_output, Exception): - # On exception, check if the server is still healthy - # possibly setting the `errored` property. - if not self._errored: - try: - await self.check_health(socket=socket) - except Exception as e: - self._errored = True - logger.exception(repr(e)) - - # NB: do before raising here so that the flag is set - # by the time the caller receives this exception - raise request_output - - finished = request_output.finished - yield request_output - - finally: - # Request was canceled by the client. - if not finished and not self._errored: - await self.abort(request_id) - - async def check_health(self, socket: Optional[Socket] = None) -> None: - """Raise if unhealthy""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.IS_SERVER_HEALTHY, - error_message="Got Unhealthy response from RPC Server", - socket=socket) - - async def encode(self, *args, - **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: - raise NotImplementedError( - "Embeddings not supported with multiprocessing backend") - - async def start_profile(self) -> None: - """Start profiling the engine""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.START_PROFILE, - error_message="RPCRequest START_PROFILE failed.") - - async def stop_profile(self) -> None: - """Stop profiling the engine""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.STOP_PROFILE, - error_message="RPCRequest STOP_PROFILE failed.") diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py deleted file mode 100644 index 460ff0636b6e9..0000000000000 --- a/vllm/entrypoints/openai/rpc/server.py +++ /dev/null @@ -1,243 +0,0 @@ -import asyncio -import pickle -import signal -from typing import Any, Coroutine, Union - -import cloudpickle -import uvloop -import zmq -import zmq.asyncio -from typing_extensions import Never -from zmq import Frame # type: ignore[attr-defined] -from zmq.asyncio import Socket - -from vllm import AsyncEngineArgs, AsyncLLMEngine -from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) -from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR, - VLLM_RPC_ZMQ_HWM, RPCAbortRequest, - RPCGenerateRequest, RPCUtilityRequest) -from vllm.logger import init_logger -from vllm.usage.usage_lib import UsageContext - -logger = init_logger(__name__) - -CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig, - SchedulerConfig, LoRAConfig] - - -class AsyncEngineRPCServer: - - def __init__(self, async_engine_args: AsyncEngineArgs, - usage_context: UsageContext, rpc_path: str): - # Initialize engine first. - self.engine = AsyncLLMEngine.from_engine_args( - async_engine_args, usage_context=usage_context) - - # Initialize context. - self.context = zmq.asyncio.Context() - - # Init socket. - self.socket: Socket = self.context.socket(zmq.constants.DEALER) - self.socket.set_hwm(VLLM_RPC_ZMQ_HWM) - self.socket.connect(rpc_path) - - def cleanup(self): - """Cleanup all resources.""" - self.socket.close() - self.context.destroy() - # Clear the engine reference so that it can be GC'ed. - del self.engine - - async def get_config(self, identity, request): - try: - config: CONFIG_TYPE - if request == RPCUtilityRequest.GET_MODEL_CONFIG: - config = await self.engine.get_model_config() - elif request == RPCUtilityRequest.GET_DECODING_CONFIG: - config = await self.engine.get_decoding_config() - elif request == RPCUtilityRequest.GET_LORA_CONFIG: - config = await self.engine.get_lora_config() - elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG: - config = await self.engine.get_scheduler_config() - elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG: - config = await self.engine.get_parallel_config() - else: - raise ValueError("Unknown Config Request: %s", request) - - await self.socket.send_multipart((identity, pickle.dumps(config)), - copy=False) - - except Exception as e: - await self.socket.send_multipart((identity, pickle.dumps(e)), - copy=False) - - async def is_tracing_enabled(self, identity): - """Send the is_tracing_enabled flag""" - tracing_flag = await self.engine.is_tracing_enabled() - - await self.socket.send_multipart( - (identity, pickle.dumps(tracing_flag))) - - async def do_log_stats(self, identity): - """Log stats and confirm success.""" - await self.engine.do_log_stats() - - await self.socket.send_multipart( - (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR))) - - async def is_server_ready(self, identity): - """Notify the client that we are ready.""" - await self.socket.send_multipart( - (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR))) - - async def abort(self, identity, request: RPCAbortRequest): - """Abort request and notify the client of success.""" - try: - # Abort the request in the llm engine. - await self.engine.abort(request.request_id) - result: Union[str, Exception] = VLLM_RPC_SUCCESS_STR - except Exception as e: - result = e - await self.socket.send_multipart((identity, pickle.dumps(result))) - - async def generate(self, identity, generate_request: RPCGenerateRequest): - try: - results_generator = self.engine.generate( - generate_request.inputs, - sampling_params=generate_request.sampling_params, - request_id=generate_request.request_id, - lora_request=generate_request.lora_request, - trace_headers=generate_request.trace_headers, - prompt_adapter_request=generate_request.prompt_adapter_request) - - async for request_output in results_generator: - await self.socket.send_multipart( - (identity, pickle.dumps(request_output)), copy=False) - - except Exception as e: - await self.socket.send_multipart((identity, pickle.dumps(e)), - copy=False) - - async def check_health(self, identity): - try: - await self.engine.check_health() - await self.socket.send_multipart( - (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR))) - - except Exception as e: - await self.socket.send_multipart((identity, pickle.dumps(e)), - copy=False) - - async def start_profile(self, identity): - logger.info("Starting profiler...") - await self.engine.start_profile() - logger.info("Profiler started.") - - await self.socket.send_multipart(( - identity, - pickle.dumps(VLLM_RPC_SUCCESS_STR), - )) - - async def stop_profile(self, identity): - logger.info("Stopping profiler...") - await self.engine.stop_profile() - logger.info("Profiler stopped.") - - await self.socket.send_multipart(( - identity, - pickle.dumps(VLLM_RPC_SUCCESS_STR), - )) - - def _make_handler_coro(self, identity, - message: Frame) -> Coroutine[Any, Any, Never]: - """Route the zmq message to the handler coroutine.""" - - request = cloudpickle.loads(message.buffer) - - if isinstance(request, RPCGenerateRequest): - return self.generate(identity, request) - - elif isinstance(request, RPCAbortRequest): - return self.abort(identity, request) - - elif isinstance(request, RPCUtilityRequest): - if request in [ - RPCUtilityRequest.GET_MODEL_CONFIG, - RPCUtilityRequest.GET_PARALLEL_CONFIG, - RPCUtilityRequest.GET_DECODING_CONFIG, - RPCUtilityRequest.GET_SCHEDULER_CONFIG, - RPCUtilityRequest.GET_LORA_CONFIG - ]: - return self.get_config(identity, request) - elif request == RPCUtilityRequest.DO_LOG_STATS: - return self.do_log_stats(identity) - elif request == RPCUtilityRequest.IS_SERVER_READY: - return self.is_server_ready(identity) - elif request == RPCUtilityRequest.IS_SERVER_HEALTHY: - return self.check_health(identity) - elif request == RPCUtilityRequest.IS_TRACING_ENABLED: - return self.is_tracing_enabled(identity) - elif request == RPCUtilityRequest.START_PROFILE: - return self.start_profile(identity) - elif request == RPCUtilityRequest.STOP_PROFILE: - return self.stop_profile(identity) - else: - raise ValueError(f"Unknown RPCUtilityRequest type: {request}") - - else: - raise ValueError(f"Unknown RPCRequest type: {request}") - - async def run_server_loop(self): - """Inner RPC Server Loop""" - - running_tasks = set() - while True: - # Wait for a request. - identity, message = await self.socket.recv_multipart(copy=False) - - # Process the request async. - task = asyncio.create_task( - self._make_handler_coro(identity, message)) - - # We need to keep around a strong reference to the task, - # to avoid the task disappearing mid-execution as running tasks - # can be GC'ed. Below is a common "fire-and-forget" tasks - # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task - running_tasks.add(task) - task.add_done_callback(running_tasks.discard) - - -async def run_server(server: AsyncEngineRPCServer): - # Put the server task into the asyncio loop. - loop = asyncio.get_running_loop() - server_task = loop.create_task(server.run_server_loop()) - - # Interruption handling. - def signal_handler() -> None: - # Kill the server on interrupt / terminate - server_task.cancel() - - loop.add_signal_handler(signal.SIGINT, signal_handler) - loop.add_signal_handler(signal.SIGTERM, signal_handler) - - try: - await server_task - except asyncio.CancelledError: - logger.info("vLLM ZMQ RPC Server was interrupted.") - finally: - # Clean up all resources. - server.cleanup() - - -def run_rpc_server(async_engine_args: AsyncEngineArgs, - usage_context: UsageContext, rpc_path: str): - - def signal_handler(*_) -> None: - # Interrupt server on sigterm while initializing - raise KeyboardInterrupt("AsyncEngineRPCServer terminated") - - signal.signal(signal.SIGTERM, signal_handler) - - server = AsyncEngineRPCServer(async_engine_args, usage_context, rpc_path) - uvloop.run(run_server(server)) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index d28362a12abdb..b84898dc39b0f 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -9,7 +9,7 @@ from fastapi import Request from vllm.config import ModelConfig -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import (ConversationMessage, apply_hf_chat_template, apply_mistral_chat_template, @@ -45,7 +45,7 @@ class OpenAIServingChat(OpenAIServing): def __init__(self, - async_engine_client: AsyncEngineClient, + engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], response_role: str, @@ -57,7 +57,7 @@ def __init__(self, return_tokens_as_token_ids: bool = False, enable_auto_tools: bool = False, tool_parser: Optional[str] = None): - super().__init__(async_engine_client=async_engine_client, + super().__init__(engine_client=engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -105,6 +105,12 @@ async def create_chat_completion( logger.error("Error with model %s", error_check_ret) return error_check_ret + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + try: ( lora_request, @@ -112,8 +118,7 @@ async def create_chat_completion( ) = self._maybe_get_adapters(request) model_config = self.model_config - tokenizer = await self.async_engine_client.get_tokenizer( - lora_request) + tokenizer = await self.engine_client.get_tokenizer(lora_request) conversation, mm_data_future = parse_chat_messages_futures( request.messages, model_config, tokenizer) @@ -207,8 +212,8 @@ async def create_chat_completion( if mm_data is not None: engine_inputs["multi_modal_data"] = mm_data - is_tracing_enabled = ( - await self.async_engine_client.is_tracing_enabled()) + is_tracing_enabled = (await + self.engine_client.is_tracing_enabled()) trace_headers = None if is_tracing_enabled and raw_request: trace_headers = extract_trace_headers(raw_request.headers) @@ -216,7 +221,7 @@ async def create_chat_completion( and contains_trace_headers(raw_request.headers)): log_tracing_disabled_warning() - result_generator = self.async_engine_client.generate( + result_generator = self.engine_client.generate( engine_inputs, sampling_params, request_id, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 42142efb5f23e..14fa60243c584 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -8,7 +8,7 @@ from fastapi import Request from vllm.config import ModelConfig -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable @@ -43,7 +43,7 @@ class OpenAIServingCompletion(OpenAIServing): def __init__( self, - async_engine_client: AsyncEngineClient, + engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], *, @@ -52,7 +52,7 @@ def __init__( request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, ): - super().__init__(async_engine_client=async_engine_client, + super().__init__(engine_client=engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -78,6 +78,12 @@ async def create_completion( if error_check_ret is not None: return error_check_ret + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + # Return error for unsupported features. if request.suffix is not None: return self.create_error_response( @@ -95,8 +101,7 @@ async def create_completion( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.async_engine_client.get_tokenizer( - lora_request) + tokenizer = await self.engine_client.get_tokenizer(lora_request) guided_decode_logits_processor = ( await self._guided_decode_logits_processor(request, tokenizer)) @@ -124,8 +129,8 @@ async def create_completion( lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) - is_tracing_enabled = ( - await self.async_engine_client.is_tracing_enabled()) + is_tracing_enabled = (await + self.engine_client.is_tracing_enabled()) trace_headers = None if is_tracing_enabled: trace_headers = extract_trace_headers(raw_request.headers) @@ -133,7 +138,7 @@ async def create_completion( raw_request.headers): log_tracing_disabled_warning() - generator = self.async_engine_client.generate( + generator = self.engine_client.generate( {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, sampling_params, request_id_item, diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 12ec6be03cd62..f111a3a8277b5 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -8,7 +8,7 @@ from typing_extensions import assert_never from vllm.config import ModelConfig -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import (EmbeddingRequest, EmbeddingResponse, @@ -71,13 +71,13 @@ class OpenAIServingEmbedding(OpenAIServing): def __init__( self, - async_engine_client: AsyncEngineClient, + engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], *, request_logger: Optional[RequestLogger], ): - super().__init__(async_engine_client=async_engine_client, + super().__init__(engine_client=engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=None, @@ -118,8 +118,7 @@ async def create_embedding( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.async_engine_client.get_tokenizer( - lora_request) + tokenizer = await self.engine_client.get_tokenizer(lora_request) pooling_params = request.to_pooling_params() @@ -144,7 +143,7 @@ async def create_embedding( "Prompt adapter is not supported " "for embedding models") - generator = self.async_engine_client.encode( + generator = self.engine_client.encode( {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, pooling_params, request_id_item, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index ac74527441cd9..72f9381abc7db 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -8,7 +8,7 @@ from typing_extensions import Annotated from vllm.config import ModelConfig -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable @@ -64,7 +64,7 @@ class OpenAIServing: def __init__( self, - async_engine_client: AsyncEngineClient, + engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], *, @@ -75,7 +75,7 @@ def __init__( ): super().__init__() - self.async_engine_client = async_engine_client + self.engine_client = engine_client self.model_config = model_config self.max_model_len = model_config.max_model_len @@ -159,7 +159,7 @@ def create_streaming_error_response( async def _guided_decode_logits_processor( self, request: Union[ChatCompletionRequest, CompletionRequest], tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]: - decoding_config = await self.async_engine_client.get_decoding_config() + decoding_config = await self.engine_client.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend return await get_guided_decoding_logits_processor( diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 6e802b71ae2b4..8f8862897fc4e 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -1,7 +1,7 @@ from typing import List, Optional, Union from vllm.config import ModelConfig -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import (apply_hf_chat_template, apply_mistral_chat_template, load_chat_template, @@ -29,7 +29,7 @@ class OpenAIServingTokenization(OpenAIServing): def __init__( self, - async_engine_client: AsyncEngineClient, + engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], *, @@ -37,7 +37,7 @@ def __init__( request_logger: Optional[RequestLogger], chat_template: Optional[str], ): - super().__init__(async_engine_client=async_engine_client, + super().__init__(engine_client=engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -66,7 +66,7 @@ async def create_tokenize( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.async_engine_client.get_tokenizer(lora_request) + tokenizer = await self.engine_client.get_tokenizer(lora_request) prompt: Union[str, List[int]] if isinstance(request, TokenizeChatRequest): @@ -132,7 +132,7 @@ async def create_detokenize( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.async_engine_client.get_tokenizer(lora_request) + tokenizer = await self.engine_client.get_tokenizer(lora_request) self._log_inputs(request_id, request.tokens, diff --git a/vllm/envs.py b/vllm/envs.py index 6edb06ecd2e20..43c7aa8af85b2 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -57,7 +57,7 @@ VERBOSE: bool = False VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False VLLM_TEST_FORCE_FP8_MARLIN: bool = False - VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000 + VLLM_RPC_TIMEOUT: int = 10000 # ms VLLM_PLUGINS: Optional[List[str]] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None VLLM_USE_TRITON_AWQ: bool = False @@ -393,8 +393,8 @@ def get_default_config_root(): # Time in ms for the zmq client to wait for a response from the backend # server for simple data operations - "VLLM_RPC_GET_DATA_TIMEOUT_MS": - lambda: int(os.getenv("VLLM_RPC_GET_DATA_TIMEOUT_MS", "5000")), + "VLLM_RPC_TIMEOUT": + lambda: int(os.getenv("VLLM_RPC_TIMEOUT", "10000")), # a list of plugin names to load, separated by commas. # if this is not set, it means all plugins will be loaded diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 7380b73ad6548..9ad240ef60820 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -106,6 +106,7 @@ def _init_executor(self) -> None: )) for rank in range(1, world_size) ] + self.worker_monitor = None if world_size != 1 or is_async: if is_async: async_worker_list = self.workers + [self.driver_worker] diff --git a/vllm/executor/multiproc_worker_utils.py b/vllm/executor/multiproc_worker_utils.py index aa2a16c04d08d..5bef76b90d332 100644 --- a/vllm/executor/multiproc_worker_utils.py +++ b/vllm/executor/multiproc_worker_utils.py @@ -168,6 +168,8 @@ def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future], self.tasks[task_id] = future try: self._task_queue.put((task_id, method, args, kwargs)) + except SystemExit: + raise except BaseException as e: del self.tasks[task_id] raise ChildProcessError("worker died") from e @@ -222,6 +224,8 @@ def _run_worker_process( try: executor = getattr(worker, method) output = executor(*args, **kwargs) + except SystemExit: + raise except KeyboardInterrupt: break except BaseException as e: From a8c1d161a7d87dbc6c7cccfce303dcbe2e4ed6be Mon Sep 17 00:00:00 2001 From: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com> Date: Wed, 18 Sep 2024 11:38:43 -0400 Subject: [PATCH 77/98] [Core] *Prompt* logprobs support in Multi-step (#8199) --- tests/conftest.py | 84 +++++++++++------- tests/models/utils.py | 108 +++++++++++++++++++++-- tests/multi_step/test_correctness_llm.py | 92 +++++++++++++++++++ tests/utils.py | 3 +- vllm/worker/multi_step_model_runner.py | 72 ++++++++++----- 5 files changed, 300 insertions(+), 59 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index e9c7fc7bf9c67..c2616bcf7091c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,6 +20,8 @@ BatchFeature) from transformers.models.auto.auto_factory import _BaseAutoModelClass +from tests.models.utils import (TokensTextLogprobs, + TokensTextLogprobsPromptLogprobs) from vllm import LLM, SamplingParams from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset @@ -33,7 +35,6 @@ to_enc_dec_tuple_list, zip_enc_dec_prompts) from vllm.logger import init_logger from vllm.outputs import RequestOutput -from vllm.sequence import SampleLogprobs from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless, identity, is_cpu) @@ -469,7 +470,7 @@ def generate_greedy_logprobs_limit( audios: Optional[PromptAudioInput] = None, videos: Optional[List[np.ndarray]] = None, **kwargs: Any, - ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]: + ) -> List[TokensTextLogprobs]: all_logprobs: List[List[Dict[int, float]]] = [] all_output_ids: List[List[int]] = [] all_output_strs: List[str] = [] @@ -525,7 +526,7 @@ def generate_encoder_decoder_greedy_logprobs_limit( max_tokens: int, num_logprobs: int, **kwargs: Any, - ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]: + ) -> List[TokensTextLogprobs]: ''' Greedy logprobs generation for vLLM encoder/decoder models ''' @@ -653,14 +654,16 @@ def generate( @staticmethod def _final_steps_generate_w_logprobs( req_outputs: List[RequestOutput], - ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: - outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = [] + ) -> List[TokensTextLogprobsPromptLogprobs]: + outputs: List[TokensTextLogprobsPromptLogprobs] = [] for req_output in req_outputs: + assert len(req_output.outputs) > 0 for sample in req_output.outputs: output_str = sample.text output_ids = list(sample.token_ids) output_logprobs = sample.logprobs - outputs.append((output_ids, output_str, output_logprobs)) + outputs.append((output_ids, output_str, output_logprobs, + req_output.prompt_logprobs)) return outputs def generate_w_logprobs( @@ -670,7 +673,8 @@ def generate_w_logprobs( images: Optional[PromptImageInput] = None, audios: Optional[PromptAudioInput] = None, videos: Optional[PromptVideoInput] = None, - ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: + ) -> Union[List[TokensTextLogprobs], + List[TokensTextLogprobsPromptLogprobs]]: assert sampling_params.logprobs is not None if images is not None: @@ -695,13 +699,20 @@ def generate_w_logprobs( req_outputs = self.model.generate(inputs, sampling_params=sampling_params) - return self._final_steps_generate_w_logprobs(req_outputs) + + toks_str_logsprobs_prompt_logprobs = ( + self._final_steps_generate_w_logprobs(req_outputs)) + # Omit prompt logprobs if not required by sampling params + return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs] + if sampling_params.prompt_logprobs is None else + toks_str_logsprobs_prompt_logprobs) def generate_encoder_decoder_w_logprobs( self, encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]], sampling_params: SamplingParams, - ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: + ) -> Union[List[TokensTextLogprobs], + List[TokensTextLogprobsPromptLogprobs]]: ''' Logprobs generation for vLLM encoder/decoder models ''' @@ -709,7 +720,12 @@ def generate_encoder_decoder_w_logprobs( assert sampling_params.logprobs is not None req_outputs = self.model.generate(encoder_decoder_prompts, sampling_params=sampling_params) - return self._final_steps_generate_w_logprobs(req_outputs) + toks_str_logsprobs_prompt_logprobs = ( + self._final_steps_generate_w_logprobs(req_outputs)) + # Omit prompt logprobs if not required by sampling params + return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs] + if sampling_params.prompt_logprobs is None else + toks_str_logsprobs_prompt_logprobs) def generate_greedy( self, @@ -727,44 +743,48 @@ def generate_greedy_logprobs( prompts: List[str], max_tokens: int, num_logprobs: int, + num_prompt_logprobs: Optional[int] = None, images: Optional[PromptImageInput] = None, audios: Optional[PromptAudioInput] = None, videos: Optional[PromptVideoInput] = None, stop_token_ids: Optional[List[int]] = None, - ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: - greedy_logprobs_params = SamplingParams(temperature=0.0, - max_tokens=max_tokens, - logprobs=num_logprobs, - stop_token_ids=stop_token_ids) - outputs = self.generate_w_logprobs(prompts, - greedy_logprobs_params, - images=images, - audios=audios, - videos=videos) - - return [(output_ids, output_str, output_logprobs) - for output_ids, output_str, output_logprobs in outputs] + ) -> Union[List[TokensTextLogprobs], + List[TokensTextLogprobsPromptLogprobs]]: + greedy_logprobs_params = SamplingParams( + temperature=0.0, + max_tokens=max_tokens, + logprobs=num_logprobs, + prompt_logprobs=(num_prompt_logprobs), + stop_token_ids=stop_token_ids) + + return self.generate_w_logprobs(prompts, + greedy_logprobs_params, + images=images, + audios=audios, + videos=videos) def generate_encoder_decoder_greedy_logprobs( self, encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]], max_tokens: int, num_logprobs: int, - ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: - greedy_logprobs_params = SamplingParams(temperature=0.0, - use_beam_search=False, - max_tokens=max_tokens, - logprobs=num_logprobs) + num_prompt_logprobs: Optional[int] = None, + ) -> Union[List[TokensTextLogprobs], + List[TokensTextLogprobsPromptLogprobs]]: + greedy_logprobs_params = SamplingParams( + temperature=0.0, + use_beam_search=False, + max_tokens=max_tokens, + logprobs=num_logprobs, + prompt_logprobs=(num_prompt_logprobs), + ) ''' Greedy logprobs generation for vLLM encoder/decoder models ''' - outputs = self.generate_encoder_decoder_w_logprobs( + return self.generate_encoder_decoder_w_logprobs( encoder_decoder_prompts, greedy_logprobs_params) - return [(output_ids, output_str, output_logprobs) - for output_ids, output_str, output_logprobs in outputs] - def generate_beam_search( self, prompts: List[str], diff --git a/tests/models/utils.py b/tests/models/utils.py index 93ec03995094b..8e31a1d6eefed 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -1,7 +1,7 @@ import warnings from typing import Dict, List, Optional, Sequence, Tuple, Union -from vllm.sequence import Logprob, SampleLogprobs +from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs TokensText = Tuple[List[int], str] @@ -34,20 +34,47 @@ def check_outputs_equal( assert output_ids_0 == output_ids_1, fail_msg +# Representation of generated sequence as a tuple of +# * Token ID list +# * String +# * List of top sample logprobs for each sampled token +# +# Assumes prompt logprobs were not requested. TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int, float]], SampleLogprobs]]] -# Allow for tokens to be represented as str's rather than IDs +# Allow for tokens to be represented as str's rather than IDs; +# tuple of +# * Token string representations list +# * String +# * Optional list of top sample logprobs for each sampled token +# +# Assumes prompt logprobs were not requested. TextTextLogprobs = Tuple[List[str], str, Optional[Union[List[Dict[str, float]], List[Dict[str, Logprob]]]]] +# Representation of generated sequence as a tuple of +# * Token ID list +# * String +# * Optional list of top sample logprobs for each sampled token +# * Optional list of top prompt logprobs for each prompt token +# +# Allows prompt logprobs to be requested. +TokensTextLogprobsPromptLogprobs = Tuple[ + List[int], str, Optional[Union[List[Dict[int, float]], SampleLogprobs]], + Optional[Union[List[Optional[Dict[int, float]]], PromptLogprobs]]] + def check_logprobs_close( *, - outputs_0_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]], - outputs_1_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]], + outputs_0_lst: Sequence[Union[TokensTextLogprobs, + TokensTextLogprobsPromptLogprobs, + TextTextLogprobs]], + outputs_1_lst: Sequence[Union[TokensTextLogprobs, + TokensTextLogprobsPromptLogprobs, + TextTextLogprobs]], name_0: str, name_1: str, num_outputs_0_skip_tokens: int = 0, @@ -57,6 +84,18 @@ def check_logprobs_close( """Compare the logprobs of two sequences generated by different models, which should be similar but not necessarily equal. + How sample logprobs are compared: + * `always_check_logprobs == True`: set of highest-logprob token ids + must match between seq0 and seq1 at all sampled token offsets + * `always_check_logprobs == False`: highest-logprob token ids are + only compared at sampled token offsets for which generated token + ids don't match + + Prompt logprobs must be provided either for both input sequences, or + for neither. If prompt logprobs are provided, then highest-logprob + prompt token ids must match between seq0 and seq1 at all prompt token + offsets. + Args: outputs_0_lst: First sequence to compare outputs_0_lst: Second sequence to compare @@ -78,8 +117,65 @@ def check_logprobs_close( for prompt_idx, (outputs_0, outputs_1) in enumerate(zip(outputs_0_lst, outputs_1_lst)): - output_ids_0, output_str_0, logprobs_0 = outputs_0 - output_ids_1, output_str_1, logprobs_1 = outputs_1 + assert len(outputs_0) == len(outputs_1) + if len(outputs_0) == 3: + assert len(outputs_1) == 3 + # Break out tokens, text & sample logprobs + # (prompt logprobs were not provided) + output_ids_0, output_str_0, logprobs_0 = outputs_0 + output_ids_1, output_str_1, logprobs_1 = outputs_1 + elif len(outputs_0) == 4: + assert len(outputs_1) == 4 + # Break out tokens, text, sample logprobs & prompt logprobs + ( + output_ids_0, + output_str_0, + logprobs_0, + prompt_logprobs_0, + ) = outputs_0 + ( + output_ids_1, + output_str_1, + logprobs_1, + prompt_logprobs_1, + ) = outputs_1 + + # Test prompt logprobs closeness + if (prompt_logprobs_0 is not None + and prompt_logprobs_1 is not None): + # Both sequences' prompt logprobs lists are not `None`` + # (although individual list elements may be `None`); + # for each token's logprobs: + for idx, (logprobs_elem_0, logprobs_elem_1) in enumerate( + zip(prompt_logprobs_0, prompt_logprobs_1)): + fail_msg = ( + f"Prompt logprobs test:" + f"\n{name_0}:\tPrompt index {idx}\t{logprobs_elem_0}" + f"\n{name_1}:\tPrompt index {idx}\t{logprobs_elem_1}") + + if logprobs_elem_0 is None: + # If the seq 0 token's logprobs are `None`, + # the seq 1 token's logprobs must be `None` + assert logprobs_elem_1 is None, fail_msg + else: + # If the seq 0 token's logprobs are not `None`, + # the seq 1 token's logprobs must not be `None` + assert logprobs_elem_1 is not None, fail_msg + # Logprobs check: top-k token choices must be the same + assert (set(logprobs_elem_0.keys()) == set( + logprobs_elem_1.keys())), fail_msg + else: + # Both sequence logprobs lists must be `None` + fail_msg = (f"Prompt logprobs test:" + f"\n{name_0}:\tlogprobs\t{prompt_logprobs_0}" + f"\n{name_1}:\tlogprobs\t{prompt_logprobs_1}") + + assert (prompt_logprobs_0 is None + and prompt_logprobs_1 is None), fail_msg + else: + raise ValueError(f"Outputs tuple must have 3 or 4 elements but " + f"{len(outputs_0)} elements were provided: " + f"{outputs_0}") if logprobs_0 is None: logprobs_0 = [None] * len(output_ids_0) diff --git a/tests/multi_step/test_correctness_llm.py b/tests/multi_step/test_correctness_llm.py index 24ebb60a9cbfd..c5dc81cc25622 100644 --- a/tests/multi_step/test_correctness_llm.py +++ b/tests/multi_step/test_correctness_llm.py @@ -100,3 +100,95 @@ def test_multi_step_llm( name_0="hf", name_1="vllm", ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("tp_size", [1]) +@pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize("enforce_eager", [True]) +@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) +@pytest.mark.parametrize("num_prompts", NUM_PROMPTS) +@pytest.mark.parametrize("num_logprobs,num_prompt_logprobs", [(5, 5)]) +def test_multi_step_llm_w_prompt_logprobs( + vllm_runner, + example_prompts, + model: str, + dtype: str, + tp_size: int, + max_tokens: int, + enforce_eager: int, + num_scheduler_steps: int, + num_prompts: int, + num_logprobs: Optional[int], + num_prompt_logprobs: Optional[int], +) -> None: + """Test prompt logprobs with multi-step scheduling via sync LLM Engine. + + Set up a vLLM engine instance w/ single-step scheduling as a ground-truth + reference. + + Prompt them with the same example prompts. + + Validate: + * All generated logprobs are all very close + + Args: + hf_runner: HF transformers model runner fixture + vllm_runner: vLLM model runner fixture + example_prompts: test fixture providing example prompts + model: model under test (same for single- and multi-step engines) + dtype: tensor datatype for engine to utilize + tp_size: degree of tensor-parallelism + max_tokens: the maximum number of tokens to generate + enforce_eager + num_scheduler_steps: for multi-step scheduling, GPU-side steps per + GPU -> CPU output transfer + num_prompts: number of example prompts under test + num_logprobs: corresponds to the `logprobs` argument to the OpenAI + completions endpoint; `None` -> no logprobs + num_prompt_logprobs: number of logprobs to return for each prompt token; + note that this argument is not supported by the + OpenAI completions endpoint. + """ + + prompts = example_prompts + if len(prompts) < num_prompts: + prompts = prompts * ((num_prompts // len(prompts)) + 1) + prompts = prompts[:num_prompts] + assert len(prompts) == num_prompts + + with vllm_runner( + model, + dtype=dtype, + enforce_eager=enforce_eager, + gpu_memory_utilization=0.7, + tensor_parallel_size=tp_size, + use_v2_block_manager=True, + num_scheduler_steps=num_scheduler_steps, + ) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + prompts, + max_tokens, + num_logprobs, + num_prompt_logprobs=num_prompt_logprobs) + + with vllm_runner( + model, + dtype=dtype, + enforce_eager=enforce_eager, + gpu_memory_utilization=0.7, + tensor_parallel_size=tp_size, + ) as vllm_model: + single_step_vllm_outputs = vllm_model.generate_greedy_logprobs( + prompts, + max_tokens, + num_logprobs, + num_prompt_logprobs=num_prompt_logprobs) + + check_logprobs_close( + outputs_0_lst=single_step_vllm_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/utils.py b/tests/utils.py index 81442cad78da2..43825e8138362 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -493,6 +493,7 @@ async def completions_with_server_args( ''' outputs = None + max_wait_seconds = 240 * 3 # 240 is default with RemoteOpenAIServer(model_name, server_cli_args, max_wait_seconds=max_wait_seconds) as server: @@ -503,7 +504,7 @@ async def completions_with_server_args( stream=False, max_tokens=5, logprobs=num_logprobs) - assert outputs is not None + assert outputs is not None, "Completion API call failed." return outputs diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index b900eb5a610ff..ebcafbbab119a 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -614,34 +614,66 @@ def _pythonize_sampler_output( frozen_model_input = model_input.frozen_model_input assert frozen_model_input.sampling_metadata is not None + sampling_metadata = frozen_model_input.sampling_metadata # samples generation should have been skipped assert not output.outputs pinned_buffer = pinned_sampled_token_buffer[:model_input.num_queries] - # CPU GPU sync - pinned_buffer = pinned_buffer.copy_(sampled_token_ids, non_blocking=False) + # We guarantee output tensors are ready, so it is safe to + # pythonize the sampler output & obtain CPU-side logprobs. + # + # However we should check whether logprobs pythonization may + # be skipped entirely, i.e. because no logprobs were requested + # or pythonization was not deferred. To that end, + # + # * `prompt_logprobs_are_requested_for_prefill` signals that + # there are *any* prefill-phase requests which specify that + # prompt logprobs should be returned. + # + # * `any_logprobs_are_requested` signals that there are any + # requests which (1) specify that sample logprobs should be + # returned, or (2) are in the prefill phase AND specify that + # prompt logprobs should be returned. + # + # Later on, these flags cause adjustments to the pythonization + # process to accommodate logprobs. + + seq_groups = sampling_metadata.seq_groups + prompt_logprobs_are_requested_for_prefill = any([ + sg.sampling_params.prompt_logprobs is not None and sg.is_prompt + for sg in seq_groups + ]) + any_logprobs_are_requested = ( + prompt_logprobs_are_requested_for_prefill + or any([sg.sampling_params.logprobs is not None for sg in seq_groups])) + + if prompt_logprobs_are_requested_for_prefill: + # CPU GPU sync, after gathering *only* sampled tokens (since + # requesting prompt logprobs leads `sampled_token_ids` to + # include prompt token ids in addition to sampled token ids.) + sample_idx_tensor = torch.tensor( + [sdx for sg in seq_groups for sdx in sg.sample_indices]) + pinned_buffer = pinned_buffer.copy_( + sampled_token_ids[sample_idx_tensor, :], non_blocking=False) + else: + # CPU GPU sync + pinned_buffer = pinned_buffer.copy_(sampled_token_ids, + non_blocking=False) # this will not block as the tensors are already on CPU samples_list = pinned_buffer.tolist() - sampling_metadata = frozen_model_input.sampling_metadata - skip_sampler_cpu_output = ( frozen_model_input.sampling_metadata.skip_sampler_cpu_output) - # We are guaranteed output tensors are ready, so it is safe to - # pythonize the sampler output & obtain CPU-side logprobs. - # - # However this computation may be skipped entirely - # if no pythonization was deferred. - seq_groups = sampling_metadata.seq_groups - logprobs_are_requested = any([ - sg.sampling_params.logprobs is not None - or sg.sampling_params.prompt_logprobs is not None for sg in seq_groups - ]) + # *Don't* skip logprobs pythonization *if*: + # * Any requests require logprobs to be returned in this + # iteration AND + # * These requests are being scheduled in a fashion which + # defers pythonization (i.e. multi-step scheduling.) do_pythonize_logprobs = (skip_sampler_cpu_output - and logprobs_are_requested) + and any_logprobs_are_requested) ( prompt_logprobs, sample_logprobs, @@ -666,7 +698,7 @@ def _pythonize_sampler_output( prompt_logprobs[sgdx], sample_logprobs[sgdx], ) - elif logprobs_are_requested: + elif any_logprobs_are_requested: ( group_prompt_logprobs, group_sample_logprobs, @@ -696,7 +728,7 @@ def _pythonize_sampler_output( seq_output.parent_seq_id = seq_ids[parent_id] seq_output.output_token = next_token_id - if logprobs_are_requested: + if any_logprobs_are_requested: seq_output.logprobs = group_sample_logprobs[tdx] else: logprobs = next(iter(seq_output.logprobs.values())) @@ -714,7 +746,7 @@ def _pythonize_sampler_output( seq_outputs.append( SequenceOutput(seq_ids[parent_id], next_token_id, (group_sample_logprobs[tdx] - if logprobs_are_requested else { + if any_logprobs_are_requested else { next_token_id: Logprob(logprob=float('inf'), rank=None, @@ -722,12 +754,12 @@ def _pythonize_sampler_output( }))) if cache is not None: completion_seq_group_output.prompt_logprobs = \ - group_prompt_logprobs if logprobs_are_requested else None + group_prompt_logprobs if any_logprobs_are_requested else None output.outputs.append(completion_seq_group_output) else: output.outputs.append( CompletionSequenceGroupOutput( seq_outputs, (group_prompt_logprobs - if logprobs_are_requested else None))) + if any_logprobs_are_requested else None))) assert len(output.outputs) > 0 From d65798f78c76f03f068fc2f69a68cff430ee6b6f Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 18 Sep 2024 12:10:27 -0400 Subject: [PATCH 78/98] [Core] zmq: bind only to 127.0.0.1 for local-only usage (#8543) Signed-off-by: Russell Bryant --- .../device_communicators/shm_broadcast.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index d4847542688c0..b507cd2e1cddb 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -196,7 +196,9 @@ def __init__( # see http://api.zeromq.org/3-3:zmq-setsockopt for more details self.local_socket.setsockopt(XPUB_VERBOSE, True) local_subscribe_port = get_open_port() - self.local_socket.bind(f"tcp://*:{local_subscribe_port}") + socket_addr = f"tcp://127.0.0.1:{local_subscribe_port}" + logger.debug("Binding to %s", socket_addr) + self.local_socket.bind(socket_addr) self.current_idx = 0 @@ -212,7 +214,8 @@ def __init__( self.remote_socket = context.socket(XPUB) self.remote_socket.setsockopt(XPUB_VERBOSE, True) remote_subscribe_port = get_open_port() - self.remote_socket.bind(f"tcp://*:{remote_subscribe_port}") + socket_addr = f"tcp://*:{remote_subscribe_port}" + self.remote_socket.bind(socket_addr) else: remote_subscribe_port = None @@ -255,8 +258,9 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue": self.local_socket = context.socket(SUB) self.local_socket.setsockopt_string(SUBSCRIBE, "") - self.local_socket.connect( - f"tcp://{handle.connect_ip}:{handle.local_subscribe_port}") + socket_addr = f"tcp://127.0.0.1:{handle.local_subscribe_port}" + logger.debug("Connecting to %s", socket_addr) + self.local_socket.connect(socket_addr) self.remote_socket = None else: @@ -270,8 +274,9 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue": self.remote_socket = context.socket(SUB) self.remote_socket.setsockopt_string(SUBSCRIBE, "") - self.remote_socket.connect( - f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}") + socket_addr = f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}" + logger.debug("Connecting to %s", socket_addr) + self.remote_socket.connect(socket_addr) return self From e18749ff09c277f7cdab278895ebdd9b1041b6e8 Mon Sep 17 00:00:00 2001 From: "Geun, Lim" Date: Thu, 19 Sep 2024 02:04:00 +0900 Subject: [PATCH 79/98] [Model] Support Solar Model (#8386) Co-authored-by: Michael Goin --- docs/source/models/supported_models.rst | 4 + vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/solar.py | 580 ++++++++++++++++++++ vllm/transformers_utils/config.py | 3 +- vllm/transformers_utils/configs/__init__.py | 2 + vllm/transformers_utils/configs/solar.py | 245 +++++++++ 6 files changed, 834 insertions(+), 1 deletion(-) create mode 100644 vllm/model_executor/models/solar.py create mode 100644 vllm/transformers_utils/configs/solar.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 3dcc242803752..745b4b8e2e0eb 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -179,6 +179,10 @@ Decoder-only Language Models - Starcoder2 - :code:`bigcode/starcoder2-3b`, :code:`bigcode/starcoder2-7b`, :code:`bigcode/starcoder2-15b`, etc. - + * - :code:`SolarForCausalLM` + - EXAONE-3 + - :code:`upstage/solar-pro-preview-instruct`, etc. + - * - :code:`XverseForCausalLM` - Xverse - :code:`xverse/XVERSE-7B-Chat`, :code:`xverse/XVERSE-13B-Chat`, :code:`xverse/XVERSE-65B-Chat`, etc. diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 41c8e754377c7..591007e787f47 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -60,6 +60,7 @@ "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), + "SolarForCausalLM": ("solar", "SolarForCausalLM"), "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), "XverseForCausalLM": ("xverse", "XverseForCausalLM"), "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py new file mode 100644 index 0000000000000..16e576d0ac29c --- /dev/null +++ b/vllm/model_executor/models/solar.py @@ -0,0 +1,580 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Solar model compatible with HuggingFace weights.""" + +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + get_compressed_tensors_cache_scale) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.interfaces import SupportsLoRA +from vllm.model_executor.models.utils import (PPMissingLayer, + is_pp_missing_parameter, + make_layers) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.utils import is_hip + + +class SolarMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class SolarAttention(nn.Module): + + def __init__( + self, + config, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr(config, "head_dim", + self.hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class SolarDecoderLayer(nn.Module): + + def __init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] \ + = config.original_max_position_embeddings + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) + self.self_attn = SolarAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = SolarMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class SolarModel(nn.Module): + + def __init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + else: + self.embed_tokens = PPMissingLayer() + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: SolarDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + bskcn_h_1 = None + bskcn_h_2 = None + bskcn_r_1 = None + bskcn_r_2 = None + bskcn_tv = (self.config.bskcn_tv[0] + if self.training else self.config.bskcn_tv[1]) + + for i in range(self.start_layer, self.end_layer): + if i in self.config.bskcn_1: + bskcn_h_1 = hidden_states.clone() + bskcn_r_1 = residual.clone() + if i in self.config.bskcn_2: + bskcn_h_2 = hidden_states.clone() + bskcn_r_2 = residual.clone() + if i in self.config.bskcn_3: + hidden_states = bskcn_h_1 * bskcn_tv + hidden_states * ( + 1 - bskcn_tv) + residual = bskcn_r_1 * bskcn_tv + residual * (1 - bskcn_tv) + if i in self.config.bskcn_4: + hidden_states = bskcn_h_2 * bskcn_tv + hidden_states * ( + 1 - bskcn_tv) + residual = bskcn_r_2 * bskcn_tv + residual * (1 - bskcn_tv) + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class SolarForCausalLM(nn.Module, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + def __init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.model = SolarModel( + config, + cache_config, + quant_config, + lora_config=lora_config, + prefix="model", + ) + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + self.sampler = Sampler() + else: + self.lm_head = PPMissingLayer() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return model_output + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros( + (batch_size, self.config.hidden_size), + dtype=dtype, + device=device, + ), + "residual": + torch.zeros( + (batch_size, self.config.hidden_size), + dtype=dtype, + device=device, + ), + }) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if scale_name := get_compressed_tensors_cache_scale(name): + # Loading kv cache scales for compressed-tensors quantization + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight[0] + weight_loader(param, loaded_weight) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # If this function is called, it should always initialize KV cache scale + # factors (or else raise an exception). Thus, handled exceptions should + # make sure to leave KV cache scale factors in a known good (dummy) state + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + for layer_idx, scaling_factor in kv_cache_scales_loader( + quantization_param_path, + tp_rank, + tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type, + ): + if not isinstance(self.model.layers[layer_idx], nn.Identity): + layer_self_attn = self.model.layers[layer_idx].self_attn + + if is_hip(): + # The scaling factor convention we are assuming is + # quantized_value * scaling_factor ~= true_value + # which is consistent with the practice of setting + # scaling_factor = tensor_amax / FPtype_max + scaling_factor *= 2 + if hasattr(layer_self_attn, "kv_scale"): + layer_self_attn.attn._kv_scale = scaling_factor + else: + raise RuntimeError("Self attention has no KV cache scaling " + "factor attribute!") diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 3c269bc10cdf8..1744935d624fb 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -24,7 +24,7 @@ JAISConfig, MedusaConfig, MLPSpeculatorConfig, MPTConfig, NemotronConfig, RWConfig, - UltravoxConfig) + SolarConfig, UltravoxConfig) # yapf: enable from vllm.transformers_utils.utils import check_gguf_file @@ -50,6 +50,7 @@ "exaone": ExaoneConfig, "internvl_chat": InternVLChatConfig, "nemotron": NemotronConfig, + "solar": SolarConfig, "ultravox": UltravoxConfig, # Granite can be removed from here once we have upgraded to # transformers 4.45+ diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 8381c5227584e..ea4fc8ad21f35 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -13,6 +13,7 @@ from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.nemotron import NemotronConfig +from vllm.transformers_utils.configs.solar import SolarConfig from vllm.transformers_utils.configs.ultravox import UltravoxConfig __all__ = [ @@ -27,6 +28,7 @@ "ExaoneConfig", "MLPSpeculatorConfig", "NemotronConfig", + "SolarConfig", "UltravoxConfig", # Granite can be removed from here once we have upgraded to # transformers 4.45+ diff --git a/vllm/transformers_utils/configs/solar.py b/vllm/transformers_utils/configs/solar.py new file mode 100644 index 0000000000000..d5113bf01695a --- /dev/null +++ b/vllm/transformers_utils/configs/solar.py @@ -0,0 +1,245 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Solar model configuration""" + +from transformers import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class SolarConfig(PretrainedConfig): + r""" + This is the configuration class to store + the configuration of a [`SolarModel`]. + It is used to instantiate an LLaMA model + according to the specified arguments, + defining the model architecture. + Instantiating a configuration with the + defaults will yield a similar + configuration to that of the LLaMA-7B. + Configuration objects inherit from [`PretrainedConfig`] + and can be used to control the model outputs. + Read the documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the LLaMA model. + Defines the number of different tokens + that can be represented by the `inputs_ids` + passed when calling [`SolarModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer + in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that + should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, + the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model + will use Multi Query Attention (MQA) + otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, + each group key and value head should be constructed + by meanpooling all the original heads within that group. + For more details checkout [this paper] + (https://arxiv.org/pdf/2305.13245.pdf). + If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) + in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + Solar 1 supports up to 2048 tokens, + Solar 2 up to 4096, CodeSolar up to 16384. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of + the truncated_normal_initializer for initializing + all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return + the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank + used during pretraining. + Please refer to [this + document](https://huggingface.co/docs/ + transformers/main/ + perf_train_gpu_many#tensor-parallelism) + to understand more about it. This value is + necessary to ensure exact reproducibility + of the pretraining results. + Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for + the RoPE embeddings. + Currently supports two scaling + strategies: linear and dynamic. + Their scaling factor must be a float greater than 1. + The expected format is + `{"type": strategy name, "factor": scaling factor}`. + When using this flag, don't update + `max_position_embeddings` to the expected new maximum. + See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/ + dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking + API changes in future versions. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value + and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj + layers in the MLP layers. + sliding_window (`int`, *optional*, defaults to 2047): + Sliding window attention window size. If not specified, + will default to `2047`. + ```python + >>> from transformers import SolarModel, SolarConfig + >>> # Initializing a Solar-pro style configuration + >>> configuration = SolarConfig() + >>> # Initializing a model from the Solar-pro style configuration + >>> model = SolarModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "solar" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + sliding_window=2047, + bskcn_1=None, + bskcn_2=None, + bskcn_3=None, + bskcn_4=None, + bskcn_tv=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + self.sliding_window = sliding_window + self.bskcn_1 = bskcn_1 if bskcn_1 is not None else [12, 20, 32, 44] + self.bskcn_2 = bskcn_2 if bskcn_2 is not None else [20, 32] + self.bskcn_3 = bskcn_3 if bskcn_3 is not None else [16, 24, 36, 48] + self.bskcn_4 = bskcn_4 if bskcn_4 is not None else [28, 40] + self.bskcn_tv = bskcn_tv if bskcn_tv is not None else [0.9, 0.8] + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if (not isinstance(self.rope_scaling, dict) + or len(self.rope_scaling) != 2): + raise ValueError( + "`rope_scaling` must be a dictionary with two fields," + " `type` and `factor`, " + f"got {self.rope_scaling}") + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in [ + "linear", + "dynamic", + ]: + raise ValueError(f"`rope_scaling`'s type field must be one of " + f"['linear', 'dynamic'], got {rope_scaling_type}") + if (rope_scaling_factor is None + or not isinstance(rope_scaling_factor, float) + or rope_scaling_factor <= 1.0): + raise ValueError( + f"`rope_scaling`'s factor field must be a float > 1," + f" got {rope_scaling_factor}") From b3195bc9e4d57b6107af2222afea26c51475e262 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Wed, 18 Sep 2024 13:41:08 -0400 Subject: [PATCH 80/98] [AMD][ROCm]Quantization methods on ROCm; Fix _scaled_mm call (#8380) Co-authored-by: Alexei-V-Ivanov-AMD <156011006+Alexei-V-Ivanov-AMD@users.noreply.github.com> Co-authored-by: Michael Goin --- vllm/config.py | 5 +- .../schemes/compressed_tensors_w8a8_fp8.py | 29 +++++++++-- .../layers/quantization/fbgemm_fp8.py | 15 +++++- .../layers/quantization/utils/w8a8_utils.py | 49 +++++++++++-------- 4 files changed, 71 insertions(+), 27 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 9d42b75c1c462..7a15606836dcc 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -255,7 +255,10 @@ def _parse_quant_hf_config(self): def _verify_quantization(self) -> None: supported_quantization = [*QUANTIZATION_METHODS] - rocm_supported_quantization = ["awq", "gptq", "fp8"] + rocm_supported_quantization = [ + "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors", + "fbgemm_fp8" + ] optimized_quantization_methods = [ "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", "awq_marlin", "fbgemm_fp8", "compressed_tensors", diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 8a3d24e2fd258..5931ec36c97d5 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -8,10 +8,12 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( QuantizationStrategy) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_fp8_linear, cutlass_fp8_supported, requantize_with_max_scale) + apply_fp8_linear, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, + requantize_with_max_scale) from vllm.model_executor.parameter import (ChannelQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) +from vllm.utils import is_hip __all__ = ["CompressedTensorsW8A8Fp8"] @@ -39,16 +41,37 @@ def process_weights_after_loading(self, layer) -> None: logical_widths=layer.logical_widths, ) + if is_hip(): + weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=max_w_scale, + input_scale=layer.input_scale) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, + requires_grad=False) + layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight_scale = Parameter(max_w_scale, requires_grad=False) # If channelwise, scales are already lined up, so just transpose. elif self.strategy == QuantizationStrategy.CHANNEL: weight = layer.weight + + if is_hip(): + weight, weight_scale, input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, + requires_grad=False) + else: + weight_scale = layer.weight_scale.data + layer.weight = Parameter(weight.t(), requires_grad=False) # required by torch.compile to be torch.nn.Parameter - layer.weight_scale = Parameter(layer.weight_scale.data, - requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) else: raise ValueError(f"Unknown quantization strategy {self.strategy}") diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index eb59344f36d2e..f26907176ad1a 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -15,10 +15,11 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_fp8_linear) + apply_fp8_linear, normalize_e4m3fn_to_e4m3fnuz) from vllm.model_executor.parameter import (ChannelQuantScaleParameter, ModelWeightParameter) from vllm.platforms import current_platform +from vllm.utils import is_hip logger = init_logger(__name__) @@ -125,8 +126,18 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.weight = Parameter(layer.weight.data, requires_grad=False) weight = layer.weight - layer.weight = Parameter(weight.t(), requires_grad=False) + if is_hip(): + weight, weight_scale, input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=layer.weight_scale, + input_scale=None) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + + layer.weight = Parameter(weight.t(), requires_grad=False) if self.quant_config.use_marlin: prepare_fp8_layer_for_marlin(layer) # Activations not quantized for marlin. diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index d86fea63d8a1b..fb263d121fe55 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -6,11 +6,9 @@ from vllm.platforms import current_platform from vllm.utils import is_hip -# scaled_mm in pytorch on rocm has a bug that requires always -# providing scaling factor for result. This value is created -# as global value to avoid multiple tensor allocations, and -# can be removed once pytorch fixes the bug. -TORCH_SCALED_MM_SCALE_RESULT = torch.ones(1).cuda() if is_hip() else None +# Input scaling factors are no longer optional in _scaled_mm starting +# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale +TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() if is_hip() else None def cutlass_fp8_supported() -> bool: @@ -131,19 +129,17 @@ def apply_fp8_linear( if per_tensor_weights and per_tensor_activations: # Fused GEMM_DQ - output = torch._scaled_mm( - qinput, - weight, - out_dtype=input.dtype, - scale_a=x_scale, - scale_b=weight_scale, - scale_result=TORCH_SCALED_MM_SCALE_RESULT, - bias=bias) - # Since in torch 2.5, scaled_mm only returns single value - # This should be removed when vllm-nvidia also moves to 2.5 - if is_hip(): - return torch.narrow(output, 0, 0, input.shape[0]) - return torch.narrow(output[0], 0, 0, input.shape[0]) + output = torch._scaled_mm(qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + return torch.narrow(output[0], 0, 0, input.shape[0]) + return torch.narrow(output, 0, 0, input.shape[0]) else: # Fallback for channelwise case, where we use unfused DQ @@ -161,12 +157,23 @@ def apply_fp8_linear( # For the scaled_mm fallback case, we break this down, since it # does not support s_w being a vector. + # Making sure the dummy tensor is on the same device as the weight + global TORCH_DEVICE_IDENTITY + if TORCH_DEVICE_IDENTITY.device != weight.device: + TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device) + # GEMM # This computes C = (X * W). # Output in fp32 to allow subsequent ops to happen in-place - output, _ = torch._scaled_mm(qinput, - weight, - out_dtype=torch.float32) + output = torch._scaled_mm(qinput, + weight, + scale_a=TORCH_DEVICE_IDENTITY, + scale_b=TORCH_DEVICE_IDENTITY, + out_dtype=torch.float32) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] # Unpad (undo num_token_padding) output = torch.narrow(output, 0, 0, input.shape[0]) x_scale = torch.narrow(x_scale, 0, 0, input.shape[0]) From db9120cdedba5033037432775417df0b6117495d Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Wed, 18 Sep 2024 16:05:06 -0400 Subject: [PATCH 81/98] [Kernel] Change interface to Mamba selective_state_update for continuous batching (#8039) --- tests/kernels/test_mamba_ssm.py | 146 ++++++++++++++++++ .../layers/mamba/ops/mamba_ssm.py | 31 +++- 2 files changed, 174 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py index f582445692344..366475222a68e 100644 --- a/tests/kernels/test_mamba_ssm.py +++ b/tests/kernels/test_mamba_ssm.py @@ -323,3 +323,149 @@ def test_selective_state_update(dim, dstate, has_z, itype): assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("itype", + [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("has_z", [False, True]) +@pytest.mark.parametrize("dstate", [16, 32, 64]) +@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) +def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) + if itype == torch.bfloat16: + rtol, atol = 7e-2, 7e-2 + if torch.version.hip: + atol *= 2 + # set seed + torch.random.manual_seed(0) + batch_size = 16 + + total_entries = 10 * batch_size + state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device) + state_indices = torch.randperm(total_entries)[:batch_size].to( + dtype=torch.int32, device=device) + + x = torch.randn(batch_size, dim, device=device, dtype=itype) + dt = torch.randn(batch_size, dim, device=device, dtype=itype) + dt_bias = torch.rand(dim, device=device) - 4.0 + A = -torch.rand(dim, dstate, device=device) - 1.0 + B = torch.randn(batch_size, dstate, device=device) + C = torch.randn(batch_size, dstate, device=device) + D = torch.randn(dim, device=device) + z = torch.randn_like(x) if has_z else None + state_ref = state[state_indices, :].detach().clone() + out = selective_state_update(state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True, + state_batch_indices=state_indices) + out_ref = selective_state_update_ref(state_ref, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True) + + assert torch.allclose(state[state_indices, :], + state_ref, + rtol=rtol, + atol=atol) + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("itype", + [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("has_z", [False, True]) +@pytest.mark.parametrize("tie_hdim", [False, True]) +@pytest.mark.parametrize("ngroups", [1, 2, 4]) +@pytest.mark.parametrize("dstate", [16, 32, 64]) +@pytest.mark.parametrize("dim", [2048, 4096]) +def test_selective_state_update_with_heads_with_batch_indices( + dim, dstate, ngroups, has_z, tie_hdim, itype): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2) + if itype == torch.bfloat16: + rtol, atol = 1e-1, 1e-1 + # set seed + torch.random.manual_seed(0) + batch_size = 16 + headdim = 64 + nheads = dim // headdim + + total_entries = 10 * batch_size + state = torch.randn(total_entries, + nheads, + headdim, + dstate, + dtype=itype, + device=device) + state_indices = torch.randperm(total_entries)[:batch_size].to( + dtype=torch.int32, device=device) + + x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype) + if not tie_hdim: + dt = torch.randn(batch_size, + nheads, + headdim, + device=device, + dtype=itype) + dt_bias = torch.rand(nheads, headdim, device=device) - 4.0 + A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0 + D = torch.randn(nheads, headdim, device=device) + else: + dt = repeat(torch.randn(batch_size, nheads, device=device, + dtype=itype), + "b h -> b h p", + p=headdim) + dt_bias = repeat(torch.rand(nheads, device=device) - 4.0, + "h -> h p", + p=headdim) + A = repeat(-torch.rand(nheads, device=device) - 1.0, + "h -> h p n", + p=headdim, + n=dstate) + D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim) + B = torch.randn(batch_size, ngroups, dstate, device=device) + C = torch.randn(batch_size, ngroups, dstate, device=device) + z = torch.randn_like(x) if has_z else None + state_ref = state[state_indices, :].detach().clone() + out = selective_state_update(state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True, + state_batch_indices=state_indices) + out_ref = selective_state_update_ref(state_ref, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + assert torch.allclose(state[state_indices, :], + state_ref, + rtol=rtol, + atol=atol) + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 869c69214caf2..a0bed07ac6193 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -1,4 +1,5 @@ # Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py import torch import triton @@ -27,6 +28,10 @@ def softplus(dt): {"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) +@triton.heuristics({ + "HAS_STATE_BATCH_INDICES": + lambda args: args["state_batch_indices_ptr"] is not None +}) @triton.heuristics( {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) @triton.jit @@ -42,6 +47,7 @@ def _selective_scan_update_kernel( D_ptr, z_ptr, out_ptr, + state_batch_indices_ptr, # Matrix dimensions batch, nheads, @@ -85,12 +91,24 @@ def _selective_scan_update_kernel( HAS_DT_BIAS: tl.constexpr, HAS_D: tl.constexpr, HAS_Z: tl.constexpr, + HAS_STATE_BATCH_INDICES: tl.constexpr, BLOCK_SIZE_DSTATE: tl.constexpr, ): pid_m = tl.program_id(axis=0) pid_b = tl.program_id(axis=1) pid_h = tl.program_id(axis=2) - state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head + + # If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate + # is taken from the state_batch_indices_ptr Otherwise, the state coordinate + # is the same as the batch id. + if HAS_STATE_BATCH_INDICES: + state_batch_indices_ptr += pid_b + state_batch_idx = tl.load(state_batch_indices_ptr) + state_ptr += (state_batch_idx * stride_state_batch + + pid_h * stride_state_head) + else: + state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head + x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head if HAS_DT_BIAS: @@ -177,7 +195,8 @@ def selective_state_update(state, D=None, z=None, dt_bias=None, - dt_softplus=False): + dt_softplus=False, + state_batch_indices=None): """ Argument: state: (batch, dim, dstate) or (batch, nheads, dim, dstate) @@ -211,7 +230,10 @@ def selective_state_update(state, z = z.unsqueeze(1) if dt_bias is not None and dt_bias.dim() == 1: dt_bias = dt_bias.unsqueeze(0) - batch, nheads, dim, dstate = state.shape + + _, nheads, dim, dstate = state.shape + batch = x.shape[0] + assert x.shape == (batch, nheads, dim) assert dt.shape == x.shape assert A.shape == (nheads, dim, dstate) @@ -225,6 +247,8 @@ def selective_state_update(state, assert z.shape == x.shape if dt_bias is not None: assert dt_bias.shape == (nheads, dim) + if state_batch_indices is not None: + assert state_batch_indices.shape == (batch, ) out = torch.empty_like(x) grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads) z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else @@ -249,6 +273,7 @@ def selective_state_update(state, D, z, out, + state_batch_indices, batch, nheads, dim, From d9cd78eb718c233ebc5b84377fc2226af7ef0fa2 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 18 Sep 2024 21:17:55 +0100 Subject: [PATCH 82/98] [BugFix] Nonzero exit code if MQLLMEngine startup fails (#8572) --- vllm/entrypoints/openai/api_server.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 1b9eb30252417..fd6f36e8768dd 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -11,7 +11,7 @@ from contextlib import asynccontextmanager from functools import partial from http import HTTPStatus -from typing import AsyncIterator, Optional, Set +from typing import AsyncIterator, Set import uvloop from fastapi import APIRouter, FastAPI, Request @@ -95,7 +95,7 @@ async def _force_log(): @asynccontextmanager async def build_async_engine_client( - args: Namespace) -> AsyncIterator[Optional[EngineClient]]: + args: Namespace) -> AsyncIterator[EngineClient]: # Context manager to handle engine_client lifecycle # Ensures everything is shutdown and cleaned up on error/exit @@ -110,7 +110,7 @@ async def build_async_engine_client( async def build_async_engine_client_from_engine_args( engine_args: AsyncEngineArgs, disable_frontend_multiprocessing: bool = False, -) -> AsyncIterator[Optional[EngineClient]]: +) -> AsyncIterator[EngineClient]: """ Create EngineClient, either: - in-process using the AsyncLLMEngine Directly @@ -188,10 +188,8 @@ async def build_async_engine_client_from_engine_args( break except TimeoutError: if not engine_process.is_alive(): - logger.error("Engine process died before responding " - "to readiness probe") - yield None - return + raise RuntimeError( + "Engine process failed to start") from None yield mp_engine_client # type: ignore[misc] finally: @@ -532,10 +530,6 @@ def signal_handler(*_) -> None: signal.signal(signal.SIGTERM, signal_handler) async with build_async_engine_client(args) as engine_client: - # If None, creation of the client failed and we exit. - if engine_client is None: - return - app = build_app(args) model_config = await engine_client.get_model_config() From 0d47bf3bf40edfe9fcfd7e5cd909388497535bc5 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Wed, 18 Sep 2024 16:10:01 -0600 Subject: [PATCH 83/98] [Bugfix] add `dead_error` property to engine client (#8574) Signed-off-by: Joe Runde --- vllm/engine/multiprocessing/client.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 18b620c74ddf9..2cb4de79131f1 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -380,6 +380,13 @@ def is_stopped(self) -> bool: def errored(self) -> bool: return self._errored_with is not None + @property + def dead_error(self) -> BaseException: + if self._errored_with is not None: + return ENGINE_DEAD_ERROR(self._errored_with) + else: + return ENGINE_DEAD_ERROR() + async def generate( self, inputs: PromptInputs, From 4c34ce8916da0e4967eadefcb7f91eb58dd7ac61 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Wed, 18 Sep 2024 21:42:49 -0400 Subject: [PATCH 84/98] [Kernel] Remove marlin moe templating on thread_m_blocks (#8573) Co-authored-by: lwilkinson@neuralmagic.com --- csrc/moe/marlin_moe_ops.cu | 79 ++++++++++++++------------------------ 1 file changed, 28 insertions(+), 51 deletions(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 666d87eb92595..49cc03f827f68 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -1342,9 +1342,6 @@ __device__ inline void MarlinMoESingle( template shared @@ -1459,9 +1456,6 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, template shared @@ -1515,26 +1509,24 @@ const int STAGES = 4; // 4 pipeline stages fit into shared memory static constexpr int min_thread_n = 64; static constexpr int min_thread_k = 64; -#define __CALL_IF_MOE(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, \ - NUM_THREADS) \ - else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ - num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute( \ - MarlinMoE, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - MarlinMoE \ - <<>>( \ - A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ - g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ - num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ - replicate_input, apply_weights, m_block, max_par, \ - exec_cfg.max_m_blocks); \ +#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \ + GROUP_BLOCKS, NUM_THREADS) \ + else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ + num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute( \ + MarlinMoE, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + MarlinMoE \ + <<>>( \ + A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ + g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ + num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ + replicate_input, apply_weights, m_block, max_par, \ + exec_cfg.max_m_blocks); \ } typedef struct { @@ -1711,31 +1703,16 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, return exec_config_t{0, {-1, -1, -1}}; } -#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) +#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, const void* sorted_ids, const void* topk_weights, From 3118f63385c0d767fba8b6d2039fc35440678da9 Mon Sep 17 00:00:00 2001 From: sroy745 <142070531+sroy745@users.noreply.github.com> Date: Wed, 18 Sep 2024 19:24:15 -0700 Subject: [PATCH 85/98] [Bugfix] [Encoder-Decoder] Bugfix for encoder specific metadata construction during decode of encoder-decoder models. (#8545) --- .../test_encoder_decoder_model_runner.py | 88 +++++++++++++------ vllm/worker/enc_dec_model_runner.py | 12 +-- 2 files changed, 69 insertions(+), 31 deletions(-) diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index c0654712b71b5..27cdf5f339ede 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -273,7 +273,8 @@ def test_prepare_prompt(batch_size): "unsupported for encoder/ " "decoder models") @pytest.mark.parametrize("batch_size", BATCH_SIZES) -def test_prepare_decode(batch_size): +@pytest.mark.parametrize("multiple_seqs_per_seq_group", [True, False]) +def test_prepare_decode(batch_size, multiple_seqs_per_seq_group): ''' Test the ability of the encoder/decoder model runner subclass to produce decode-phase model inputs & attention metadata. @@ -288,6 +289,7 @@ def test_prepare_decode(batch_size): Arguments: * batch_size + * multiple_seqs_per_seq_group * backend_name: The attention backend under test * enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph) ''' @@ -305,22 +307,29 @@ def test_prepare_decode(batch_size): seq_lens: List[int] = [] encoder_seq_lens: List[int] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = [] - block_tables = {0: [1]} + block_tables = { + 0: [1], + 1: [3] + } if multiple_seqs_per_seq_group else { + 0: [1] + } cross_block_table = [2] for i in range(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 - seq_lens.append(seq_len) seq_data = SequenceData( array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len)))) encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 - encoder_seq_lens.append(encoder_seq_len) encoder_seq_data = SequenceData( array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len)))) + seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=False, - seq_data={0: seq_data}, + seq_data={ + 0: seq_data, + 1: seq_data + } if multiple_seqs_per_seq_group else {0: seq_data}, sampling_params=SamplingParams(temperature=0), block_tables=block_tables, encoder_seq_data=encoder_seq_data, @@ -328,6 +337,10 @@ def test_prepare_decode(batch_size): ) assert seq_group_metadata.token_chunk_size == 1 seq_group_metadata_list.append(seq_group_metadata) + seq_lens.extend( + [seq_len for _ in range(len(seq_group_metadata.seq_data))]) + encoder_seq_lens.extend( + [encoder_seq_len for _ in range(len(seq_group_metadata.seq_data))]) # Build # * Decoder model inputs @@ -398,19 +411,24 @@ def test_prepare_decode(batch_size): # Verify block tables are correct for prompts # - Decoder self-attention - expected = torch.tensor( - [block_tables[0] for _ in range(len(seq_group_metadata_list))], - dtype=torch.int32, - device=model_runner.device) + flattened_block_tables = [ + block_table for block_table in block_tables.values() + ] + expected = torch.tensor(flattened_block_tables * + len(seq_group_metadata_list), + dtype=torch.int32, + device=model_runner.device) assert torch.equal( attn_metadata.block_tables, expected, ) # - Encoder/decoder cross-attention - expected = torch.tensor( - [cross_block_table for _ in range(len(seq_group_metadata_list))], - dtype=torch.int32, - device=model_runner.device) + expected = torch.tensor([ + cross_block_table for seq_group_metadata in seq_group_metadata_list + for _ in range(len(seq_group_metadata.seq_data)) + ], + dtype=torch.int32, + device=model_runner.device) assert torch.equal( attn_metadata.cross_block_tables, expected, @@ -474,7 +492,8 @@ def test_prepare_decode(batch_size): @pytest.mark.parametrize("batch_size", list(range(1, 257))) -def test_prepare_decode_cuda_graph(batch_size): +@pytest.mark.parametrize("multiple_seqs_per_seq_group", [True, False]) +def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group): """ Tests that for encoder-decoder models with CUDA Graph capture and replay enabled, the tensors used during the decode phase are correctly padded @@ -489,32 +508,45 @@ def test_prepare_decode_cuda_graph(batch_size): enable_chunked_prefill=False, enforce_eager=False, ) - + block_tables = { + 0: [1], + 1: [3] + } if multiple_seqs_per_seq_group else { + 0: [1] + } seq_lens: List[int] = [] encoder_seq_lens: List[int] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = [] - block_tables = {0: [1]} + cross_block_table = [2] + expanded_batch_size = 0 for i in range(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 - seq_lens.append(seq_len) seq_data = SequenceData( array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len)))) encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 - encoder_seq_lens.append(encoder_seq_len) encoder_seq_data = SequenceData( array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len)))) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=False, - seq_data={0: seq_data}, + seq_data={ + 0: seq_data, + 1: seq_data + } if multiple_seqs_per_seq_group else {0: seq_data}, sampling_params=SamplingParams(temperature=0), block_tables=block_tables, encoder_seq_data=encoder_seq_data, cross_block_table=cross_block_table, ) assert seq_group_metadata.token_chunk_size == 1 + seq_lens.extend( + [seq_len for _ in range(len(seq_group_metadata.seq_data))]) + encoder_seq_lens.extend( + [encoder_seq_len for _ in range(len(seq_group_metadata.seq_data))]) + expanded_batch_size = expanded_batch_size + len( + seq_group_metadata.seq_data) seq_group_metadata_list.append(seq_group_metadata) model_input = model_runner.prepare_model_input(seq_group_metadata_list) @@ -530,8 +562,8 @@ def test_prepare_decode_cuda_graph(batch_size): # With CUDA Graph capture and replay enabled, the decoder and encoder # input sequences will be padded. Create the expected padded tensors # accordingly. - graph_batch_size = _get_graph_batch_size(batch_size) - cuda_graph_pad_size = graph_batch_size - batch_size + graph_batch_size = _get_graph_batch_size(expanded_batch_size) + cuda_graph_pad_size = graph_batch_size - expanded_batch_size padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size)) padded_encoder_seq_lens = encoder_seq_lens + list( itertools.repeat(1, cuda_graph_pad_size)) @@ -560,10 +592,13 @@ def test_prepare_decode_cuda_graph(batch_size): # Verify block tables are correct for prompts # - Decoder self-attention. Pad the block tables as expected. - expected = [block_tables[0] for _ in range(batch_size)] - expected.extend([[] for _ in range(cuda_graph_pad_size)]) + flattened_block_tables = [ + block_table for _ in range(len(seq_group_metadata_list)) + for block_table in block_tables.values() + ] + flattened_block_tables.extend([[] for _ in range(cuda_graph_pad_size)]) expected = make_tensor_with_pad( - expected, + flattened_block_tables, max_len=64, pad=0, dtype=torch.int32, @@ -575,7 +610,10 @@ def test_prepare_decode_cuda_graph(batch_size): ) # - Encoder/decoder cross-attention. Pad the cross-attention block tables # as expected. - expected = [cross_block_table for _ in range(len(seq_group_metadata_list))] + expected = [ + cross_block_table for seq_group_metadata in seq_group_metadata_list + for _ in range(len(seq_group_metadata.seq_data)) + ] expected.extend([[] for _ in range(cuda_graph_pad_size)]) expected = make_tensor_with_pad( expected, diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 09dab0135f390..709efdc8b9d57 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -435,18 +435,18 @@ def _prepare_encoder_model_input_tensors( encoder_input_tokens_tensor = self._empty_long_tensor() encoder_input_positions_tensor = self._empty_long_tensor() cross_slot_mapping_tensor = self._empty_long_tensor() - # Extract cross-attention block tables & # seq len from each sequence group metadata. # Cross-attention block tables are empty # during vLLM memory profiling. cross_block_tables = [] for seq_group_metadata in seq_group_metadata_list: - encoder_seq_lens.append( - seq_group_metadata.encoder_seq_data.get_len()) - cross_block_table = seq_group_metadata.cross_block_table - cross_block_tables.append([] if ( - cross_block_table is None) else cross_block_table) + for _ in range(len(seq_group_metadata.seq_data)): + encoder_seq_lens.append( + seq_group_metadata.encoder_seq_data.get_len()) + cross_block_table = seq_group_metadata.cross_block_table + cross_block_tables.append([] if ( + cross_block_table is None) else cross_block_table) if (model_input.attn_metadata is not None and model_input.attn_metadata.use_cuda_graph): From 02c9afa2d04a85269faa2760e9af30527a61d7f6 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Wed, 18 Sep 2024 21:14:28 -0700 Subject: [PATCH 86/98] Revert "[Misc][Bugfix] Disable guided decoding for mistral tokenizer" (#8593) --- .../guided_decoding/__init__.py | 23 ------------------- 1 file changed, 23 deletions(-) diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index f4fe8a7307c04..7161e83952a3d 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -6,7 +6,6 @@ from vllm.model_executor.guided_decoding.guided_fields import ( GuidedDecodingRequest) from vllm.sampling_params import LogitsProcessor -from vllm.transformers_utils.tokenizer import MistralTokenizer async def get_guided_decoding_logits_processor( @@ -16,23 +15,12 @@ async def get_guided_decoding_logits_processor( request = _adapt_request_for_tool_use(request) if guided_decoding_backend == 'outlines': - if isinstance(tokenizer, MistralTokenizer): - raise NotImplementedError( - "Guided decoding with 'outlines' is currently not supported " - "for Mistral tokenizer. Please consider contributing to the " - "'outlines' project if you are interested in this feature.") # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_outlines_guided_decoding_logits_processor) return await get_outlines_guided_decoding_logits_processor( request, tokenizer) if guided_decoding_backend == 'lm-format-enforcer': - if isinstance(tokenizer, MistralTokenizer): - raise NotImplementedError( - "Guided decoding with 'lm-format-enforcer' is currently not " - "supported for Mistral tokenizer. Please consider contributing " - "to the 'lm-format-enforcer' project if you are interested " - "in this feature.") from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa get_lm_format_enforcer_guided_decoding_logits_processor) return await get_lm_format_enforcer_guided_decoding_logits_processor( @@ -49,23 +37,12 @@ def get_local_guided_decoding_logits_processor( # request = _adapt_request_for_tool_use(request) if guided_decoding_backend == 'outlines': - if isinstance(tokenizer, MistralTokenizer): - raise NotImplementedError( - "Guided decoding with 'outlines' is currently not supported " - "for Mistral tokenizer. Please consider contributing to the " - "'outlines' project if you are interested in this feature.") # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_local_outlines_guided_decoding_logits_processor) return get_local_outlines_guided_decoding_logits_processor( guided_options, tokenizer) if guided_decoding_backend == 'lm-format-enforcer': - if isinstance(tokenizer, MistralTokenizer): - raise NotImplementedError( - "Guided decoding with 'lm-format-enforcer' is currently not " - "supported for Mistral tokenizer. Please consider contributing " - "to the 'lm-format-enforcer' project if you are interested " - "in this feature.") from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa get_local_lm_format_enforcer_guided_decoding_logits_processor) return get_local_lm_format_enforcer_guided_decoding_logits_processor( From c52ec5f03471008fa1312d82fb17d40b95a3ca5d Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Wed, 18 Sep 2024 22:24:24 -0700 Subject: [PATCH 87/98] [Bugfix] fixing sonnet benchmark bug in benchmark_serving.py (#8616) --- benchmarks/benchmark_serving.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 3ace910a6cac6..a407a263120bb 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -626,9 +626,9 @@ def main(args: argparse.Namespace): prefix_len=args.sonnet_prefix_len, tokenizer=tokenizer, ) - input_requests = [(prompt, prompt_len, output_len) + input_requests = [(prompt, prompt_len, output_len, None) for prompt, prompt_formatted, prompt_len, - output_len in input_requests] + output_len, _ in input_requests] else: assert ( tokenizer.chat_template or tokenizer.default_chat_template @@ -641,9 +641,9 @@ def main(args: argparse.Namespace): prefix_len=args.sonnet_prefix_len, tokenizer=tokenizer, ) - input_requests = [(prompt_formatted, prompt_len, output_len) + input_requests = [(prompt_formatted, prompt_len, output_len, None) for prompt, prompt_formatted, prompt_len, - output_len in input_requests] + output_len, _ in input_requests] elif args.dataset_name == "hf": input_requests = sample_hf_requests( @@ -963,4 +963,4 @@ def main(args: argparse.Namespace): ) args = parser.parse_args() - main(args) + main(args) \ No newline at end of file From 855c8ae2c9a4085b1ebd66d9a978fb23f47f822c Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Thu, 19 Sep 2024 13:33:20 +0800 Subject: [PATCH 88/98] [MISC] remove engine_use_ray in benchmark_throughput.py (#8615) --- benchmarks/benchmark_throughput.py | 1 - 1 file changed, 1 deletion(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 3f531ee82cc94..e1a5d4ee28ea1 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -191,7 +191,6 @@ async def run_vllm_async( use_v2_block_manager=use_v2_block_manager, disable_async_output_proc=disable_async_output_proc, worker_use_ray=False, - engine_use_ray=False, disable_log_requests=True, ) From 76515f303b44cb3ffc6de63c49148d5081a77119 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 19 Sep 2024 17:51:06 +0100 Subject: [PATCH 89/98] [Frontend] Use MQLLMEngine for embeddings models too (#8584) --- vllm/engine/multiprocessing/__init__.py | 7 +- vllm/engine/multiprocessing/client.py | 106 +++++++++++++++++------- vllm/engine/multiprocessing/engine.py | 23 ++--- 3 files changed, 90 insertions(+), 46 deletions(-) diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index ba5c6e15fc821..700332864d17a 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -2,6 +2,7 @@ from enum import Enum from typing import List, Mapping, Optional, Union +from vllm import PoolingParams from vllm.inputs import PromptInputs from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput @@ -21,9 +22,9 @@ class MQEngineDeadError(RuntimeError): @dataclass -class RPCGenerateRequest: +class RPCProcessRequest: inputs: PromptInputs - sampling_params: SamplingParams + params: Union[SamplingParams, PoolingParams] request_id: str lora_request: Optional[LoRARequest] = None trace_headers: Optional[Mapping[str, str]] = None @@ -55,7 +56,7 @@ class RPCStartupResponse: tracing_enabled: bool -RPC_REQUEST_T = Union[RPCGenerateRequest, RPCAbortRequest, RPCHealthRequest, +RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCHealthRequest, RPCStartupRequest] REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError] diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 2cb4de79131f1..aa9dbbd448af2 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -11,6 +11,7 @@ from zmq import Frame # type: ignore[attr-defined] from zmq.asyncio import Socket +from vllm import PoolingParams from vllm.config import DecodingConfig, EngineConfig, ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs # yapf conflicts with isort for this block @@ -19,8 +20,8 @@ IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_OUTPUT_EXT, RPC_REQUEST_T, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, - RPCError, RPCGenerateRequest, - RPCHealthRequest, RPCStartupRequest, + RPCError, RPCHealthRequest, + RPCProcessRequest, RPCStartupRequest, RPCStartupResponse) # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT @@ -111,20 +112,8 @@ def __init__(self, ipc_path: str, engine_config: EngineConfig): @staticmethod def is_unsupported_config(engine_args: AsyncEngineArgs): - if engine_args.pipeline_parallel_size > 1: - return True - - is_embedding = ModelConfig( - model=engine_args.model, - revision=engine_args.revision, - tokenizer=engine_args.model, - tokenizer_mode="auto", - trust_remote_code=engine_args.trust_remote_code, - quantization=engine_args.quantization, - seed=0, - dtype="auto").embedding_mode - - return is_embedding + # Pipeline parallel not yet supported + return engine_args.pipeline_parallel_size > 1 @contextmanager def get_data_socket(self) -> Iterator[Socket]: @@ -382,12 +371,9 @@ def errored(self) -> bool: @property def dead_error(self) -> BaseException: - if self._errored_with is not None: - return ENGINE_DEAD_ERROR(self._errored_with) - else: - return ENGINE_DEAD_ERROR() + return ENGINE_DEAD_ERROR(self._errored_with) - async def generate( + def generate( self, inputs: PromptInputs, sampling_params: SamplingParams, @@ -396,6 +382,67 @@ async def generate( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncGenerator[RequestOutput, None]: + """Generate outputs for a request. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. + + Args: + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` + for more details about the format of each input. + sampling_params: The sampling parameters of the request. + request_id: The unique id of the request. + lora_request: LoRA request to use for generation, if any. + trace_headers: OpenTelemetry trace headers. + prompt_adapter_request: Prompt Adapter request to use + for generation, if any. + """ + return self._process_request(inputs, sampling_params, request_id, + lora_request, trace_headers, + prompt_adapter_request) + + def encode( + self, + inputs: PromptInputs, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + """Generate outputs for a request from an embedding model. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. + + Args: + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` + for more details about the format of each input. + pooling_params: The pooling parameters of the request. + request_id: The unique id of the request. + lora_request: LoRA request to use for generation, if any. + trace_headers: OpenTelemetry trace headers. + + Yields: + The output `EmbeddingRequestOutput` objects from the LLMEngine + for the request. + """ + return self._process_request(inputs, pooling_params, request_id, + lora_request, trace_headers) + + async def _process_request( + self, + inputs: PromptInputs, + params: Union[SamplingParams, PoolingParams], + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None + ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ + EmbeddingRequestOutput, None]]: """Send an RPCGenerateRequest to the RPCServer and stream responses.""" # If already dead, error out. @@ -410,19 +457,19 @@ async def generate( try: # 2) Detach logits processors so that they can be pickled # separately (may require cloudpickle which is slower) - if sampling_params.logits_processors: + if isinstance(params, SamplingParams) and params.logits_processors: # Defensive shallow copy - sampling_params = copy.copy(sampling_params) - logits_processors = sampling_params.logits_processors - sampling_params.logits_processors = None + params = copy.copy(params) + logits_processors = params.logits_processors + params.logits_processors = None lp_bytes = cloudpickle.dumps(logits_processors) else: lp_bytes = None request_bytes = pickle.dumps( - RPCGenerateRequest( + RPCProcessRequest( inputs=inputs, - sampling_params=sampling_params, + params=params, request_id=request_id, lora_request=lora_request, trace_headers=trace_headers, @@ -452,8 +499,3 @@ async def generate( await self.abort(request_id) finally: self.output_queues.pop(request_id) - - async def encode(self, *args, - **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: - raise NotImplementedError( - "Embeddings not supported with multiprocessing backend") diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 70cd6e5cb6000..f4ca231570853 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -6,7 +6,7 @@ import cloudpickle import zmq -from vllm import AsyncEngineArgs, LLMEngine +from vllm import AsyncEngineArgs, LLMEngine, SamplingParams from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) # yapf conflicts with isort for this block @@ -15,8 +15,8 @@ IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, - RPCError, RPCGenerateRequest, - RPCHealthRequest, RPCStartupRequest, + RPCError, RPCHealthRequest, + RPCProcessRequest, RPCStartupRequest, RPCStartupResponse) # yapf: enable from vllm.logger import init_logger @@ -39,8 +39,8 @@ class MQLLMEngine: in concurrnet manner. It runs a background loop and uses zeromq to receive new requests and stream outputs incrementally via ipc. - The :class:`LLMEngine.generate` is kicked off when a new - RPCGenerateRequest is received by the input_socket. + The :class:`LLMEngine` generate or encode process is kicked off when a new + RPCProcessRequest is received by the input_socket. The self.engine_loop checks the input_socket for new requests, adds them to the LLMEngine if there are any, calls the internal @@ -213,12 +213,13 @@ def handle_new_input(self): frames = self.input_socket.recv_multipart(copy=False) request = pickle.loads(frames[0].buffer) - if isinstance(request, RPCGenerateRequest): + if isinstance(request, RPCProcessRequest): if len(frames) > 1: # Use cloudpickle for logits processors + assert isinstance(request.params, SamplingParams) lprocs = cloudpickle.loads(frames[1].buffer) - request.sampling_params.logits_processors = lprocs - self._handle_generate_request(request) + request.params.logits_processors = lprocs + self._handle_process_request(request) elif isinstance(request, RPCAbortRequest): self._handle_abort_request(request) elif isinstance(request, RPCHealthRequest): @@ -231,8 +232,8 @@ def handle_new_input(self): self._send_unhealthy(e) raise e - def _handle_generate_request(self, request: RPCGenerateRequest): - """Handle RPCGenerateRequest by adding it to the LLMEngine.""" + def _handle_process_request(self, request: RPCProcessRequest): + """Handle RPCProcessRequest by adding it to the LLMEngine.""" request_id = request.request_id if self._errored_with is not None: @@ -245,7 +246,7 @@ def _handle_generate_request(self, request: RPCGenerateRequest): self.engine.add_request( request_id=request_id, inputs=request.inputs, - params=request.sampling_params, + params=request.params, lora_request=request.lora_request, trace_headers=request.trace_headers, prompt_adapter_request=request.prompt_adapter_request) From 9cc373f39036af789fb1ffc1e06b23766996d3f4 Mon Sep 17 00:00:00 2001 From: Charlie Fu Date: Thu, 19 Sep 2024 12:37:57 -0500 Subject: [PATCH 90/98] [Kernel][Amd] Add fp8 kv cache support for rocm custom paged attention (#8577) --- csrc/rocm/attention.cu | 240 +++++++++++++------- csrc/rocm/ops.h | 3 +- csrc/rocm/torch_bindings.cpp | 3 +- tests/kernels/test_attention.py | 251 ++++++--------------- vllm/_custom_ops.py | 4 +- vllm/attention/backends/rocm_flash_attn.py | 28 +-- 6 files changed, 246 insertions(+), 283 deletions(-) diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 8fa7c862fbfa8..b48348a515c8d 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -18,8 +18,11 @@ #include #include #include +#include "cuda_compat.h" #include +#include "../attention/dtype_fp8.cuh" +#include "../quantization/fp8/amd/quant_utils.cuh" #if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \ defined(__gfx941__) || defined(__gfx942__)) @@ -38,7 +41,6 @@ #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) -#define WARP_SIZE 64 #if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support @@ -60,6 +62,8 @@ typedef struct _B16x8 { _B16x4 xy[2]; } _B16x8; +using _B8x8 = uint2; + ////// Non temporal load stores /////// template @@ -168,18 +172,40 @@ __device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, } } +template +__device__ __forceinline__ _B16x8 scaled_convert_b8x8(const _B8x8 input, + const float scale) { + union alignas(16) { + uint4 u4; + _B16x8 u16x8; + vllm::bf16_8_t b16x8; + } tmp; + if constexpr (std::is_same::value) { + tmp.u4 = vllm::fp8::scaled_convert(input, scale); + return tmp.u16x8; + } else if constexpr (std::is_same::value) { + tmp.b16x8 = vllm::fp8::scaled_convert( + input, scale); + return tmp.u16x8; + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + /////////////////////////////////////// // grid (num_seqs, num_partitions,num_heads/gqa_ratio) // block (partition size) -template __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] const int num_kv_heads, const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ context_lens, // [num_seqs] @@ -192,10 +218,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, // head_size] scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] - #if 0 - scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size] - #endif - int max_ctx_blocks) { + int max_ctx_blocks, float k_scale, float v_scale) { constexpr int NWARPS = NUM_THREADS / WARP_SIZE; const int warpid = threadIdx.x / WARP_SIZE; const int laneid = threadIdx.x % WARP_SIZE; @@ -222,12 +245,14 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( constexpr int x = 16 / sizeof(scalar_t); constexpr int KHELOOP = HEAD_SIZE / x; _B16x8 Klocal[KHELOOP]; + _B8x8 Klocalb8[KHELOOP]; constexpr int VHELOOP = HEAD_SIZE / WARP_SIZE; // v head_size dimension is distributed across lanes constexpr int VTLOOP = 8; // 16 separate 4xtokens across warp -> 16/2 // 8xtokens _B16x8 Vlocal[VHELOOP][VTLOOP]; + _B8x8 Vlocalb8[VHELOOP][VTLOOP]; floatx4 dout[QHLOOP]; float qk_max[QHLOOP]; #pragma unroll @@ -279,6 +304,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( (vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block; vphysical_blocks[b] = block_table[vblock_idx_ctx]; } + // each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems const scalar_t* q_ptr = q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE; @@ -298,17 +324,29 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( Qlocal[QHLOOP - 1].xy[1] = {0}; } - const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride + - wg_start_kv_head_idx * kv_head_stride; + const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride + + wg_start_kv_head_idx * kv_head_stride; const int physical_block_offset = local_token_idx % BLOCK_SIZE; // since x=half8, physical_block_offset // is already cast as _H8 - - const _B16x8* k_ptrh8 = reinterpret_cast(k_ptr); + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + const _B16x8* k_ptrh8 = reinterpret_cast(k_ptr); + #pragma unroll + for (int d = 0; d < KHELOOP; d++) { + Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset]; + } + } else { + constexpr int X = 16 / sizeof(cache_t); + const cache_t* k_ptr2 = k_ptr + physical_block_offset * X; #pragma unroll - for (int d = 0; d < KHELOOP; d++) { - Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset]; + for (int d = 0; d < KHELOOP; d++) { + const int head_elem = d * 8; + const int offset1 = head_elem / X; + const int offset2 = head_elem % X; + const cache_t* k_ptr3 = k_ptr2 + offset1 * BLOCK_SIZE * X + offset2; + Klocalb8[d] = *reinterpret_cast(k_ptr3); + } } float alibi_slope[QHLOOP]; @@ -322,30 +360,66 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } - const scalar_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; - const _B16x8* v_ptrh8 = reinterpret_cast(v_ptr); - // iterate over each v block + const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + const _B16x8* v_ptrh8 = reinterpret_cast(v_ptr); + // iterate over each v block #pragma unroll - for (int b = 0; b < VBLOCKS; b++) { - // int32 physical_block_number leads to overflow when multiplied with - // kv_block_stride - const int64_t vphysical_block_number = - static_cast(vphysical_blocks[b]); - const _B16x8* v_ptrh8b = - v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; - // iterate over each head elem (within head_size) + for (int b = 0; b < VBLOCKS; b++) { + // int32 physical_block_number leads to overflow when multiplied with + // kv_block_stride + const int64_t vphysical_block_number = + static_cast(vphysical_blocks[b]); + const _B16x8* v_ptrh8b = + v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; + // iterate over each head elem (within head_size) + #pragma unroll + for (int h = 0; h < VHELOOP; h++) { + const int head_size_elem = h * WARP_SIZE + laneid; + const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; + // iterate over all velems within block + #pragma unroll + for (int d = 0; d < BLOCK_SIZE / 8; d++) { + Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; + } + } + } + } else { + const _B8x8* v_ptrh8 = reinterpret_cast(v_ptr); + // iterate over each v block + #pragma unroll + for (int b = 0; b < VBLOCKS; b++) { + // int32 physical_block_number leads to overflow when multiplied with + // kv_block_stride + const int64_t vphysical_block_number = + static_cast(vphysical_blocks[b]); + const _B8x8* v_ptrh8b = + v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; + // iterate over each head elem (within head_size) #pragma unroll - for (int h = 0; h < VHELOOP; h++) { - const int head_size_elem = h * WARP_SIZE + laneid; - const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; - // iterate over all velems within block + for (int h = 0; h < VHELOOP; h++) { + const int head_size_elem = h * WARP_SIZE + laneid; + const _B8x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; + // iterate over all velems within block #pragma unroll - for (int d = 0; d < BLOCK_SIZE / 8; d++) { - Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; + for (int d = 0; d < BLOCK_SIZE / 8; d++) { + // Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; + const _B8x8 Vlocalb8 = v_ptrh8be[d]; + Vlocal[h][b * BLOCK_SIZE / 8 + d] = + scaled_convert_b8x8(Vlocalb8, v_scale); + } } } } + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + #pragma unroll + for (int d = 0; d < KHELOOP; d++) { + Klocal[d] = + scaled_convert_b8x8(Klocalb8[d], k_scale); + } + } + #pragma unroll for (int h = 0; h < QHLOOP; h++) { dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], @@ -794,14 +868,16 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support -template __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] const int num_kv_heads, const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ context_lens, // [num_seqs] @@ -814,10 +890,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, // head_size] scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] - #if 0 - scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size] - #endif - int max_ctx_blocks) { + int max_ctx_blocks, float k_scale, float v_scale) { UNREACHABLE_CODE } @@ -839,26 +912,24 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( #endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support #define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \ - paged_attention_ll4mi_QKV_kernel \ + paged_attention_ll4mi_QKV_kernel \ <<>>( \ query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ - exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks); + exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \ + k_scale, v_scale); -template +template void paged_attention_custom_launcher( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, const int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& context_lens, - int max_context_len, -#if 0 - torch::Tensor& qk_out, - torch::Tensor& softmax_out, -#endif - const c10::optional& alibi_slopes) { - + int max_context_len, const c10::optional& alibi_slopes, + float k_scale, float v_scale) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -878,14 +949,10 @@ void paged_attention_custom_launcher( float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); - T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + KVT* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + KVT* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* context_lens_ptr = context_lens.data_ptr(); -#if 0 - T* qk_out_ptr = reinterpret_cast(qk_out.data_ptr()); - T* softmax_out_ptr = reinterpret_cast(softmax_out.data_ptr()); -#endif const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); const int max_num_partitions = @@ -972,32 +1039,32 @@ void paged_attention_custom_launcher( } } -#define CALL_CUSTOM_LAUNCHER(T, BLK_SIZE, HEAD_SIZE) \ - paged_attention_custom_launcher( \ - out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, context_lens, max_context_len, \ - alibi_slopes); +#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ + paged_attention_custom_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, context_lens, max_context_len, \ + alibi_slopes, k_scale, v_scale); -#define CALL_CUSTOM_LAUNCHER_BLK(T, HEAD_SIZE) \ +#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \ switch (block_size) { \ case 16: \ - CALL_CUSTOM_LAUNCHER(T, 16, HEAD_SIZE); \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \ break; \ case 32: \ - CALL_CUSTOM_LAUNCHER(T, 32, HEAD_SIZE); \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \ break; \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ break; \ } -#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T) \ +#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \ switch (head_size) { \ case 64: \ - CALL_CUSTOM_LAUNCHER_BLK(T, 64); \ + CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64); \ break; \ case 128: \ - CALL_CUSTOM_LAUNCHER_BLK(T, 128); \ + CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128); \ break; \ default: \ TORCH_CHECK(false, "Unsupported head size: ", head_size); \ @@ -1020,19 +1087,34 @@ void paged_attention( torch::Tensor& context_lens, // [num_seqs] int64_t block_size, int64_t max_context_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype) { - assert(kv_cache_dtype == "auto"); + const std::string& kv_cache_dtype, double k_scale, double v_scale) { const int head_size = query.size(2); - if (query.dtype() == at::ScalarType::Half) { - CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16); + if (kv_cache_dtype == "auto") { + if (query.dtype() == at::ScalarType::Half) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, _Float16, + vllm::Fp8KVCacheDataType::kAuto); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, __hip_bfloat16, + vllm::Fp8KVCacheDataType::kAuto); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } + } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") { + if (query.dtype() == at::ScalarType::Half) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + TORCH_CHECK(false, "Unsupported KV cache dtype: ", kv_cache_dtype); } } #undef WARP_SIZE #undef MAX #undef MIN -#undef DIVIDE_ROUND_UP +#undef DIVIDE_ROUND_UP \ No newline at end of file diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index 4a07a3f1775bd..9f085115a3956 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -10,4 +10,5 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& context_lens, int64_t block_size, int64_t max_context_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype); + const std::string& kv_cache_dtype, double k_scale, + double v_scale); diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index 082e314587908..a283d4263d293 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -26,7 +26,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { " Tensor context_lens, int block_size," " int max_context_len," " Tensor? alibi_slopes," - " str kv_cache_dtype) -> ()"); + " str kv_cache_dtype," + " float k_scale, float v_scale) -> ()"); rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention); } diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 4bd6f7863a658..ecab512cba16f 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -31,8 +31,7 @@ # FlashAttention forward only supports head dimension at most 128 # https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 -HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256 - ] if not is_hip() else [64, 80, 96, 112, 128] +HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256] BLOCK_SIZES = [16, 32] USE_ALIBI = [False, True] @@ -114,7 +113,8 @@ def ref_single_query_cached_kv_attention( output[i].copy_(out, non_blocking=True) -@pytest.mark.parametrize("version", ["v1", "v2"]) +@pytest.mark.parametrize( + "version", ["v1", "v2"] if not is_hip() else ["v1", "v2", "rocm"]) @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -137,7 +137,8 @@ def test_paged_attention( seed: int, device: str, ) -> None: - if kv_cache_dtype == "fp8" and head_size % 16: + if ((kv_cache_dtype == "fp8" and head_size % 16) + or (version == "rocm" and head_size not in (64, 128))): pytest.skip() seed_everything(seed) @@ -206,7 +207,7 @@ def test_paged_attention( kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), cond=(head_size == HEAD_SIZES[0])) - elif version == "v2": + elif version in ("v2", "rocm"): num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) assert PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape @@ -219,32 +220,61 @@ def test_paged_attention( dtype=torch.float32, ) max_logits = torch.empty_like(exp_sums) - ops.paged_attention_v2( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - seq_lens, - block_size, - max_seq_len, - alibi_slopes, - kv_cache_dtype, - k_scale, - v_scale, - ) - - opcheck(torch.ops._C.paged_attention_v2, - (output, exp_sums, max_logits, tmp_output, query, key_cache, - value_cache, num_kv_heads, scale, block_tables, seq_lens, - block_size, max_seq_len, alibi_slopes, kv_cache_dtype, - k_scale, v_scale, 0, 0, 0, 64, 0), - cond=(head_size == HEAD_SIZES[0])) + if version == "v2": + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ) + + opcheck(torch.ops._C.paged_attention_v2, + (output, exp_sums, max_logits, tmp_output, query, + key_cache, value_cache, num_kv_heads, scale, block_tables, + seq_lens, block_size, max_seq_len, alibi_slopes, + kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), + cond=(head_size == HEAD_SIZES[0])) + + else: + ops.paged_attention_rocm( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ) + + opcheck(torch.ops._rocm_C.paged_attention, + (output, exp_sums, max_logits, tmp_output, query, + key_cache, value_cache, num_kv_heads, scale, block_tables, + seq_lens, block_size, max_seq_len, alibi_slopes, + kv_cache_dtype, k_scale, v_scale), + cond=(head_size == HEAD_SIZES[0])) else: raise AssertionError(f"Unknown version: {version}") @@ -328,162 +358,6 @@ def ref_multi_query_kv_attention( return torch.cat(ref_outputs, dim=0) -@pytest.mark.parametrize("version", ["rocm"]) -@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", [64, 128]) # only test 64 128 -@pytest.mark.parametrize("use_alibi", USE_ALIBI) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("kv_cache_dtype", ["auto"]) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.skipif(not is_hip(), reason="only for rocm") -def test_paged_attention_rocm( - kv_cache_factory, - version: str, - num_seqs: int, - num_heads: Tuple[int, int], - head_size: int, - use_alibi: bool, - block_size: int, - dtype: torch.dtype, - kv_cache_dtype: str, - seed: int, - device: str, -) -> None: - seed_everything(seed) - torch.set_default_device(device) - scale = float(1.0 / (head_size**0.5)) - num_query_heads, num_kv_heads = num_heads - query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype) - query.uniform_(-scale, scale) - - assert num_query_heads % num_kv_heads == 0 - num_queries_per_kv = num_query_heads // num_kv_heads - alibi_slopes = None - if use_alibi: - alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) - - context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] - context_lens[-1] = MAX_SEQ_LEN - #context_lens = [8192 for _ in range(num_seqs)] - max_context_len = max(context_lens) - context_lens = torch.tensor(context_lens, dtype=torch.int) - #print('>>> ctx lens', context_lens) - - # Create the block tables. - max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size - block_tables = [] - for _ in range(num_seqs): - block_table = [ - random.randint(0, NUM_BLOCKS - 1) - for _ in range(max_num_blocks_per_seq) - ] - block_tables.append(block_table) - block_tables = torch.tensor(block_tables, dtype=torch.int) - - # Create the KV caches. - key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, - num_kv_heads, head_size, - kv_cache_dtype, dtype, seed, - device) - key_cache, value_cache = key_caches[0], value_caches[0] - - # TODO(charlifu) enable fp8 kv cache - # Using default kv_scale - # kv_scale = 1.0 - - # Call the paged attention kernel. - output = torch.empty_like(query) - PARTITION_SIZE_ROCM = 256 - num_partitions = ((max_context_len + PARTITION_SIZE_ROCM - 1) // - PARTITION_SIZE_ROCM) - assert PARTITION_SIZE_ROCM % block_size == 0 - num_seqs, num_heads, head_size = output.shape - tmp_output = torch.empty( - size=(num_seqs, num_heads, num_partitions, head_size), - dtype=output.dtype, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, num_partitions), - dtype=torch.float32, - ) - max_logits = torch.empty_like(exp_sums) - if version == "rocm": - ops.paged_attention_rocm( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - context_lens, - block_size, - max_context_len, - alibi_slopes, - kv_cache_dtype, - ) - else: - raise AssertionError(f"Unknown version: {version}") - - # Run the reference implementation. - if kv_cache_dtype == "fp8": - # Convert cache data back to dtype. - x = 16 // torch.tensor([], dtype=dtype).element_size() - key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, - block_size, x) - dequantized_key_cache = torch.empty(size=key_cache_shape, - dtype=dtype, - device=device) - ops.convert_fp8(key_cache, dequantized_key_cache) - key_cache = dequantized_key_cache - - value_cache_shape = value_cache.shape - dequantized_value_cache = torch.empty(size=value_cache_shape, - dtype=dtype, - device=device) - ops.convert_fp8(value_cache, dequantized_value_cache) - value_cache = dequantized_value_cache - - ref_output = torch.empty_like(query) - ref_single_query_cached_kv_attention( - ref_output, - query, - num_queries_per_kv, - key_cache, - value_cache, - block_tables, - context_lens, - scale, - alibi_slopes, - ) - - # NOTE(woosuk): Due to the kernel-level differences in the two - # implementations, there is a small numerical difference in the two - # outputs. Thus, we use a relaxed tolerance for the test. - atol = get_default_atol(output) if is_hip() else 1e-3 - rtol = get_default_rtol(output) if is_hip() else 1e-5 - - # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, - # so we use a relaxed tolerance for the test. - atol, rtol = 1e-4, 1e-5 - if dtype == torch.bfloat16: - atol, rtol = 2e-4, 1e-5 - if use_alibi: - if dtype == torch.half: - atol, rtol = 5e-4, 1e-5 - if dtype == torch.bfloat16: - atol, rtol = 1e-3, 1e-5 - if kv_cache_dtype == "fp8": - atol, rtol = 1e-2, 1e-5 - assert torch.allclose(output, ref_output, atol=atol, rtol=rtol) - - # TODO(woosuk): Add tests for USE_ALIBI=True. @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -491,7 +365,8 @@ def test_paged_attention_rocm( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.skipif(is_hip(), reason="skip for rocm") +@pytest.mark.skipif(is_hip(), + reason="Xformers backend is not supported on ROCm.") @torch.inference_mode() def test_multi_query_kv_attention( num_seqs: int, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ff5aa8bee3c27..678700055c992 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -146,12 +146,14 @@ def paged_attention_rocm( max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, + k_scale: float, + v_scale: float, ) -> None: torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype) + kv_cache_dtype, k_scale, v_scale) # pos encoding ops diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 6bd276ade1d41..70e6857584ace 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -17,8 +17,8 @@ logger = init_logger(__name__) -_PARTITION_SIZE = 256 -ON_NAVI = "gfx1" in torch.cuda.get_device_properties("cuda").gcnArchName +_PARTITION_SIZE_ROCM = 512 +_ON_NAVI = "gfx1" in torch.cuda.get_device_properties("cuda").gcnArchName class ROCmFlashAttentionBackend(AttentionBackend): @@ -489,14 +489,15 @@ def forward( num_seqs, num_heads, head_size = decode_query.shape block_size = value_cache.shape[3] gqa_ratio = num_heads // self.num_kv_heads - use_custom = use_rocm_custom_paged_attention( - decode_query.dtype, head_size, block_size, self.kv_cache_dtype, - gqa_ratio, decode_meta.max_decode_seq_len) + use_custom = _use_rocm_custom_paged_attention( + decode_query.dtype, head_size, block_size, gqa_ratio, + decode_meta.max_decode_seq_len) if use_custom: max_seq_len = decode_meta.max_decode_seq_len - max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // - _PARTITION_SIZE) - assert _PARTITION_SIZE % block_size == 0 + max_num_partitions = ( + (max_seq_len + _PARTITION_SIZE_ROCM - 1) // + _PARTITION_SIZE_ROCM) + assert _PARTITION_SIZE_ROCM % block_size == 0 tmp_output = torch.empty( size=(num_seqs, num_heads, max_num_partitions, head_size), dtype=output.dtype, @@ -524,6 +525,8 @@ def forward( max_seq_len, self.alibi_slopes, self.kv_cache_dtype, + k_scale, + v_scale, ) else: output[num_prefill_tokens:] = PagedAttention.forward_decode( @@ -580,12 +583,11 @@ def _sdpa_attention( return output -def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, - block_size: int, kv_cache_dtype: str, - gqa_ratio: int, max_seq_len: int) -> bool: +def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, + block_size: int, gqa_ratio: int, + max_seq_len: int) -> bool: # rocm custom page attention not support on navi (gfx1*) - return (not ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16) + return (not _ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16) and (head_size == 64 or head_size == 128) and (block_size == 16 or block_size == 32) - and kv_cache_dtype == "auto" and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768) From e42c634acbd1b86b5becca51e8b8108a32a438d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9B=8F=E4=B8=80?= Date: Fri, 20 Sep 2024 02:28:25 +0800 Subject: [PATCH 91/98] [Core] simplify logits resort in _apply_top_k_top_p (#8619) --- vllm/model_executor/layers/sampler.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 487f5a3d2a441..2ca86a4653cf4 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -433,12 +433,9 @@ def _apply_top_k_top_p( logits_sort.masked_fill_(top_p_mask, -float("inf")) # Re-sort the probabilities. - src = torch.arange(logits_idx.shape[-1], - device=logits_idx.device).expand_as(logits_idx) - logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1, - index=logits_idx, - src=src) - logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv) + logits = torch.empty_like(logits_sort).scatter_(dim=-1, + index=logits_idx, + src=logits_sort) return logits From ea4647b7d77c4738c5ed2ab77a2c9f5ad335f6fb Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 20 Sep 2024 03:15:55 +0800 Subject: [PATCH 92/98] [Doc] Add documentation for GGUF quantization (#8618) --- docs/source/index.rst | 1 + docs/source/quantization/gguf.rst | 73 +++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+) create mode 100644 docs/source/quantization/gguf.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index 4b817c4ba9498..79f723eace762 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -107,6 +107,7 @@ Documentation quantization/supported_hardware quantization/auto_awq quantization/bnb + quantization/gguf quantization/int8 quantization/fp8 quantization/fp8_e5m2_kvcache diff --git a/docs/source/quantization/gguf.rst b/docs/source/quantization/gguf.rst new file mode 100644 index 0000000000000..9f00dc5563909 --- /dev/null +++ b/docs/source/quantization/gguf.rst @@ -0,0 +1,73 @@ +.. _gguf: + +GGUF +================== + +.. warning:: + + Please note that GGUF support in vLLM is highly experimental and under-optimized at the moment, it might be incompatible with other features. Currently, you can use GGUF as a way to reduce memory footprint. If you encounter any issues, please report them to the vLLM team. + +.. warning:: + + Currently, vllm only supports loading single-file GGUF models. If you have a multi-files GGUF model, you can use `gguf-split `_ tool to merge them to a single-file model. + +To run a GGUF model with vLLM, you can download and use the local GGUF model from `TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF `_ with the following command: + +.. code-block:: console + + $ wget https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf + $ # We recommend using the tokenizer from base model to avoid long-time and buggy tokenizer conversion. + $ vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf --tokenizer TinyLlama/TinyLlama-1.1B-Chat-v1.0 + +You can also add ``--tensor-parallel-size 2`` to enable tensor parallelism inference with 2 GPUs: + +.. code-block:: console + + $ # We recommend using the tokenizer from base model to avoid long-time and buggy tokenizer conversion. + $ vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf --tokenizer TinyLlama/TinyLlama-1.1B-Chat-v1.0 --tensor-parallel-size 2 + +.. warning:: + + We recommend using the tokenizer from base model instead of GGUF model. Because the tokenizer conversion from GGUF is time-consuming and unstable, especially for some models with large vocab size. + +You can also use the GGUF model directly through the LLM entrypoint: + +.. code-block:: python + + from vllm import LLM, SamplingParams + + # In this script, we demonstrate how to pass input to the chat method: + conversation = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": "Hello" + }, + { + "role": "assistant", + "content": "Hello! How can I assist you today?" + }, + { + "role": "user", + "content": "Write an essay about the importance of higher education.", + }, + ] + + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + + # Create an LLM. + llm = LLM(model="./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf", + tokenizer="TinyLlama/TinyLlama-1.1B-Chat-v1.0") + # Generate texts from the prompts. The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.chat(conversation, sampling_params) + + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") From 9e99407e3ccbb290bae77af230da38c70a52a055 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Thu, 19 Sep 2024 12:16:28 -0700 Subject: [PATCH 93/98] Create SECURITY.md (#8642) --- SECURITY.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 SECURITY.md diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000000000..d9a392158472d --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,12 @@ +# Security Policy + +## Reporting a Vulnerability + +If you believe you have found a security vulnerability in vLLM, we encourage you to let us know right away. +We will investigate all legitimate reports and do our best to quickly fix the problem. + +Please report security issues using https://github.com/vllm-project/vllm/security/advisories/new + +--- +Please see PyTorch Security for more information how to securely interact with models: https://github.com/pytorch/pytorch/blob/main/SECURITY.md +This document mostly references the recommendation from PyTorch, thank you! From 6cb748e190a94e20987314025614b8bd806602f2 Mon Sep 17 00:00:00 2001 From: "Alexey Kondratiev(AMD)" <143633163+alexeykondrat@users.noreply.github.com> Date: Thu, 19 Sep 2024 16:06:32 -0400 Subject: [PATCH 94/98] [CI/Build] Re-enabling Entrypoints tests on ROCm, excluding ones that fail (#8551) --- .buildkite/run-amd-test.sh | 9 +++++++++ .buildkite/test-pipeline.yaml | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index 9274a30e04325..45b20c9447c7d 100755 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -94,6 +94,15 @@ if [[ $commands == *" kernels "* ]]; then --ignore=kernels/test_sampler.py" fi +#ignore certain Entrypoints tests +if [[ $commands == *" entrypoints/openai "* ]]; then + commands=${commands//" entrypoints/openai "/" entrypoints/openai \ + --ignore=entrypoints/openai/test_accuracy.py \ + --ignore=entrypoints/openai/test_audio.py \ + --ignore=entrypoints/openai/test_encoder_decoder.py \ + --ignore=entrypoints/openai/test_oot_registration.py "} +fi + PARALLEL_JOB_COUNT=8 # check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs. if [[ $commands == *"--shard-id="* ]]; then diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 37207b677a1ee..379a67c4c8cf8 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -84,7 +84,7 @@ steps: - label: Entrypoints Test # 20min working_dir: "/vllm-workspace/tests" fast_check: true - #mirror_hardwares: [amd] + mirror_hardwares: [amd] source_file_dependencies: - vllm/ commands: From de6f90a13d7b98c4958ba107ec16cb6f95efb10f Mon Sep 17 00:00:00 2001 From: bnellnm <49004751+bnellnm@users.noreply.github.com> Date: Thu, 19 Sep 2024 18:36:30 -0400 Subject: [PATCH 95/98] [Misc] guard against change in cuda library name (#8609) --- cmake/utils.cmake | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 1ea6d2b0f090e..730517a20129a 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -350,13 +350,14 @@ function (define_gpu_extension_target GPU_MOD_NAME) target_include_directories(${GPU_MOD_NAME} PRIVATE csrc ${GPU_INCLUDE_DIRECTORIES}) - # TODO: is torch_python_LIBRARY needed? - target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${torch_python_LIBRARY} - ${GPU_LIBRARIES}) + target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${GPU_LIBRARIES}) # Don't use `TORCH_LIBRARIES` for CUDA since it pulls in a bunch of # dependencies that are not necessary and may not be installed. if (GPU_LANGUAGE STREQUAL "CUDA") + if ("${CUDA_CUDA_LIB}" STREQUAL "") + set(CUDA_CUDA_LIB "${CUDA_CUDA_LIBRARY}") + endif() target_link_libraries(${GPU_MOD_NAME} PRIVATE ${CUDA_CUDA_LIB} ${CUDA_LIBRARIES}) else() From 18ae428a0d8792d160d811a9cd5bb004d68ea8bd Mon Sep 17 00:00:00 2001 From: Amit Garg Date: Thu, 19 Sep 2024 17:54:02 -0700 Subject: [PATCH 96/98] [Bugfix] Fix Phi3.5 mini and MoE LoRA inference (#8571) --- vllm/model_executor/models/__init__.py | 2 +- vllm/model_executor/models/phi3.py | 17 +++++++++++++++++ vllm/model_executor/models/phimoe.py | 4 ++++ 3 files changed, 22 insertions(+), 1 deletion(-) create mode 100644 vllm/model_executor/models/phi3.py diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 591007e787f47..7427060922281 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -50,7 +50,7 @@ "OrionForCausalLM": ("orion", "OrionForCausalLM"), "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"), - "Phi3ForCausalLM": ("llama", "LlamaForCausalLM"), + "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), diff --git a/vllm/model_executor/models/phi3.py b/vllm/model_executor/models/phi3.py new file mode 100644 index 0000000000000..02b2ff01c3832 --- /dev/null +++ b/vllm/model_executor/models/phi3.py @@ -0,0 +1,17 @@ +# coding=utf-8 +# Adapted from llama.py +"""Inference-only Phi3 model code inherit from Llama.py""" + +from vllm.model_executor.models.llama import LlamaForCausalLM + + +class Phi3ForCausalLM(LlamaForCausalLM): + + packed_modules_mapping = { + "qkv_proj": [ + "qkv_proj", + ], + "gate_up_proj": [ + "gate_up_proj", + ], + } diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 5036f55803c20..a3555a294bb66 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -491,6 +491,10 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA): "o_proj", "embed_tokens", "lm_head", + "w1", + "w2", + "w3", + "gate", ] embedding_modules = { "embed_tokens": "input_embeddings", From 9e5ec35b1f8239453b1aaab28e7a02307db4ab1f Mon Sep 17 00:00:00 2001 From: William Lin Date: Thu, 19 Sep 2024 20:49:54 -0700 Subject: [PATCH 97/98] [bugfix] [AMD] add multi-step advance_step to ROCmFlashAttentionMetadata (#8474) --- vllm/attention/backends/rocm_flash_attn.py | 58 +++++++++++++++++++++- vllm/worker/multi_step_model_runner.py | 2 +- 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 70e6857584ace..5560f44be4196 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -1,6 +1,6 @@ """Attention layer ROCm GPUs.""" from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import torch @@ -15,6 +15,9 @@ from vllm.logger import init_logger from vllm.platforms import current_platform +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + logger = init_logger(__name__) _PARTITION_SIZE_ROCM = 512 @@ -180,6 +183,59 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: ) return self._cached_decode_metadata + def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, num_seqs: int, num_queries: int): + """ + Update metadata in-place to advance one decode step. + """ + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + assert self.use_cuda_graph + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.num_decode_tokens == num_seqs + assert self.slot_mapping.shape == (num_seqs, ) + + assert self.seq_lens is not None + assert len(self.seq_lens) == num_seqs + assert self.seq_lens_tensor is not None + assert self.seq_lens_tensor.shape == (num_seqs, ) + assert self.max_query_len == 1 + assert self.max_prefill_seq_len == 0 + assert self.max_decode_seq_len == max(self.seq_lens) + + assert self.query_start_loc is not None + assert self.query_start_loc.shape == (num_queries + 1, ) + assert self.seq_start_loc is not None + assert self.seq_start_loc.shape == (num_seqs + 1, ) + + assert self.context_lens_tensor is not None + assert self.context_lens_tensor.shape == (num_queries, ) + + assert self.block_tables is not None + assert self.block_tables.shape[0] == num_seqs + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + self.seq_lens[i] += 1 + self.max_decode_seq_len = max(self.seq_lens) + + ops.advance_step_flashattn(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables) + class ROCmFlashAttentionMetadataBuilder( CommonMetadataBuilder[ROCmFlashAttentionMetadata]): diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index ebcafbbab119a..c7295f872f70f 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -29,7 +29,7 @@ logger = init_logger(__name__) -MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "flashinfer"] +MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "rocm-flash-attn", "flashinfer"] def seq_output_builder(): From 260d40b5ea48df9421325388abcc8d907a560fc5 Mon Sep 17 00:00:00 2001 From: Jiaxin Shan Date: Thu, 19 Sep 2024 23:20:56 -0700 Subject: [PATCH 98/98] [Core] Support Lora lineage and base model metadata management (#6315) --- docs/source/models/lora.rst | 64 +++++++++++++ tests/entrypoints/openai/test_cli_args.py | 91 +++++++++++++++++++ tests/entrypoints/openai/test_lora_lineage.py | 83 +++++++++++++++++ tests/entrypoints/openai/test_models.py | 6 +- tests/entrypoints/openai/test_serving_chat.py | 6 +- .../entrypoints/openai/test_serving_engine.py | 5 +- vllm/entrypoints/openai/api_server.py | 14 ++- vllm/entrypoints/openai/cli_args.py | 27 +++++- vllm/entrypoints/openai/run_batch.py | 9 +- vllm/entrypoints/openai/serving_chat.py | 11 ++- vllm/entrypoints/openai/serving_completion.py | 9 +- vllm/entrypoints/openai/serving_embedding.py | 6 +- vllm/entrypoints/openai/serving_engine.py | 43 ++++++--- .../openai/serving_tokenization.py | 7 +- vllm/lora/request.py | 1 + 15 files changed, 337 insertions(+), 45 deletions(-) create mode 100644 tests/entrypoints/openai/test_cli_args.py create mode 100644 tests/entrypoints/openai/test_lora_lineage.py diff --git a/docs/source/models/lora.rst b/docs/source/models/lora.rst index b3821ebdfceca..ef0177eaf2162 100644 --- a/docs/source/models/lora.rst +++ b/docs/source/models/lora.rst @@ -159,3 +159,67 @@ Example request to unload a LoRA adapter: -d '{ "lora_name": "sql_adapter" }' + + +New format for `--lora-modules` +------------------------------- + +In the previous version, users would provide LoRA modules via the following format, either as a key-value pair or in JSON format. For example: + +.. code-block:: bash + + --lora-modules sql-lora=$HOME/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c/ + +This would only include the `name` and `path` for each LoRA module, but did not provide a way to specify a `base_model_name`. +Now, you can specify a base_model_name alongside the name and path using JSON format. For example: + +.. code-block:: bash + + --lora-modules '{"name": "sql-lora", "path": "/path/to/lora", "base_model_name": "meta-llama/Llama-2-7b"}' + +To provide the backward compatibility support, you can still use the old key-value format (name=path), but the `base_model_name` will remain unspecified in that case. + + +Lora model lineage in model card +-------------------------------- + +The new format of `--lora-modules` is mainly to support the display of parent model information in the model card. Here's an explanation of how your current response supports this: + +- The `parent` field of LoRA model `sql-lora` now links to its base model `meta-llama/Llama-2-7b-hf`. This correctly reflects the hierarchical relationship between the base model and the LoRA adapter. +- The `root` field points to the artifact location of the lora adapter. + +.. code-block:: bash + + $ curl http://localhost:8000/v1/models + + { + "object": "list", + "data": [ + { + "id": "meta-llama/Llama-2-7b-hf", + "object": "model", + "created": 1715644056, + "owned_by": "vllm", + "root": "~/.cache/huggingface/hub/models--meta-llama--Llama-2-7b-hf/snapshots/01c7f73d771dfac7d292323805ebc428287df4f9/", + "parent": null, + "permission": [ + { + ..... + } + ] + }, + { + "id": "sql-lora", + "object": "model", + "created": 1715644056, + "owned_by": "vllm", + "root": "~/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c/", + "parent": meta-llama/Llama-2-7b-hf, + "permission": [ + { + .... + } + ] + } + ] + } diff --git a/tests/entrypoints/openai/test_cli_args.py b/tests/entrypoints/openai/test_cli_args.py new file mode 100644 index 0000000000000..8ee7fb8b2c6bf --- /dev/null +++ b/tests/entrypoints/openai/test_cli_args.py @@ -0,0 +1,91 @@ +import json +import unittest + +from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.entrypoints.openai.serving_engine import LoRAModulePath +from vllm.utils import FlexibleArgumentParser + +LORA_MODULE = { + "name": "module2", + "path": "/path/to/module2", + "base_model_name": "llama" +} + + +class TestLoraParserAction(unittest.TestCase): + + def setUp(self): + # Setting up argparse parser for tests + parser = FlexibleArgumentParser( + description="vLLM's remote OpenAI server.") + self.parser = make_arg_parser(parser) + + def test_valid_key_value_format(self): + # Test old format: name=path + args = self.parser.parse_args([ + '--lora-modules', + 'module1=/path/to/module1', + ]) + expected = [LoRAModulePath(name='module1', path='/path/to/module1')] + self.assertEqual(args.lora_modules, expected) + + def test_valid_json_format(self): + # Test valid JSON format input + args = self.parser.parse_args([ + '--lora-modules', + json.dumps(LORA_MODULE), + ]) + expected = [ + LoRAModulePath(name='module2', + path='/path/to/module2', + base_model_name='llama') + ] + self.assertEqual(args.lora_modules, expected) + + def test_invalid_json_format(self): + # Test invalid JSON format input, missing closing brace + with self.assertRaises(SystemExit): + self.parser.parse_args([ + '--lora-modules', + '{"name": "module3", "path": "/path/to/module3"' + ]) + + def test_invalid_type_error(self): + # Test type error when values are not JSON or key=value + with self.assertRaises(SystemExit): + self.parser.parse_args([ + '--lora-modules', + 'invalid_format' # This is not JSON or key=value format + ]) + + def test_invalid_json_field(self): + # Test valid JSON format but missing required fields + with self.assertRaises(SystemExit): + self.parser.parse_args([ + '--lora-modules', + '{"name": "module4"}' # Missing required 'path' field + ]) + + def test_empty_values(self): + # Test when no LoRA modules are provided + args = self.parser.parse_args(['--lora-modules', '']) + self.assertEqual(args.lora_modules, []) + + def test_multiple_valid_inputs(self): + # Test multiple valid inputs (both old and JSON format) + args = self.parser.parse_args([ + '--lora-modules', + 'module1=/path/to/module1', + json.dumps(LORA_MODULE), + ]) + expected = [ + LoRAModulePath(name='module1', path='/path/to/module1'), + LoRAModulePath(name='module2', + path='/path/to/module2', + base_model_name='llama') + ] + self.assertEqual(args.lora_modules, expected) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/entrypoints/openai/test_lora_lineage.py b/tests/entrypoints/openai/test_lora_lineage.py new file mode 100644 index 0000000000000..ab39684c2f31a --- /dev/null +++ b/tests/entrypoints/openai/test_lora_lineage.py @@ -0,0 +1,83 @@ +import json + +import openai # use the official client for correctness check +import pytest +import pytest_asyncio +# downloading lora to test lora requests +from huggingface_hub import snapshot_download + +from ...utils import RemoteOpenAIServer + +# any model with a chat template should work here +MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" +# technically this needs Mistral-7B-v0.1 as base, but we're not testing +# generation quality here +LORA_NAME = "typeof/zephyr-7b-beta-lora" + + +@pytest.fixture(scope="module") +def zephyr_lora_files(): + return snapshot_download(repo_id=LORA_NAME) + + +@pytest.fixture(scope="module") +def server_with_lora_modules_json(zephyr_lora_files): + # Define the json format LoRA module configurations + lora_module_1 = { + "name": "zephyr-lora", + "path": zephyr_lora_files, + "base_model_name": MODEL_NAME + } + + lora_module_2 = { + "name": "zephyr-lora2", + "path": zephyr_lora_files, + "base_model_name": MODEL_NAME + } + + args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--enforce-eager", + # lora config below + "--enable-lora", + "--lora-modules", + json.dumps(lora_module_1), + json.dumps(lora_module_2), + "--max-lora-rank", + "64", + "--max-cpu-loras", + "2", + "--max-num-seqs", + "64", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client_for_lora_lineage(server_with_lora_modules_json): + async with server_with_lora_modules_json.get_async_client( + ) as async_client: + yield async_client + + +@pytest.mark.asyncio +async def test_check_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI, + zephyr_lora_files): + models = await client_for_lora_lineage.models.list() + models = models.data + served_model = models[0] + lora_models = models[1:] + assert served_model.id == MODEL_NAME + assert served_model.root == MODEL_NAME + assert served_model.parent is None + assert all(lora_model.root == zephyr_lora_files + for lora_model in lora_models) + assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models) + assert lora_models[0].id == "zephyr-lora" + assert lora_models[1].id == "zephyr-lora2" diff --git a/tests/entrypoints/openai/test_models.py b/tests/entrypoints/openai/test_models.py index 5cd570f43e1a7..ae5bf404d3d2b 100644 --- a/tests/entrypoints/openai/test_models.py +++ b/tests/entrypoints/openai/test_models.py @@ -51,12 +51,14 @@ async def client(server): @pytest.mark.asyncio -async def test_check_models(client: openai.AsyncOpenAI): +async def test_check_models(client: openai.AsyncOpenAI, zephyr_lora_files): models = await client.models.list() models = models.data served_model = models[0] lora_models = models[1:] assert served_model.id == MODEL_NAME - assert all(model.root == MODEL_NAME for model in models) + assert served_model.root == MODEL_NAME + assert all(lora_model.root == zephyr_lora_files + for lora_model in lora_models) assert lora_models[0].id == "zephyr-lora" assert lora_models[1].id == "zephyr-lora2" diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index de2a932199a01..db31745cc102e 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -7,10 +7,12 @@ from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_engine import BaseModelPath from vllm.transformers_utils.tokenizer import get_tokenizer MODEL_NAME = "openai-community/gpt2" CHAT_TEMPLATE = "Dummy chat template for testing {}" +BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] @dataclass @@ -37,7 +39,7 @@ async def _async_serving_chat_init(): serving_completion = OpenAIServingChat(engine, model_config, - served_model_names=[MODEL_NAME], + BASE_MODEL_PATHS, response_role="assistant", chat_template=CHAT_TEMPLATE, lora_modules=None, @@ -58,7 +60,7 @@ def test_serving_chat_should_set_correct_max_tokens(): serving_chat = OpenAIServingChat(mock_engine, MockModelConfig(), - served_model_names=[MODEL_NAME], + BASE_MODEL_PATHS, response_role="assistant", chat_template=CHAT_TEMPLATE, lora_modules=None, diff --git a/tests/entrypoints/openai/test_serving_engine.py b/tests/entrypoints/openai/test_serving_engine.py index 6d9e620b4af7d..6199a75b5b4f8 100644 --- a/tests/entrypoints/openai/test_serving_engine.py +++ b/tests/entrypoints/openai/test_serving_engine.py @@ -8,9 +8,10 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, LoadLoraAdapterRequest, UnloadLoraAdapterRequest) -from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing MODEL_NAME = "meta-llama/Llama-2-7b" +BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] LORA_LOADING_SUCCESS_MESSAGE = ( "Success: LoRA adapter '{lora_name}' added successfully.") LORA_UNLOADING_SUCCESS_MESSAGE = ( @@ -25,7 +26,7 @@ async def _async_serving_engine_init(): serving_engine = OpenAIServing(mock_engine_client, mock_model_config, - served_model_names=[MODEL_NAME], + BASE_MODEL_PATHS, lora_modules=None, prompt_adapters=None, request_logger=None) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index fd6f36e8768dd..5078a2654eb22 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -50,6 +50,7 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding +from vllm.entrypoints.openai.serving_engine import BaseModelPath from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) from vllm.logger import init_logger @@ -476,13 +477,18 @@ def init_app_state( else: request_logger = RequestLogger(max_log_len=args.max_log_len) + base_model_paths = [ + BaseModelPath(name=name, model_path=args.model) + for name in served_model_names + ] + state.engine_client = engine_client state.log_stats = not args.disable_log_stats state.openai_serving_chat = OpenAIServingChat( engine_client, model_config, - served_model_names, + base_model_paths, args.response_role, lora_modules=args.lora_modules, prompt_adapters=args.prompt_adapters, @@ -494,7 +500,7 @@ def init_app_state( state.openai_serving_completion = OpenAIServingCompletion( engine_client, model_config, - served_model_names, + base_model_paths, lora_modules=args.lora_modules, prompt_adapters=args.prompt_adapters, request_logger=request_logger, @@ -503,13 +509,13 @@ def init_app_state( state.openai_serving_embedding = OpenAIServingEmbedding( engine_client, model_config, - served_model_names, + base_model_paths, request_logger=request_logger, ) state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, model_config, - served_model_names, + base_model_paths, lora_modules=args.lora_modules, request_logger=request_logger, chat_template=args.chat_template, diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index bbb0823de9a51..9d3071a97fbe6 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -31,8 +31,23 @@ def __call__( lora_list: List[LoRAModulePath] = [] for item in values: - name, path = item.split('=') - lora_list.append(LoRAModulePath(name, path)) + if item in [None, '']: # Skip if item is None or empty string + continue + if '=' in item and ',' not in item: # Old format: name=path + name, path = item.split('=') + lora_list.append(LoRAModulePath(name, path)) + else: # Assume JSON format + try: + lora_dict = json.loads(item) + lora = LoRAModulePath(**lora_dict) + lora_list.append(lora) + except json.JSONDecodeError: + parser.error( + f"Invalid JSON format for --lora-modules: {item}") + except TypeError as e: + parser.error( + f"Invalid fields for --lora-modules: {item} - {str(e)}" + ) setattr(namespace, self.dest, lora_list) @@ -95,8 +110,12 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=None, nargs='+', action=LoRAParserAction, - help="LoRA module configurations in the format name=path. " - "Multiple modules can be specified.") + help="LoRA module configurations in either 'name=path' format" + "or JSON format. " + "Example (old format): 'name=path' " + "Example (new format): " + "'{\"name\": \"name\", \"local_path\": \"path\", " + "\"base_model_name\": \"id\"}'") parser.add_argument( "--prompt-adapters", type=nullable_str, diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index b745410fe6b3b..f5249a0c447b3 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -20,6 +20,7 @@ # yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding +from vllm.entrypoints.openai.serving_engine import BaseModelPath from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser, random_uuid from vllm.version import __version__ as VLLM_VERSION @@ -196,6 +197,10 @@ async def main(args): engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER) model_config = await engine.get_model_config() + base_model_paths = [ + BaseModelPath(name=name, model_path=args.model) + for name in served_model_names + ] if args.disable_log_requests: request_logger = None @@ -206,7 +211,7 @@ async def main(args): openai_serving_chat = OpenAIServingChat( engine, model_config, - served_model_names, + base_model_paths, args.response_role, lora_modules=None, prompt_adapters=None, @@ -216,7 +221,7 @@ async def main(args): openai_serving_embedding = OpenAIServingEmbedding( engine, model_config, - served_model_names, + base_model_paths, request_logger=request_logger, ) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index b84898dc39b0f..1ee4b3ce17cfa 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -23,7 +23,8 @@ ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, DeltaToolCall, ErrorResponse, FunctionCall, ToolCall, UsageInfo) -from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, +from vllm.entrypoints.openai.serving_engine import (BaseModelPath, + LoRAModulePath, OpenAIServing, PromptAdapterPath, TextTokensPrompt) @@ -47,7 +48,7 @@ class OpenAIServingChat(OpenAIServing): def __init__(self, engine_client: EngineClient, model_config: ModelConfig, - served_model_names: List[str], + base_model_paths: List[BaseModelPath], response_role: str, *, lora_modules: Optional[List[LoRAModulePath]], @@ -59,7 +60,7 @@ def __init__(self, tool_parser: Optional[str] = None): super().__init__(engine_client=engine_client, model_config=model_config, - served_model_names=served_model_names, + base_model_paths=base_model_paths, lora_modules=lora_modules, prompt_adapters=prompt_adapters, request_logger=request_logger, @@ -262,7 +263,7 @@ async def chat_completion_stream_generator( conversation: List[ConversationMessage], tokenizer: AnyTokenizer, ) -> AsyncGenerator[str, None]: - model_name = self.served_model_names[0] + model_name = self.base_model_paths[0].name created_time = int(time.time()) chunk_object_type: Final = "chat.completion.chunk" first_iteration = True @@ -596,7 +597,7 @@ async def chat_completion_full_generator( tokenizer: AnyTokenizer, ) -> Union[ErrorResponse, ChatCompletionResponse]: - model_name = self.served_model_names[0] + model_name = self.base_model_paths[0].name created_time = int(time.time()) final_res: Optional[RequestOutput] = None diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 14fa60243c584..9abd74d0561d0 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -20,7 +20,8 @@ CompletionStreamResponse, ErrorResponse, UsageInfo) # yapf: enable -from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, +from vllm.entrypoints.openai.serving_engine import (BaseModelPath, + LoRAModulePath, OpenAIServing, PromptAdapterPath) from vllm.logger import init_logger @@ -45,7 +46,7 @@ def __init__( self, engine_client: EngineClient, model_config: ModelConfig, - served_model_names: List[str], + base_model_paths: List[BaseModelPath], *, lora_modules: Optional[List[LoRAModulePath]], prompt_adapters: Optional[List[PromptAdapterPath]], @@ -54,7 +55,7 @@ def __init__( ): super().__init__(engine_client=engine_client, model_config=model_config, - served_model_names=served_model_names, + base_model_paths=base_model_paths, lora_modules=lora_modules, prompt_adapters=prompt_adapters, request_logger=request_logger, @@ -89,7 +90,7 @@ async def create_completion( return self.create_error_response( "suffix is not currently supported") - model_name = self.served_model_names[0] + model_name = self.base_model_paths[0].name request_id = f"cmpl-{random_uuid()}" created_time = int(time.time()) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index f111a3a8277b5..5d95e1369b884 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -14,7 +14,7 @@ EmbeddingResponse, EmbeddingResponseData, ErrorResponse, UsageInfo) -from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing from vllm.logger import init_logger from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput from vllm.utils import merge_async_iterators, random_uuid @@ -73,13 +73,13 @@ def __init__( self, engine_client: EngineClient, model_config: ModelConfig, - served_model_names: List[str], + base_model_paths: List[BaseModelPath], *, request_logger: Optional[RequestLogger], ): super().__init__(engine_client=engine_client, model_config=model_config, - served_model_names=served_model_names, + base_model_paths=base_model_paths, lora_modules=None, prompt_adapters=None, request_logger=request_logger) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 72f9381abc7db..9c4e8d8bb671a 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -39,6 +39,12 @@ logger = init_logger(__name__) +@dataclass +class BaseModelPath: + name: str + model_path: str + + @dataclass class PromptAdapterPath: name: str @@ -49,6 +55,7 @@ class PromptAdapterPath: class LoRAModulePath: name: str path: str + base_model_name: Optional[str] = None AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest, @@ -66,7 +73,7 @@ def __init__( self, engine_client: EngineClient, model_config: ModelConfig, - served_model_names: List[str], + base_model_paths: List[BaseModelPath], *, lora_modules: Optional[List[LoRAModulePath]], prompt_adapters: Optional[List[PromptAdapterPath]], @@ -79,17 +86,20 @@ def __init__( self.model_config = model_config self.max_model_len = model_config.max_model_len - self.served_model_names = served_model_names + self.base_model_paths = base_model_paths self.lora_id_counter = AtomicCounter(0) self.lora_requests = [] if lora_modules is not None: self.lora_requests = [ - LoRARequest( - lora_name=lora.name, - lora_int_id=i, - lora_path=lora.path, - ) for i, lora in enumerate(lora_modules, start=1) + LoRARequest(lora_name=lora.name, + lora_int_id=i, + lora_path=lora.path, + base_model_name=lora.base_model_name + if lora.base_model_name + and self._is_model_supported(lora.base_model_name) + else self.base_model_paths[0].name) + for i, lora in enumerate(lora_modules, start=1) ] self.prompt_adapter_requests = [] @@ -112,21 +122,23 @@ def __init__( async def show_available_models(self) -> ModelList: """Show available models. Right now we only have one model.""" model_cards = [ - ModelCard(id=served_model_name, + ModelCard(id=base_model.name, max_model_len=self.max_model_len, - root=self.served_model_names[0], + root=base_model.model_path, permission=[ModelPermission()]) - for served_model_name in self.served_model_names + for base_model in self.base_model_paths ] lora_cards = [ ModelCard(id=lora.lora_name, - root=self.served_model_names[0], + root=lora.local_path, + parent=lora.base_model_name if lora.base_model_name else + self.base_model_paths[0].name, permission=[ModelPermission()]) for lora in self.lora_requests ] prompt_adapter_cards = [ ModelCard(id=prompt_adapter.prompt_adapter_name, - root=self.served_model_names[0], + root=self.base_model_paths[0].name, permission=[ModelPermission()]) for prompt_adapter in self.prompt_adapter_requests ] @@ -169,7 +181,7 @@ async def _check_model( self, request: AnyRequest, ) -> Optional[ErrorResponse]: - if request.model in self.served_model_names: + if self._is_model_supported(request.model): return None if request.model in [lora.lora_name for lora in self.lora_requests]: return None @@ -187,7 +199,7 @@ def _maybe_get_adapters( self, request: AnyRequest ) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[ None, PromptAdapterRequest]]: - if request.model in self.served_model_names: + if self._is_model_supported(request.model): return None, None for lora in self.lora_requests: if request.model == lora.lora_name: @@ -480,3 +492,6 @@ async def unload_lora_adapter( if lora_request.lora_name != lora_name ] return f"Success: LoRA adapter '{lora_name}' removed successfully." + + def _is_model_supported(self, model_name): + return any(model.name == model_name for model in self.base_model_paths) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 8f8862897fc4e..6d9a1ae088079 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -16,7 +16,8 @@ TokenizeRequest, TokenizeResponse) # yapf: enable -from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, +from vllm.entrypoints.openai.serving_engine import (BaseModelPath, + LoRAModulePath, OpenAIServing) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import MistralTokenizer @@ -31,7 +32,7 @@ def __init__( self, engine_client: EngineClient, model_config: ModelConfig, - served_model_names: List[str], + base_model_paths: List[BaseModelPath], *, lora_modules: Optional[List[LoRAModulePath]], request_logger: Optional[RequestLogger], @@ -39,7 +40,7 @@ def __init__( ): super().__init__(engine_client=engine_client, model_config=model_config, - served_model_names=served_model_names, + base_model_paths=base_model_paths, lora_modules=lora_modules, prompt_adapters=None, request_logger=request_logger) diff --git a/vllm/lora/request.py b/vllm/lora/request.py index 47a59d80d3a45..c4b26dc92c6f4 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -28,6 +28,7 @@ class LoRARequest( lora_path: str = "" lora_local_path: Optional[str] = msgspec.field(default=None) long_lora_max_len: Optional[int] = None + base_model_name: Optional[str] = msgspec.field(default=None) def __post_init__(self): if 'lora_local_path' in self.__struct_fields__: