diff --git a/.github/workflows/png-lint.yml b/.github/workflows/png-lint.yml new file mode 100644 index 0000000000000..4932af943a07b --- /dev/null +++ b/.github/workflows/png-lint.yml @@ -0,0 +1,37 @@ +name: Lint PNG exports from excalidraw +on: + push: + branches: + - "main" + paths: + - '*.excalidraw.png' + - '.github/workflows/png-lint.yml' + pull_request: + branches: + - "main" + paths: + - '*.excalidraw.png' + - '.github/workflows/png-lint.yml' + +env: + LC_ALL: en_US.UTF-8 + +defaults: + run: + shell: bash + +permissions: + contents: read + +jobs: + actionlint: + runs-on: ubuntu-latest + steps: + - name: "Checkout" + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + + - name: "Run png-lint.sh to check excalidraw exported images" + run: | + tools/png-lint.sh diff --git a/docs/source/assets/design/arch_overview/entrypoints.excalidraw.png b/docs/source/assets/design/arch_overview/entrypoints.excalidraw.png new file mode 100644 index 0000000000000..bbf46286cfe5d Binary files /dev/null and b/docs/source/assets/design/arch_overview/entrypoints.excalidraw.png differ diff --git a/docs/source/assets/design/arch_overview/llm_engine.excalidraw.png b/docs/source/assets/design/arch_overview/llm_engine.excalidraw.png new file mode 100644 index 0000000000000..ade1d602a9187 Binary files /dev/null and b/docs/source/assets/design/arch_overview/llm_engine.excalidraw.png differ diff --git a/docs/source/design/arch_overview.rst b/docs/source/design/arch_overview.rst new file mode 100644 index 0000000000000..a9e7b4bd69bc7 --- /dev/null +++ b/docs/source/design/arch_overview.rst @@ -0,0 +1,274 @@ +.. _arch_overview: + +Architecture Overview +====================== + +This document provides an overview of the vLLM architecture. + +.. contents:: Table of Contents + :local: + :depth: 2 + +Entrypoints +----------- + +vLLM provides a number of entrypoints for interacting with the system. The +following diagram shows the relationship between them. + +.. image:: /assets/design/arch_overview/entrypoints.excalidraw.png + :alt: Entrypoints Diagram + +LLM Class +^^^^^^^^^ + +The LLM class provides the primary Python interface for doing offline inference, +which is interacting with a model without using a separate model inference +server. + +Here is a sample of `LLM` class usage: + +.. code-block:: python + + from vllm import LLM, SamplingParams + + # Define a list of input prompts + prompts = [ + "Hello, my name is", + "The capital of France is", + "The largest ocean is", + ] + + # Define sampling parameters + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + + # Initialize the LLM engine with the OPT-125M model + llm = LLM(model="Qwen/Qwen2.5-1.5B-Instruct") + + # Generate outputs for the input prompts + outputs = llm.generate(prompts, sampling_params) + + # Print the generated outputs + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +More API details can be found in the :doc:`Offline Inference +` section of the API docs. + +The code for the `LLM` class can be found in `vllm/entrypoints/llm.py +`_. + +OpenAI-compatible API server +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The second primary interface to vLLM is via its OpenAI-compatible API server. +This server can be started using the `vllm serve` command. + +.. code-block:: bash + + vllm serve + +The code for the `vllm` CLI can be found in `vllm/scripts.py +`_. + +Sometimes you may see the API server entrypoint used directly instead of via the +`vllm` CLI command. For example: + +.. code-block:: bash + + python -m vllm.entrypoints.openai.api_server --model + +That code can be found in `vllm/entrypoints/openai/api_server.py +`_. + +More details on the API server can be found in the :doc:`OpenAI Compatible +Server ` document. + +LLM Engine +---------- + +The `LLMEngine` and `AsyncLLMEngine` classes are central to the functioning of +the vLLM system, handling model inference and asynchronous request processing. + +.. image:: /assets/design/arch_overview/llm_engine.excalidraw.png + :alt: LLMEngine Diagram + +LLMEngine +^^^^^^^^^ + +The `LLMEngine` class is the core component of the vLLM engine. It is +responsible for receiving requests from clients and generating outputs from the +model. The `LLMEngine` includes input processing, model execution (possibly +distributed across multiple hosts and/or GPUs), scheduling, and output +processing. + +- **Input Processing**: Handles tokenization of input text using the specified + tokenizer. + +- **Scheduling**: Chooses which requests are processed in each step. + +- **Model Execution**: Manages the execution of the language model, including + distributed execution across multiple GPUs. + +- **Output Processing**: Processes the outputs generated by the model, decoding the + token IDs from a language model into human-readable text. + +The code for `LLMEngine` can be found in `vllm/engine/llm_engine.py`_. + +.. _vllm/engine/llm_engine.py: https://github.com/vllm-project/vllm/tree/main/vllm/engine/llm_engine.py + +AsyncLLMEngine +^^^^^^^^^^^^^^ + +The `AsyncLLMEngine` class is an asynchronous wrapper for the `LLMEngine` class. +It uses `asyncio` to create a background loop that continuously processes +incoming requests. The `AsyncLLMEngine` is designed for online serving, where it +can handle multiple concurrent requests and stream outputs to clients. + +The OpenAI-compatible API server uses the `AsyncLLMEngine`. There is also a demo +API server that serves as a simpler example in +`vllm/entrypoints/api_server.py`_. + +.. _vllm/entrypoints/api_server.py: https://github.com/vllm-project/vllm/tree/main/vllm/entrypoints/api_server.py + +The code for `AsyncLLMEngine` can be found in `vllm/engine/async_llm_engine.py`_. + +.. _vllm/engine/async_llm_engine.py: https://github.com/vllm-project/vllm/tree/main/vllm/engine/async_llm_engine.py + +Worker +------ + +A worker is a process that runs the model inference. vLLM follows the common +practice of using one process to control one accelerator device, such as GPUs. +For example, if we use tensor parallelism of size 2 and pipeline parallelism of +size 2, we will have 4 workers in total. Workers are identified by their +``rank`` and ``local_rank``. ``rank`` is used for global orchestration, while +``local_rank`` is mainly used for assigning the accelerator device and accessing +local resources such as the file system and shared memory. + +Model Runner +------------ + +Every worker has one model runner object, responsible for loading and running +the model. Much of the model execution logic resides here, such as preparing +input tensors and capturing cudagraphs. + +Model +----- + +Every model runner object has one model object, which is the actual +``torch.nn.Module`` instance. See :ref:`huggingface_integration` for how various +configurations affect the class we ultimately get. + +Class Hierarchy +--------------- + +The following figure shows the class hierarchy of vLLM: + + .. figure:: /assets/design/hierarchy.png + :alt: query + :width: 100% + :align: center + +There are several important design choices behind this class hierarchy: + +1. **Extensibility**: All classes in the hierarchy accept a configuration object +containing all the necessary information. The `VllmConfig +`__ +class is the main configuration object that is passed around. The class +hierarchy is quite deep, and every class needs to read the configuration it is +interested in. By encapsulating all configurations in one object, we can easily +pass the configuration object around and access the configuration we need. +Suppose we want to add a new feature (this is often the case given how fast the +field of LLM inference is evolving) that only touches the model runner. We will +have to add a new configuration option in the `VllmConfig` class. Since we pass +the whole config object around, we only need to add the configuration option to +the `VllmConfig` class, and the model runner can access it directly. We don't +need to change the constructor of the engine, worker, or model class to pass the +new configuration option. + +2. **Uniformity**: The model runner needs a unified interface to create and +initialize the model. vLLM supports more than 50 types of popular open-source +models. Each model has its own initialization logic. If the constructor +signature varies with models, the model runner does not know how to call the +constructor accordingly, without complicated and error-prone inspection logic. +By making the constructor of the model class uniform, the model runner can +easily create and initialize the model without knowing the specific model type. +This is also useful for composing models. Vision-language models often consist +of a vision model and a language model. By making the constructor uniform, we +can easily create a vision model and a language model and compose them into a +vision-language model. + +.. note:: + + To support this change, all vLLM models' signatures have been updated to: + + .. code-block:: python + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + + To avoid accidentally passing incorrect arguments, the constructor is now keyword-only. This ensures that the constructor will raise an error if old configurations are passed. vLLM developers have already made this change for all models within vLLM. For out-of-tree registered models, developers need to update their models, for example by adding shim code to adapt the old constructor signature to the new one: + + .. code-block:: python + + class MyOldModel(nn.Module): + def __init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + ) -> None: + ... + + from vllm.config import VllmConfig + class MyNewModel(MyOldModel): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + super().__init__(config, cache_config, quant_config, lora_config, prefix) + + if __version__ >= "0.6.4": + MyModel = MyNewModel + else: + MyModel = MyOldModel + + This way, the model can work with both old and new versions of vLLM. + +3. **Sharding and Quantization at Initialization**: Certain features require +changing the model weights. For example, tensor parallelism needs to shard the +model weights, and quantization needs to quantize the model weights. There are +two possible ways to implement this feature. One way is to change the model +weights after the model is initialized. The other way is to change the model +weights during the model initialization. vLLM chooses the latter. The first +approach is not scalable to large models. Suppose we want to run a 405B model +(with roughly 810GB weights) with 16 H100 80GB GPUs. Ideally, every GPU should +only load 50GB weights. If we change the model weights after the model is +initialized, we need to load the full 810GB weights to every GPU and then shard +the weights, leading to a huge memory overhead. Instead, if we shard the weights +during the model initialization, every layer will only create a shard of the +weights it needs, leading to a much smaller memory overhead. The same idea +applies to quantization. Note that we also add an additional argument ``prefix`` +to the model's constructor so that the model can initialize itself differently +based on the prefix. This is useful for non-uniform quantization, where +different parts of the model are quantized differently. The ``prefix`` is +usually an empty string for the top-level model and a string like ``"vision"`` +or ``"language"`` for the sub-models. In general, it matches the name of the +module's state dict in the checkpoint file. + +One disadvantage of this design is that it is hard to write unit tests for +individual components in vLLM because every component needs to be initialized by +a complete config object. We solve this problem by providing a default +initialization function that creates a default config object with all fields set +to ``None``. If the component we want to test only cares about a few fields in +the config object, we can create a default config object and set the fields we +care about. This way, we can test the component in isolation. Note that many +tests in vLLM are end-to-end tests that test the whole system, so this is not a +big problem. + +In summary, the complete config object ``VllmConfig`` can be treated as an +engine-level global state that is shared among all vLLM classes. diff --git a/docs/source/design/class_hierarchy.rst b/docs/source/design/class_hierarchy.rst deleted file mode 100644 index 58a888b17ba53..0000000000000 --- a/docs/source/design/class_hierarchy.rst +++ /dev/null @@ -1,74 +0,0 @@ -.. _class_hierarchy: - -vLLM's Class Hierarchy -======================= - -This document describes the class hierarchy of vLLM. We will explain the relationships between the core classes, their responsibilities, and the design choices behind them to make vLLM more modular and extensible. - -1. **Entrypoints**: vLLM has two entrypoints: `command line usage `__ with ``vllm serve`` for launching an OpenAI-API compatible server, and `library-style usage `__ with the ``vllm.LLM`` class for running inference in a Python script. These are user-facing entrypoints that end-users interact with. Under the hood, both create an engine object to handle model inference. - -2. **Engine**: Each vLLM instance contains one engine object, orchestrating and serving as the control plane for model inference. Depending on the configuration, the engine can create multiple workers to handle the inference workload. - -3. **Worker**: A worker is a process that runs the model inference. vLLM follows the common practice of using one process to control one accelerator device, such as GPUs. For example, if we use tensor parallelism of size 2 and pipeline parallelism of size 2, we will have 4 workers in total. Workers are identified by their ``rank`` and ``local_rank``. ``rank`` is used for global orchestration, while ``local_rank`` is mainly used for assigning the accelerator device and accessing local resources such as the file system and shared memory. - -4. **Model Runner**: Every worker has one model runner object, responsible for loading and running the model. Much of the model execution logic resides here, such as preparing input tensors and capturing cudagraphs. - -5. **Model**: Every model runner object has one model object, which is the actual ``torch.nn.Module`` instance. See :ref:`huggingface_integration` for how various configurations affect the class we ultimately get. - -The following figure shows the class hierarchy of vLLM: - - .. figure:: ../assets/design/hierarchy.png - :alt: query - :width: 100% - :align: center - -There are several important design choices behind this class hierarchy: - -1. **Extensibility**: All classes in the hierarchy accept a configuration object containing all the necessary information. The `VllmConfig `__ class is the main configuration object that is passed around. The class hierarchy is quite deep, and every class needs to read the configuration it is interested in. By encapsulating all configurations in one object, we can easily pass the configuration object around and access the configuration we need. Suppose we want to add a new feature (this is often the case given how fast the field of LLM inference is evolving) that only touches the model runner. We will have to add a new configuration option in the `VllmConfig` class. Since we pass the whole config object around, we only need to add the configuration option to the `VllmConfig` class, and the model runner can access it directly. We don't need to change the constructor of the engine, worker, or model class to pass the new configuration option. - -2. **Uniformity**: The model runner needs a unified interface to create and initialize the model. vLLM supports more than 50 types of popular open-source models. Each model has its own initialization logic. If the constructor signature varies with models, the model runner does not know how to call the constructor accordingly, without complicated and error-prone inspection logic. By making the constructor of the model class uniform, the model runner can easily create and initialize the model without knowing the specific model type. This is also useful for composing models. Vision-language models often consist of a vision model and a language model. By making the constructor uniform, we can easily create a vision model and a language model and compose them into a vision-language model. - -.. note:: - - To support this change, all vLLM models' signatures have been updated to: - - .. code-block:: python - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - - To avoid accidentally passing incorrect arguments, the constructor is now keyword-only. This ensures that the constructor will raise an error if old configurations are passed. vLLM developers have already made this change for all models within vLLM. For out-of-tree registered models, developers need to update their models, for example by adding shim code to adapt the old constructor signature to the new one: - - .. code-block:: python - - class MyOldModel(nn.Module): - def __init__( - self, - config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - prefix: str = "", - ) -> None: - ... - - from vllm.config import VllmConfig - class MyNewModel(MyOldModel): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - super().__init__(config, cache_config, quant_config, lora_config, prefix) - - if __version__ >= "0.6.4": - MyModel = MyNewModel - else: - MyModel = MyOldModel - - This way, the model can work with both old and new versions of vLLM. - -3. **Sharding and Quantization at Initialization**: Certain features require changing the model weights. For example, tensor parallelism needs to shard the model weights, and quantization needs to quantize the model weights. There are two possible ways to implement this feature. One way is to change the model weights after the model is initialized. The other way is to change the model weights during the model initialization. vLLM chooses the latter. The first approach is not scalable to large models. Suppose we want to run a 405B model (with roughly 810GB weights) with 16 H100 80GB GPUs. Ideally, every GPU should only load 50GB weights. If we change the model weights after the model is initialized, we need to load the full 810GB weights to every GPU and then shard the weights, leading to a huge memory overhead. Instead, if we shard the weights during the model initialization, every layer will only create a shard of the weights it needs, leading to a much smaller memory overhead. The same idea applies to quantization. Note that we also add an additional argument ``prefix`` to the model's constructor so that the model can initialize itself differently based on the prefix. This is useful for non-uniform quantization, where different parts of the model are quantized differently. The ``prefix`` is usually an empty string for the top-level model and a string like ``"vision"`` or ``"language"`` for the sub-models. In general, it matches the name of the module's state dict in the checkpoint file. - -One disadvantage of this design is that it is hard to write unit tests for individual components in vLLM because every component needs to be initialized by a complete config object. We solve this problem by providing a default initialization function that creates a default config object with all fields set to ``None``. If the component we want to test only cares about a few fields in the config object, we can create a default config object and set the fields we care about. This way, we can test the component in isolation. Note that many tests in vLLM are end-to-end tests that test the whole system, so this is not a big problem. - -In summary, the complete config object ``VllmConfig`` can be treated as an engine-level global state that is shared among all vLLM classes. diff --git a/docs/source/design/plugin_system.rst b/docs/source/design/plugin_system.rst index bfca702b9267a..5a96cc8b3a464 100644 --- a/docs/source/design/plugin_system.rst +++ b/docs/source/design/plugin_system.rst @@ -8,7 +8,7 @@ The community frequently requests the ability to extend vLLM with custom feature How Plugins Work in vLLM ------------------------ -Plugins are user-registered code that vLLM executes. Given vLLM's architecture (see :ref:`class_hierarchy`), multiple processes may be involved, especially when using distributed inference with various parallelism techniques. To enable plugins successfully, every process created by vLLM needs to load the plugin. This is done by the `load_general_plugins `__ function in the ``vllm.plugins`` module. This function is called for every process created by vLLM before it starts any work. +Plugins are user-registered code that vLLM executes. Given vLLM's architecture (see :ref:`arch_overview`), multiple processes may be involved, especially when using distributed inference with various parallelism techniques. To enable plugins successfully, every process created by vLLM needs to load the plugin. This is done by the `load_general_plugins `__ function in the ``vllm.plugins`` module. This function is called for every process created by vLLM before it starts any work. How vLLM Discovers Plugins -------------------------- @@ -59,4 +59,4 @@ Guidelines for Writing Plugins Compatibility Guarantee ----------------------- -vLLM guarantees the interface of documented plugins, such as ``ModelRegistry.register_model``, will always be available for plugins to register models. However, it is the responsibility of plugin developers to ensure their plugins are compatible with the version of vLLM they are targeting. For example, ``"vllm_add_dummy_model.my_llava:MyLlava"`` should be compatible with the version of vLLM that the plugin targets. The interface for the model may change during vLLM's development. \ No newline at end of file +vLLM guarantees the interface of documented plugins, such as ``ModelRegistry.register_model``, will always be available for plugins to register models. However, it is the responsibility of plugin developers to ensure their plugins are compatible with the version of vLLM they are targeting. For example, ``"vllm_add_dummy_model.my_llava:MyLlava"`` should be compatible with the version of vLLM that the plugin targets. The interface for the model may change during vLLM's development. diff --git a/docs/source/index.rst b/docs/source/index.rst index b04acbbce4169..c2afd806c50f9 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -157,7 +157,7 @@ Documentation :maxdepth: 2 :caption: Design - design/class_hierarchy + design/arch_overview design/huggingface_integration design/plugin_system design/input_processing/model_inputs_index diff --git a/format.sh b/format.sh index a57882d2ac3f9..b3dcdc15bf948 100755 --- a/format.sh +++ b/format.sh @@ -299,6 +299,10 @@ echo 'vLLM shellcheck:' tools/shellcheck.sh echo 'vLLM shellcheck: Done' +echo 'excalidraw png check:' +tools/png-lint.sh +echo 'excalidraw png check: Done' + if ! git diff --quiet &>/dev/null; then echo echo "๐Ÿ”๐Ÿ”There are files changed by the format checker or by you that are not added and committed:" diff --git a/tools/png-lint.sh b/tools/png-lint.sh new file mode 100755 index 0000000000000..a80fe9837342f --- /dev/null +++ b/tools/png-lint.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +# Ensure that *.excalidraw.png files have the excalidraw metadata +# embedded in them. This ensures they can be loaded back into +# the tool and edited in the future. + +find . -iname '*.excalidraw.png' | while read -r file; do + if git check-ignore -q "$file"; then + continue + fi + if ! grep -q "excalidraw+json" "$file"; then + echo "$file was not exported from excalidraw with 'Embed Scene' enabled." + exit 1 + fi +done diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 92fa87c7fa45b..ee4b6addfd466 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -793,7 +793,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=str, default=[], help="The pattern(s) to ignore when loading the model." - "Default to 'original/**/*' to avoid repeated loading of llama's " + "Default to `original/**/*` to avoid repeated loading of llama's " "checkpoints.") parser.add_argument( '--preemption-mode',