From 795ec7e720d8cab28b9f924ed495b3fc13563718 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 11 Dec 2024 17:28:00 +0800 Subject: [PATCH] [Misc] Split up pooling tasks (#10820) Signed-off-by: DarkLight1337 --- docs/source/index.rst | 2 + docs/source/models/generative_models.rst | 146 ++++++++++++++++ docs/source/models/pooling_models.rst | 99 +++++++++++ docs/source/models/supported_models.rst | 157 ++++++++++++------ docs/source/usage/compatibility_matrix.rst | 12 +- examples/offline_inference_embedding.py | 7 +- ...ine_inference_vision_language_embedding.py | 4 +- tests/compile/test_basic_correctness.py | 4 +- tests/core/test_scheduler_encoder_decoder.py | 2 +- .../openai/test_vision_embedding.py | 2 +- .../embedding/language/test_embedding.py | 2 +- .../models/embedding/language/test_scoring.py | 12 +- .../vision_language/test_dse_qwen2_vl.py | 2 +- .../vision_language/test_llava_next.py | 2 +- .../embedding/vision_language/test_phi3v.py | 2 +- tests/test_config.py | 17 +- vllm/config.py | 137 ++++++++++----- vllm/core/scheduler.py | 2 +- vllm/engine/arg_utils.py | 7 +- vllm/engine/llm_engine.py | 4 +- vllm/entrypoints/llm.py | 53 +++--- vllm/entrypoints/openai/api_server.py | 8 +- vllm/entrypoints/openai/run_batch.py | 4 +- vllm/model_executor/model_loader/utils.py | 2 +- vllm/v1/engine/core.py | 2 +- vllm/worker/cpu_worker.py | 2 +- vllm/worker/worker.py | 2 +- 27 files changed, 527 insertions(+), 168 deletions(-) create mode 100644 docs/source/models/generative_models.rst create mode 100644 docs/source/models/pooling_models.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index ebf1361976c5e..842013d6d49c4 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -94,6 +94,8 @@ Documentation :caption: Models models/supported_models + models/generative_models + models/pooling_models models/adding_model models/enabling_multimodal_inputs diff --git a/docs/source/models/generative_models.rst b/docs/source/models/generative_models.rst new file mode 100644 index 0000000000000..fb71185600863 --- /dev/null +++ b/docs/source/models/generative_models.rst @@ -0,0 +1,146 @@ +.. _generative_models: + +Generative Models +================= + +vLLM provides first-class support for generative models, which covers most of LLMs. + +In vLLM, generative models implement the :class:`~vllm.model_executor.models.VllmModelForTextGeneration` interface. +Based on the final hidden states of the input, these models output log probabilities of the tokens to generate, +which are then passed through :class:`~vllm.model_executor.layers.Sampler` to obtain the final text. + +Offline Inference +----------------- + +The :class:`~vllm.LLM` class provides various methods for offline inference. +See :ref:`Engine Arguments ` for a list of options when initializing the model. + +For generative models, the only supported :code:`task` option is :code:`"generate"`. +Usually, this is automatically inferred so you don't have to specify it. + +``LLM.generate`` +^^^^^^^^^^^^^^^^ + +The :class:`~vllm.LLM.generate` method is available to all generative models in vLLM. +It is similar to `its counterpart in HF Transformers `__, +except that tokenization and detokenization are also performed automatically. + +.. code-block:: python + + llm = LLM(model="facebook/opt-125m") + outputs = llm.generate("Hello, my name is") + + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +You can optionally control the language generation by passing :class:`~vllm.SamplingParams`. +For example, you can use greedy sampling by setting :code:`temperature=0`: + +.. code-block:: python + + llm = LLM(model="facebook/opt-125m") + params = SamplingParams(temperature=0) + outputs = llm.generate("Hello, my name is", params) + + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +A code example can be found in `examples/offline_inference.py `_. + +``LLM.beam_search`` +^^^^^^^^^^^^^^^^^^^ + +The :class:`~vllm.LLM.beam_search` method implements `beam search `__ on top of :class:`~vllm.LLM.generate`. +For example, to search using 5 beams and output at most 50 tokens: + +.. code-block:: python + + llm = LLM(model="facebook/opt-125m") + params = BeamSearchParams(beam_width=5, max_tokens=50) + outputs = llm.generate("Hello, my name is", params) + + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +``LLM.chat`` +^^^^^^^^^^^^ + +The :class:`~vllm.LLM.chat` method implements chat functionality on top of :class:`~vllm.LLM.generate`. +In particular, it accepts input similar to `OpenAI Chat Completions API `__ +and automatically applies the model's `chat template `__ to format the prompt. + +.. important:: + + In general, only instruction-tuned models have a chat template. + Base models may perform poorly as they are not trained to respond to the chat conversation. + +.. code-block:: python + + llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct") + conversation = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": "Hello" + }, + { + "role": "assistant", + "content": "Hello! How can I assist you today?" + }, + { + "role": "user", + "content": "Write an essay about the importance of higher education.", + }, + ] + outputs = llm.chat(conversation) + + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +A code example can be found in `examples/offline_inference_chat.py `_. + +If the model doesn't have a chat template or you want to specify another one, +you can explicitly pass a chat template: + +.. code-block:: python + + from vllm.entrypoints.chat_utils import load_chat_template + + # You can find a list of existing chat templates under `examples/` + custom_template = load_chat_template(chat_template="") + print("Loaded chat template:", custom_template) + + outputs = llm.chat(conversation, chat_template=custom_template) + +Online Inference +---------------- + +Our `OpenAI Compatible Server <../serving/openai_compatible_server>`__ can be used for online inference. +Please click on the above link for more details on how to launch the server. + +Completions API +^^^^^^^^^^^^^^^ + +Our Completions API is similar to ``LLM.generate`` but only accepts text. +It is compatible with `OpenAI Completions API `__ +so that you can use OpenAI client to interact with it. +A code example can be found in `examples/openai_completion_client.py `_. + +Chat API +^^^^^^^^ + +Our Chat API is similar to ``LLM.chat``, accepting both text and :ref:`multi-modal inputs `. +It is compatible with `OpenAI Chat Completions API `__ +so that you can use OpenAI client to interact with it. +A code example can be found in `examples/openai_chat_completion_client.py `_. diff --git a/docs/source/models/pooling_models.rst b/docs/source/models/pooling_models.rst new file mode 100644 index 0000000000000..7fa66274c3c5a --- /dev/null +++ b/docs/source/models/pooling_models.rst @@ -0,0 +1,99 @@ +.. _pooling_models: + +Pooling Models +============== + +vLLM also supports pooling models, including embedding, reranking and reward models. + +In vLLM, pooling models implement the :class:`~vllm.model_executor.models.VllmModelForPooling` interface. +These models use a :class:`~vllm.model_executor.layers.Pooler` to aggregate the final hidden states of the input +before returning them. + +.. note:: + + We currently support pooling models primarily as a matter of convenience. + As shown in the :ref:`Compatibility Matrix `, most vLLM features are not applicable to + pooling models as they only work on the generation or decode stage, so performance may not improve as much. + +Offline Inference +----------------- + +The :class:`~vllm.LLM` class provides various methods for offline inference. +See :ref:`Engine Arguments ` for a list of options when initializing the model. + +For pooling models, we support the following :code:`task` options: + +- Embedding (:code:`"embed"` / :code:`"embedding"`) +- Classification (:code:`"classify"`) +- Sentence Pair Scoring (:code:`"score"`) +- Reward Modeling (:code:`"reward"`) + +The selected task determines the default :class:`~vllm.model_executor.layers.Pooler` that is used: + +- Embedding: Extract only the hidden states corresponding to the last token, and apply normalization. +- Classification: Extract only the hidden states corresponding to the last token, and apply softmax. +- Sentence Pair Scoring: Extract only the hidden states corresponding to the last token, and apply softmax. +- Reward Modeling: Extract all of the hidden states and return them directly. + +When loading `Sentence Transformers `__ models, +we attempt to override the default pooler based on its Sentence Transformers configuration file (:code:`modules.json`). + +You can customize the model's pooling method via the :code:`override_pooler_config` option, +which takes priority over both the model's and Sentence Transformers's defaults. + +``LLM.encode`` +^^^^^^^^^^^^^^ + +The :class:`~vllm.LLM.encode` method is available to all pooling models in vLLM. +It returns the aggregated hidden states directly. + +.. code-block:: python + + llm = LLM(model="intfloat/e5-mistral-7b-instruct", task="embed") + outputs = llm.encode("Hello, my name is") + + outputs = model.encode(prompts) + for output in outputs: + embeddings = output.outputs.embedding + print(f"Prompt: {prompt!r}, Embeddings (size={len(embeddings)}: {embeddings!r}") + +A code example can be found in `examples/offline_inference_embedding.py `_. + +``LLM.score`` +^^^^^^^^^^^^^ + +The :class:`~vllm.LLM.score` method outputs similarity scores between sentence pairs. +It is primarily designed for `cross-encoder models `__. +These types of models serve as rerankers between candidate query-document pairs in RAG systems. + +.. note:: + + vLLM can only perform the model inference component (e.g. embedding, reranking) of RAG. + To handle RAG at a higher level, you should use integration frameworks such as `LangChain `_. + +You can use `these tests `_ as reference. + +Online Inference +---------------- + +Our `OpenAI Compatible Server <../serving/openai_compatible_server>`__ can be used for online inference. +Please click on the above link for more details on how to launch the server. + +Embeddings API +^^^^^^^^^^^^^^ + +Our Embeddings API is similar to ``LLM.encode``, accepting both text and :ref:`multi-modal inputs `. + +The text-only API is compatible with `OpenAI Embeddings API `__ +so that you can use OpenAI client to interact with it. +A code example can be found in `examples/openai_embedding_client.py `_. + +The multi-modal API is an extension of the `OpenAI Embeddings API `__ +that incorporates `OpenAI Chat Completions API `__, +so it is not part of the OpenAI standard. Please see :ref:`this page ` for more details on how to use it. + +Score API +^^^^^^^^^ + +Our Score API is similar to ``LLM.score``. +Please see `this page <../serving/openai_compatible_server.html#score-api-for-cross-encoder-models>`__ for more details on how to use it. diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 6540e023c1ab0..b9957cf9563b1 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -3,11 +3,21 @@ Supported Models ================ -vLLM supports a variety of generative and embedding models from `HuggingFace (HF) Transformers `_. -This page lists the model architectures that are currently supported by vLLM. +vLLM supports generative and pooling models across various tasks. +If a model supports more than one task, you can set the task via the :code:`--task` argument. + +For each task, we list the model architectures that have been implemented in vLLM. Alongside each architecture, we include some popular models that use it. -For other models, you can check the :code:`config.json` file inside the model repository. +Loading a Model +^^^^^^^^^^^^^^^ + +HuggingFace Hub ++++++++++++++++ + +By default, vLLM loads models from `HuggingFace (HF) Hub `_. + +To determine whether a given model is supported, you can check the :code:`config.json` file inside the HF repository. If the :code:`"architectures"` field contains a model architecture listed below, then it should be supported in theory. .. tip:: @@ -17,38 +27,57 @@ If the :code:`"architectures"` field contains a model architecture listed below, from vllm import LLM - llm = LLM(model=...) # Name or path of your model + # For generative models (task=generate) only + llm = LLM(model=..., task="generate") # Name or path of your model output = llm.generate("Hello, my name is") print(output) - If vLLM successfully generates text, it indicates that your model is supported. + # For pooling models (task={embed,classify,reward}) only + llm = LLM(model=..., task="embed") # Name or path of your model + output = llm.encode("Hello, my name is") + print(output) + + If vLLM successfully returns text (for generative models) or hidden states (for pooling models), it indicates that your model is supported. Otherwise, please refer to :ref:`Adding a New Model ` and :ref:`Enabling Multimodal Inputs ` for instructions on how to implement your model in vLLM. Alternatively, you can `open an issue on GitHub `_ to request vLLM support. -.. note:: - To use models from `ModelScope `_ instead of HuggingFace Hub, set an environment variable: +ModelScope +++++++++++ - .. code-block:: shell +To use models from `ModelScope `_ instead of HuggingFace Hub, set an environment variable: - $ export VLLM_USE_MODELSCOPE=True +.. code-block:: shell - And use with :code:`trust_remote_code=True`. + $ export VLLM_USE_MODELSCOPE=True - .. code-block:: python +And use with :code:`trust_remote_code=True`. - from vllm import LLM +.. code-block:: python - llm = LLM(model=..., revision=..., trust_remote_code=True) # Name or path of your model - output = llm.generate("Hello, my name is") - print(output) + from vllm import LLM + + llm = LLM(model=..., revision=..., task=..., trust_remote_code=True) -Text-only Language Models -^^^^^^^^^^^^^^^^^^^^^^^^^ + # For generative models (task=generate) only + output = llm.generate("Hello, my name is") + print(output) -Text Generation ---------------- + # For pooling models (task={embed,classify,reward}) only + output = llm.encode("Hello, my name is") + print(output) + +List of Text-only Language Models +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Generative Models ++++++++++++++++++ + +See :ref:`this page ` for more information on how to use generative models. + +Text Generation (``--task generate``) +------------------------------------- .. list-table:: :widths: 25 25 50 5 5 @@ -328,8 +357,24 @@ Text Generation .. note:: Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096. -Text Embedding --------------- +Pooling Models +++++++++++++++ + +See :ref:`this page ` for more information on how to use pooling models. + +.. important:: + Since some model architectures support both generative and pooling tasks, + you should explicitly specify the task type to ensure that the model is used in pooling mode instead of generative mode. + +Text Embedding (``--task embed``) +--------------------------------- + +Any text generation model can be converted into an embedding model by passing :code:`--task embed`. + +.. note:: + To get the best results, you should use pooling models that are specifically trained as such. + +The following table lists those that are tested in vLLM. .. list-table:: :widths: 25 25 50 5 5 @@ -371,13 +416,6 @@ Text Embedding - - -.. important:: - Some model architectures support both generation and embedding tasks. - In this case, you have to pass :code:`--task embedding` to run the model in embedding mode. - -.. tip:: - You can override the model's pooling method by passing :code:`--override-pooler-config`. - .. note:: :code:`ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config. You should manually set mean pooling by passing :code:`--override-pooler-config '{"pooling_type": "MEAN"}'`. @@ -389,8 +427,8 @@ Text Embedding On the other hand, its 1.5B variant (:code:`Alibaba-NLP/gte-Qwen2-1.5B-instruct`) uses causal attention despite being described otherwise on its model card. -Reward Modeling ---------------- +Reward Modeling (``--task reward``) +----------------------------------- .. list-table:: :widths: 25 25 50 5 5 @@ -416,11 +454,8 @@ Reward Modeling For process-supervised reward models such as :code:`peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly, e.g.: :code:`--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`. -.. note:: - As an interim measure, these models are supported in both offline and online inference via Embeddings API. - -Classification ---------------- +Classification (``--task classify``) +------------------------------------ .. list-table:: :widths: 25 25 50 5 5 @@ -437,11 +472,8 @@ Classification - ✅︎ - ✅︎ -.. note:: - As an interim measure, these models are supported in both offline and online inference via Embeddings API. - -Sentence Pair Scoring ---------------------- +Sentence Pair Scoring (``--task score``) +---------------------------------------- .. list-table:: :widths: 25 25 50 5 5 @@ -468,13 +500,10 @@ Sentence Pair Scoring - - -.. note:: - These models are supported in both offline and online inference via Score API. - .. _supported_mm_models: -Multimodal Language Models -^^^^^^^^^^^^^^^^^^^^^^^^^^ +List of Multimodal Language Models +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ The following modalities are supported depending on the model: @@ -491,8 +520,15 @@ On the other hand, modalities separated by :code:`/` are mutually exclusive. - e.g.: :code:`T / I` means that the model supports text-only and image-only inputs, but not text-with-image inputs. -Text Generation ---------------- +See :ref:`this page ` on how to pass multi-modal inputs to the model. + +Generative Models ++++++++++++++++++ + +See :ref:`this page ` for more information on how to use generative models. + +Text Generation (``--task generate``) +------------------------------------- .. list-table:: :widths: 25 25 15 20 5 5 5 @@ -696,8 +732,24 @@ Text Generation The official :code:`openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 -Multimodal Embedding --------------------- +Pooling Models +++++++++++++++ + +See :ref:`this page ` for more information on how to use pooling models. + +.. important:: + Since some model architectures support both generative and pooling tasks, + you should explicitly specify the task type to ensure that the model is used in pooling mode instead of generative mode. + +Text Embedding (``--task embed``) +--------------------------------- + +Any text generation model can be converted into an embedding model by passing :code:`--task embed`. + +.. note:: + To get the best results, you should use pooling models that are specifically trained as such. + +The following table lists those that are tested in vLLM. .. list-table:: :widths: 25 25 15 25 5 5 @@ -728,12 +780,7 @@ Multimodal Embedding - - ✅︎ -.. important:: - Some model architectures support both generation and embedding tasks. - In this case, you have to pass :code:`--task embedding` to run the model in embedding mode. - -.. tip:: - You can override the model's pooling method by passing :code:`--override-pooler-config`. +---- Model Support Policy ===================== diff --git a/docs/source/usage/compatibility_matrix.rst b/docs/source/usage/compatibility_matrix.rst index a93632ff36fb8..04dd72b1e3527 100644 --- a/docs/source/usage/compatibility_matrix.rst +++ b/docs/source/usage/compatibility_matrix.rst @@ -39,13 +39,13 @@ Feature x Feature - :abbr:`prmpt adptr (Prompt Adapter)` - :ref:`SD ` - CUDA graph - - :abbr:`emd (Embedding Models)` + - :abbr:`pooling (Pooling Models)` - :abbr:`enc-dec (Encoder-Decoder Models)` - :abbr:`logP (Logprobs)` - :abbr:`prmpt logP (Prompt Logprobs)` - :abbr:`async output (Async Output Processing)` - multi-step - - :abbr:`mm (Multimodal)` + - :abbr:`mm (Multimodal Inputs)` - best-of - beam-search - :abbr:`guided dec (Guided Decoding)` @@ -151,7 +151,7 @@ Feature x Feature - - - - * - :abbr:`emd (Embedding Models)` + * - :abbr:`pooling (Pooling Models)` - ✗ - ✗ - ✗ @@ -253,7 +253,7 @@ Feature x Feature - - - - * - :abbr:`mm (Multimodal)` + * - :abbr:`mm (Multimodal Inputs)` - ✅ - `✗ `__ - `✗ `__ @@ -386,7 +386,7 @@ Feature x Hardware - ✅ - ✗ - ✅ - * - :abbr:`emd (Embedding Models)` + * - :abbr:`pooling (Pooling Models)` - ✅ - ✅ - ✅ @@ -402,7 +402,7 @@ Feature x Hardware - ✅ - ✅ - ✗ - * - :abbr:`mm (Multimodal)` + * - :abbr:`mm (Multimodal Inputs)` - ✅ - ✅ - ✅ diff --git a/examples/offline_inference_embedding.py b/examples/offline_inference_embedding.py index ae158eef2ca4c..17f6d992073d7 100644 --- a/examples/offline_inference_embedding.py +++ b/examples/offline_inference_embedding.py @@ -9,7 +9,12 @@ ] # Create an LLM. -model = LLM(model="intfloat/e5-mistral-7b-instruct", enforce_eager=True) +model = LLM( + model="intfloat/e5-mistral-7b-instruct", + task="embed", # You should pass task="embed" for embedding models + enforce_eager=True, +) + # Generate embedding. The output is a list of PoolingRequestOutputs. outputs = model.encode(prompts) # Print the outputs. diff --git a/examples/offline_inference_vision_language_embedding.py b/examples/offline_inference_vision_language_embedding.py index e1732d045f949..bf466109f0981 100644 --- a/examples/offline_inference_vision_language_embedding.py +++ b/examples/offline_inference_vision_language_embedding.py @@ -59,7 +59,7 @@ def run_e5_v(query: Query): llm = LLM( model="royokong/e5-v", - task="embedding", + task="embed", max_model_len=4096, ) @@ -88,7 +88,7 @@ def run_vlm2vec(query: Query): llm = LLM( model="TIGER-Lab/VLM2Vec-Full", - task="embedding", + task="embed", trust_remote_code=True, mm_processor_kwargs={"num_crops": 4}, ) diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 99781c55b672e..87d5aefea6cb4 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -55,7 +55,7 @@ class TestSetting: # embedding model TestSetting( model="BAAI/bge-multilingual-gemma2", - model_args=["--task", "embedding"], + model_args=["--task", "embed"], pp_size=1, tp_size=1, attn_backend="FLASHINFER", @@ -65,7 +65,7 @@ class TestSetting: # encoder-based embedding model (BERT) TestSetting( model="BAAI/bge-base-en-v1.5", - model_args=["--task", "embedding"], + model_args=["--task", "embed"], pp_size=1, tp_size=1, attn_backend="XFORMERS", diff --git a/tests/core/test_scheduler_encoder_decoder.py b/tests/core/test_scheduler_encoder_decoder.py index 7cd0416d321ef..16bea54936bc8 100644 --- a/tests/core/test_scheduler_encoder_decoder.py +++ b/tests/core/test_scheduler_encoder_decoder.py @@ -37,7 +37,7 @@ def test_scheduler_schedule_simple_encoder_decoder(): num_seq_group = 4 max_model_len = 16 scheduler_config = SchedulerConfig( - task="generate", + "generate", max_num_batched_tokens=64, max_num_seqs=num_seq_group, max_model_len=max_model_len, diff --git a/tests/entrypoints/openai/test_vision_embedding.py b/tests/entrypoints/openai/test_vision_embedding.py index 425f2a10ec855..43c63daacb17f 100644 --- a/tests/entrypoints/openai/test_vision_embedding.py +++ b/tests/entrypoints/openai/test_vision_embedding.py @@ -27,7 +27,7 @@ def server(): args = [ "--task", - "embedding", + "embed", "--dtype", "bfloat16", "--max-model-len", diff --git a/tests/models/embedding/language/test_embedding.py b/tests/models/embedding/language/test_embedding.py index 5ef8540265d14..f458ef5ef556d 100644 --- a/tests/models/embedding/language/test_embedding.py +++ b/tests/models/embedding/language/test_embedding.py @@ -54,7 +54,7 @@ def test_models( hf_outputs = hf_model.encode(example_prompts) with vllm_runner(model, - task="embedding", + task="embed", dtype=dtype, max_model_len=None, **vllm_extra_kwargs) as vllm_model: diff --git a/tests/models/embedding/language/test_scoring.py b/tests/models/embedding/language/test_scoring.py index 30fa5ea7b36c0..0c3115d195fc1 100644 --- a/tests/models/embedding/language/test_scoring.py +++ b/tests/models/embedding/language/test_scoring.py @@ -35,9 +35,7 @@ def test_llm_1_to_1(vllm_runner, hf_runner, model_name, dtype: str): with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model: hf_outputs = hf_model.predict([text_pair]).tolist() - with vllm_runner(model_name, - task="embedding", - dtype=dtype, + with vllm_runner(model_name, task="score", dtype=dtype, max_model_len=None) as vllm_model: vllm_outputs = vllm_model.score(text_pair[0], text_pair[1]) @@ -58,9 +56,7 @@ def test_llm_1_to_N(vllm_runner, hf_runner, model_name, dtype: str): with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model: hf_outputs = hf_model.predict(text_pairs).tolist() - with vllm_runner(model_name, - task="embedding", - dtype=dtype, + with vllm_runner(model_name, task="score", dtype=dtype, max_model_len=None) as vllm_model: vllm_outputs = vllm_model.score(TEXTS_1[0], TEXTS_2) @@ -82,9 +78,7 @@ def test_llm_N_to_N(vllm_runner, hf_runner, model_name, dtype: str): with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model: hf_outputs = hf_model.predict(text_pairs).tolist() - with vllm_runner(model_name, - task="embedding", - dtype=dtype, + with vllm_runner(model_name, task="score", dtype=dtype, max_model_len=None) as vllm_model: vllm_outputs = vllm_model.score(TEXTS_1, TEXTS_2) diff --git a/tests/models/embedding/vision_language/test_dse_qwen2_vl.py b/tests/models/embedding/vision_language/test_dse_qwen2_vl.py index 3dd8cb729f8a6..2641987b25a3a 100644 --- a/tests/models/embedding/vision_language/test_dse_qwen2_vl.py +++ b/tests/models/embedding/vision_language/test_dse_qwen2_vl.py @@ -93,7 +93,7 @@ def _run_test( # if we run HF first, the cuda initialization will be done and it # will hurt multiprocessing backend with fork method (the default method). with vllm_runner(model, - task="embedding", + task="embed", dtype=dtype, enforce_eager=True, max_model_len=8192) as vllm_model: diff --git a/tests/models/embedding/vision_language/test_llava_next.py b/tests/models/embedding/vision_language/test_llava_next.py index 693abd7252d5e..f4cd8b81a0d7d 100644 --- a/tests/models/embedding/vision_language/test_llava_next.py +++ b/tests/models/embedding/vision_language/test_llava_next.py @@ -47,7 +47,7 @@ def _run_test( # if we run HF first, the cuda initialization will be done and it # will hurt multiprocessing backend with fork method (the default method). with vllm_runner(model, - task="embedding", + task="embed", dtype=dtype, max_model_len=4096, enforce_eager=True) as vllm_model: diff --git a/tests/models/embedding/vision_language/test_phi3v.py b/tests/models/embedding/vision_language/test_phi3v.py index 6145aff1a5ea2..9374c23dd6ffe 100644 --- a/tests/models/embedding/vision_language/test_phi3v.py +++ b/tests/models/embedding/vision_language/test_phi3v.py @@ -39,7 +39,7 @@ def _run_test( # vLLM needs a fresh new process without cuda initialization. # if we run HF first, the cuda initialization will be done and it # will hurt multiprocessing backend with fork method (the default method). - with vllm_runner(model, task="embedding", dtype=dtype, + with vllm_runner(model, task="embed", dtype=dtype, enforce_eager=True) as vllm_model: vllm_outputs = vllm_model.encode(input_texts, images=input_images) diff --git a/tests/test_config.py b/tests/test_config.py index 45b0b938af215..4518adfc31bfc 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -7,11 +7,17 @@ from vllm.platforms import current_platform -@pytest.mark.parametrize(("model_id", "expected_task"), [ - ("facebook/opt-125m", "generate"), - ("intfloat/e5-mistral-7b-instruct", "embedding"), -]) -def test_auto_task(model_id, expected_task): +@pytest.mark.parametrize( + ("model_id", "expected_runner_type", "expected_task"), + [ + ("facebook/opt-125m", "generate", "generate"), + ("intfloat/e5-mistral-7b-instruct", "pooling", "embed"), + ("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"), + ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "score"), + ("Qwen/Qwen2.5-Math-RM-72B", "pooling", "reward"), + ], +) +def test_auto_task(model_id, expected_runner_type, expected_task): config = ModelConfig( model_id, task="auto", @@ -22,6 +28,7 @@ def test_auto_task(model_id, expected_task): dtype="float16", ) + assert config.runner_type == expected_runner_type assert config.task == expected_task diff --git a/vllm/config.py b/vllm/config.py index 2a9f0ebae997d..2d9a76fe7ddb1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -45,13 +45,27 @@ logger = init_logger(__name__) -_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 +_POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 -TaskOption = Literal["auto", "generate", "embedding"] +TaskOption = Literal["auto", "generate", "embedding", "embed", "classify", + "score", "reward"] -# "draft" is only used internally for speculative decoding -_Task = Literal["generate", "embedding", "draft"] +_ResolvedTask = Literal["generate", "embed", "classify", "score", "reward", + "draft"] + +RunnerType = Literal["generate", "pooling", "draft"] + +_RUNNER_TASKS: Dict[RunnerType, List[_ResolvedTask]] = { + "generate": ["generate"], + "pooling": ["embed", "classify", "score", "reward"], + "draft": ["draft"], +} + +_TASK_RUNNER: Dict[_ResolvedTask, RunnerType] = { + task: runner + for runner, tasks in _RUNNER_TASKS.items() for task in tasks +} HfOverrides = Union[Dict[str, Any], Callable[[PretrainedConfig], PretrainedConfig]] @@ -144,7 +158,7 @@ class ModelConfig: def __init__( self, model: str, - task: Union[TaskOption, _Task], + task: Union[TaskOption, Literal["draft"]], tokenizer: str, tokenizer_mode: str, trust_remote_code: bool, @@ -295,6 +309,7 @@ def __init__( supported_tasks, task = self._resolve_task(task, self.hf_config) self.supported_tasks = supported_tasks self.task: Final = task + self.pooler_config = self._init_pooler_config(override_pooler_config) self._verify_quantization() @@ -323,7 +338,7 @@ def _init_pooler_config( override_pooler_config: Optional["PoolerConfig"], ) -> Optional["PoolerConfig"]: - if self.task == "embedding": + if self.runner_type == "pooling": user_config = override_pooler_config or PoolerConfig() base_config = get_pooling_config(self.model, self.revision) @@ -357,60 +372,90 @@ def _verify_tokenizer_mode(self) -> None: "either 'auto', 'slow' or 'mistral'.") self.tokenizer_mode = tokenizer_mode + def _get_preferred_task( + self, + architectures: List[str], + supported_tasks: Set[_ResolvedTask], + ) -> Optional[_ResolvedTask]: + model_id = self.model + if get_pooling_config(model_id, self.revision): + return "embed" + if ModelRegistry.is_cross_encoder_model(architectures): + return "score" + + suffix_to_preferred_task: List[Tuple[str, _ResolvedTask]] = [ + # Other models follow this pattern + ("ForCausalLM", "generate"), + ("ForConditionalGeneration", "generate"), + ("ForSequenceClassification", "classify"), + ("ChatModel", "generate"), + ("LMHeadModel", "generate"), + ("EmbeddingModel", "embed"), + ("RewardModel", "reward"), + ] + _, arch = ModelRegistry.inspect_model_cls(architectures) + + for suffix, pref_task in suffix_to_preferred_task: + if arch.endswith(suffix) and pref_task in supported_tasks: + return pref_task + + return None + def _resolve_task( self, - task_option: Union[TaskOption, _Task], + task_option: Union[TaskOption, Literal["draft"]], hf_config: PretrainedConfig, - ) -> Tuple[Set[_Task], _Task]: + ) -> Tuple[Set[_ResolvedTask], _ResolvedTask]: if task_option == "draft": return {"draft"}, "draft" architectures = getattr(hf_config, "architectures", []) - task_support: Dict[_Task, bool] = { + runner_support: Dict[RunnerType, bool] = { # NOTE: Listed from highest to lowest priority, # in case the model supports multiple of them "generate": ModelRegistry.is_text_generation_model(architectures), - "embedding": ModelRegistry.is_pooling_model(architectures), + "pooling": ModelRegistry.is_pooling_model(architectures), } - supported_tasks_lst: List[_Task] = [ - task for task, is_supported in task_support.items() if is_supported + supported_runner_types_lst: List[RunnerType] = [ + runner_type + for runner_type, is_supported in runner_support.items() + if is_supported + ] + + supported_tasks_lst: List[_ResolvedTask] = [ + task for runner_type in supported_runner_types_lst + for task in _RUNNER_TASKS[runner_type] ] supported_tasks = set(supported_tasks_lst) if task_option == "auto": selected_task = next(iter(supported_tasks_lst)) - if len(supported_tasks) > 1: - suffix_to_preferred_task: List[Tuple[str, _Task]] = [ - # Hardcode the models that are exceptions - ("AquilaModel", "generate"), - ("ChatGLMModel", "generate"), - # Other models follow this pattern - ("ForCausalLM", "generate"), - ("ForConditionalGeneration", "generate"), - ("ChatModel", "generate"), - ("LMHeadModel", "generate"), - ("EmbeddingModel", "embedding"), - ("RewardModel", "embedding"), - ("ForSequenceClassification", "embedding"), - ] - info, arch = ModelRegistry.inspect_model_cls(architectures) - - for suffix, pref_task in suffix_to_preferred_task: - if arch.endswith(suffix) and pref_task in supported_tasks: - selected_task = pref_task - break - else: - if (arch.endswith("Model") - and info.architecture.endswith("ForCausalLM") - and "embedding" in supported_tasks): - selected_task = "embedding" + if len(supported_tasks_lst) > 1: + preferred_task = self._get_preferred_task( + architectures, supported_tasks) + if preferred_task is not None: + selected_task = preferred_task logger.info( "This model supports multiple tasks: %s. " "Defaulting to '%s'.", supported_tasks, selected_task) else: + # Aliases + if task_option == "embedding": + preferred_task = self._get_preferred_task( + architectures, supported_tasks) + if preferred_task != "embed": + msg = ("The 'embedding' task will be restricted to " + "embedding models in a future release. Please " + "pass `--task classify`, `--task score`, or " + "`--task reward` explicitly for other pooling " + "models.") + warnings.warn(msg, DeprecationWarning, stacklevel=2) + + task_option = preferred_task or "embed" + if task_option not in supported_tasks: msg = ( f"This model does not support the '{task_option}' task. " @@ -533,7 +578,7 @@ def verify_async_output_proc(self, parallel_config, speculative_config, # Async postprocessor is not necessary with embedding mode # since there is no token generation - if self.task == "embedding": + if self.runner_type == "pooling": self.use_async_output_proc = False # Reminder: Please update docs/source/usage/compatibility_matrix.rst @@ -750,6 +795,14 @@ def is_cross_encoder(self) -> bool: architectures = getattr(self.hf_config, "architectures", []) return ModelRegistry.is_cross_encoder_model(architectures) + @property + def supported_runner_types(self) -> Set[RunnerType]: + return {_TASK_RUNNER[task] for task in self.supported_tasks} + + @property + def runner_type(self) -> RunnerType: + return _TASK_RUNNER[self.task] + class CacheConfig: """Configuration for the KV cache. @@ -1096,7 +1149,7 @@ def _verify_args(self) -> None: class SchedulerConfig: """Scheduler configuration.""" - task: str = "generate" # The task to use the model for. + runner_type: str = "generate" # The runner type to launch for the model. # Maximum number of tokens to be processed in a single iteration. max_num_batched_tokens: int = field(default=None) # type: ignore @@ -1164,11 +1217,11 @@ def __post_init__(self) -> None: # for higher throughput. self.max_num_batched_tokens = max(self.max_model_len, 2048) - if self.task == "embedding": - # For embedding, choose specific value for higher throughput + if self.runner_type == "pooling": + # Choose specific value for higher throughput self.max_num_batched_tokens = max( self.max_num_batched_tokens, - _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS, + _POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, ) if self.is_multimodal_model: # The value needs to be at least the number of multimodal tokens diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 94c62743883ec..c3bc6becf0995 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -337,7 +337,7 @@ def __init__( self.lora_config = lora_config version = "selfattn" - if (self.scheduler_config.task == "embedding" + if (self.scheduler_config.runner_type == "pooling" or self.cache_config.is_attention_free): version = "placeholder" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 7b9adc401abcf..d485c2a9e7208 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1066,7 +1066,7 @@ def create_engine_config(self, if (is_gpu and not use_sliding_window and not use_spec_decode and not self.enable_lora and not self.enable_prompt_adapter - and model_config.task != "embedding"): + and model_config.runner_type != "pooling"): self.enable_chunked_prefill = True logger.warning( "Chunked prefill is enabled by default for models with " @@ -1083,7 +1083,8 @@ def create_engine_config(self, "errors during the initial memory profiling phase, or result " "in low performance due to small KV cache space. Consider " "setting --max-model-len to a smaller value.", max_model_len) - elif self.enable_chunked_prefill and model_config.task == "embedding": + elif (self.enable_chunked_prefill + and model_config.runner_type == "pooling"): msg = "Chunked prefill is not supported for embedding models" raise ValueError(msg) @@ -1144,7 +1145,7 @@ def create_engine_config(self, " please file an issue with detailed information.") scheduler_config = SchedulerConfig( - task=model_config.task, + runner_type=model_config.runner_type, max_num_batched_tokens=self.max_num_batched_tokens, max_num_seqs=self.max_num_seqs, max_model_len=model_config.max_model_len, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 6eca304b45f07..9be30c635cb2c 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -288,7 +288,7 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: self.model_executor = executor_class(vllm_config=vllm_config, ) - if self.model_config.task != "embedding": + if self.model_config.runner_type != "pooling": self._initialize_kv_caches() # If usage stat is enabled, collect relevant info. @@ -1123,7 +1123,7 @@ def _process_model_outputs(self, seq_group.metrics.model_execute_time = ( o.model_execute_time) - if self.model_config.task == "embedding": + if self.model_config.runner_type == "pooling": self._process_sequence_group_outputs(seq_group, output) else: self.output_processor.process_prompt_logprob(seq_group, output) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 2a02187223a33..0bec978c4869c 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -381,19 +381,20 @@ def generate( considered legacy and may be deprecated in the future. You should instead pass them via the ``inputs`` parameter. """ - task = self.llm_engine.model_config.task - if task != "generate": + runner_type = self.llm_engine.model_config.runner_type + if runner_type != "generate": messages = [ "LLM.generate() is only supported for (conditional) generation " "models (XForCausalLM, XForConditionalGeneration).", ] - supported_tasks = self.llm_engine.model_config.supported_tasks - if "generate" in supported_tasks: + supported_runner_types = self.llm_engine.model_config \ + .supported_runner_types + if "generate" in supported_runner_types: messages.append( - "Your model supports the 'generate' task, but is " - f"currently initialized for the '{task}' task. Please " - "initialize the model using `--task generate`.") + "Your model supports the 'generate' runner, but is " + f"currently initialized for the '{runner_type}' runner. " + "Please initialize vLLM using `--task generate`.") raise ValueError(" ".join(messages)) @@ -793,16 +794,18 @@ def encode( considered legacy and may be deprecated in the future. You should instead pass them via the ``inputs`` parameter. """ - task = self.llm_engine.model_config.task - if task != "embedding": - messages = ["LLM.encode() is only supported for embedding models."] + runner_type = self.llm_engine.model_config.runner_type + if runner_type != "pooling": + messages = ["LLM.encode() is only supported for pooling models."] - supported_tasks = self.llm_engine.model_config.supported_tasks - if "embedding" in supported_tasks: + supported_runner_types = self.llm_engine.model_config \ + .supported_runner_types + if "pooling" in supported_runner_types: messages.append( - "Your model supports the 'embedding' task, but is " - f"currently initialized for the '{task}' task. Please " - "initialize the model using `--task embedding`.") + "Your model supports the 'pooling' runner, but is " + f"currently initialized for the '{runner_type}' runner. " + "Please initialize vLLM using `--task embed`, " + "`--task classify`, `--task score` etc.") raise ValueError(" ".join(messages)) @@ -864,21 +867,23 @@ def score( A list of ``PoolingRequestOutput`` objects containing the generated scores in the same order as the input prompts. """ - task = self.llm_engine.model_config.task - if task != "embedding": - messages = ["LLM.score() is only supported for embedding models."] + runner_type = self.llm_engine.model_config.runner_type + if runner_type != "pooling": + messages = ["LLM.score() is only supported for pooling models."] - supported_tasks = self.llm_engine.model_config.supported_tasks - if "embedding" in supported_tasks: + supported_runner_types = self.llm_engine.model_config \ + .supported_runner_types + if "pooling" in supported_runner_types: messages.append( - "Your model supports the 'embedding' task, but is " - f"currently initialized for the '{task}' task. Please " - "initialize the model using `--task embedding`.") + "Your model supports the 'pooling' runner, but is " + f"currently initialized for the '{runner_type}' runner. " + "Please initialize vLLM using `--task embed`, " + "`--task classify`, `--task score` etc.") raise ValueError(" ".join(messages)) if not self.llm_engine.model_config.is_cross_encoder: - raise ValueError("Your model does not support the cross encoding") + raise ValueError("Your model does not support cross encoding") tokenizer = self.llm_engine.get_tokenizer() diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 0f93eb54111ad..a345f8caeeed2 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -573,7 +573,7 @@ def init_app_state( enable_auto_tools=args.enable_auto_tool_choice, tool_parser=args.tool_call_parser, enable_prompt_tokens_details=args.enable_prompt_tokens_details, - ) if model_config.task == "generate" else None + ) if model_config.runner_type == "generate" else None state.openai_serving_completion = OpenAIServingCompletion( engine_client, model_config, @@ -582,7 +582,7 @@ def init_app_state( prompt_adapters=args.prompt_adapters, request_logger=request_logger, return_tokens_as_token_ids=args.return_tokens_as_token_ids, - ) if model_config.task == "generate" else None + ) if model_config.runner_type == "generate" else None state.openai_serving_embedding = OpenAIServingEmbedding( engine_client, model_config, @@ -590,13 +590,13 @@ def init_app_state( request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, - ) if model_config.task == "embedding" else None + ) if model_config.runner_type == "pooling" else None state.openai_serving_scores = OpenAIServingScores( engine_client, model_config, base_model_paths, request_logger=request_logger - ) if (model_config.task == "embedding" \ + ) if (model_config.runner_type == "pooling" \ and model_config.is_cross_encoder) else None state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 00cdb3b6839f5..675daf54c0d0d 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -224,7 +224,7 @@ async def main(args): chat_template=None, chat_template_content_format="auto", enable_prompt_tokens_details=args.enable_prompt_tokens_details, - ) if model_config.task == "generate" else None + ) if model_config.runner_type == "generate" else None openai_serving_embedding = OpenAIServingEmbedding( engine, model_config, @@ -232,7 +232,7 @@ async def main(args): request_logger=request_logger, chat_template=None, chat_template_content_format="auto", - ) if model_config.task == "embedding" else None + ) if model_config.runner_type == "pooling" else None tracker = BatchProgressTracker() logger.info("Reading batch from %s...", args.input_file) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index cfb89e0f336bc..f15e7176b3d50 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -35,7 +35,7 @@ def get_model_architecture( architectures = ["QuantMixtralForCausalLM"] model_cls, arch = ModelRegistry.resolve_model_cls(architectures) - if model_config.task == "embedding": + if model_config.runner_type == "pooling": model_cls = as_embedding_model(model_cls) return model_cls, arch diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index fdb241e6753fb..55a5c4dff3a5c 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -42,7 +42,7 @@ def __init__( executor_class: Type[Executor], usage_context: UsageContext, ): - assert vllm_config.model_config.task != "embedding" + assert vllm_config.model_config.runner_type != "pooling" logger.info("Initializing an LLM engine (v%s) with config: %s", VLLM_VERSION, vllm_config) diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 4fad1a3f4caeb..ba3d4a130a80b 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -163,7 +163,7 @@ def __init__( not in ["medusa", "mlp_speculator", "eagle"]) \ else {"return_hidden_states": True} ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner - if self.model_config.task == "embedding": + if self.model_config.runner_type == "pooling": ModelRunnerClass = CPUPoolingModelRunner elif self.model_config.is_encoder_decoder: ModelRunnerClass = CPUEncoderDecoderModelRunner diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 094dd5a5d08b3..832b9903b7abc 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -75,7 +75,7 @@ def __init__( else {"return_hidden_states": True} ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner - if model_config.task == "embedding": + if model_config.runner_type == "pooling": ModelRunnerClass = PoolingModelRunner elif self.model_config.is_encoder_decoder: ModelRunnerClass = EncoderDecoderModelRunner