forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Misc] Split up pooling tasks (vllm-project#10820)
Signed-off-by: DarkLight1337 <[email protected]>
- Loading branch information
1 parent
40766ca
commit 795ec7e
Showing
27 changed files
with
527 additions
and
168 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
.. _generative_models: | ||
|
||
Generative Models | ||
================= | ||
|
||
vLLM provides first-class support for generative models, which covers most of LLMs. | ||
|
||
In vLLM, generative models implement the :class:`~vllm.model_executor.models.VllmModelForTextGeneration` interface. | ||
Based on the final hidden states of the input, these models output log probabilities of the tokens to generate, | ||
which are then passed through :class:`~vllm.model_executor.layers.Sampler` to obtain the final text. | ||
|
||
Offline Inference | ||
----------------- | ||
|
||
The :class:`~vllm.LLM` class provides various methods for offline inference. | ||
See :ref:`Engine Arguments <engine_args>` for a list of options when initializing the model. | ||
|
||
For generative models, the only supported :code:`task` option is :code:`"generate"`. | ||
Usually, this is automatically inferred so you don't have to specify it. | ||
|
||
``LLM.generate`` | ||
^^^^^^^^^^^^^^^^ | ||
|
||
The :class:`~vllm.LLM.generate` method is available to all generative models in vLLM. | ||
It is similar to `its counterpart in HF Transformers <https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate>`__, | ||
except that tokenization and detokenization are also performed automatically. | ||
|
||
.. code-block:: python | ||
llm = LLM(model="facebook/opt-125m") | ||
outputs = llm.generate("Hello, my name is") | ||
for output in outputs: | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text | ||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||
You can optionally control the language generation by passing :class:`~vllm.SamplingParams`. | ||
For example, you can use greedy sampling by setting :code:`temperature=0`: | ||
|
||
.. code-block:: python | ||
llm = LLM(model="facebook/opt-125m") | ||
params = SamplingParams(temperature=0) | ||
outputs = llm.generate("Hello, my name is", params) | ||
for output in outputs: | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text | ||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||
A code example can be found in `examples/offline_inference.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference.py>`_. | ||
|
||
``LLM.beam_search`` | ||
^^^^^^^^^^^^^^^^^^^ | ||
|
||
The :class:`~vllm.LLM.beam_search` method implements `beam search <https://huggingface.co/docs/transformers/en/generation_strategies#beam-search-decoding>`__ on top of :class:`~vllm.LLM.generate`. | ||
For example, to search using 5 beams and output at most 50 tokens: | ||
|
||
.. code-block:: python | ||
llm = LLM(model="facebook/opt-125m") | ||
params = BeamSearchParams(beam_width=5, max_tokens=50) | ||
outputs = llm.generate("Hello, my name is", params) | ||
for output in outputs: | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text | ||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||
``LLM.chat`` | ||
^^^^^^^^^^^^ | ||
|
||
The :class:`~vllm.LLM.chat` method implements chat functionality on top of :class:`~vllm.LLM.generate`. | ||
In particular, it accepts input similar to `OpenAI Chat Completions API <https://platform.openai.com/docs/api-reference/chat>`__ | ||
and automatically applies the model's `chat template <https://huggingface.co/docs/transformers/en/chat_templating>`__ to format the prompt. | ||
|
||
.. important:: | ||
|
||
In general, only instruction-tuned models have a chat template. | ||
Base models may perform poorly as they are not trained to respond to the chat conversation. | ||
|
||
.. code-block:: python | ||
llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct") | ||
conversation = [ | ||
{ | ||
"role": "system", | ||
"content": "You are a helpful assistant" | ||
}, | ||
{ | ||
"role": "user", | ||
"content": "Hello" | ||
}, | ||
{ | ||
"role": "assistant", | ||
"content": "Hello! How can I assist you today?" | ||
}, | ||
{ | ||
"role": "user", | ||
"content": "Write an essay about the importance of higher education.", | ||
}, | ||
] | ||
outputs = llm.chat(conversation) | ||
for output in outputs: | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text | ||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||
A code example can be found in `examples/offline_inference_chat.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_chat.py>`_. | ||
|
||
If the model doesn't have a chat template or you want to specify another one, | ||
you can explicitly pass a chat template: | ||
|
||
.. code-block:: python | ||
from vllm.entrypoints.chat_utils import load_chat_template | ||
# You can find a list of existing chat templates under `examples/` | ||
custom_template = load_chat_template(chat_template="<path_to_template>") | ||
print("Loaded chat template:", custom_template) | ||
outputs = llm.chat(conversation, chat_template=custom_template) | ||
Online Inference | ||
---------------- | ||
|
||
Our `OpenAI Compatible Server <../serving/openai_compatible_server>`__ can be used for online inference. | ||
Please click on the above link for more details on how to launch the server. | ||
|
||
Completions API | ||
^^^^^^^^^^^^^^^ | ||
|
||
Our Completions API is similar to ``LLM.generate`` but only accepts text. | ||
It is compatible with `OpenAI Completions API <https://platform.openai.com/docs/api-reference/completions>`__ | ||
so that you can use OpenAI client to interact with it. | ||
A code example can be found in `examples/openai_completion_client.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_completion_client.py>`_. | ||
|
||
Chat API | ||
^^^^^^^^ | ||
|
||
Our Chat API is similar to ``LLM.chat``, accepting both text and :ref:`multi-modal inputs <multimodal_inputs>`. | ||
It is compatible with `OpenAI Chat Completions API <https://platform.openai.com/docs/api-reference/chat>`__ | ||
so that you can use OpenAI client to interact with it. | ||
A code example can be found in `examples/openai_chat_completion_client.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_chat_completion_client.py>`_. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
.. _pooling_models: | ||
|
||
Pooling Models | ||
============== | ||
|
||
vLLM also supports pooling models, including embedding, reranking and reward models. | ||
|
||
In vLLM, pooling models implement the :class:`~vllm.model_executor.models.VllmModelForPooling` interface. | ||
These models use a :class:`~vllm.model_executor.layers.Pooler` to aggregate the final hidden states of the input | ||
before returning them. | ||
|
||
.. note:: | ||
|
||
We currently support pooling models primarily as a matter of convenience. | ||
As shown in the :ref:`Compatibility Matrix <compatibility_matrix>`, most vLLM features are not applicable to | ||
pooling models as they only work on the generation or decode stage, so performance may not improve as much. | ||
|
||
Offline Inference | ||
----------------- | ||
|
||
The :class:`~vllm.LLM` class provides various methods for offline inference. | ||
See :ref:`Engine Arguments <engine_args>` for a list of options when initializing the model. | ||
|
||
For pooling models, we support the following :code:`task` options: | ||
|
||
- Embedding (:code:`"embed"` / :code:`"embedding"`) | ||
- Classification (:code:`"classify"`) | ||
- Sentence Pair Scoring (:code:`"score"`) | ||
- Reward Modeling (:code:`"reward"`) | ||
|
||
The selected task determines the default :class:`~vllm.model_executor.layers.Pooler` that is used: | ||
|
||
- Embedding: Extract only the hidden states corresponding to the last token, and apply normalization. | ||
- Classification: Extract only the hidden states corresponding to the last token, and apply softmax. | ||
- Sentence Pair Scoring: Extract only the hidden states corresponding to the last token, and apply softmax. | ||
- Reward Modeling: Extract all of the hidden states and return them directly. | ||
|
||
When loading `Sentence Transformers <https://huggingface.co/sentence-transformers>`__ models, | ||
we attempt to override the default pooler based on its Sentence Transformers configuration file (:code:`modules.json`). | ||
|
||
You can customize the model's pooling method via the :code:`override_pooler_config` option, | ||
which takes priority over both the model's and Sentence Transformers's defaults. | ||
|
||
``LLM.encode`` | ||
^^^^^^^^^^^^^^ | ||
|
||
The :class:`~vllm.LLM.encode` method is available to all pooling models in vLLM. | ||
It returns the aggregated hidden states directly. | ||
|
||
.. code-block:: python | ||
llm = LLM(model="intfloat/e5-mistral-7b-instruct", task="embed") | ||
outputs = llm.encode("Hello, my name is") | ||
outputs = model.encode(prompts) | ||
for output in outputs: | ||
embeddings = output.outputs.embedding | ||
print(f"Prompt: {prompt!r}, Embeddings (size={len(embeddings)}: {embeddings!r}") | ||
A code example can be found in `examples/offline_inference_embedding.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_embedding.py>`_. | ||
|
||
``LLM.score`` | ||
^^^^^^^^^^^^^ | ||
|
||
The :class:`~vllm.LLM.score` method outputs similarity scores between sentence pairs. | ||
It is primarily designed for `cross-encoder models <https://www.sbert.net/examples/applications/cross-encoder/README.html>`__. | ||
These types of models serve as rerankers between candidate query-document pairs in RAG systems. | ||
|
||
.. note:: | ||
|
||
vLLM can only perform the model inference component (e.g. embedding, reranking) of RAG. | ||
To handle RAG at a higher level, you should use integration frameworks such as `LangChain <https://github.com/langchain-ai/langchain>`_. | ||
|
||
You can use `these tests <https://github.com/vllm-project/vllm/blob/main/tests/models/embedding/language/test_scoring.py>`_ as reference. | ||
|
||
Online Inference | ||
---------------- | ||
|
||
Our `OpenAI Compatible Server <../serving/openai_compatible_server>`__ can be used for online inference. | ||
Please click on the above link for more details on how to launch the server. | ||
|
||
Embeddings API | ||
^^^^^^^^^^^^^^ | ||
|
||
Our Embeddings API is similar to ``LLM.encode``, accepting both text and :ref:`multi-modal inputs <multimodal_inputs>`. | ||
|
||
The text-only API is compatible with `OpenAI Embeddings API <https://platform.openai.com/docs/api-reference/embeddings>`__ | ||
so that you can use OpenAI client to interact with it. | ||
A code example can be found in `examples/openai_embedding_client.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_embedding_client.py>`_. | ||
|
||
The multi-modal API is an extension of the `OpenAI Embeddings API <https://platform.openai.com/docs/api-reference/embeddings>`__ | ||
that incorporates `OpenAI Chat Completions API <https://platform.openai.com/docs/api-reference/chat>`__, | ||
so it is not part of the OpenAI standard. Please see :ref:`this page <multimodal_inputs>` for more details on how to use it. | ||
|
||
Score API | ||
^^^^^^^^^ | ||
|
||
Our Score API is similar to ``LLM.score``. | ||
Please see `this page <../serving/openai_compatible_server.html#score-api-for-cross-encoder-models>`__ for more details on how to use it. |
Oops, something went wrong.