diff --git a/.github/workflows/coverage-tests.yml b/.github/workflows/coverage-tests.yml
index bcac25393..f44dba5f9 100644
--- a/.github/workflows/coverage-tests.yml
+++ b/.github/workflows/coverage-tests.yml
@@ -38,14 +38,15 @@ jobs:
submodules: false
set-safe-directory: true
+ - name: Coverage Online Report
+ run: |
+ REPORT_ADDR=$(cat "/workspace/config/report_address")
+ echo "After 'Run Unit Tests and Check Coverage' is completed/stopped, you can access the test diff-coverage report at the http://${REPORT_ADDR}/${{github.sha}}/cov-report-${{ inputs.backend }}/diff-cover-report-${{ inputs.backend }}.html"
+
- name: Run Unit Tests and Check Coverage
run: |
# Execute the test coverage script with the appropriate backend and id (commit SHA)
tests/scripts/unit_tests/test_coverage.sh --backend ${{ inputs.backend }} --id ${{ github.sha }} --status online
- - name: Coverage Online Report
- run: |
- REPORT_ADDR=$(cat "/workspace/config/report_address")
- echo "You can access the test diff-coverage report at the http://${REPORT_ADDR}/${{github.sha}}/cov-report-${{ inputs.backend }}/diff-cover-report-${{ inputs.backend }}.html"
diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml
index 06eefb64d..38c0509b7 100644
--- a/.github/workflows/unit-tests.yml
+++ b/.github/workflows/unit-tests.yml
@@ -40,11 +40,12 @@ jobs:
lfs: false
submodules: false
set-safe-directory: true
-
- - name: Run Unit Test
- run: tests/scripts/unit_tests/test_subset.sh --backend ${{ inputs.backend }} --subset ${{ inputs.subset }} --id ${{ github.sha }}
-
+
- name: Unit Test Coverage Online Report
run: |
REPORT_ADDR=$(cat "/workspace/config/report_address")
- echo "You can access the test coverage report at http://${REPORT_ADDR}/${{github.sha}}/cov-report-${{ inputs.backend }}/index.html"
+ echo "After 'Run Unit Test' is completed/stopped, you can access the test coverage report at http://${REPORT_ADDR}/${{github.sha}}/cov-report-${{ inputs.backend }}/index.html"
+
+ - name: Run Unit Test
+ run: tests/scripts/unit_tests/test_subset.sh --backend ${{ inputs.backend }} --subset ${{ inputs.subset }} --id ${{ github.sha }}
+
diff --git a/.gitignore b/.gitignore
index 3b579c698..8e520bcbb 100644
--- a/.gitignore
+++ b/.gitignore
@@ -8,4 +8,6 @@ slurm*
logs
.vscode
log_file/*
-outputs
\ No newline at end of file
+outputs
+*.log
+*.out
diff --git a/README.md b/README.md
index a49ce14ab..809db47f4 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,5 @@
+[](https://flagopen.baai.ac.cn/)
+
## Latest News
- **[2024/11]** Released [v0.6.0](https://github.com/FlagOpen/FlagScale/tree/release/v0.6.0):
- Introduced general multi-dimensional heterogeneous parallelism and CPU-based communication between different chips.
@@ -12,7 +14,9 @@
[FlagScale](https://github.com/FlagOpen/FlagScale.git) is a comprehensive toolkit designed to support the entire lifecycle of large models, developed with the backing of the Beijing Academy of Artificial Intelligence (BAAI). It builds on the strengths of several prominent open-source projects, including [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) and [vllm](https://github.com/vllm-project/vllm), to provide a robust, end-to-end solution for managing and scaling large models.
-The primary objective of FlagScale is to enable seamless scalability across diverse hardware architectures while maximizing computational resource efficiency and enhancing model performance. By offering essential components for model development, training, and deployment, FlagScale aims to serve as an indispensable toolkit for optimizing both the speed and effectiveness of large model workflows.
+The primary objective of FlagScale is to enable seamless scalability across diverse hardware architectures while maximizing computational resource efficiency and enhancing model performance. By offering essential components for model development, training, and deployment, FlagScale seeks to establish itself as an indispensable toolkit for optimizing both the speed and effectiveness of large model workflows.
+
+FlagScale is also a part of [FlagAI-Open](https://flagopen.baai.ac.cn/), an open-source initiative by BAAI that aims to foster an open-source ecosystem for AI technologies. It serves as a platform where developers, researchers, and AI enthusiasts can collaborate on various AI projects, contribute to the development of cutting-edge AI solutions, and share their work with the global community.
## Quick Start
@@ -43,13 +47,15 @@ We recommend using the latest release of [NGC's PyTorch container](https://catal
cd vllm
pip install .
- cd megatron-energon
- pip install .
+ pip install -e ./megatron-energon
+ cp -r megatron-energon/src/megatron/energon megatron/megatron
```
### Run a Task
-FlagScale provides a unified runner for various tasks, including training and inference. Simply specify the configuration file to run the task with a single command. The runner will automatically load the configurations and execute the task. The following example demonstrates how to run a distributed training task.
+FlagScale provides a unified runner for various tasks, including training,inference and serve. Simply specify the configuration file to run the task with a single command. The runner will automatically load the configurations and execute the task. The following example demonstrates how to run a distributed training task.
+
+#### Train
1. Start the distributed training job:
```sh
@@ -62,6 +68,18 @@ FlagScale provides a unified runner for various tasks, including training and in
python run.py --config-path ./examples/aquila/conf --config-name config action=stop
```
+#### Serve
+
+1. Start the server:
+ ```sh
+ python run.py --config-path ./examples/qwen/conf --config-name config_qwen2.5_7b action=run
+ ```
+2. Stop the server:
+ ```sh
+ python run.py --config-path ./examples/qwen/conf --config-name config_qwen2.5_7b action=stop
+ ```
+For more details, please refer to [Quick Start](./flagscale/serve/README.md).
+
## License
This project is licensed under the [Apache License (Version 2.0)](https://github.com/FlagOpen/FlagScale/blob/main/LICENSE). This project also contains other third-party components under other open-source licenses. See the [LICENSE](https://github.com/FlagOpen/FlagScale/blob/main/LICENSE) file for more information.
diff --git a/examples/llama/conf/train/train_llama2_7b_tp_hetero.yaml b/examples/llama/conf/train/train_llama2_7b_tp_hetero.yaml
deleted file mode 100644
index 90995e2d9..000000000
--- a/examples/llama/conf/train/train_llama2_7b_tp_hetero.yaml
+++ /dev/null
@@ -1,67 +0,0 @@
-system:
- tensor_model_parallel_size: 4
- pipeline_model_parallel_size: 3
- disable_bias_linear: True
- use_flash_attn: True
- sequence_parallel: True
- use_distributed_optimizer: True
- hetero_mode: pp
- hetero_device_types: A100
- hetero_current_device_type: A100
- hetero_pipeline_stages: [3,16,8,8]
- process_meshes: [4,1,1,2,1,2]
- precision:
- bf16: True
- initial_loss_scale: 16384
- min_loss_scale: 1.0
- logging:
- log_interval: 1
- checkpoint:
- save_interval: 100
-
-model:
- use_mcore_models: True
- transformer_impl: transformer_engine
- num_layers: 32
- hidden_size: 4096
- ffn_hidden_size: 11008
- num_attention_heads: 32
- seq_length: 4096
- group_query_attention: False
- num_query_groups: 8
- max_position_embeddings: 4096
- norm_epsilon: 1e-5
- use_rotary_position_embeddings: True
- no_position_embedding: True
- swiglu: True
- normalization: RMSNorm
- untie_embeddings_and_output_weights: True
- init_method_std: 0.02
- attention_dropout: 0.0
- hidden_dropout: 0.0
- weight_decay: 0.1
- clip_grad: 1.0
- train_iters: 30
- eval_iters: 0
- eval_interval: 2000
- micro_batch_size: 1
- global_batch_size: 32
-
- optimizer:
- weight_decay: 1e-2
- adam_beta1: 0.9
- adam_beta2: 0.95
- lr_scheduler:
- lr: 0.00015
- min_lr: 1.0e-5
- lr_warmup_fraction: .01
- lr_decay_iters: 1
- lr_decay_style: cosine
-
-data:
- data_path: ${data_path:??}
- split: 1
- tokenizer:
- tokenizer_type: Llama2Tokenizer
- tokenizer_model: examples/llama/tokenizer.model
- vocab_size: 32000
diff --git a/examples/llama/conf/train/train_llama3_8b_hetero.yaml b/examples/llama/conf/train/train_llama3_8b_hetero.yaml
new file mode 100644
index 000000000..c73be22c8
--- /dev/null
+++ b/examples/llama/conf/train/train_llama3_8b_hetero.yaml
@@ -0,0 +1,100 @@
+system:
+ tensor_model_parallel_size: 4
+ pipeline_model_parallel_size: 2
+ disable_bias_linear: True
+ use_flash_attn: True
+ sequence_parallel: True
+ use_distributed_optimizer: True
+ precision:
+ bf16: True
+ attention_softmax_in_fp32: true
+ accumulate_allreduce_grads_in_fp32: true
+ logging:
+ log_interval: 1
+ tensorboard_log_interval: 1
+ wandb_project: "train-llama3-8B"
+ wandb_exp_name: "train-test-8B"
+ checkpoint:
+ load: outputs_llama3/checkpoint_mc
+ save_interval: 10
+ finetune: True
+ ckpt_format: "torch"
+
+model:
+ use_mcore_models: True
+ transformer_impl: transformer_engine
+ num_layers: 32
+ hidden_size: 4096
+ ffn_hidden_size: 14336
+ num_attention_heads: 32
+ seq_length: 4096
+ group_query_attention: True
+ num_query_groups: 8
+ max_position_embeddings: 8192
+ norm_epsilon: 1e-5
+ use_rotary_position_embeddings: True
+ no_position_embedding: True
+ swiglu: True
+ normalization: RMSNorm
+ rotary_interleaved_patch: False
+ position_embedding_type: rope
+ rotary_base: 500000
+ untie_embeddings_and_output_weights: True
+ init_method_std: 0.02
+ attention_dropout: 0.0
+ hidden_dropout: 0.0
+ clip_grad: 1.0
+ train_samples: 200000
+ eval_iters: 100
+ eval_interval: 1000
+ micro_batch_size: 1
+ global_batch_size: 16
+
+ hetero:
+ enable_hetero: True
+ hetero_use_cpu_communication: True
+ # mesh format [tp1,cp1,ep1,dp1,pp1,(tp2,cp2...)]
+
+ # 2 mesh, diff tp dp pp
+ hetero_pipeline_layer_split: [18, 14]
+ hetero_process_meshes: [2, 1, 1, 4, 1, 4, 1, 1, 2, 1]
+ hetero_device_types: ["A800", "A100"]
+
+ standalone_embedding_stage: False
+ hetero_current_device_type: "A800"
+
+ # recompute:
+ # recompute_granularity: "full"
+ # recompute_method: "uniform"
+ # recompute_num_layers: 1
+
+ # ## pp 2 stages and num_micro_batches 4
+ # recompute_granularity_per_stage_micro_batch:
+ # - [1, 3, 0, 1, 0]
+ # - [1, 3, 1, 1, 1]
+ # recompute_method_per_stage_micro_batch:
+ # - [1, 3, 0, 1, 0]
+ # - [1, 3, 0, 1, 0]
+ # recompute_num_layers_per_stage_micro_batch:
+ # - [1, 3, 2, 1, 2]
+ # - [1, 3, 1, 1, 1]
+
+
+ optimizer:
+ weight_decay: 1e-2
+ adam_beta1: 0.9
+ adam_beta2: 0.95
+ lr_scheduler:
+ lr: 1.0e-5
+ min_lr: 1.0e-6
+ lr_warmup_fraction: .1
+ lr_decay_style: cosine
+
+data:
+ data_path: examples/llama/pile-openwebtext_text_document/pile-openwebtext_text_document
+ split: 1
+ tokenizer:
+ tokenizer_type: Llama3TokenizerFS
+ tokenizer_path: meta-llama3/Meta-Llama-3-8B
+ vocab_size: 128256
+ make_vocab_size_divisible_by: 64
diff --git a/examples/llava_onevision/conf/train/train_llava_onevision_1.5b.yaml b/examples/llava_onevision/conf/train/train_llava_onevision_1.5b.yaml
index 142e59a8f..c9b44e0bd 100644
--- a/examples/llava_onevision/conf/train/train_llava_onevision_1.5b.yaml
+++ b/examples/llava_onevision/conf/train/train_llava_onevision_1.5b.yaml
@@ -44,6 +44,11 @@ model:
hidden_dropout: 0.0
clip_grad: 1.0
train_iters: 10
+ profile: False
+ profile-step-start: 10
+ profile-step-end: 20
+ profile_ranks: 7
+ use_pytorch_profiler: True
eval_iters: 0
micro_batch_size: 2
global_batch_size: 512
diff --git a/vllm/vllm/v1/tokenizer/__init__.py b/examples/qwen/__init__.py
similarity index 100%
rename from vllm/vllm/v1/tokenizer/__init__.py
rename to examples/qwen/__init__.py
diff --git a/examples/qwen/conf/config_qwen2.5_1.5b.yaml b/examples/qwen/conf/config_qwen2.5_1.5b.yaml
new file mode 100644
index 000000000..c3ab52d9f
--- /dev/null
+++ b/examples/qwen/conf/config_qwen2.5_1.5b.yaml
@@ -0,0 +1,30 @@
+defaults:
+ - _self_
+ - train: train_qwen_2.5_1.5b
+ # - train: train_mixtral_1.8b
+
+experiment:
+ exp_name: train_qwen_2.5_1.5b
+ exp_dir: ./outputs # outputs ## log、checkpoints output path
+ task:
+ type: train
+ backend: megatron
+ entrypoint: ./flagscale/train/train_aquila.py
+ runner:
+ backend: torchrun
+ nnodes: 2
+ nproc_per_node: 8
+ hostfile: torchrun # Please replace with your actual hostfile path
+ envs:
+ CUDA_VISIBLE_DEVICES: "0,1,2,3,4,5,6,7"
+ CUDA_DEVICE_MAX_CONNECTIONS: 1
+ NCCL_SOCKET_IFNAME: eth0
+ NCCL_IB_DISABLE: 0
+ NCCL_IB_CUDA_SUPPORT: 1
+ NCCL_IB_GID_INDEX: 0
+ NCCL_DEBUG: INFO
+ OMP_NUM_THREADS: 4
+ GLOO_SOCKET_IFNAME: eth0
+ NCCL_IB_HCA: mlx5_2,mlx5_5
+
+action: run
diff --git a/examples/qwen/conf/config_qwen2.5_72b_tp.yaml b/examples/qwen/conf/config_qwen2.5_72b_tp.yaml
new file mode 100644
index 000000000..a68a1e752
--- /dev/null
+++ b/examples/qwen/conf/config_qwen2.5_72b_tp.yaml
@@ -0,0 +1,22 @@
+defaults:
+ - _self_
+ - serve: serve_qwen2.5_72b
+
+experiment:
+ exp_name: qwen2.5_72b
+ exp_dir: outputs/${experiment.exp_name}
+ task:
+ type: serve
+ backend: vllm
+ entrypoint: null
+ runner:
+ hostfile: null
+ envs:
+ CUDA_VISIBLE_DEVICES: 0,1,2,3
+ CUDA_DEVICE_MAX_CONNECTIONS: 1
+
+action: run
+
+hydra:
+ run:
+ dir: ${experiment.exp_dir}/hydra
diff --git a/examples/qwen/conf/config_qwen2.5_7b.yaml b/examples/qwen/conf/config_qwen2.5_7b.yaml
new file mode 100644
index 000000000..48609ecd3
--- /dev/null
+++ b/examples/qwen/conf/config_qwen2.5_7b.yaml
@@ -0,0 +1,22 @@
+defaults:
+ - _self_
+ - serve: serve_qwen2.5_7b
+
+experiment:
+ exp_name: qwen2.5_7b
+ exp_dir: outputs/${experiment.exp_name}
+ task:
+ type: serve
+ backend: vllm
+ entrypoint: null
+ runner:
+ hostfile: null
+ envs:
+ CUDA_VISIBLE_DEVICES: 0
+ CUDA_DEVICE_MAX_CONNECTIONS: 1
+
+action: run
+
+hydra:
+ run:
+ dir: ${experiment.exp_dir}/hydra
diff --git a/examples/qwen/conf/config_ssh_qwen2.5_7b.yaml b/examples/qwen/conf/config_ssh_qwen2.5_7b.yaml
new file mode 100644
index 000000000..3a94895e4
--- /dev/null
+++ b/examples/qwen/conf/config_ssh_qwen2.5_7b.yaml
@@ -0,0 +1,25 @@
+defaults:
+ - _self_
+ - serve: serve_qwen2.5_7b
+
+experiment:
+ exp_name: qwen2.5_7b
+ exp_dir: outputs/${experiment.exp_name}
+ task:
+ type: serve
+ backend: vllm
+ entrypoint: null
+ runner:
+ hostfile: /path/to/hostfile # type: {remote ip} slots={gpu num} type={gpu type} (like: x.x.x.x slots=8 type=A100)
+ ssh_port: 22 # replace with your ssh port
+ envs:
+ CUDA_VISIBLE_DEVICES: 0
+ CUDA_DEVICE_MAX_CONNECTIONS: 1
+ cmds:
+ before_start: source /root/miniconda3/bin/activate flagscale
+
+action: run
+
+hydra:
+ run:
+ dir: ${experiment.exp_dir}/hydra
diff --git a/examples/qwen/conf/serve/serve_qwen2.5_72b.yaml b/examples/qwen/conf/serve/serve_qwen2.5_72b.yaml
new file mode 100644
index 000000000..86d05e037
--- /dev/null
+++ b/examples/qwen/conf/serve/serve_qwen2.5_72b.yaml
@@ -0,0 +1,17 @@
+model_args:
+ vllm_model:
+ model-tag: /models/Qwen2.5-72B-Instruct
+ tensor-parallel-size: 4
+ gpu-memory-utilization: 0.9
+ max-model-len: 32768
+ max-num-seqs: 256
+ port: 4567
+ action-args:
+ - trust-remote-code
+ - enable-chunked-prefill
+
+deploy:
+ command-line-mode: true
+ models:
+ vllm_model:
+ num_gpus: 4
diff --git a/examples/qwen/conf/serve/serve_qwen2.5_7b.yaml b/examples/qwen/conf/serve/serve_qwen2.5_7b.yaml
new file mode 100644
index 000000000..0faa78b42
--- /dev/null
+++ b/examples/qwen/conf/serve/serve_qwen2.5_7b.yaml
@@ -0,0 +1,17 @@
+model_args:
+ vllm_model:
+ model-tag: /models/Qwen2.5-7B-Instruct
+ tensor-parallel-size: 1
+ gpu-memory-utilization: 0.9
+ max-model-len: 32768
+ max-num-seqs: 256
+ port: 4567
+ action-args:
+ - trust-remote-code
+ - enable-chunked-prefill
+
+deploy:
+ command-line-mode: true
+ models:
+ vllm_model:
+ num_gpus: 1
diff --git a/examples/qwen/conf/train/qwen_2.5_1.5b.yaml b/examples/qwen/conf/train/qwen_2.5_1.5b.yaml
new file mode 100644
index 000000000..2ea5368fb
--- /dev/null
+++ b/examples/qwen/conf/train/qwen_2.5_1.5b.yaml
@@ -0,0 +1,80 @@
+system:
+ tensor_model_parallel_size: 1
+ pipeline_model_parallel_size: 1
+ make_vocab_size_divisible_by: 128
+ disable_bias_linear: True
+ sequence_parallel: True
+ use_flash_attn: True
+ use_distributed_optimizer: True
+ distributed-timeout-minutes: 60
+ precision:
+ bf16: True
+ attention_softmax_in_fp32: True
+ accumulate_allreduce_grads_in_fp32: True
+ logging:
+ log_interval: 1
+ tensorboard_log_interval: 1
+ wandb_project: "train-qwen2.5-1.5B"
+ wandb_exp_name: "train-qwen2.5-1.5B"
+ checkpoint:
+ load: ${megatron_model__path:}
+ # If you want to train the model, you need to comment out ckpt_format, ckpt_convert_format, ckpt_convert_save, which are used for converting ckpt.
+ ckpt_format: torch_dist # ${experiment.ckpt_format}
+ ckpt_convert_format: torch # ${experiment.ckpt_convert_format}
+ ckpt_convert_save: ${experiment.ckpt_convert_save}
+ save_interval: 5000000
+ rampup_save_interval: 50000
+
+model:
+ use_mcore_models: true
+ num_layers: 28
+ hidden_size: 1536
+ num_attention_heads: 12
+ num_query_groups: 2
+ group_query_attention: True
+ ffn_hidden_size: 8960
+ seq_length: 4096
+ max_position_embeddings: 4096
+ norm_epsilon: 1e-6
+ norm_init_weight: 0.02
+ use_rotary_position_embeddings: true
+ rotary_base: 1000000.0
+ no_position_embedding: true
+ reset_position_ids: true
+ add_qkv_bias: true
+ reset_attention_mask: true
+ swiglu: true
+ normalization: RMSNorm
+ untie_embeddings_and_output_weights: false
+ init_method_std: 0.02
+ attention_dropout: 0.0
+ hidden_dropout: 0.0
+ weight_decay: 0.0
+ clip_grad: 1.0
+ train_samples: 1478125
+ eval_iters: 0
+ eval_interval: 2000000
+ micro_batch_size: 1
+ global_batch_size: 512
+ finetune: true
+ transformer_impl: transformer_engine
+ seed: 42
+ #data_searching_range: [1156,1274]
+ optimizer:
+ weight_decay: 0.0
+ adam_beta1: 0.9
+ adam_beta2: 0.95
+ lr_scheduler:
+ lr: 1e-5
+ min_lr: 0
+ lr_warmup_samples: 21120
+ lr_decay_style: cosine
+
+data:
+ data_path: ${data_path:??}
+ split: 1
+ apply_sft_dataset_separated_loss_mask_if_existed: true
+ tokenizer:
+ tokenizer_type: HFTokenizerFS
+ tokenizer_path: ${HF_model_path:??}
+ vocab_size: 151665
diff --git a/examples/qwen/utils/__init__.py b/examples/qwen/utils/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/examples/qwen/utils/convo_dataset.py b/examples/qwen/utils/convo_dataset.py
new file mode 100644
index 000000000..a9234b7d2
--- /dev/null
+++ b/examples/qwen/utils/convo_dataset.py
@@ -0,0 +1,237 @@
+"""GPT style dataset."""
+
+import copy
+import hashlib
+import os
+import time
+
+import numpy as np
+import torch
+
+from megatron import print_rank_0
+from megatron.core import mpu
+from megatron.data.data_samplers import RandomSeedDataset
+
+class ConversationDatasetCPT(torch.utils.data.Dataset):
+ def __init__(self, conversations, tokenizer, maxlen, seed, num_samples, role_sep="\n\n"):
+ super(ConversationDatasetCPT, self).__init__()
+ self.conversations = conversations
+ self.tokenizer = tokenizer
+ self.maxlen = maxlen+1
+ self.seed = seed
+ self.num_samples = num_samples
+
+ ## TODO convo template
+ self.sep = role_sep
+
+ # rng state
+ np_rng = np.random.RandomState(seed=seed)
+ np_rng.shuffle(self.conversations)
+
+ def __getitem__(self, i):
+ source = self.conversations[i]
+
+ instruction = source['instruction']
+ conversations = source['conversations']
+
+ BOS_TOKEN = self.tokenizer.cls
+ EOS_TOKEN = self.tokenizer.eod
+ example = [BOS_TOKEN]
+
+ # instruction
+ instruction = self.tokenizer.tokenize(f"{instruction}")
+ example += instruction
+
+ labels = [-100] * len(example)
+
+ for conversation in conversations:
+ role = conversation['from']
+ content = conversation['value']
+ content += self.sep
+
+ content = self.tokenizer.tokenize(f"{content}")
+
+ example += content
+ if role == 'gpt':
+ role_labels = copy.deepcopy(content)
+ else:
+ # masking
+ role_labels = [-100] * len(content)
+ labels += role_labels
+
+ example.append(EOS_TOKEN)
+ labels.append(EOS_TOKEN)
+
+ # maxlen
+ example = example[:self.maxlen]
+ labels = labels[:self.maxlen]
+
+ # padding
+ delta = self.maxlen - len(example)
+ if delta > 0:
+ example.extend([self.tokenizer.pad]*delta)
+ labels.extend([-100]*delta)
+
+ output = {
+ "tokens": np.array(example, dtype=np.int64),
+ "labels": np.array(labels, dtype=np.int64),
+ }
+ return output
+
+ def __len__(self):
+ return len(self.conversations)
+
+
+class ConversationDatasetV2(torch.utils.data.Dataset):
+ def __init__(self, conversations, tokenizer, maxlen, seed, num_samples):
+ super(ConversationDatasetV2, self).__init__()
+ self.conversations = conversations
+ self.tokenizer = tokenizer
+ self.maxlen = maxlen+1
+ self.seed = seed
+ self.num_samples = num_samples
+
+ # rng state
+ np_rng = np.random.RandomState(seed=seed)
+ np_rng.shuffle(self.conversations)
+
+
+ def __getitem__(self, i):
+ from examples.aquila.utils.convo_prompt import _add_speaker_and_signal
+ from examples.aquila.utils.convo_prompt import header
+
+ #source = self.conversations[self.sample_idx[i]]
+ source = self.conversations[i]
+ _add_speaker_and_signal(source)
+
+ source["chat_desc"] = header
+ chat_desc = source['chat_desc']
+ instruction = source['instruction']
+ conversations = source['conversations']
+
+ BOS_TOKEN = self.tokenizer.cls
+ EOS_TOKEN = self.tokenizer.eod
+ example = [BOS_TOKEN]
+
+ # chat_desc
+ example += self.tokenizer.tokenize(f"{chat_desc}")
+
+ # instruction
+ instruction = self.tokenizer.tokenize(f"{instruction}")
+ example += instruction
+
+ labels = copy.deepcopy(example)
+ # add zero-out
+ #labels = [-100] * len(example)
+
+ for conversation in conversations:
+ role = conversation['from']
+ content = conversation['value']
+ content = self.tokenizer.tokenize(f"{content}")
+ example += content
+ if role == 'gpt':
+ role_labels = copy.deepcopy(content)
+ else:
+ # masking
+ role_labels = [-100] * len(content)
+ labels += role_labels
+
+ example.append(EOS_TOKEN)
+ labels.append(EOS_TOKEN)
+
+ # maxlen
+ example = example[:self.maxlen]
+ labels = labels[:self.maxlen]
+
+ # padding
+ delta = self.maxlen - len(example)
+ if delta > 0:
+ example.extend([self.tokenizer.pad]*delta)
+ labels.extend([-100]*delta)
+
+ output = {
+ "tokens": np.array(example, dtype=np.int64),
+ "labels": np.array(labels, dtype=np.int64),
+ }
+ return output
+
+ def __len__(self):
+ #return len(self.sample_idx)
+ return len(self.conversations)
+
+
+def build_train_valid_test_datasets(train_valid_test_num_samples,
+ seq_length, seed, tokenizer,
+ train_data_prefix,
+ valid_data_prefix,
+ test_data_prefix=None,
+ finetune_dataset_type=None):
+ """Build train, valid, and test datasets."""
+ suppored_dataset_types = dict(CPT=ConversationDatasetCPT)
+ dataset_cls = ConversationDatasetV2
+ if finetune_dataset_type in suppored_dataset_types:
+ dataset_cls = suppored_dataset_types[finetune_dataset_type]
+
+ def read_file(jsonl_file):
+ import jsonlines
+ conversations = []
+ with jsonlines.open(jsonl_file) as reader:
+ for line in reader:
+ conversations.append(line)
+ return conversations
+
+ train_dataset, valid_dataset, test_dataset = None, None, None
+ # Single dataset.
+ if train_data_prefix is not None:
+ train_conversations = read_file(train_data_prefix[0])
+ train_dataset = dataset_cls(
+ train_conversations,
+ tokenizer=tokenizer,
+ maxlen=seq_length,
+ seed=seed,
+ num_samples=train_valid_test_num_samples[0])
+ train_dataset = RandomSeedDataset(train_dataset)
+
+ if valid_data_prefix is not None:
+ valid_conversations = read_file(valid_data_prefix[0])
+ valid_dataset = dataset_cls(
+ valid_conversations,
+ tokenizer=tokenizer,
+ maxlen=seq_length,
+ seed=seed,
+ num_samples=train_valid_test_num_samples[1])
+ valid_dataset = RandomSeedDataset(valid_dataset)
+
+ if test_data_prefix is not None:
+ test_conversations = read_file(test_data_prefix[0])
+ test_dataset = dataset_cls(
+ test_conversations,
+ tokenizer=tokenizer,
+ maxlen=seq_length,
+ seed=seed,
+ num_samples=train_valid_test_num_samples[2])
+ test_dataset = RandomSeedDataset(test_dataset)
+
+ return (train_dataset, valid_dataset, test_dataset)
+
+if __name__ == "__main__":
+ train_valid_test_num_samples = [12000,2000,0]
+ seq_length = 2048
+ seed = 1234
+ from megatron.tokenizer.tokenizer import _AquilaTokenizer
+ tokenizer = _AquilaTokenizer(
+ '../examples/aquila/tokenizer/vocab.json',
+ '../examples/aquila/tokenizer/merges.txt')
+ print(f"{dir(tokenizer)}")
+ train_data_prefix = ['path/to/train/set']
+ valid_data_prefix = ['path/to/valid/set']
+ train_dataset, valid_dataset, test_dataset = build_train_valid_test_datasets(
+ train_valid_test_num_samples,
+ seq_length, seed, tokenizer,
+ train_data_prefix,
+ valid_data_prefix,
+ test_data_prefix=None)
+ for idx, sample in enumerate(train_dataset):
+ print(f"idx={idx} sample={type(sample['labels'])}")
+ break
+
diff --git a/examples/qwen/utils/convo_prompt.py b/examples/qwen/utils/convo_prompt.py
new file mode 100644
index 000000000..e0e947282
--- /dev/null
+++ b/examples/qwen/utils/convo_prompt.py
@@ -0,0 +1,166 @@
+import dataclasses
+from enum import auto, Enum
+from typing import List, Tuple, Any
+
+
+class SeparatorStyle(Enum):
+ """Different separator style."""
+ SINGLE = auto()
+ TWO = auto()
+
+
+@dataclasses.dataclass
+class Conversation:
+ """A class that keeps all conversation history."""
+ system: str
+ instruction: str
+ roles: List[str]
+ messages: List[List[str]]
+ offset: int
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
+ sep: str = "###"
+ sep2: str = None
+
+ skip_next: bool = False
+ conv_id: Any = None
+
+ def get_prompt(self):
+ if self.sep_style == SeparatorStyle.SINGLE:
+ ret = self.system + self.sep
+ if self.instruction is not None and len(self.instruction) > 0:
+ ret += self.roles[2] + ": " + self.instruction + self.sep
+ for role, message in self.messages:
+ if message:
+ ret += role + ": " + message + self.sep
+ else:
+ ret += role + ":"
+ return ret
+ elif self.sep_style == SeparatorStyle.TWO:
+ seps = [self.sep, self.sep2]
+ ret = self.system + seps[0]
+ if self.instruction is not None and len(self.instruction) > 0:
+ ret += self.roles[2] + ": " + self.instruction + self.sep
+ for i, (role, message) in enumerate(self.messages):
+ if message:
+ ret += role + ": " + message + seps[i % 2]
+ else:
+ ret += role + ":"
+ return ret
+ else:
+ raise ValueError(f"Invalid style: {self.sep_style}")
+
+ def append_message(self, role, message):
+ self.messages.append([role, message])
+
+ def to_gradio_chatbot(self):
+ ret = []
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
+ if i % 2 == 0:
+ ret.append([msg, None])
+ else:
+ ret[-1][-1] = msg
+ return ret
+
+ def copy(self):
+ return Conversation(
+ system=self.system,
+ instruction=self.instruction,
+ roles=self.roles,
+ messages=[[x, y] for x, y in self.messages],
+ offset=self.offset,
+ sep_style=self.sep_style,
+ sep=self.sep,
+ sep2=self.sep2,
+ conv_id=self.conv_id)
+
+ def dict(self):
+ return {
+ "system": self.system,
+ "instruction": self.instruction,
+ "roles": self.roles,
+ "messages": self.messages,
+ "offset": self.offset,
+ "sep": self.sep,
+ "sep2": self.sep2,
+ "conv_id": self.conv_id,
+ }
+
+
+conv_v1 = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ instruction="",
+ roles=("Human", "Assistant", "System"),
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+conv_v1_2 = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ instruction="",
+ roles=("Human", "Assistant", "System"),
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+conv_bair_v1 = Conversation(
+ system="BEGINNING OF CONVERSATION:",
+ instruction="",
+ roles=("USER", "GPT", "System"),
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="",
+)
+
+
+default_conversation = conv_v1_2
+conv_templates = {
+ "v1": conv_v1_2,
+ "bair_v1": conv_bair_v1,
+}
+
+# utils
+"""Add speaker and start/end signal on each round."""
+BEGIN_SIGNAL = "### "
+END_SIGNAL = "\n"
+header = f"{default_conversation.system}\n\n"
+unknown_role = "unknown" # use default unknown role
+roles = {
+ "human": default_conversation.roles[0], # human role
+ "gpt": default_conversation.roles[1], # gpt role
+}
+
+def _add_speaker_and_signal(source, get_conversation=True):
+ conversation = header
+ if "instruction" in source and source["instruction"] is not None and len(source["instruction"]) > 0:
+ source["instruction"] = (
+ BEGIN_SIGNAL
+ + conversation_lib.default_conversation.roles[2]
+ + ": "
+ + source["instruction"]
+ + END_SIGNAL
+ )
+ if get_conversation:
+ conversation += source["instruction"]
+ for sentence in source["conversations"]:
+ sentence_from = sentence["from"].lower()
+ sentence["value"] = (
+ BEGIN_SIGNAL
+ + roles.get(sentence_from, unknown_role)
+ + ": "
+ + sentence["value"]
+ + END_SIGNAL
+ )
+ if get_conversation:
+ conversation += sentence["value"]
+ return conversation
+
+if __name__ == "__main__":
+ print(default_conversation.get_prompt())
diff --git a/examples/qwen/utils/cyg_conversation.py b/examples/qwen/utils/cyg_conversation.py
new file mode 100644
index 000000000..6bb9c3a0f
--- /dev/null
+++ b/examples/qwen/utils/cyg_conversation.py
@@ -0,0 +1,157 @@
+import dataclasses
+from enum import auto, Enum
+from typing import List, Tuple, Any
+
+
+class SeparatorStyle(Enum):
+ """Different separator style."""
+ SINGLE = auto()
+ TWO = auto()
+
+
+@dataclasses.dataclass
+class Conversation:
+ """A class that keeps all conversation history."""
+ system: str
+ instruction: str
+ roles: List[str]
+ messages: List[List[str]]
+ offset: int
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
+ sep: str = "###"
+ sep2: str = None
+
+ skip_next: bool = False
+ conv_id: Any = None
+
+ def get_prompt(self):
+ if self.sep_style == SeparatorStyle.SINGLE:
+ ret = self.system + self.sep
+ if self.instruction is not None and len(self.instruction) > 0:
+ ret += self.roles[2] + ": " + self.instruction + self.sep
+ for role, message in self.messages:
+ if message:
+ ret += role + ": " + message + self.sep
+ else:
+ ret += role + ":"
+ return ret
+ elif self.sep_style == SeparatorStyle.TWO:
+ seps = [self.sep, self.sep2]
+ ret = self.system + seps[0]
+ if self.instruction is not None and len(self.instruction) > 0:
+ ret += self.roles[2] + ": " + self.instruction + self.sep
+ for i, (role, message) in enumerate(self.messages):
+ if message:
+ ret += role + ": " + message + seps[i % 2]
+ else:
+ ret += role + ":"
+ return ret
+ else:
+ raise ValueError(f"Invalid style: {self.sep_style}")
+
+ def append_message(self, role, message):
+ self.messages.append([role, message])
+
+ def to_gradio_chatbot(self):
+ ret = []
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
+ if i % 2 == 0:
+ ret.append([msg, None])
+ else:
+ ret[-1][-1] = msg
+ return ret
+
+ def copy(self):
+ return Conversation(
+ system=self.system,
+ instruction=self.instruction,
+ roles=self.roles,
+ messages=[[x, y] for x, y in self.messages],
+ offset=self.offset,
+ sep_style=self.sep_style,
+ sep=self.sep,
+ sep2=self.sep2,
+ conv_id=self.conv_id)
+
+ def dict(self):
+ return {
+ "system": self.system,
+ "instruction": self.instruction,
+ "roles": self.roles,
+ "messages": self.messages,
+ "offset": self.offset,
+ "sep": self.sep,
+ "sep2": self.sep2,
+ "conv_id": self.conv_id,
+ }
+
+
+conv_v1 = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ instruction="",
+ roles=("Human", "Assistant", "System"),
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+conv_v1_2 = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ instruction="",
+ roles=("Human", "Assistant", "System"),
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+conv_bair_v1 = Conversation(
+ system="BEGINNING OF CONVERSATION:",
+ instruction="",
+ roles=("USER", "GPT", "System"),
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="",
+)
+
+
+default_conversation = conv_v1_2
+conv_templates = {
+ "v1": conv_v1_2,
+ "bair_v1": conv_bair_v1,
+}
+
+
+def covert_prompt_to_input_ids_with_history(text, history, tokenizer, max_token):
+ conv = default_conversation.copy()
+
+ conv.append_message(conv.roles[1], None)
+ conv.append_message(conv.roles[0], text)
+
+ example = tokenizer.encode_plus(f"{conv.get_prompt()} ", None, max_length=None)['input_ids']
+
+ while(len(history) > 0 and (len(example) < max_token)):
+ tmp = history.pop()
+ if tmp[0] == 'ASSISTANT':
+ conv.append_message(conv.roles[1], tmp[1])
+ else:
+ conv.append_message(conv.roles[0], tmp[1])
+ example = tokenizer.encode_plus(f"{conv.get_prompt()} ", None, max_length=None)['input_ids']
+
+ if len(example) >= max_token:
+ conv.messages.pop()
+ conv.messages = conv.messages[::-1]
+ print('model in:', conv.get_prompt())
+ example = tokenizer.encode_plus(f"{conv.get_prompt()} ", None, max_length=None)['input_ids']
+ #example = example[1:-1]
+
+ return example
+
+if __name__ == "__main__":
+ print(default_conversation.get_prompt())
+
diff --git a/flagopen.png b/flagopen.png
new file mode 100644
index 000000000..82ffee806
Binary files /dev/null and b/flagopen.png differ
diff --git a/flagscale/inference/core/block_manager.py b/flagscale/inference/core/block_manager.py
index 34a0d3d52..93770b8dc 100644
--- a/flagscale/inference/core/block_manager.py
+++ b/flagscale/inference/core/block_manager.py
@@ -16,9 +16,7 @@
SeqId = int
EncoderSeqId = str
-# --- FLAGSCALE MODIFICATION BEG ---
-NegativeSeqId = int
-# --- FLAGSCALE MODIFICATION END ---
+NegativeSeqId = int # --- FLAGSCALE MODIFICATION ---
class SelfAttnBlockSpaceManager(BlockSpaceManager):
@@ -103,12 +101,10 @@ def __init__(
self.block_tables: Dict[SeqId, BlockTable] = {}
self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {}
- # --- FLAGSCALE MODIFICATION BEG ---
- self.negative_block_tables: Dict[NegativeSeqId, BlockTable] = {}
- # --- FLAGSCALE MODIFICATION END ---
+ self.negative_block_tables: Dict[NegativeSeqId, BlockTable] = {} # --- FLAGSCALE MODIFICATION ---
self._computed_blocks_tracker = ComputedBlocksTracker(
- self.block_allocator)
+ self.block_allocator, self.block_size, self.enable_caching)
self._last_access_blocks_tracker = LastAccessBlocksTracker(
self.block_allocator)
@@ -187,7 +183,6 @@ def allocate(self, seq_group: SequenceGroup) -> None:
self.block_tables[seq.seq_id] = block_table
# Track seq
- self._computed_blocks_tracker.add_seq(seq.seq_id)
self._last_access_blocks_tracker.add_seq(seq.seq_id)
# Assign the block table for each sequence.
@@ -195,7 +190,6 @@ def allocate(self, seq_group: SequenceGroup) -> None:
self.block_tables[seq.seq_id] = block_table.fork()
# Track seq
- self._computed_blocks_tracker.add_seq(seq.seq_id)
self._last_access_blocks_tracker.add_seq(seq.seq_id)
# Allocate cross-attention block table for encoder sequence
@@ -380,11 +374,13 @@ def get_common_computed_block_ids(
"""
computed_seq_block_ids = []
for seq in seqs:
- computed_seq_block_ids.append(
- self._computed_blocks_tracker.
- get_cached_computed_blocks_and_update(
- seq.seq_id,
- self.block_tables[seq.seq_id].physical_block_ids))
+ all_blocks = self.block_tables[seq.seq_id].physical_block_ids
+ num_cached_tokens = (
+ self._computed_blocks_tracker.get_num_cached_tokens(seq))
+ assert num_cached_tokens % self.block_size == 0
+ num_cached_blocks = num_cached_tokens // self.block_size
+ computed_block_ids = all_blocks[:num_cached_blocks]
+ computed_seq_block_ids.append(computed_block_ids)
# NOTE(sang): This assumes seq_block_ids doesn't contain any None.
return self.block_allocator.get_common_computed_block_ids(
@@ -398,7 +394,6 @@ def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
self.block_tables[child_seq.seq_id] = src_block_table.fork()
# Track child seq
- self._computed_blocks_tracker.add_seq(child_seq.seq_id)
self._last_access_blocks_tracker.add_seq(child_seq.seq_id)
def can_swap_in(self, seq_group: SequenceGroup,
@@ -459,7 +454,7 @@ def can_swap_out(self, seq_group: SequenceGroup) -> bool:
with num_lookahead_slots.
Args:
- seq_group (SequenceGroup): The sequence group to swap in.
+ seq_group (SequenceGroup): The sequence group to swap out.
num_lookahead_slots (int): Number of lookahead slots used in
speculative decoding, default to 0.
@@ -475,7 +470,7 @@ def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
swapping out the given sequence_group with num_lookahead_slots.
Args:
- sequence_group (SequenceGroup): The sequence group to swap in.
+ sequence_group (SequenceGroup): The sequence group to swap out.
Returns:
List[Tuple[int, int]]: The mapping of swapping block from
@@ -525,7 +520,7 @@ def _can_swap(self,
on to the 'device'.
Args:
- sequence_group (SequenceGroup): The sequence group to swap in.
+ sequence_group (SequenceGroup): The sequence group to swap in/out.
device (Device): device to swap the 'seq_group' on.
status (SequenceStatus): The status of sequence which is needed
for action. RUNNING for swap out and SWAPPED for swap in
@@ -569,3 +564,9 @@ def _can_swap(self,
return AllocStatus.OK
else:
return AllocStatus.LATER
+
+ def get_num_cached_tokens(self, seq: Sequence) -> int:
+ """Get the number of tokens in blocks that are already computed and
+ cached in the block manager for the sequence.
+ """
+ return self._computed_blocks_tracker.get_num_cached_tokens(seq)
diff --git a/flagscale/inference/core/data.py b/flagscale/inference/core/data.py
index 55b7240c0..5321d6a6e 100644
--- a/flagscale/inference/core/data.py
+++ b/flagscale/inference/core/data.py
@@ -1,11 +1,16 @@
# This file is modified from 'FlagScale/vllm/vllm/inputs/data.py'
-from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
+from dataclasses import dataclass
+from functools import cached_property
+from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, Literal,
Optional, Tuple, Union, cast)
-from typing_extensions import NotRequired, TypedDict, TypeVar
+import torch
+from typing_extensions import NotRequired, TypedDict, TypeVar, assert_never
if TYPE_CHECKING:
- from vllm.multimodal import MultiModalDataDict
+ from vllm.multimodal import (MultiModalDataDict, MultiModalKwargs,
+ MultiModalPlaceholderDict)
+ from vllm.multimodal.inputs import MultiModalInputsV2
class TextPrompt(TypedDict):
@@ -40,15 +45,18 @@ class TokensPrompt(TypedDict):
prompt_token_ids: List[int]
"""A list of token IDs to pass to the model."""
+ token_type_ids: NotRequired[List[int]]
+ """A list of token type IDs to pass to the cross encoder model."""
+
multi_modal_data: NotRequired["MultiModalDataDict"]
"""
- Optional multi-modal data to pass to the model,
+ DEPRECATED: Optional multi-modal data to pass to the model,
if the model supports it.
"""
mm_processor_kwargs: NotRequired[Dict[str, Any]]
"""
- Optional multi-modal processor kwargs to be forwarded to the
+ DEPRECATED: Optional multi-modal processor kwargs to be forwarded to the
multimodal input mapper & processor. Note that if multiple modalities
have registered mappers etc for the model being considered, we attempt
to pass the mm_processor_kwargs to each of them.
@@ -133,21 +141,39 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
class TokenInputs(TypedDict):
"""Represents token-based inputs."""
+
+ type: Literal["token"]
+ """The type of inputs."""
+
prompt_token_ids: List[int]
"""The token IDs of the prompt."""
- prompt: NotRequired[Optional[str]]
+ token_type_ids: NotRequired[List[int]]
+ """The token type IDs of the prompt."""
+
+ prompt: NotRequired[str]
"""
The original prompt text corresponding to the token IDs, if available.
"""
- multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
+ multi_modal_data: NotRequired["MultiModalDataDict"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
- mm_processor_kwargs: NotRequired[Optional[Dict[str, Any]]]
+ multi_modal_inputs: NotRequired["MultiModalKwargs"]
+ """
+ Optional multi-modal inputs to pass to the model,
+ if the model supports it.
+ """
+
+ multi_modal_placeholders: NotRequired["MultiModalPlaceholderDict"]
+ """
+ Placeholder ranges for the multi-modal data.
+ """
+
+ mm_processor_kwargs: NotRequired[Dict[str, Any]]
"""
Optional multi-modal processor kwargs to be forwarded to the
multimodal input mapper & processor. Note that if multiple modalities
@@ -163,8 +189,11 @@ class TokenInputs(TypedDict):
def token_inputs(
prompt_token_ids: List[int],
+ token_type_ids: Optional[List[int]] = None,
prompt: Optional[str] = None,
multi_modal_data: Optional["MultiModalDataDict"] = None,
+ multi_modal_inputs: Optional["MultiModalKwargs"] = None,
+ multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
# --- FLAGSCALE MODIFICATION BEG ---
negative_prompt_token_ids: Optional[List[int]] = None,
@@ -173,31 +202,31 @@ def token_inputs(
) -> TokenInputs:
"""Construct :class:`TokenInputs` from optional values."""
# --- FLAGSCALE MODIFICATION BEG ---
- inputs = TokenInputs(prompt_token_ids=prompt_token_ids,
+ inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids,
negative_prompt_token_ids=negative_prompt_token_ids)
# --- FLAGSCALE MODIFICATION END ---
if prompt is not None:
inputs["prompt"] = prompt
- if multi_modal_data is not None:
- inputs["multi_modal_data"] = multi_modal_data
- if mm_processor_kwargs is not None:
- inputs["mm_processor_kwargs"] = mm_processor_kwargs
# --- FLAGSCALE MODIFICATION BEG ---
if negative_prompt is not None:
inputs["negative_prompt"] = negative_prompt
# --- FLAGSCALE MODIFICATION END ---
+ if token_type_ids is not None:
+ inputs["token_type_ids"] = token_type_ids
+ if multi_modal_data is not None:
+ inputs["multi_modal_data"] = multi_modal_data
+ if multi_modal_inputs is not None:
+ inputs["multi_modal_inputs"] = multi_modal_inputs
+ if multi_modal_placeholders is not None:
+ inputs["multi_modal_placeholders"] = multi_modal_placeholders
+ if mm_processor_kwargs is not None:
+ inputs["mm_processor_kwargs"] = mm_processor_kwargs
return inputs
-SingletonInputs = TokenInputs
-"""
-A processed :class:`SingletonPrompt` which can be passed to
-:class:`vllm.sequence.Sequence`.
-"""
-
-DecoderOnlyInputs = TokenInputs
+DecoderOnlyInputs = Union[TokenInputs, "MultiModalInputsV2"]
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
@@ -205,28 +234,123 @@ def token_inputs(
"""
-class EncoderDecoderInputs(TokenInputs):
+class EncoderDecoderInputs(TypedDict):
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
This specifies the required data for encoder-decoder models.
"""
- encoder_prompt_token_ids: List[int]
- """The token IDs of the encoder prompt."""
+ encoder: Union[TokenInputs, "MultiModalInputsV2"]
+ """The inputs for the encoder portion."""
+
+ decoder: Union[TokenInputs, "MultiModalInputsV2"]
+ """The inputs for the decoder portion."""
- encoder_prompt: NotRequired[Optional[str]]
- """
- The original encoder prompt text corresponding to the token IDs, if
- available.
- """
- encoder_multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
+SingletonInputs = Union[TokenInputs, "MultiModalInputsV2"]
+"""
+A processed :class:`SingletonPrompt` which can be passed to
+:class:`vllm.sequence.Sequence`.
+"""
+
+
+@dataclass
+class SingletonInputsAdapter:
"""
- Optional multi-modal data to pass to the encoder model,
- if the model supports it.
+ Unified interface to access the components of :class:`SingletonInputs`.
"""
+ inputs: SingletonInputs
+
+ @cached_property
+ def prompt(self) -> Optional[str]:
+ inputs = self.inputs
+
+ if inputs["type"] == "token" or inputs["type"] == "multimodal":
+ return inputs.get("prompt")
+
+ assert_never(inputs)
+ @cached_property
+ def prompt_token_ids(self) -> List[int]:
+ inputs = self.inputs
+
+ if inputs["type"] == "token" or inputs["type"] == "multimodal":
+ return inputs.get("prompt_token_ids", [])
+
+ assert_never(inputs)
+
+ @cached_property
+ def token_type_ids(self) -> List[int]:
+ inputs = self.inputs
+
+ if inputs["type"] == "token" or inputs["type"] == "multimodal":
+ return inputs.get("token_type_ids", [])
+
+ assert_never(inputs)
+
+ @cached_property
+ def prompt_embeds(self) -> Optional[torch.Tensor]:
+ inputs = self.inputs
+
+ if inputs["type"] == "token" or inputs["type"] == "multimodal":
+ return None
+
+ assert_never(inputs)
+
+ @cached_property
+ def multi_modal_data(self) -> "MultiModalDataDict":
+ inputs = self.inputs
+
+ if inputs["type"] == "token":
+ return inputs.get("multi_modal_data", {})
+
+ if inputs["type"] == "multimodal":
+ return inputs.get("mm_kwargs", {})
+
+ assert_never(inputs)
+
+ @cached_property
+ def multi_modal_inputs(self) -> Union[Dict, "MultiModalKwargs"]:
+ inputs = self.inputs
+
+ if inputs["type"] == "token":
+ return inputs.get("multi_modal_inputs", {})
+
+ if inputs["type"] == "multimodal":
+ return inputs.get("mm_kwargs", {})
+
+ assert_never(inputs)
+
+ @cached_property
+ def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
+ inputs = self.inputs
+
+ if inputs["type"] == "token":
+ return inputs.get("multi_modal_placeholders", {})
+
+ if inputs["type"] == "multimodal":
+ return inputs.get("mm_placeholders", {})
+
+ assert_never(inputs)
+
+ @cached_property
+ def mm_processor_kwargs(self) -> Dict[str, Any]:
+ inputs = self.inputs
+
+ if inputs["type"] == "token":
+ return inputs.get("mm_processor_kwargs", {})
+
+ if inputs["type"] == "multimodal":
+ return {}
+
+ assert_never(inputs)
+
+
+ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs]
+"""
+The inputs to :data:`vllm.inputs.InputProcessor`.
+"""
_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt)
_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt)
@@ -253,10 +377,11 @@ def zip_enc_dec_prompts(
) -> List[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
"""
Zip encoder and decoder prompts together into a list of
- :class:`ExplicitEncoderDecoderPrompt` instances. mm_processor_kwargs
- may also be provided; if a dict is passed, the same dictionary will be
- used for every encoder/decoder prompt. If an iterable is provided, it will
- be zipped with the encoder/decoder prompts.
+ :class:`ExplicitEncoderDecoderPrompt` instances.
+
+ ``mm_processor_kwargs`` may also be provided; if a dict is passed, the same
+ dictionary will be used for every encoder/decoder prompt. If an iterable is
+ provided, it will be zipped with the encoder/decoder prompts.
"""
if mm_processor_kwargs is None:
mm_processor_kwargs = cast(Dict[str, Any], {})
@@ -282,34 +407,3 @@ def to_enc_dec_tuple_list(
return [(enc_dec_prompt["encoder_prompt"],
enc_dec_prompt["decoder_prompt"])
for enc_dec_prompt in enc_dec_prompts]
-
-
-def __getattr__(name: str):
- import warnings
-
- if name == "PromptInput":
- msg = ("PromptInput has been renamed to PromptType. "
- "The original name will be removed in an upcoming version.")
-
- warnings.warn(DeprecationWarning(msg), stacklevel=2)
-
- return PromptType
-
- if name == "LLMInputs":
- msg = ("LLMInputs has been renamed to DecoderOnlyInputs. "
- "The original name will be removed in an upcoming version.")
-
- warnings.warn(DeprecationWarning(msg), stacklevel=2)
-
- return DecoderOnlyInputs
-
- if name == "EncoderDecoderLLMInputs":
- msg = (
- "EncoderDecoderLLMInputs has been renamed to EncoderDecoderInputs. "
- "The original name will be removed in an upcoming version.")
-
- warnings.warn(DeprecationWarning(msg), stacklevel=2)
-
- return EncoderDecoderInputs
-
- raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
diff --git a/flagscale/inference/core/llm_engine.py b/flagscale/inference/core/llm_engine.py
index a765f6cf5..53b6d4a99 100644
--- a/flagscale/inference/core/llm_engine.py
+++ b/flagscale/inference/core/llm_engine.py
@@ -1,4 +1,5 @@
# This file is modified from 'FlagScale/vllm/vllm/engine/llm_engine.py'
+import copy
import time
from collections import Counter as collectionsCounter
from collections import deque
@@ -11,14 +12,12 @@
from typing import Set, Type, Union, cast, overload
import torch
-from typing_extensions import TypeVar
+from typing_extensions import TypeVar, deprecated
import vllm.envs as envs
-from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
- EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
- ObservabilityConfig, ParallelConfig,
- PromptAdapterConfig, SchedulerConfig,
- SpeculativeConfig)
+from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
+ ObservabilityConfig, ParallelConfig, SchedulerConfig,
+ VllmConfig)
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
SchedulerOutputs)
from vllm.engine.arg_utils import EngineArgs
@@ -32,8 +31,9 @@
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster
-from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
- EncoderDecoderInputs, InputRegistry, PromptType)
+from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
+ PromptType, SingletonInputsAdapter)
+from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.logits_process import get_bad_words_logits_processors
@@ -41,7 +41,8 @@
from vllm.model_executor.guided_decoding import (
get_local_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput
-from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
+from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
+from vllm.outputs import (PoolingRequestOutput, RequestOutput,
RequestOutputFactory)
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
@@ -81,7 +82,7 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
-_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
+_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
@dataclass
@@ -113,7 +114,7 @@ class SchedulerContext:
def __init__(self, multi_step_stream_outputs: bool = False):
self.output_queue: Deque[OutputData] = deque()
self.request_outputs: List[Union[RequestOutput,
- EmbeddingRequestOutput]] = []
+ PoolingRequestOutput]] = []
self.seq_group_metadata_list: Optional[
List[SequenceGroupMetadata]] = None
self.scheduler_outputs: Optional[SchedulerOutputs] = None
@@ -222,30 +223,35 @@ def validate_outputs(
def __init__(
self,
- model_config: ModelConfig,
- cache_config: CacheConfig,
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- device_config: DeviceConfig,
- load_config: LoadConfig,
- lora_config: Optional[LoRAConfig],
- speculative_config: Optional[SpeculativeConfig],
- decoding_config: Optional[DecodingConfig],
- observability_config: Optional[ObservabilityConfig],
- prompt_adapter_config: Optional[PromptAdapterConfig],
+ vllm_config: VllmConfig,
executor_class: Type[ExecutorBase],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
+ mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False,
) -> None:
+
+ self.model_config = vllm_config.model_config
+ self.cache_config = vllm_config.cache_config
+ self.lora_config = vllm_config.lora_config
+ self.parallel_config = vllm_config.parallel_config
+ self.scheduler_config = vllm_config.scheduler_config
+ self.device_config = vllm_config.device_config
+ self.speculative_config = vllm_config.speculative_config # noqa
+ self.load_config = vllm_config.load_config
+ self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa
+ )
+ self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa
+ self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa
+ )
+
logger.info(
"Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
- "override_neuron_config=%s, "
- "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
+ "override_neuron_config=%s, tokenizer_revision=%s, "
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
"pipeline_parallel_size=%d, "
@@ -257,57 +263,46 @@ def __init__(
"num_scheduler_steps=%d, chunked_prefill_enabled=%s "
"multi_step_stream_outputs=%s, enable_prefix_caching=%s, "
"use_async_output_proc=%s, use_cached_outputs=%s, "
- "chat_template_text_format=%s, mm_processor_kwargs=%s)",
+ "mm_processor_kwargs=%s, pooler_config=%r,"
+ "compilation_config=%r",
VLLM_VERSION,
- model_config.model,
- speculative_config,
- model_config.tokenizer,
- model_config.skip_tokenizer_init,
- model_config.tokenizer_mode,
- model_config.revision,
- model_config.override_neuron_config,
- model_config.rope_scaling,
- model_config.rope_theta,
- model_config.tokenizer_revision,
- model_config.trust_remote_code,
- model_config.dtype,
- model_config.max_model_len,
- load_config.download_dir,
- load_config.load_format,
- parallel_config.tensor_parallel_size,
- parallel_config.pipeline_parallel_size,
- parallel_config.disable_custom_all_reduce,
- model_config.quantization,
- model_config.enforce_eager,
- cache_config.cache_dtype,
- model_config.quantization_param_path,
- device_config.device,
- decoding_config,
- observability_config,
- model_config.seed,
- model_config.served_model_name,
- scheduler_config.num_scheduler_steps,
- scheduler_config.chunked_prefill_enabled,
- scheduler_config.multi_step_stream_outputs,
- cache_config.enable_prefix_caching,
- model_config.use_async_output_proc,
+ self.model_config.model,
+ self.speculative_config,
+ self.model_config.tokenizer,
+ self.model_config.skip_tokenizer_init,
+ self.model_config.tokenizer_mode,
+ self.model_config.revision,
+ self.model_config.override_neuron_config,
+ self.model_config.tokenizer_revision,
+ self.model_config.trust_remote_code,
+ self.model_config.dtype,
+ self.model_config.max_model_len,
+ self.load_config.download_dir,
+ self.load_config.load_format,
+ self.parallel_config.tensor_parallel_size,
+ self.parallel_config.pipeline_parallel_size,
+ self.parallel_config.disable_custom_all_reduce,
+ self.model_config.quantization,
+ self.model_config.enforce_eager,
+ self.cache_config.cache_dtype,
+ self.model_config.quantization_param_path,
+ self.device_config.device,
+ self.decoding_config,
+ self.observability_config,
+ self.model_config.seed,
+ self.model_config.served_model_name,
+ self.scheduler_config.num_scheduler_steps,
+ self.scheduler_config.chunked_prefill_enabled,
+ self.scheduler_config.multi_step_stream_outputs,
+ self.cache_config.enable_prefix_caching,
+ self.model_config.use_async_output_proc,
use_cached_outputs,
- model_config.chat_template_text_format,
- model_config.mm_processor_kwargs,
+ self.model_config.mm_processor_kwargs,
+ self.model_config.pooler_config,
+ vllm_config.compilation_config,
)
# TODO(woosuk): Print more configs in debug mode.
- self.model_config = model_config
- self.cache_config = cache_config
- self.lora_config = lora_config
- self.parallel_config = parallel_config
- self.scheduler_config = scheduler_config
- self.device_config = device_config
- self.speculative_config = speculative_config
- self.load_config = load_config
- self.decoding_config = decoding_config or DecodingConfig()
- self.prompt_adapter_config = prompt_adapter_config
- self.observability_config = observability_config or ObservabilityConfig(
- )
+
self.log_stats = log_stats
self.use_cached_outputs = use_cached_outputs
@@ -329,27 +324,17 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
self.seq_counter = Counter()
self.generation_config_fields = _load_generation_config_dict(
- model_config)
+ self.model_config)
- self.input_preprocessor = InputPreprocessor(model_config,
- self.tokenizer)
+ self.input_preprocessor = InputPreprocessor(self.model_config,
+ self.tokenizer,
+ mm_registry)
self.input_registry = input_registry
self.input_processor = input_registry.create_input_processor(
- model_config)
-
- self.model_executor = executor_class(
- model_config=model_config,
- cache_config=cache_config,
- parallel_config=parallel_config,
- scheduler_config=scheduler_config,
- device_config=device_config,
- lora_config=lora_config,
- speculative_config=speculative_config,
- load_config=load_config,
- prompt_adapter_config=prompt_adapter_config,
- observability_config=self.observability_config,
- )
+ self.model_config)
+
+ self.model_executor = executor_class(vllm_config=vllm_config, )
if self.model_config.task != "embedding":
self._initialize_kv_caches()
@@ -359,36 +344,36 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
from vllm.model_executor.model_loader import (
get_architecture_class_name)
usage_message.report_usage(
- get_architecture_class_name(model_config),
+ get_architecture_class_name(self.model_config),
usage_context,
extra_kvs={
# Common configuration
"dtype":
- str(model_config.dtype),
+ str(self.model_config.dtype),
"tensor_parallel_size":
- parallel_config.tensor_parallel_size,
+ self.parallel_config.tensor_parallel_size,
"block_size":
- cache_config.block_size,
+ self.cache_config.block_size,
"gpu_memory_utilization":
- cache_config.gpu_memory_utilization,
+ self.cache_config.gpu_memory_utilization,
# Quantization
"quantization":
- model_config.quantization,
+ self.model_config.quantization,
"kv_cache_dtype":
- str(cache_config.cache_dtype),
+ str(self.cache_config.cache_dtype),
# Feature flags
"enable_lora":
- bool(lora_config),
+ bool(self.lora_config),
"enable_prompt_adapter":
- bool(prompt_adapter_config),
+ bool(self.prompt_adapter_config),
"enable_prefix_caching":
- cache_config.enable_prefix_caching,
+ self.cache_config.enable_prefix_caching,
"enforce_eager":
- model_config.enforce_eager,
+ self.model_config.enforce_eager,
"disable_custom_all_reduce":
- parallel_config.disable_custom_all_reduce,
+ self.parallel_config.disable_custom_all_reduce,
})
if self.tokenizer:
@@ -407,7 +392,7 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
for _ in range(self.parallel_config.pipeline_parallel_size)
]
- if model_config.use_async_output_proc:
+ if self.model_config.use_async_output_proc:
process_model_outputs = weak_bind(self._process_model_outputs)
self.async_callbacks = [
@@ -427,11 +412,11 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
# GPU and CPU blocks, which are profiled in the distributed executor.
self.scheduler = [
Scheduler(
- scheduler_config, cache_config, lora_config,
- parallel_config.pipeline_parallel_size,
+ self.scheduler_config, self.cache_config, self.lora_config,
+ self.parallel_config.pipeline_parallel_size,
self.async_callbacks[v_id]
- if model_config.use_async_output_proc else None)
- for v_id in range(parallel_config.pipeline_parallel_size)
+ if self.model_config.use_async_output_proc else None)
+ for v_id in range(self.parallel_config.pipeline_parallel_size)
]
# Metric Logging.
@@ -453,7 +438,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
"prometheus":
PrometheusStatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
- labels=dict(model_name=model_config.served_model_name),
+ labels=dict(
+ model_name=self.model_config.served_model_name),
max_model_len=self.model_config.max_model_len),
}
self.stat_loggers["prometheus"].info("cache_config",
@@ -506,7 +492,7 @@ def _initialize_kv_caches(self) -> None:
@classmethod
def _get_executor_cls(cls,
- engine_config: EngineConfig) -> Type[ExecutorBase]:
+ engine_config: VllmConfig) -> Type[ExecutorBase]:
distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend)
# Initialize the cluster and specify the executor class.
@@ -533,6 +519,14 @@ def _get_executor_cls(cls,
elif engine_config.device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutor
executor_class = CPUExecutor
+ elif engine_config.device_config.device_type == "hpu":
+ if distributed_executor_backend == "ray":
+ initialize_ray_cluster(engine_config.parallel_config)
+ from vllm.executor.ray_hpu_executor import RayHPUExecutor
+ executor_class = RayHPUExecutor
+ else:
+ from vllm.executor.hpu_executor import HPUExecutor
+ executor_class = HPUExecutor
elif engine_config.device_config.device_type == "openvino":
from vllm.executor.openvino_executor import OpenVINOExecutor
executor_class = OpenVINOExecutor
@@ -576,11 +570,11 @@ def from_engine_args(
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
- engine_config = engine_args.create_engine_config()
+ engine_config = engine_args.create_engine_config(usage_context)
executor_class = cls._get_executor_cls(engine_config)
# Create the LLM engine.
engine = cls(
- **engine_config.to_dict(),
+ vllm_config=engine_config,
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context,
@@ -627,7 +621,7 @@ def _init_tokenizer(self) -> BaseTokenizerGroup:
model_config=self.model_config,
scheduler_config=self.scheduler_config,
parallel_config=self.parallel_config,
- enable_lora=bool(self.lora_config))
+ lora_config=self.lora_config)
def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
@@ -643,7 +637,7 @@ def _verify_args(self) -> None:
def _add_processed_request(
self,
request_id: str,
- processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs],
+ processed_inputs: ProcessorInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: float,
lora_request: Optional[LoRARequest],
@@ -668,38 +662,45 @@ def _add_processed_request(
)
return None
- self._validate_model_inputs(processed_inputs)
+ self._validate_model_inputs(processed_inputs, lora_request)
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
- seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
- lora_request, prompt_adapter_request)
+ negative_decoder_inputs = None # --- FLAGSCALE MODIFICATION ---
+ if is_encoder_decoder_inputs(processed_inputs):
+ decoder_inputs = processed_inputs["decoder"]
+ encoder_inputs = processed_inputs["encoder"]
+ # --- FLAGSCALE MODIFICATION BEG ---
+ elif "negative_prompt_token_ids" in processed_inputs \
+ and processed_inputs["negative_prompt_token_ids"] is not None:
+ positive_inputs = processed_inputs.copy()
+ positive_inputs.pop("negative_prompt_token_ids")
+ negative_inputs = processed_inputs.copy()
+ negative_inputs["prompt_token_ids"] = negative_inputs["negative_prompt_token_ids"]
+ negative_inputs.pop("negative_prompt_token_ids")
+ decoder_inputs = positive_inputs
+ negative_decoder_inputs = negative_inputs
+ encoder_inputs = None
+ # --- FLAGSCALE MODIFICATION END ---
+ else:
+ decoder_inputs = processed_inputs
+ encoder_inputs = None
- encoder_seq = None
- if 'encoder_prompt_token_ids' in processed_inputs:
- encoder_seq = Sequence(seq_id,
- processed_inputs,
- block_size,
- eos_token_id,
- lora_request,
- prompt_adapter_request,
- from_decoder_prompt=False)
+ seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
+ lora_request, prompt_adapter_request)
# --- FLAGSCALE MODIFICATION BEG ---
- negative_seq = None
- if "negative_prompt_token_ids" in processed_inputs \
- and processed_inputs["negative_prompt_token_ids"]:
- negative_seq = Sequence(seq_id,
- processed_inputs,
- block_size,
- eos_token_id,
- lora_request,
- prompt_adapter_request,
- from_negative_prompt=True)
+ negative_seq = (None if negative_decoder_inputs is None else Sequence(
+ seq_id, negative_decoder_inputs, block_size, eos_token_id, lora_request,
+ prompt_adapter_request))
# --- FLAGSCALE MODIFICATION END ---
+ encoder_seq = (None if encoder_inputs is None else Sequence(
+ seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
+ prompt_adapter_request))
+
# Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams):
seq_group = self._create_sequence_group_with_sampling(
@@ -740,7 +741,8 @@ def _add_processed_request(
def stop_remote_worker_execution_loop(self) -> None:
self.model_executor.stop_remote_worker_execution_loop()
- @overload # DEPRECATED
+ @overload
+ @deprecated("'inputs' will be renamed to 'prompt")
def add_request(
self,
request_id: str,
@@ -841,9 +843,21 @@ def add_request(
raise ValueError(f"Got priority {priority} but "
"Priority scheduling is not enabled.")
+ if isinstance(params, SamplingParams) \
+ and (params.guided_decoding or params.logits_processors) \
+ and self.scheduler_config.num_scheduler_steps > 1:
+ raise ValueError(
+ "Guided decoding and logits processors are not supported "
+ "in multi-step decoding")
+
if arrival_time is None:
arrival_time = time.time()
+ if self.tokenizer is not None:
+ self._validate_token_prompt(
+ prompt,
+ tokenizer=self.get_tokenizer(lora_request=lora_request))
+
preprocessed_inputs = self.input_preprocessor.preprocess(
prompt,
request_id=request_id,
@@ -852,13 +866,6 @@ def add_request(
)
processed_inputs = self.input_processor(preprocessed_inputs)
- # This is a bit of a hack - copy the mm_processor_kwargs that were
- # used in the input processor to the processed output, since these
- # kwargs are presumed to be immutable and the values should be aligned
- # between the input processor (here) and the input mapper.
- processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get(
- "mm_processor_kwargs")
-
self._add_processed_request(
request_id=request_id,
processed_inputs=processed_inputs,
@@ -870,6 +877,27 @@ def add_request(
priority=priority,
)
+ def _validate_token_prompt(self, prompt: PromptType,
+ tokenizer: AnyTokenizer):
+ # Guard against out-of-vocab tokens.
+ # For some tokenizers, tokenizer.decode will happily return empty text
+ # for token ids that are out of vocab, and we don't detect token ids
+ # that are greater than the max token id before running the model.
+ # However, these token ids will later crash a cuda kernel at runtime
+ # with an index out of bounds error. This will crash the entire engine.
+ # This needs to happen before multimodal input pre-processing, which
+ # may add dummy tokens that aren't part of the tokenizer's
+ # vocabulary.
+ if is_token_prompt(prompt):
+ prompt_ids = prompt["prompt_token_ids"]
+ if len(prompt_ids) == 0:
+ # Empty prompt check is handled later
+ return
+ max_input_id = max(prompt_ids)
+ if max_input_id > tokenizer.max_token_id:
+ raise ValueError(
+ "Token id {} is out of vocabulary".format(max_input_id))
+
def _create_sequence_group_with_sampling(
self,
request_id: str,
@@ -1020,9 +1048,9 @@ def _update_num_computed_tokens_for_multi_step_prefill(
This function updates num_computed_tokens for prompt sequences
when Multi-Step is enabled.
- seq_group: SequenceGroup to update the num_computed_tokens for.
+ seq_group: SequenceGroup to update the num_computed_tokens for.
seq_group_meta: Metadata of the given SequenceGroup.
- is_first_step_output: Optional[bool] -
+ is_first_step_output: Optional[bool] -
When available, is_first_step_output indicates if the appended
output token is the output of the first-step in multi-step.
A value of None indicates that outputs from all steps in
@@ -1142,14 +1170,11 @@ def _process_model_outputs(self,
else:
seq_group.update_num_computed_tokens(
seq_group_meta.token_chunk_size or 0)
-
- # --- FLAGSCALE MODIFICATION BEG ---
- if seq_group.has_negative_seqs():
- assert self.scheduler_config.is_multi_step is False
+ # --- FLAGSCALE MODIFICATION BEG ---
seq_group.update_negative_num_computed_tokens(
- scheduled_seq_group.negative_token_chunk_size
+ scheduled_seq_group.negative_token_chunk_size or 0
)
- # --- FLAGSCALE MODIFICATION END ---
+ # --- FLAGSCALE MODIFICATION END ---
if outputs:
for o in outputs:
@@ -1300,13 +1325,13 @@ def _advance_to_next_step(
if seq_group_metadata.token_chunk_size
is not None else 0)
seq_group.update_num_computed_tokens(token_chunk_size)
-
- # --- FLAGSCALE MODIFICATION BEG ---
- if seq_group.has_negative_seqs():
- seq_group.update_negative_num_computed_tokens(
- scheduled_seq_group.negative_token_chunk_size
- )
- # --- FLAGSCALE MODIFICATION END ---
+ # --- FLAGSCALE MODIFICATION BEG ---
+ if seq_group.has_negative_seqs():
+ negative_token_chunk_size = (seq_group_metadata.negative_token_chunk_size
+ if seq_group_metadata.negative_token_chunk_size
+ is not None else 0)
+ seq_group.update_negative_num_computed_tokens(negative_token_chunk_size)
+ # --- FLAGSCALE MODIFICATION END ---
if seq_group_metadata.do_sample:
assert len(sequence_group_outputs.samples) == 1, (
@@ -1325,15 +1350,13 @@ def _advance_to_next_step(
seq_group.update_num_computed_tokens(1)
else:
seq.append_token_id(sample.output_token, sample.logprobs)
+ # --- FLAGSCALE MODIFICATION BEG ---
+ if seq_group.has_negative_seqs():
+ negative_seq = seq_group.negative_seqs[0]
+ negative_seq.append_token_id(sample.output_token, sample.logprobs)
+ # --- FLAGSCALE MODIFICATION END ---
- # --- FLAGSCALE MODIFICATION BEG ---
- if seq_group.has_negative_seqs():
- assert self.scheduler_config.is_multi_step is False
- negative_seq = seq_group.negative_seqs[0]
- negative_seq.append_token_id(sample.output_token, sample.logprobs)
- # --- FLAGSCALE MODIFICATION END ---
-
- def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
+ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results.
.. figure:: https://i.imgur.com/sv2HssD.png
@@ -1417,6 +1440,9 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
+ finished_requests_ids = self.scheduler[
+ virtual_engine].get_and_reset_finished_requests_ids()
+
# Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
@@ -1428,13 +1454,13 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
self._cache_scheduler_outputs_for_multi_step(
virtual_engine, seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc)
+ else:
+ finished_requests_ids = list()
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None
if not scheduler_outputs.is_empty():
- finished_requests_ids = self.scheduler[
- virtual_engine].get_and_reset_finished_requests_ids()
# Check if we have a cached last_output from the previous iteration.
# For supporting PP this is probably the best way to pass the
@@ -1550,8 +1576,8 @@ def _has_remaining_steps(
seq_group.state.remaining_steps != ref_remaining_steps
for seq_group in seq_group_metadata_list[1:]
]):
- raise AssertionError(("All running sequence groups should "
- "have the same remaining steps."))
+ raise AssertionError("All running sequence groups should "
+ "have the same remaining steps.")
return ref_remaining_steps > 0
@@ -1676,6 +1702,7 @@ def _get_stats(self,
# Iteration stats
num_prompt_tokens_iter = 0
num_generation_tokens_iter = 0
+ num_tokens_iter = 0
time_to_first_tokens_iter: List[float] = []
time_per_output_tokens_iter: List[float] = []
num_preemption_iter = (0 if scheduler_outputs is None else
@@ -1684,10 +1711,19 @@ def _get_stats(self,
# Request stats
# Latency
time_e2e_requests: List[float] = []
+ time_queue_requests: List[float] = []
+ time_inference_requests: List[float] = []
+ time_prefill_requests: List[float] = []
+ time_decode_requests: List[float] = []
+ time_in_queue_requests: List[float] = []
+ model_forward_time_requests: List[float] = []
+ model_execute_time_requests: List[float] = []
# Metadata
num_prompt_tokens_requests: List[int] = []
num_generation_tokens_requests: List[int] = []
n_requests: List[int] = []
+ max_num_generation_tokens_requests: List[int] = []
+ max_tokens_requests: List[int] = []
finished_reason_requests: List[str] = []
# Lora requests
@@ -1716,7 +1752,7 @@ def _get_stats(self,
# not counted (to avoid double counting)
actual_num_batched_tokens = scheduler_outputs.num_batched_tokens # type: ignore
- num_generation_tokens_from_prefill_groups = 0.
+ num_generation_tokens_from_prefill_groups = 0
# NOTE: if scheduler_outputs.num_prefill_groups > 0 and
# the len of scheduler_outputs.scheduled_seq_groups is !=
# scheduler_outputs.num_prefill_groups, this means that
@@ -1777,6 +1813,27 @@ def _get_stats(self,
# Latency timings
time_e2e_requests.append(now -
seq_group.metrics.arrival_time)
+ if (seq_group.metrics.first_scheduled_time is not None and
+ seq_group.metrics.first_token_time is not None):
+ time_queue_requests.append(
+ seq_group.metrics.first_scheduled_time -
+ seq_group.metrics.arrival_time)
+ time_prefill_requests.append(
+ seq_group.metrics.first_token_time -
+ seq_group.metrics.first_scheduled_time)
+ time_decode_requests.append(
+ now - seq_group.metrics.first_token_time)
+ time_inference_requests.append(
+ now - seq_group.metrics.first_scheduled_time)
+ if seq_group.metrics.time_in_queue is not None:
+ time_in_queue_requests.append(
+ seq_group.metrics.time_in_queue)
+ if seq_group.metrics.model_forward_time is not None:
+ model_forward_time_requests.append(
+ seq_group.metrics.model_forward_time)
+ if seq_group.metrics.model_execute_time is not None:
+ model_execute_time_requests.append(
+ seq_group.metrics.model_execute_time * 1000)
# Metadata
num_prompt_tokens_requests.append(
len(seq_group.prompt_token_ids))
@@ -1784,8 +1841,13 @@ def _get_stats(self,
seq.get_output_len()
for seq in seq_group.get_finished_seqs()
])
+ max_num_generation_tokens_requests.append(
+ max(seq.get_output_len()
+ for seq in seq_group.get_seqs()))
if seq_group.sampling_params is not None:
n_requests.append(seq_group.sampling_params.n)
+ max_tokens_requests.append(
+ seq_group.sampling_params.max_tokens)
finished_reason_requests.extend([
SequenceStatus.get_finished_reason(seq.status)
for seq in seq_group.get_finished_seqs()
@@ -1800,7 +1862,8 @@ def _get_stats(self,
num_generation_tokens_iter = (
actual_num_batched_tokens - num_prompt_tokens_iter +
num_generation_tokens_from_prefill_groups)
-
+ num_tokens_iter = (num_generation_tokens_iter +
+ num_prompt_tokens_iter)
# Spec decode, if enabled, emits specialized metrics from the worker in
# sampler output.
if model_output and (model_output[0].spec_decode_worker_metrics
@@ -1826,6 +1889,7 @@ def _get_stats(self,
# Iteration stats
num_prompt_tokens_iter=num_prompt_tokens_iter,
num_generation_tokens_iter=num_generation_tokens_iter,
+ num_tokens_iter=num_tokens_iter,
time_to_first_tokens_iter=time_to_first_tokens_iter,
time_per_output_tokens_iter=time_per_output_tokens_iter,
spec_decode_metrics=spec_decode_metrics,
@@ -1834,10 +1898,20 @@ def _get_stats(self,
# Request stats
# Latency
time_e2e_requests=time_e2e_requests,
+ time_queue_requests=time_queue_requests,
+ time_inference_requests=time_inference_requests,
+ time_prefill_requests=time_prefill_requests,
+ time_decode_requests=time_decode_requests,
+ time_in_queue_requests=time_in_queue_requests,
+ model_forward_time_requests=model_forward_time_requests,
+ model_execute_time_requests=model_execute_time_requests,
# Metadata
num_prompt_tokens_requests=num_prompt_tokens_requests,
num_generation_tokens_requests=num_generation_tokens_requests,
+ max_num_generation_tokens_requests=
+ max_num_generation_tokens_requests,
n_requests=n_requests,
+ max_tokens_requests=max_tokens_requests,
finished_reason_requests=finished_reason_requests,
max_lora=str(max_lora_stat),
waiting_lora_adapters=list(waiting_lora_adapters.keys()),
@@ -1962,19 +2036,17 @@ def create_trace_span(self, seq_group: SequenceGroup) -> None:
SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_EXECUTE,
metrics.model_execute_time)
- def is_encoder_decoder_model(self):
- return self.input_preprocessor.is_encoder_decoder_model()
-
- def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
- EncoderDecoderInputs]):
- if self.model_config.is_multimodal_model:
+ def _validate_model_inputs(self, inputs: ProcessorInputs,
+ lora_request: Optional[LoRARequest]):
+ if is_encoder_decoder_inputs(inputs):
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
- prompt_ids = inputs.get("prompt_token_ids")
- elif self.is_encoder_decoder_model():
- prompt_ids = inputs.get("encoder_prompt_token_ids")
+ prompt_inputs = inputs["decoder" if self.model_config.
+ is_multimodal_model else "encoder"]
else:
- prompt_ids = inputs.get("prompt_token_ids")
+ prompt_inputs = inputs
+
+ prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids
if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty")
@@ -2005,7 +2077,11 @@ def _build_logits_processors(
logits_processors = []
- if (guided_decoding := sampling_params.guided_decoding) is not None:
+ if sampling_params.guided_decoding is not None:
+ # Defensively copy sampling params since guided decoding logits
+ # processors can have different state for each request
+ sampling_params = copy.copy(sampling_params)
+ guided_decoding = sampling_params.guided_decoding
logger.debug(
"Building guided decoding logits processor in "
@@ -2016,7 +2092,9 @@ def _build_logits_processors(
self.decoding_config.guided_decoding_backend
processor = get_local_guided_decoding_logits_processor(
- guided_params=guided_decoding, tokenizer=tokenizer)
+ guided_params=guided_decoding,
+ tokenizer=tokenizer,
+ model_config=self.model_config)
if processor:
logits_processors.append(processor)
diff --git a/flagscale/inference/core/logits_processor.py b/flagscale/inference/core/logits_processor.py
index 7ce2105fb..1964197e2 100644
--- a/flagscale/inference/core/logits_processor.py
+++ b/flagscale/inference/core/logits_processor.py
@@ -112,8 +112,14 @@ def _prune_hidden_states(
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
- return hidden_states.index_select(0,
- sampling_metadata.selected_token_indices)
+ # NOTE(kzawora): The if guard is needed for Gaudi - in some scenarios
+ # (warmup, profile_run) we might not have selected_token_indices,
+ # so we skip pruning.
+ if sampling_metadata.selected_token_indices is not None:
+ return hidden_states.index_select(
+ 0, sampling_metadata.selected_token_indices)
+ else:
+ return hidden_states
def _apply_logits_processors(
@@ -152,10 +158,9 @@ def _apply_logits_processors(
else:
logits_row = logits[2 * logits_row_idx]
logits_row_idx *= 2
+ # --- FLAGSCALE MODIFICATION END ---
else:
logits_row = logits[logits_row_idx]
- # --- FLAGSCALE MODIFICATION END ---
-
past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids
prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids
diff --git a/flagscale/inference/core/model_runner.py b/flagscale/inference/core/model_runner.py
index cc27dc6dc..49e1d1c66 100644
--- a/flagscale/inference/core/model_runner.py
+++ b/flagscale/inference/core/model_runner.py
@@ -19,13 +19,9 @@
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.attention.backends.abstract import AttentionState
from vllm.attention.backends.utils import CommonAttentionState
-from vllm.compilation.compile_context import set_compile_context
-from vllm.compilation.levels import CompilationLevel
-from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
- ModelConfig, ObservabilityConfig, ParallelConfig,
- PromptAdapterConfig, SchedulerConfig)
+from vllm.config import CompilationLevel, VllmConfig
from vllm.core.scheduler import SchedulerOutputs
-from vllm.distributed import get_pp_group
+from vllm.distributed import get_kv_transfer_group, get_pp_group
from vllm.distributed.parallel_state import graph_capture
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
@@ -41,7 +37,8 @@
from vllm.model_executor.models import supports_lora, supports_multimodal
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
- MultiModalInputs, MultiModalRegistry)
+ MultiModalKwargs, MultiModalPlaceholderMap,
+ MultiModalRegistry)
from vllm.platforms import current_platform
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest
@@ -49,10 +46,10 @@
LRUCacheWorkerPromptAdapterManager)
from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
-from vllm.transformers_utils.config import uses_mrope
-from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, async_tensor_h2d,
- flatten_2d_lists, is_pin_memory_available,
- supports_dynamo, weak_ref_tensor)
+from vllm.utils import (DeviceMemoryProfiler, GiB_bytes, PyObjectCache,
+ async_tensor_h2d, flatten_2d_lists,
+ is_pin_memory_available, supports_dynamo,
+ weak_ref_tensor)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict,
@@ -66,16 +63,7 @@
logger = init_logger(__name__)
LORA_WARMUP_RANK = 8
-_BATCH_SIZE_ALIGNMENT = 8
-# all the token sizes that **can** be captured by cudagraph.
-# they can be arbitrarily large.
-# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192.
-# the actual sizes to capture will be determined by the model,
-# depending on the model's max_num_seqs.
-# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
-_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
- _BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025)
-]
+
_NUM_WARMUP_ITERS = 2
TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU")
@@ -95,6 +83,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
"""
input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
+ token_types: Optional[torch.Tensor] = None
seq_lens: Optional[List[int]] = None
query_lens: Optional[List[int]] = None
lora_mapping: Optional["LoRAMapping"] = None
@@ -137,6 +126,18 @@ def from_broadcasted_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
+ # Exclude `async_callback` to be able to pickle this object
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ del state["async_callback"]
+ return state
+
+ # TODO: What happens when we depickle this object?
+ # How can we update this callback to properly pass it to the engine?
+ def __setstate__(self, state):
+ self.__dict__.update(state)
+ self.__dict__.update({'async_callback': None})
+
@dataclass(frozen=True)
class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
@@ -191,6 +192,7 @@ class InterDataForSeqGroup:
def simple_reinit(self):
self.input_tokens[0].clear() # type: ignore
self.input_positions[0].clear() # type: ignore
+ self.token_types[0].clear() # type: ignore
self.mrope_input_positions = None # type: ignore
self.seq_lens[0] = 0 # type: ignore
self.orig_seq_lens[0] = 0 # type: ignore
@@ -217,6 +219,7 @@ def __init__(
# Input tokens and positions.
input_tokens: Optional[List[List[int]]] = None,
input_positions: Optional[List[List[int]]] = None,
+ token_types: Optional[List[List[int]]] = None,
mrope_input_positions: Optional[List[List[List[int]]]] = None,
# The sequence length (may be capped to the sliding window).
@@ -242,7 +245,9 @@ def __init__(
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
# Multi-modal inputs.
- multi_modal_inputs: Optional[MultiModalInputs] = None,
+ multi_modal_kwargs: Optional[MultiModalKwargs] = None,
+ multi_modal_placeholder_maps: Optional[Dict[
+ str, MultiModalPlaceholderMap]] = None,
# Whether the prefix cache is hit (prefill only).
prefix_cache_hit: bool = False,
@@ -280,6 +285,12 @@ def __init__(
for seq_id in range(len(self.seq_ids)):
self.input_positions[seq_id].clear()
+ if token_types:
+ self.token_types = token_types
+ else:
+ for seq_id in range(len(self.seq_ids)):
+ self.token_types[seq_id].clear()
+
self.mrope_input_positions = None
if seq_lens:
@@ -343,6 +354,7 @@ def __init__(
else:
self.input_tokens = input_tokens or []
self.input_positions = input_positions or []
+ self.token_types = token_types or []
self.mrope_input_positions = mrope_input_positions or None
self.seq_lens = seq_lens or []
self.orig_seq_lens = orig_seq_lens or []
@@ -361,7 +373,8 @@ def __init__(
prompt_adapter_prompt_mapping or [])
self.prompt_adapter_request = prompt_adapter_request
- self.multi_modal_inputs = multi_modal_inputs
+ self.multi_modal_kwargs = multi_modal_kwargs
+ self.multi_modal_placeholder_maps = multi_modal_placeholder_maps
self.prefix_cache_hit = prefix_cache_hit
self.n_seqs = len(self.seq_ids)
@@ -374,6 +387,7 @@ def __post_init__(self):
self.input_tokens = [[] for _ in range(self.n_seqs)]
self.input_positions = [[] for _ in range(self.n_seqs)]
+ self.token_types = [[] for _ in range(self.n_seqs)]
self.mrope_input_positions = None
self.seq_lens = [0] * self.n_seqs
self.orig_seq_lens = [0] * self.n_seqs
@@ -472,10 +486,10 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int,
if is_negative:
seq_data = seq_group_metadata.negative_seq_data[inter_data.seq_ids[seq_idx]]
token_chunk_size = seq_group_metadata.negative_token_chunk_size
+ # --- FLAGSCALE MODIFICATION END ---
else:
seq_data = seq_group_metadata.seq_data[inter_data.seq_ids[seq_idx]]
token_chunk_size = seq_group_metadata.token_chunk_size
- # --- FLAGSCALE MODIFICATION END ---
# Compute context length (the number of tokens that are
# already computed) and sequence length (total number of tokens).
@@ -485,19 +499,22 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int,
context_len = seq_data.get_num_computed_tokens()
seq_len = min(seq_len, context_len + token_chunk_size)
elif self.runner.scheduler_config.is_multi_step or \
- self.runner.model_config.is_encoder_decoder_model:
+ self.runner.model_config.is_encoder_decoder:
context_len = seq_len - 1
else:
context_len = seq_data.get_num_computed_tokens()
# Compute tokens.
tokens = seq_data.get_token_ids()[context_len:seq_len]
+ token_types = seq_group_metadata.token_type_ids
inter_data.seq_lens[seq_idx] = seq_len
inter_data.orig_seq_lens[seq_idx] = seq_len
inter_data.context_lens[seq_idx] = context_len
inter_data.input_tokens[seq_idx].extend(tokens)
inter_data.input_positions[seq_idx].extend(range(context_len, seq_len))
+ inter_data.token_types[seq_idx].extend(
+ token_types if token_types else [])
inter_data.query_lens[seq_idx] = seq_len - context_len
if seq_data.mrope_position_delta is not None:
@@ -535,6 +552,9 @@ def _compute_for_prefix_cache_hit(
# this may be larger than the sequence length if chunked
# prefill is enabled.
prefix_cache_len = len(computed_block_nums) * self.block_size
+ seq_group_metadata.seq_data[inter_data.seq_ids[
+ seq_idx]].update_num_cached_tokens(prefix_cache_len)
+
# The number of so far computed prompt tokens in this sequence.
context_len = inter_data.context_lens[seq_idx]
# The total number of prompt tokens in this sequence.
@@ -552,6 +572,8 @@ def _compute_for_prefix_cache_hit(
seq_idx][uncomputed_start:]
inter_data.input_positions[seq_idx] = inter_data.input_positions[
seq_idx][uncomputed_start:]
+ inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][
+ uncomputed_start:]
context_len = prefix_cache_len
inter_data.context_lens[seq_idx] = context_len
@@ -566,6 +588,8 @@ def _compute_for_prefix_cache_hit(
seq_idx][-1:]
inter_data.input_positions[seq_idx] = inter_data.input_positions[
seq_idx][-1:]
+ inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][
+ -1:]
inter_data.query_lens[seq_idx] = 1
inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1
@@ -615,7 +639,8 @@ def _compute_lora_input(self, inter_data: InterDataForSeqGroup,
def _compute_prompt_adapter_input(
self, inter_data: InterDataForSeqGroup,
- seq_group_metadata: SequenceGroupMetadata):
+ seq_group_metadata: SequenceGroupMetadata,
+ is_negative: bool = False): # --- FLAGSCALE MODIFICATION ---
"""If prompt adapter is enabled, compute index and prompt mapping.
"""
# Note that when is_prompt=True, we expect only one sequence
@@ -642,19 +667,31 @@ def _compute_prompt_adapter_input(
and seq_group_metadata.sampling_params.prompt_logprobs else 1)
def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
- seq_group_metadata: SequenceGroupMetadata):
+ seq_group_metadata: SequenceGroupMetadata,
+ is_negative: bool = False): # --- FLAGSCALE MODIFICATION ---
"""If multi-modal data is given, add it to the input."""
- mm_data = seq_group_metadata.multi_modal_data
+ # NOTE: mm_data only includes the subset of multi-modal items that
+ # intersect with the current prefill positions.
+ positions = inter_data.input_positions[0]
+ mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group(
+ seq_group_metadata,
+ range(positions[0], positions[0] + len(positions)))
if not mm_data:
return
- mm_kwargs = self.multi_modal_input_mapper(
- mm_data,
- mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs)
- inter_data.multi_modal_inputs = mm_kwargs
+ if self.runner.mm_registry.has_processor(self.runner.model_config):
+ mm_kwargs = mm_data
+ else:
+ mm_kwargs = self.multi_modal_input_mapper(
+ mm_data,
+ seq_group_metadata.mm_processor_kwargs,
+ )
+
+ inter_data.multi_modal_kwargs = mm_kwargs
+ inter_data.multi_modal_placeholder_maps = placeholder_maps
# special processing for mrope position deltas.
- if self.runner.model_is_mrope:
+ if self.runner.model_config.uses_mrope:
image_grid_thw = mm_kwargs.get("image_grid_thw", None)
video_grid_thw = mm_kwargs.get("video_grid_thw", None)
assert image_grid_thw is not None or video_grid_thw is not None, (
@@ -681,6 +718,7 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
spatial_merge_size=hf_config.vision_config.
spatial_merge_size,
context_len=inter_data.context_lens[seq_idx],
+ seq_len=inter_data.seq_lens[seq_idx],
)
seq_data.mrope_position_delta = mrope_position_delta
@@ -699,7 +737,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
encoder_seq_len = 0
- if self.runner.model_config.is_encoder_decoder_model:
+ if self.runner.model_config.is_encoder_decoder:
encoder_seq_len = seq_group_metadata.encoder_seq_data.get_len()
inter_data = self.init_cached_inter_data(
@@ -729,8 +767,8 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
block_tables=seq_group_metadata.negative_block_tables,
computed_block_nums=[], # for prefix caching.
reinit=True,
- reinit_use_defaults=True
- )
+ reinit_use_defaults=True,
+ encoder_seq_len=encoder_seq_len)
self.inter_data_list.append(negative_inter_data)
for seq_idx in range(n_seqs):
@@ -746,7 +784,6 @@ def _use_captured_graph(self,
max_decode_seq_len: int,
max_encoder_seq_len: int = 0) -> bool:
return (decode_only and not self.runner.model_config.enforce_eager
- and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_decode_seq_len <= self.runner.max_seq_len_to_capture
and max_encoder_seq_len <= self.runner.max_seq_len_to_capture
and batch_size <= self.runner.max_batchsize_to_capture)
@@ -794,7 +831,7 @@ def _get_cuda_graph_pad_size(self,
max_encoder_seq_len):
return -1
- graph_batch_size = _get_graph_batch_size(batch_size)
+ graph_batch_size = VllmConfig.get_graph_batch_size(batch_size)
assert graph_batch_size >= batch_size
return graph_batch_size - batch_size
@@ -804,9 +841,12 @@ def build(self) -> ModelInputForGPU:
"""
# Combine and flatten intermediate data.
input_tokens = []
+ token_types = []
for inter_data in self.inter_data_list:
for cur_input_tokens in inter_data.input_tokens:
input_tokens.extend(cur_input_tokens)
+ for cur_token_types in inter_data.token_types:
+ token_types.extend(cur_token_types)
if not input_tokens:
# This may happen when all prefill requests hit
@@ -845,7 +885,7 @@ def build(self) -> ModelInputForGPU:
if not inter_data.is_prompt:
max_decode_seq_len = max(max_decode_seq_len,
max(inter_data.seq_lens))
- if self.runner.model_config.is_encoder_decoder_model:
+ if self.runner.model_config.is_encoder_decoder:
max_encoder_seq_len = max(max_encoder_seq_len,
inter_data.encoder_seq_len)
@@ -875,6 +915,12 @@ def build(self) -> ModelInputForGPU:
input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
self.runner.device,
self.runner.pin_memory)
+
+ token_types_tensor = async_tensor_h2d(token_types, torch.long,
+ self.runner.device,
+ self.runner.pin_memory) \
+ if token_types else None
+
if mrope_input_positions is not None:
for idx in range(3):
mrope_input_positions[idx].extend(
@@ -944,15 +990,16 @@ def build(self) -> ModelInputForGPU:
)
# Multi-modal data.
- multi_modal_inputs_list = [
- data.multi_modal_inputs for data in self.inter_data_list
- if data.multi_modal_inputs is not None
+ multi_modal_kwargs_list = [
+ data.multi_modal_kwargs for data in self.inter_data_list
+ if data.multi_modal_kwargs is not None
]
- multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
+ multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
return self.model_input_cls(
input_tokens=input_tokens_tensor,
input_positions=input_positions_tensor,
+ token_types=token_types_tensor,
attn_metadata=attn_metadata,
seq_lens=seq_lens,
query_lens=query_lens,
@@ -974,32 +1021,20 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
def __init__(
self,
- model_config: ModelConfig,
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- device_config: DeviceConfig,
- cache_config: CacheConfig,
- load_config: LoadConfig,
- lora_config: Optional[LoRAConfig],
+ vllm_config: VllmConfig,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
- prompt_adapter_config: Optional[PromptAdapterConfig] = None,
return_hidden_states: bool = False,
- observability_config: Optional[ObservabilityConfig] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
- self.model_config = model_config
- self.parallel_config = parallel_config
- self.scheduler_config = scheduler_config
- self.device_config = device_config
- self.cache_config = cache_config
- self.lora_config = lora_config
- self.load_config = load_config
+
+ ModelRunnerBase.__init__(self, vllm_config)
+ model_config = self.model_config
+ cache_config = self.cache_config
+
self.is_driver_worker = is_driver_worker
- self.prompt_adapter_config = prompt_adapter_config
self.return_hidden_states = return_hidden_states
- self.observability_config = observability_config
self.device = self.device_config.device
self.pin_memory = is_pin_memory_available()
@@ -1008,7 +1043,7 @@ def __init__(
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
- self.max_batchsize_to_capture = _get_max_graph_batch_size(
+ self.max_batchsize_to_capture = VllmConfig.get_max_graph_batch_size(
self.scheduler_config.max_num_seqs)
self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
@@ -1024,7 +1059,7 @@ def __init__(
# Python can be expensive. To optimize this, we cache the block table
# in numpy and only copy the actual input content at every iteration.
# The shape of the cached block table will be
- # (max batch size to capture, max context len to capture / block size).
+ # (max batch size to capture, max seq len to capture / block size).
self.graph_block_tables = np.zeros(
(self.max_batchsize_to_capture, self.get_max_block_per_batch()),
dtype=np.int32)
@@ -1084,13 +1119,7 @@ def __init__(
def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m:
- self.model = get_model(model_config=self.model_config,
- device_config=self.device_config,
- load_config=self.load_config,
- lora_config=self.lora_config,
- parallel_config=self.parallel_config,
- scheduler_config=self.scheduler_config,
- cache_config=self.cache_config)
+ self.model = get_model(vllm_config=self.vllm_config)
self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB",
@@ -1160,10 +1189,9 @@ def load_model(self) -> None:
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!")
- if envs.VLLM_TORCH_COMPILE_LEVEL == CompilationLevel.DYNAMO_AS_IS \
- and supports_dynamo():
- from vllm.plugins import get_torch_compile_backend
- backend = get_torch_compile_backend() or "eager"
+ if self.vllm_config.compilation_config.level ==\
+ CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
+ backend = self.vllm_config.compilation_config.init_backend()
self.model = torch.compile(
self.model,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
@@ -1284,7 +1312,7 @@ def profile_run(self) -> None:
(group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len
- seq_data, dummy_multi_modal_data = self.input_registry \
+ dummy_data = self.input_registry \
.dummy_data_for_profiling(self.model_config,
seq_len,
self.mm_registry)
@@ -1292,12 +1320,13 @@ def profile_run(self) -> None:
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
- seq_data={group_id: seq_data},
+ seq_data={group_id: dummy_data.seq_data},
sampling_params=sampling_params,
block_tables=None,
lora_request=dummy_lora_requests_per_seq[group_id]
if dummy_lora_requests_per_seq else None,
- multi_modal_data=dummy_multi_modal_data,
+ multi_modal_data=dummy_data.multi_modal_data,
+ multi_modal_placeholders=dummy_data.multi_modal_placeholders,
)
seqs.append(seq)
@@ -1324,14 +1353,7 @@ def profile_run(self) -> None:
dtype=self.model_config.dtype,
device=self.device)
- graph_batch_size = self.max_batchsize_to_capture
- batch_size_capture_list = [
- bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
- ]
- if self.model_config.enforce_eager:
- batch_size_capture_list = []
- with set_compile_context(batch_size_capture_list):
- self.execute_model(model_input, kv_caches, intermediate_tensors)
+ self.execute_model(model_input, kv_caches, intermediate_tensors)
torch.cuda.synchronize()
return
@@ -1400,12 +1422,6 @@ def list_prompt_adapters(self) -> Set[int]:
raise RuntimeError("PromptAdapter is not enabled.")
return self.prompt_adapter_manager.list_adapters()
- @property
- def model_is_mrope(self) -> bool:
- """Detect if the model has "mrope" rope_scaling type.
- mrope requires keep "rope_deltas" between prompt and decoding phases."""
- return uses_mrope(self.model_config.hf_config)
-
@torch.inference_mode()
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
"""Cuda graph capture a model.
@@ -1421,22 +1437,22 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
per sequence in the batch.
"""
assert not self.model_config.enforce_eager
- logger.info("Capturing the model for CUDA graphs. This may lead to "
+ logger.info("Capturing cudagraphs for decoding. This may lead to "
"unexpected consequences if the model is not static. To "
"run the model in eager mode, set 'enforce_eager=True' or "
"use '--enforce-eager' in the CLI.")
- logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
- "If you are running out of memory, consider decreasing "
- "`gpu_memory_utilization` or enforcing eager mode. "
- "You can also reduce the `max_num_seqs` as needed "
- "to decrease memory usage.")
+ logger.info("If out-of-memory error occurs during cudagraph capture,"
+ " consider decreasing `gpu_memory_utilization` or "
+ "switching to eager mode. You can also reduce the "
+ "`max_num_seqs` as needed to decrease memory usage.")
start_time = time.perf_counter()
+ start_free_gpu_memory = torch.cuda.mem_get_info()[0]
# Prepare dummy inputs. These will be reused for all batch sizes.
max_batch_size = self.max_batchsize_to_capture
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
- if self.model_is_mrope:
+ if self.model_config.uses_mrope:
input_positions = torch.tile(input_positions, (3, 1))
# Prepare dummy previous_hidden_states only if needed by the model.
# This is used by draft models such as EAGLE.
@@ -1456,23 +1472,19 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
dtype=self.model_config.dtype,
device=self.device)
- graph_batch_size = self.max_batchsize_to_capture
- batch_size_capture_list = [
- bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
- ]
-
with self.attn_state.graph_capture(
max_batch_size), graph_capture() as graph_capture_context:
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for virtual_engine in range(
self.parallel_config.pipeline_parallel_size):
- for batch_size in reversed(batch_size_capture_list):
+ for batch_size in \
+ self.vllm_config.compilation_config.capture_sizes:
attn_metadata = (
self.attn_state.graph_capture_get_metadata_for_batch(
batch_size,
is_encoder_decoder_model=self.model_config.
- is_encoder_decoder_model))
+ is_encoder_decoder))
if self.lora_config:
lora_mapping = LoRAMapping(
@@ -1491,7 +1503,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
graph_runner = CUDAGraphRunner(
self.model, self.attn_backend.get_name(),
self.attn_state.graph_clone(batch_size),
- self.model_config.is_encoder_decoder_model)
+ self.model_config.is_encoder_decoder)
capture_inputs = {
"input_ids":
@@ -1522,22 +1534,25 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
self.model.get_seqlen_agnostic_capture_inputs(
batch_size)
})
- if self.model_config.is_encoder_decoder_model:
+ if self.model_config.is_encoder_decoder:
# add the additional inputs to capture for
# encoder-decoder models.
self._update_inputs_to_capture_for_enc_dec_model(
capture_inputs)
- with set_forward_context(attn_metadata):
+ with set_forward_context(attn_metadata, self.vllm_config):
graph_runner.capture(**capture_inputs)
self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[virtual_engine][batch_size] = (
graph_runner)
end_time = time.perf_counter()
+ end_free_gpu_memory = torch.cuda.mem_get_info()[0]
elapsed_time = end_time - start_time
+ cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
# This usually takes < 10 seconds.
- logger.info("Graph capturing finished in %.0f secs.", elapsed_time)
+ logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
+ elapsed_time, cuda_graph_size / GiB_bytes)
def _update_inputs_to_capture_for_enc_dec_model(self,
capture_inputs: Dict[str,
@@ -1660,6 +1675,24 @@ def execute_model(
else:
model_executable = self.model
+ # Receive KV cache in distributed KV cache transfer setting
+ # In disagg prefill setting, it will also recv hidden states and bypass
+ # model forwarding
+ # In KV cache database setting, it will change the model input so that
+ # we can skip prefilling on tokens that successfully received KV caches
+ # NOTE: The receive operation is blocking
+ bypass_model_exec = False
+ if self.need_recv_kv(model_input, kv_caches):
+ hidden_or_intermediate_states, bypass_model_exec, model_input = \
+ get_kv_transfer_group().recv_kv_caches_and_hidden_states(
+ # model is used to know which layer the current worker
+ # is working on, so that we can receive KV for only those
+ # layers.
+ model_executable,
+ model_input,
+ kv_caches=kv_caches
+ )
+
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
seqlen_agnostic_kwargs = {
"finished_requests_ids": model_input.finished_requests_ids,
@@ -1671,21 +1704,36 @@ def execute_model(
model_forward_end = torch.cuda.Event(enable_timing=True)
model_forward_start.record()
- with set_forward_context(model_input.attn_metadata):
- hidden_or_intermediate_states = model_executable(
- input_ids=model_input.input_tokens,
- positions=model_input.input_positions,
- kv_caches=kv_caches,
- attn_metadata=model_input.attn_metadata,
- intermediate_tensors=intermediate_tensors,
- **MultiModalInputs.as_kwargs(multi_modal_kwargs,
- device=self.device),
- **seqlen_agnostic_kwargs)
+ if not bypass_model_exec:
+ with set_forward_context(model_input.attn_metadata,
+ self.vllm_config):
+ hidden_or_intermediate_states = model_executable(
+ input_ids=model_input.input_tokens,
+ positions=model_input.input_positions,
+ kv_caches=kv_caches,
+ attn_metadata=model_input.attn_metadata,
+ intermediate_tensors=intermediate_tensors,
+ **MultiModalKwargs.as_kwargs(multi_modal_kwargs,
+ device=self.device),
+ **seqlen_agnostic_kwargs)
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end.record()
+ # Sending KV cache in distributed KV cache transfer setting
+ # NOTE: the send operation is non-blocking
+ if self.need_send_kv(model_input, kv_caches):
+ get_kv_transfer_group().send_kv_caches_and_hidden_states(
+ # model_executable is used to know which layer the current
+ # worker is working on, so that we can send KV for only those
+ # layers.
+ model_executable,
+ model_input,
+ kv_caches,
+ hidden_or_intermediate_states,
+ )
+
# Compute the logits in the last pipeline stage.
if not get_pp_group().is_last_rank:
if (self.is_driver_worker
@@ -1753,6 +1801,56 @@ def execute_model(
return [output]
+ def need_recv_kv(self, model_input, kv_caches) -> bool:
+ """Check if we need to receive kv-cache from the other worker.
+ We need to receive KV when
+ 1. current vLLM instance is KV cache consumer/decode vLLM instance
+ 2. this batch is not a profiling run
+ 3. this batch is a prefill run
+
+ Args:
+ model_input: input to the model executable
+ kv_caches: vLLM's paged memory
+ """
+
+ prefill_meta = model_input.attn_metadata.prefill_metadata
+
+ # check if the current run is profiling
+ is_profile_run = (kv_caches[0].numel() == 0)
+ # check if the current run is prefill
+ is_prefill_run = prefill_meta is not None
+
+ if self.vllm_config.kv_transfer_config is None:
+ return False
+
+ return self.vllm_config.kv_transfer_config.is_kv_consumer and (
+ not is_profile_run) and is_prefill_run
+
+ def need_send_kv(self, model_input, kv_caches) -> bool:
+ """Check if we need to send kv-cache to the other worker.
+ We need to send KV when
+ 1. current vLLM instance is KV cache producer/prefill vLLM instance
+ 2. this batch is not a profiling run
+ 3. this batch is a prefill run
+
+ Args:
+ model_input: input to the model executable
+ kv_caches: vLLM's paged memory
+ """
+
+ prefill_meta = model_input.attn_metadata.prefill_metadata
+
+ # check if the current run is profiling
+ is_profile_run = (kv_caches[0].numel() == 0)
+ # check if the current run is prefill
+ is_prefill_run = prefill_meta is not None
+
+ if self.vllm_config.kv_transfer_config is None:
+ return False
+
+ return self.vllm_config.kv_transfer_config.is_kv_producer and (
+ not is_profile_run) and is_prefill_run
+
# NOTE: this is nn.Module so the profiler can properly capture/group
# kernels calls made within the graph
@@ -1791,7 +1889,7 @@ def capture(
# Run the model a few times without capturing the graph.
# This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune).
- # Note one iteration is not enough for torch.jit.script
+ # Note one iteration is not enough for torch.compile
for _ in range(_NUM_WARMUP_ITERS):
self.model(
input_ids=input_ids,
@@ -1904,37 +2002,3 @@ def forward(
return self.output_buffers["hidden_states"]
return self.output_buffers
-
-
-def _get_graph_batch_size(batch_size: int) -> int:
- """Returns the padded batch size given actual batch size.
-
- Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
- 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
- """
- if batch_size <= 2:
- return batch_size
- elif batch_size <= 4:
- return 4
- else:
- return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
- _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
-
-
-def _get_max_graph_batch_size(max_num_seqs: int) -> int:
- """
- max_num_seqs: Maximum number of sequences in a batch.
- _BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.
-
- pad the max_num_seqs if necessary by calling _get_graph_batch_size,
- which will deal with some edge cases like 1, 2, 4.
-
- if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded size.
- if not, it means the padded size is larger than the largest size in
- _BATCH_SIZES_TO_CAPTURE, return the largest size in _BATCH_SIZES_TO_CAPTURE.
- """
- padded_size = _get_graph_batch_size(max_num_seqs)
- if padded_size in _BATCH_SIZES_TO_CAPTURE:
- return padded_size
- assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1]
- return _BATCH_SIZES_TO_CAPTURE[-1]
diff --git a/flagscale/inference/core/preprocess.py b/flagscale/inference/core/preprocess.py
index c91d70424..1f93a2920 100644
--- a/flagscale/inference/core/preprocess.py
+++ b/flagscale/inference/core/preprocess.py
@@ -1,37 +1,26 @@
# This file is modified from 'FlagScale/vllm/vllm/inputs/preprocess.py'
import asyncio
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+from typing import List, Mapping, Optional, Union
from typing_extensions import assert_never
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
+from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
+from vllm.multimodal.processing import MultiModalDataDict, MultiModalInputsV2
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
-from vllm.utils import print_warning_once
+from vllm.utils import print_info_once, print_warning_once
# --- FLAGSCALE MODIFICATION BEG ---
-from vllm.inputs.data import (DecoderOnlyInputs, EncoderDecoderInputs, PromptType,
- SingletonPrompt)
+from vllm.inputs.data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs,
+ PromptType, SingletonInputs, SingletonPrompt, token_inputs)
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
# --- FLAGSCALE MODIFICATION END ---
-if TYPE_CHECKING:
- from vllm.multimodal import MultiModalDataDict
-
logger = init_logger(__name__)
-# --- FLAGSCALE MODIFICATION BEG ---
-PromptComponents = Tuple[Optional[str], List[int],
- Optional["MultiModalDataDict"],
- Optional[Dict[str, Any]],
- Optional[None], Optional[None]]
-DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
- Optional["MultiModalDataDict"],
- Optional[Dict[str, Any]]]
-# --- FLAGSCALE MODIFICATION END ---
-
class InputPreprocessor:
@@ -39,11 +28,13 @@ def __init__(
self,
model_config: ModelConfig,
tokenizer: Optional[BaseTokenizerGroup],
+ mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
) -> None:
super().__init__()
self.model_config = model_config
self.tokenizer = tokenizer
+ self.mm_registry = mm_registry
def get_tokenizer_group(self) -> BaseTokenizerGroup:
if self.tokenizer is None:
@@ -79,7 +70,7 @@ def get_decoder_start_token_id(self) -> Optional[int]:
model config is unavailable.
'''
- if not self.is_encoder_decoder_model():
+ if not self.model_config.is_encoder_decoder:
print_warning_once("Using None for decoder start token id because "
"this is not an encoder/decoder model.")
return None
@@ -121,7 +112,7 @@ def _get_default_enc_dec_decoder_prompt(self) -> List[int]:
"default" decoder prompt be .
However, it is possible that in the future
- other models may have different or more
+ other models may have different or more
complex logic for the default decoder prompt.
This motivates having a special helper method
for default decoder prompts.
@@ -138,7 +129,6 @@ def _get_default_enc_dec_decoder_prompt(self) -> List[int]:
def _prepare_decoder_input_ids_for_generation(
self,
decoder_input_ids: Optional[List[int]],
- force_bos: bool = True,
) -> List[int]:
"""
Prepares `decoder_input_ids` for generation with encoder-decoder models.
@@ -168,8 +158,8 @@ def _prepare_decoder_input_ids_for_generation(
# use decoder_start_token_id as decoder_input_ids
decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
- if force_bos and (len(decoder_input_ids) == 0
- or decoder_input_ids[0] != decoder_start_token_id):
+ if (len(decoder_input_ids) == 0
+ or decoder_input_ids[0] != decoder_start_token_id):
decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
return decoder_input_ids
@@ -215,14 +205,79 @@ async def _tokenize_prompt_async(
prompt=prompt,
lora_request=lora_request)
- def _extract_prompt_components(
+ def _can_process_multimodal(self) -> bool:
+ model_config = self.model_config
+
+ if not model_config.is_multimodal_model:
+ raise ValueError("Your model does not support multi-modal inputs")
+
+ # Interim measure so we can handle models that have yet to be
+ # updated to use the new multi-modal processor
+ can_process_multimodal = self.mm_registry.has_processor(model_config)
+ if not can_process_multimodal:
+ print_info_once(
+ "Your model uses the legacy input pipeline instead of the new "
+ "multi-modal processor. Please note that the legacy pipeline "
+ "will be removed in a future release. For more details, see: "
+ "https://github.com/vllm-project/vllm/issues/10114")
+
+ return can_process_multimodal
+
+ def _process_multimodal(
+ self,
+ prompt: Union[str, List[int]],
+ mm_data: MultiModalDataDict,
+ mm_processor_kwargs: Optional[Mapping[str, object]],
+ lora_request: Optional[LoRARequest],
+ ) -> MultiModalInputsV2:
+ """
+ Apply the model's multi-modal processor to a multi-modal prompt,
+ returning the corresponding token IDs and metadata.
+ """
+ tokenizer_group = self.get_tokenizer_group()
+ tokenizer = tokenizer_group.get_lora_tokenizer(lora_request)
+
+ mm_processor = self.mm_registry.create_processor(
+ self.model_config, tokenizer)
+
+ if isinstance(prompt, list):
+ prompt = tokenizer.decode(prompt)
+ if mm_processor_kwargs is None:
+ mm_processor_kwargs = {}
+
+ return mm_processor.apply(prompt, mm_data, mm_processor_kwargs)
+
+ async def _process_multimodal_async(
+ self,
+ prompt: Union[str, List[int]],
+ mm_data: MultiModalDataDict,
+ mm_processor_kwargs: Optional[Mapping[str, object]],
+ lora_request: Optional[LoRARequest],
+ ) -> MultiModalInputsV2:
+ """Async version of :meth:`_process_multimodal`."""
+ tokenizer_group = self.get_tokenizer_group()
+ tokenizer = await tokenizer_group.get_lora_tokenizer_async(lora_request
+ )
+
+ mm_processor = self.mm_registry.create_processor(
+ self.model_config, tokenizer)
+ if isinstance(prompt, list):
+ logger.warning("Passing `multi_modal_data` in TokensPrompt is"
+ "deprecated and will be removed in a future update")
+ prompt = tokenizer.decode(prompt)
+ if mm_processor_kwargs is None:
+ mm_processor_kwargs = {}
+
+ return mm_processor.apply(prompt, mm_data, mm_processor_kwargs)
+
+ def _prompt_to_llm_inputs(
self,
prompt: SingletonPrompt,
request_id: str,
lora_request: Optional[LoRARequest] = None,
- ) -> PromptComponents:
- '''
- Extract the components of any single encoder or decoder input prompt.
+ ) -> SingletonInputs:
+ """
+ Extract the singleton inputs from a prompt.
Arguments:
@@ -232,12 +287,8 @@ def _extract_prompt_components(
Returns:
- * prompt
- * prompt_token_ids
- * multi_modal_data
- * mm_processor_kwargs (request-level input processor/mapper overrides)
- '''
-
+ * :class:`SingletonInputs` instance
+ """
parsed = parse_singleton_prompt(prompt)
if parsed["type"] == "str":
@@ -247,29 +298,60 @@ def _extract_prompt_components(
request_id=request_id,
lora_request=lora_request,
)
- multi_modal_data = None
- mm_processor_kwargs = None
- negative_prompt_text = negative_prompt_token_ids = None # --- FLAGSCALE MODIFICATION ---
- elif parsed["type"] == "tokens":
- prompt_text = None
- prompt_token_ids = parsed["content"]["prompt_token_ids"]
- multi_modal_data = parsed["content"].get("multi_modal_data")
- mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")
- # --- FLAGSCALE MODIFICATION BEG ---
- negative_prompt_text = None
- negative_prompt_token_ids = parsed["content"].get("negative_prompt_token_ids")
- # --- FLAGSCALE MODIFICATION END ---
- elif parsed["type"] == "text":
- prompt_text = parsed["content"]["prompt"]
+
+ return token_inputs(
+ prompt=prompt_text,
+ prompt_token_ids=prompt_token_ids,
+ )
+
+ if parsed["type"] == "tokens":
+ tokens_content = parsed["content"]
+
+ prompt_token_ids = tokens_content["prompt_token_ids"]
+ token_type_ids = tokens_content.get("token_type_ids")
+ multi_modal_data = tokens_content.get("multi_modal_data")
+ mm_processor_kwargs = tokens_content.get("mm_processor_kwargs")
+ negative_prompt_token_ids = tokens_content.get("negative_prompt_token_ids") # --- FLAGSCALE MODIFICATION ---
+
+ if multi_modal_data is not None and self._can_process_multimodal():
+ return self._process_multimodal(
+ prompt_token_ids,
+ multi_modal_data,
+ mm_processor_kwargs,
+ lora_request=lora_request,
+ )
+
+ return token_inputs(
+ prompt_token_ids=prompt_token_ids,
+ token_type_ids=token_type_ids,
+ multi_modal_data=multi_modal_data,
+ mm_processor_kwargs=mm_processor_kwargs,
+ negative_prompt_token_ids=negative_prompt_token_ids, # --- FLAGSCALE MODIFICATION ---
+ )
+
+ if parsed["type"] == "text":
+ text_content = parsed["content"]
+
+ prompt_text = text_content["prompt"]
+ multi_modal_data = text_content.get("multi_modal_data")
+ mm_processor_kwargs = text_content.get("mm_processor_kwargs")
+ negative_prompt_text = tokens_content.get("negative_prompt") # --- FLAGSCALE MODIFICATION ---
+
+ if multi_modal_data is not None and self._can_process_multimodal():
+ return self._process_multimodal(
+ prompt_text,
+ multi_modal_data,
+ mm_processor_kwargs,
+ lora_request=lora_request,
+ )
+
prompt_token_ids = self._tokenize_prompt(
prompt_text,
request_id=request_id,
lora_request=lora_request,
)
- multi_modal_data = parsed["content"].get("multi_modal_data")
- mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")
+
# --- FLAGSCALE MODIFICATION BEG ---
- negative_prompt_text = parsed["content"].get("negative_prompt")
negative_prompt_token_ids = None
if negative_prompt_text:
negative_prompt_token_ids = self._tokenize_prompt(
@@ -278,19 +360,26 @@ def _extract_prompt_components(
lora_request=lora_request,
)
# --- FLAGSCALE MODIFICATION END ---
- else:
- assert_never(parsed)
- return (prompt_text, prompt_token_ids, multi_modal_data,
- mm_processor_kwargs,
- negative_prompt_text, negative_prompt_token_ids) # --- FLAGSCALE MODIFICATION ---
+ return token_inputs(
+ prompt=prompt_text,
+ prompt_token_ids=prompt_token_ids,
+ multi_modal_data=multi_modal_data,
+ mm_processor_kwargs=mm_processor_kwargs,
+ # --- FLAGSCALE MODIFICATION BEG ---
+ negative_prompt_text=negative_prompt_text,
+ negative_prompt_token_ids=negative_prompt_token_ids,
+ # --- FLAGSCALE MODIFICATION END ---
+ )
+
+ assert_never(parsed)
- async def _extract_prompt_components_async(
+ async def _prompt_to_llm_inputs_async(
self,
prompt: SingletonPrompt,
request_id: str,
lora_request: Optional[LoRARequest] = None,
- ) -> PromptComponents:
+ ) -> SingletonInputs:
"""Async version of :meth:`_extract_prompt_components`."""
parsed = parse_singleton_prompt(prompt)
@@ -301,29 +390,58 @@ async def _extract_prompt_components_async(
request_id=request_id,
lora_request=lora_request,
)
- multi_modal_data = None
- mm_processor_kwargs = None
- negative_prompt_text = negative_prompt_token_ids = None # --- FLAGSCALE MODIFICATION ---
- elif parsed["type"] == "tokens":
- prompt_text = None
- prompt_token_ids = parsed["content"]["prompt_token_ids"]
- multi_modal_data = parsed["content"].get("multi_modal_data")
- mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")
- # --- FLAGSCALE MODIFICATION BEG ---
- negative_prompt_text = None
- negative_prompt_token_ids = parsed["content"].get("negative_prompt_token_ids")
- # --- FLAGSCALE MODIFICATION END ---
- elif parsed["type"] == "text":
- prompt_text = parsed["content"]["prompt"]
+
+ return token_inputs(
+ prompt=prompt_text,
+ prompt_token_ids=prompt_token_ids,
+ )
+
+ if parsed["type"] == "tokens":
+ tokens_content = parsed["content"]
+
+ prompt_token_ids = tokens_content["prompt_token_ids"]
+ multi_modal_data = tokens_content.get("multi_modal_data")
+ mm_processor_kwargs = tokens_content.get("mm_processor_kwargs")
+ negative_prompt_token_ids = tokens_content.get("negative_prompt_token_ids") # --- FLAGSCALE MODIFICATION ---
+
+ if multi_modal_data is not None and self._can_process_multimodal():
+ return await self._process_multimodal_async(
+ prompt_token_ids,
+ multi_modal_data,
+ mm_processor_kwargs,
+ lora_request=lora_request,
+ )
+
+ return token_inputs(
+ prompt_token_ids=prompt_token_ids,
+ multi_modal_data=multi_modal_data,
+ mm_processor_kwargs=mm_processor_kwargs,
+ negative_prompt_token_ids=negative_prompt_token_ids, # --- FLAGSCALE MODIFICATION ---
+ )
+
+ if parsed["type"] == "text":
+ text_content = parsed["content"]
+
+ prompt_text = text_content["prompt"]
+ multi_modal_data = text_content.get("multi_modal_data")
+ mm_processor_kwargs = text_content.get("mm_processor_kwargs")
+ negative_prompt_text = tokens_content.get("negative_prompt") # --- FLAGSCALE MODIFICATION ---
+
+ if multi_modal_data is not None and self._can_process_multimodal():
+ return await self._process_multimodal_async(
+ prompt_text,
+ multi_modal_data,
+ mm_processor_kwargs,
+ lora_request=lora_request,
+ )
+
prompt_token_ids = await self._tokenize_prompt_async(
prompt_text,
request_id=request_id,
lora_request=lora_request,
)
- multi_modal_data = parsed["content"].get("multi_modal_data")
- mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")
+
# --- FLAGSCALE MODIFICATION BEG ---
- negative_prompt_text = parsed["content"].get("negative_prompt")
negative_prompt_token_ids = None
if negative_prompt_text:
negative_prompt_token_ids = await self._tokenize_prompt_async(
@@ -332,44 +450,50 @@ async def _extract_prompt_components_async(
lora_request=lora_request,
)
# --- FLAGSCALE MODIFICATION END ---
- else:
- assert_never(parsed)
- return (prompt_text, prompt_token_ids, multi_modal_data,
- mm_processor_kwargs,
- negative_prompt_text, negative_prompt_token_ids) # --- FLAGSCALE MODIFICATION ---
+ return token_inputs(
+ prompt=prompt_text,
+ prompt_token_ids=prompt_token_ids,
+ multi_modal_data=multi_modal_data,
+ mm_processor_kwargs=mm_processor_kwargs,
+ # --- FLAGSCALE MODIFICATION BEG ---
+ negative_prompt_text=negative_prompt_text,
+ negative_prompt_token_ids=negative_prompt_token_ids
+ # --- FLAGSCALE MODIFICATION END ---
+ )
+
+ assert_never(parsed)
def _build_enc_dec_llm_inputs(
self,
- encoder_comps: PromptComponents,
- decoder_comps: DecoderPromptComponents,
- mm_processor_kwargs: Dict[str, Any],
+ encoder_inputs: SingletonInputs,
+ decoder_inputs: Optional[SingletonInputs],
) -> EncoderDecoderInputs:
- encoder_prompt, encoder_prompt_ids, encoder_mm_data, _ = encoder_comps
- decoder_prompt, decoder_prompt_ids, decoder_mm_data, _ = decoder_comps
-
- # Reminder: Please update docs/source/serving/compatibility_matrix.rst
- # If the feature combo become valid
- if decoder_mm_data is not None:
- raise ValueError(
- "Multi-modality decoder inputs of encoder-decoder models are "
- "not supported yet")
-
- # For Multi-Modal models (e.g., mllama), the text input can be
- # <|image|><|begin_of_text|>hello world. And we should not add
- # another <|begin_of_text|> to the beginning.
- decoder_prompt_ids = (self._prepare_decoder_input_ids_for_generation(
- decoder_prompt_ids,
- force_bos=(encoder_mm_data is None and decoder_mm_data is None)))
+ if (encoder_inputs["type"] == "token"
+ or encoder_inputs["type"] == "multimodal"):
+ pass
+ else:
+ assert_never(encoder_inputs)
+
+ if decoder_inputs is None:
+ dec_token_ids = self._prepare_decoder_input_ids_for_generation(
+ None)
+ decoder_inputs = token_inputs(dec_token_ids)
+ elif (decoder_inputs["type"] == "token"
+ or decoder_inputs["type"] == "multimodal"):
+ dec_token_ids = self._prepare_decoder_input_ids_for_generation(
+ decoder_inputs["prompt_token_ids"])
+ decoder_inputs["prompt_token_ids"] = dec_token_ids
+
+ if "multi_modal_data" in decoder_inputs:
+ raise ValueError("Multi-modal decoder inputs of encoder-"
+ "decoder models are not supported yet")
+ else:
+ assert_never(encoder_inputs)
return EncoderDecoderInputs(
- prompt_token_ids=decoder_prompt_ids,
- prompt=decoder_prompt,
- multi_modal_data=decoder_mm_data,
- mm_processor_kwargs=mm_processor_kwargs,
- encoder_prompt_token_ids=encoder_prompt_ids,
- encoder_prompt=encoder_prompt,
- encoder_multi_modal_data=encoder_mm_data,
+ encoder=encoder_inputs,
+ decoder=decoder_inputs,
)
def _process_encoder_decoder_prompt(
@@ -377,10 +501,9 @@ def _process_encoder_decoder_prompt(
prompt: PromptType,
request_id: str,
) -> EncoderDecoderInputs:
- '''
+ """
For encoder/decoder models only:
- Process an input prompt into an
- :class:`EncoderDecoderInputs` instance.
+ Process an input prompt into an :class:`EncoderDecoderInputs` instance.
There are two types of input prompts:
singleton prompts which carry only the
@@ -399,7 +522,7 @@ def _process_encoder_decoder_prompt(
have any possible singleton type; thus this
method relies on helper functions to obtain
token ids for the sub-prompts.
-
+
Arguments:
* prompt: an input prompt
@@ -408,42 +531,32 @@ def _process_encoder_decoder_prompt(
Returns:
* :class:`EncoderDecoderInputs` instance
- '''
-
- encoder_comps: PromptComponents
- decoder_comps: DecoderPromptComponents
+ """
+ encoder_inputs: SingletonInputs
+ decoder_inputs: Optional[SingletonInputs]
if is_explicit_encoder_decoder_prompt(prompt):
- encoder_comps = self._extract_prompt_components(
+ encoder_inputs = self._prompt_to_llm_inputs(
prompt["encoder_prompt"],
request_id=request_id,
)
if (decoder_input := prompt["decoder_prompt"]) is None:
- decoder_comps = None, None, None, None
+ decoder_inputs = None
else:
- decoder_comps = self._extract_prompt_components(
+ decoder_inputs = self._prompt_to_llm_inputs(
decoder_input,
request_id=request_id,
)
- # Handle this carefully in case it was directly initialized by user
- mm_processor_kwargs = prompt.get("mm_processor_kwargs", {})
else:
- encoder_comps = self._extract_prompt_components(
+ encoder_inputs = self._prompt_to_llm_inputs(
prompt,
request_id=request_id,
)
- # If there are no decoder components, we assume the
- # mm_processor_kwargs are in the encoder prompt
- mm_processor_kwargs = encoder_comps[-1] if encoder_comps[
- -1] is not None else {}
- decoder_comps = None, None, None, None
-
- return self._build_enc_dec_llm_inputs(
- encoder_comps,
- decoder_comps,
- mm_processor_kwargs,
- )
+
+ decoder_inputs = None
+
+ return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
async def _process_encoder_decoder_prompt_async(
self,
@@ -451,63 +564,51 @@ async def _process_encoder_decoder_prompt_async(
request_id: str,
) -> EncoderDecoderInputs:
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
- encoder_comps: PromptComponents
- decoder_comps: DecoderPromptComponents
+ encoder_inputs: SingletonInputs
+ decoder_inputs: Optional[SingletonInputs]
if is_explicit_encoder_decoder_prompt(prompt):
- encoder_task = self._extract_prompt_components_async(
+ encoder_task = self._prompt_to_llm_inputs_async(
prompt["encoder_prompt"],
request_id=request_id,
)
if (decoder_input := prompt["decoder_prompt"]) is None:
- encoder_comps = await encoder_task
- decoder_comps = None, None, None, None
+ encoder_inputs = await encoder_task
+ decoder_inputs = None
else:
- decoder_task = self._extract_prompt_components_async(
+ decoder_task = self._prompt_to_llm_inputs_async(
decoder_input,
request_id=request_id,
)
- encoder_comps, decoder_comps = await asyncio.gather(
+ encoder_inputs, decoder_inputs = await asyncio.gather(
encoder_task, decoder_task)
- mm_processor_kwargs = prompt["mm_processor_kwargs"]
else:
- encoder_comps = await self._extract_prompt_components_async(
+ encoder_inputs = await self._prompt_to_llm_inputs_async(
prompt,
request_id=request_id,
)
- # If there are no decoder components, we assume the
- # mm_processor_kwargs are in the encoder prompt
- mm_processor_kwargs = encoder_comps[-1] if encoder_comps[
- -1] is not None else {}
- decoder_comps = None, None, None, None
-
- return self._build_enc_dec_llm_inputs(
- encoder_comps,
- decoder_comps,
- mm_processor_kwargs,
- )
+
+ decoder_inputs = None
+
+ return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
def _build_decoder_only_llm_inputs(
self,
- prompt_comps: PromptComponents,
+ prompt_inputs: DecoderOnlyInputs,
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> DecoderOnlyInputs:
- # --- FLAGSCALE MODIFICATION BEG ---
- (prompt, prompt_token_ids, multi_modal_data,
- mm_processor_kwargs, negative_prompt, negative_prompt_token_ids) = prompt_comps
- # --- FLAGSCALE MODIFICATION END ---
-
- prompt_token_ids = self._apply_prompt_adapter(
- prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
+ if (prompt_inputs["type"] == "token"
+ or prompt_inputs["type"] == "multimodal"):
+ prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter(
+ prompt_inputs["prompt_token_ids"],
+ prompt_adapter_request=prompt_adapter_request,
+ )
+ else:
+ assert_never(prompt_inputs)
- return DecoderOnlyInputs(prompt_token_ids=prompt_token_ids,
- prompt=prompt,
- multi_modal_data=multi_modal_data,
- mm_processor_kwargs=mm_processor_kwargs,
- negative_prompt_token_ids=negative_prompt_token_ids, # --- FLAGSCALE MODIFICATION ---
- negative_prompt=negative_prompt) # --- FLAGSCALE MODIFICATION ---
+ return prompt_inputs
def _process_decoder_only_prompt(
self,
@@ -516,7 +617,7 @@ def _process_decoder_only_prompt(
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> DecoderOnlyInputs:
- '''
+ """
For decoder-only models:
Process an input prompt into an :class:`DecoderOnlyInputs` instance.
@@ -530,9 +631,9 @@ def _process_decoder_only_prompt(
Returns:
* :class:`DecoderOnlyInputs` instance
- '''
+ """
- prompt_comps = self._extract_prompt_components(
+ prompt_comps = self._prompt_to_llm_inputs(
prompt,
request_id=request_id,
lora_request=lora_request,
@@ -551,7 +652,7 @@ async def _process_decoder_only_prompt_async(
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> DecoderOnlyInputs:
"""Async version of :meth:`_process_decoder_only_prompt`."""
- prompt_comps = await self._extract_prompt_components_async(
+ prompt_comps = await self._prompt_to_llm_inputs_async(
prompt,
request_id=request_id,
lora_request=lora_request,
@@ -568,9 +669,9 @@ def preprocess(
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
- ) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]:
+ ) -> ProcessorInputs:
"""Preprocess the input prompt."""
- if self.is_encoder_decoder_model():
+ if self.model_config.is_encoder_decoder:
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
return self._process_encoder_decoder_prompt(
@@ -596,9 +697,9 @@ async def preprocess_async(
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
- ) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]:
+ ) -> ProcessorInputs:
"""Async version of :meth:`preprocess`."""
- if self.is_encoder_decoder_model():
+ if self.model_config.is_encoder_decoder:
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
return await self._process_encoder_decoder_prompt_async(
@@ -617,6 +718,3 @@ async def preprocess_async(
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
-
- def is_encoder_decoder_model(self):
- return self.model_config.is_encoder_decoder_model
diff --git a/flagscale/inference/core/sampling_metadata.py b/flagscale/inference/core/sampling_metadata.py
index 9762ddb30..f86c9cae2 100644
--- a/flagscale/inference/core/sampling_metadata.py
+++ b/flagscale/inference/core/sampling_metadata.py
@@ -301,6 +301,7 @@ def _prepare_seq_groups(
negative_prompt_logprob_len = (negative_query_len - num_prefill_sample
if do_sample else negative_query_len)
negative_sample_len = num_prefill_sample if do_sample else 0
+ # --- FLAGSCALE MODIFICATION END ---
else:
query_len, seq_len = query_lens[i], seq_lens[i]
# If we need sampling, exclude num_prefill_sample tokens from
@@ -308,22 +309,26 @@ def _prepare_seq_groups(
prompt_logprob_len = (query_len - num_prefill_sample
if do_sample else query_len)
sample_len = num_prefill_sample if do_sample else 0
- # --- FLAGSCALE MODIFICATION END ---
else:
# Decode
# --- FLAGSCALE MODIFICATION BEG ---
if has_negative:
+ # positive
prompt_logprob_len = 0
- query_len = query_lens[::2][i] if query_lens is not None else 1
+ query_len = query_lens[::2][i] if query_lens is not None and len(
+ query_lens[::2]) > 0 else 1
sample_len = len(seq_ids) * query_len if do_sample else 0
+ # negative
negative_prompt_logprob_len = 0
- negative_query_len = query_lens[1::2][i] if query_lens is not None else 1
+ negative_query_len = query_lens[1::2][i] if query_lens is not None and len(
+ query_lens[1::2]) > 0 else 1
negative_sample_len = len(seq_ids) * negative_query_len if do_sample else 0
+ # --- FLAGSCALE MODIFICATION END ---
else:
prompt_logprob_len = 0
- query_len = query_lens[i] if query_lens is not None else 1
+ query_len = query_lens[i] if query_lens is not None and len(
+ query_lens) > 0 else 1
sample_len = len(seq_ids) * query_len if do_sample else 0
- # --- FLAGSCALE MODIFICATION END ---
if sampling_params.seed is not None and generators is not None:
generator = generators.get(seq_group_metadata.request_id)
@@ -515,6 +520,7 @@ def from_sampling_metadata(
if do_penalties:
for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
+ sampling_params = seq_group.sampling_params
if (seq_group.is_prompt
and sampling_params.prompt_logprobs is not None):
prefill_len = len(seq_group.prompt_logprob_indices)
diff --git a/flagscale/inference/core/sampling_params.py b/flagscale/inference/core/sampling_params.py
index 4cd943a23..fbe33c808 100644
--- a/flagscale/inference/core/sampling_params.py
+++ b/flagscale/inference/core/sampling_params.py
@@ -199,9 +199,7 @@ class SamplingParams(
include_stop_str_in_output: bool = False
truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
- # --- FLAGSCALE MODIFICATION BEG ---
- guidance_scale: Optional[float] = None
- # --- FLAGSCALE MODIFICATION END ---
+ guidance_scale: Optional[float] = None # --- FLAGSCALE MODIFICATION ---
# The below fields are not supposed to be used as an input.
# They are set in post_init.
@@ -297,8 +295,9 @@ def __post_init__(self) -> None:
raise ValueError(
f"best_of must be greater than or equal to n, "
f"got n={self.n} and best_of={self.best_of}.")
- self._real_n = self.n
- self.n = self.best_of
+ if not self._real_n:
+ self._real_n = self.n
+ self.n = self.best_of
if 0 < self.temperature < _MAX_TEMP:
logger.warning(
@@ -490,10 +489,8 @@ def __repr__(self) -> str:
"spaces_between_special_tokens="
f"{self.spaces_between_special_tokens}, "
f"truncate_prompt_tokens={self.truncate_prompt_tokens}, "
- f"guided_decoding={self.guided_decoding}, "
- # --- FLAGSCALE MODIFICATION BEG ---
- f"guidance_scale={self.guidance_scale})")
- # --- FLAGSCALE MODIFICATION END ---
+ f"guided_decoding={self.guided_decoding}, "
+ f"guidance_scale={self.guidance_scale})") # --- FLAGSCALE MODIFICATION ---
class BeamSearchParams(
@@ -507,3 +504,4 @@ class BeamSearchParams(
ignore_eos: bool = False
temperature: float = 0.0
length_penalty: float = 1.0
+ include_stop_str_in_output: bool = False
diff --git a/flagscale/inference/core/scheduler.py b/flagscale/inference/core/scheduler.py
index df0a010a8..f2e5c1c10 100644
--- a/flagscale/inference/core/scheduler.py
+++ b/flagscale/inference/core/scheduler.py
@@ -3,6 +3,7 @@
import os
import random
import time
+import contextlib # --- FLAGSCALE MODIFICATION ---
from collections import deque
from dataclasses import dataclass, field
from typing import Callable, Deque, Dict, Iterable, List, Optional
@@ -57,11 +58,16 @@ class SchedulingBudget:
max_num_seqs: int
_request_ids_num_batched_tokens: Set[str] = field(default_factory=set)
_request_ids_num_curr_seqs: Set[str] = field(default_factory=set)
+ # Number of cached tokens in the batch.
+ _num_cached_tokens: int = 0
+ # Number of actual non-cached tokens in the batch.
_num_batched_tokens: int = 0
_num_curr_seqs: int = 0
def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int):
- assert num_new_tokens != 0
+ # We allow num_new_tokens to be 0 when the entire sequence has
+ # been cached.
+ assert num_new_tokens >= 0
assert num_new_seqs != 0
return (self.num_batched_tokens + num_new_tokens <= self.token_budget
and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs)
@@ -69,12 +75,18 @@ def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int):
def remaining_token_budget(self):
return self.token_budget - self.num_batched_tokens
- def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int):
+ def add_num_batched_tokens(self,
+ req_id: str,
+ num_batched_tokens: int,
+ num_cached_tokens: int = 0):
if req_id in self._request_ids_num_batched_tokens:
return
+ assert num_cached_tokens >= 0
+ assert num_batched_tokens >= 0
self._request_ids_num_batched_tokens.add(req_id)
self._num_batched_tokens += num_batched_tokens
+ self._num_cached_tokens += num_cached_tokens
def subtract_num_batched_tokens(self, req_id: str,
num_batched_tokens: int):
@@ -102,6 +114,10 @@ def num_batched_tokens(self):
def num_curr_seqs(self):
return self._num_curr_seqs
+ @property
+ def num_cached_tokens(self):
+ return self._num_cached_tokens
+
@dataclass
class ScheduledSequenceGroup:
@@ -111,9 +127,7 @@ class ScheduledSequenceGroup:
# 1 for decoding. Same as prompt tokens for prefill, but if prefill is
# chunked, it can be smaller than that.
token_chunk_size: int
- # --- FLAGSCALE MODIFICATION BEG ---
- negative_token_chunk_size: int = 0
- # --- FLAGSCALE MODIFICATION END ---
+ negative_token_chunk_size: int = 0 # --- FLAGSCALE MODIFICATION ---
@dataclass
@@ -295,11 +309,55 @@ def scheduler_running_outputs_builder():
def scheduled_seq_group_builder():
return ScheduledSequenceGroup(SequenceGroup.__new__(SequenceGroup),
- token_chunk_size=0,
- negative_token_chunk_size=0) # --- FLAGSCALE MODIFICATION ---
+ token_chunk_size=0, negative_token_chunk_size=0) # --- FLAGSCALE MODIFICATION ---
# return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0)
+# --- FLAGSCALE MODIFICATION BEG ---
+@contextlib.contextmanager
+def _switch_seq_group(seq_group: SequenceGroup, new_seqs: Sequence):
+
+ if new_seqs is None:
+ yield
+ else:
+ origin_seq = seq_group.seqs
+ origin_negative_seq = seq_group.negative_seqs
+
+ for idx, new_seq in enumerate(new_seqs):
+ new_seq.status = origin_seq[idx].status
+
+ seq_group.seqs = new_seqs
+ seq_group.negative_seqs = None
+ try:
+ yield
+ finally:
+ seq_group.seqs = origin_seq
+ seq_group.negative_seqs = origin_negative_seq
+
+
+def _update_num_new_tokens(
+ budget: SchedulingBudget,
+ num_new_tokens: int = 0,
+ num_new_tokens_negative: int = 0,
+):
+ remaining_token_budget = budget.remaining_token_budget()
+ if num_new_tokens + num_new_tokens_negative >= remaining_token_budget:
+ min_num_new_tokens = min(num_new_tokens, remaining_token_budget // 2)
+ min_num_new_tokens_negative = min(num_new_tokens_negative, remaining_token_budget // 2)
+
+ if min_num_new_tokens == num_new_tokens and \
+ min_num_new_tokens_negative != num_new_tokens_negative:
+ return 0, 0
+ elif min_num_new_tokens != num_new_tokens and \
+ min_num_new_tokens_negative == num_new_tokens_negative:
+ return 0, 0
+ else:
+ return min_num_new_tokens, min_num_new_tokens_negative
+
+ return num_new_tokens, num_new_tokens_negative
+# --- FLAGSCALE MODIFICATION END ---
+
+
class Scheduler:
def __init__(
@@ -546,10 +604,34 @@ def _schedule_running(
assert len(self._async_stopped) == 0
while running_queue:
seq_group = running_queue[0]
- num_running_tokens, negative_num_running_tokens = self._get_num_new_tokens( # --- FLAGSCALE MODIFICATION ---
- seq_group, SequenceStatus.RUNNING, enable_chunking, budget)
+ # We discard the cached tokens info here because we don't need it
+ # for running sequence:
+ # 1. If a sequence is running with chunked prefill, the cached
+ # tokens info was already used for the first prefill.
+ # 2. If a sequence is running with non-chunked prefill, then
+ # there it's a decoding sequence, and the cached tokens info is
+ # irrelevant.
+ num_uncached_new_tokens, _ = (
+ self._get_num_new_uncached_and_cached_tokens(
+ seq_group, SequenceStatus.RUNNING, enable_chunking,
+ budget))
+
+ # --- FLAGSCALE MODIFICATION BEG ---
+ with _switch_seq_group(seq_group, seq_group.negative_seqs):
+ num_uncached_new_tokens_negative, _ = (
+ self._get_num_new_uncached_and_cached_tokens(
+ seq_group, SequenceStatus.RUNNING, enable_chunking,
+ budget))
+ # --- FLAGSCALE MODIFICATION END ---
- if num_running_tokens + negative_num_running_tokens == 0: # --- FLAGSCALE MODIFICATION ---
+ num_running_tokens = num_uncached_new_tokens
+ # --- FLAGSCALE MODIFICATION BEG ---
+ num_running_tokens_negative = num_uncached_new_tokens_negative
+ num_running_tokens, num_running_tokens_negative = _update_num_new_tokens(
+ budget, num_running_tokens, num_running_tokens_negative)
+ # --- FLAGSCALE MODIFICATION END ---
+
+ if num_running_tokens + num_running_tokens_negative == 0: # --- FLAGSCALE MODIFICATION ---
# No budget => Stop
break
@@ -568,7 +650,7 @@ def _schedule_running(
# slot to keep all the sequence groups in the RUNNING state.
while not self._can_append_slots(seq_group, enable_chunking):
budget.subtract_num_batched_tokens(seq_group.request_id,
- num_running_tokens + negative_num_running_tokens) # --- FLAGSCALE MODIFICATION ---
+ num_running_tokens + num_running_tokens_negative) # --- FLAGSCALE MODIFICATION ---
num_running_seqs = seq_group.get_max_num_running_seqs()
budget.subtract_num_seqs(seq_group.request_id,
num_running_seqs)
@@ -624,7 +706,7 @@ def _schedule_running(
scheduled_seq_group.seq_group = seq_group
if is_prefill:
scheduled_seq_group.token_chunk_size = num_running_tokens
- scheduled_seq_group.negative_token_chunk_size = negative_num_running_tokens # --- FLAGSCALE MODIFICATION ---
+ scheduled_seq_group.negative_token_chunk_size = num_running_tokens_negative # --- FLAGSCALE MODIFICATION ---
prefill_seq_groups.append(scheduled_seq_group)
ret.prefill_seq_groups_list.append(seq_group)
else:
@@ -634,7 +716,7 @@ def _schedule_running(
ret.decode_seq_groups_list.append(seq_group)
budget.add_num_batched_tokens(seq_group.request_id,
- num_running_tokens + negative_num_running_tokens) # --- FLAGSCALE MODIFICATION ---
+ num_running_tokens + num_running_tokens_negative) # --- FLAGSCALE MODIFICATION ---
# OPTIMIZATION: Note that get_max_num_running_seqs is
# expensive. For the default scheduling chase where
# enable_chunking is False, num_seqs are updated before running
@@ -722,14 +804,24 @@ def _schedule_swapped(
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_new_seqs = seq_group.get_max_num_running_seqs()
- num_new_tokens, negative_num_new_tokens = self._get_num_new_tokens(seq_group, # --- FLAGSCALE MODIFICATION ---
- SequenceStatus.SWAPPED,
- enable_chunking, budget)
+ num_new_tokens_uncached, num_new_tokens_cached = (
+ self._get_num_new_uncached_and_cached_tokens(
+ seq_group, SequenceStatus.SWAPPED, enable_chunking,
+ budget))
- if (num_new_tokens + negative_num_new_tokens == 0
- or not budget.can_schedule(num_new_tokens=num_new_tokens + negative_num_new_tokens, # --- FLAGSCALE MODIFICATION ---
- num_new_seqs=num_new_seqs)):
+ # --- FLAGSCALE MODIFICATION BEG ---
+ with _switch_seq_group(seq_group, seq_group.negative_seqs):
+ num_new_tokens_uncached_negative, num_new_tokens_cached_negative = (
+ self._get_num_new_uncached_and_cached_tokens(
+ seq_group, SequenceStatus.SWAPPED, enable_chunking,
+ budget))
+ if num_new_tokens_uncached + num_new_tokens_uncached_negative == 0 \
+ or not budget.can_schedule(
+ num_new_tokens=num_new_tokens_uncached + num_new_tokens_uncached_negative,
+ num_new_seqs=num_new_seqs,
+ ):
break
+ # --- FLAGSCALE MODIFICATION END ---
if lora_int_id > 0 and curr_loras is not None:
curr_loras.add(lora_int_id)
@@ -739,13 +831,24 @@ def _schedule_swapped(
is_prefill = seq_group.is_prefill()
if is_prefill:
prefill_seq_groups.append(
- ScheduledSequenceGroup(seq_group,
- token_chunk_size=num_new_tokens,
- negative_token_chunk_size=negative_num_new_tokens)) # --- FLAGSCALE MODIFICATION ---
+ ScheduledSequenceGroup(
+ seq_group,
+ token_chunk_size=num_new_tokens_uncached +
+ num_new_tokens_cached,
+ negative_token_chunk_size=num_new_tokens_uncached_negative +
+ num_new_tokens_cached_negative # --- FLAGSCALE MODIFICATION ---
+ ))
else:
decode_seq_groups.append(
- ScheduledSequenceGroup(seq_group, token_chunk_size=1, negative_token_chunk_size=1 if seq_group.has_negative_seqs() else 0)) # --- FLAGSCALE MODIFICATION ---
- budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens + negative_num_new_tokens) # --- FLAGSCALE MODIFICATION ---
+ ScheduledSequenceGroup(
+ seq_group,
+ token_chunk_size=1,
+ negative_token_chunk_size=1 if seq_group.has_negative_seqs() else 0)) # --- FLAGSCALE MODIFICATION ---
+ budget.add_num_batched_tokens(
+ seq_group.request_id,
+ num_batched_tokens=num_new_tokens_uncached + num_new_tokens_uncached_negative, # --- FLAGSCALE MODIFICATION ---
+ num_cached_tokens=num_new_tokens_cached + num_new_tokens_cached_negative, # --- FLAGSCALE MODIFICATION ---
+ )
budget.add_num_seqs(seq_group.request_id, num_new_seqs)
swapped_queue.extendleft(leftover_swapped)
@@ -811,33 +914,36 @@ def _schedule_priority_preemption(
if waiting_queue:
seq_group = waiting_queue.popleft()
num_new_seqs = seq_group.get_max_num_running_seqs()
- num_new_tokens = self._get_num_new_tokens(seq_group,
- SequenceStatus.WAITING,
- False, budget)
+ num_new_tokens_uncached, _ = (
+ self._get_num_new_uncached_and_cached_tokens(
+ seq_group, SequenceStatus.WAITING, False, budget))
#Only preempt if priority inversion exists
while running_queue and self._get_priority(
running_queue[-1]) > self._get_priority(seq_group):
#Only preempt if waiting sequence cannot be allocated
can_allocate = self.block_manager.can_allocate(seq_group)
- if (num_new_tokens and can_allocate == AllocStatus.OK
- and budget.can_schedule(num_new_tokens=num_new_tokens,
- num_new_seqs=num_new_seqs)):
+ if (num_new_tokens_uncached > 0
+ and can_allocate == AllocStatus.OK
+ and budget.can_schedule(
+ num_new_tokens=num_new_tokens_uncached,
+ num_new_seqs=num_new_seqs,
+ )):
break
#Adjust budget to remove the victim sequence group
vseq_group = running_queue.pop()
- num_running_tokens = self._get_num_new_tokens(
- vseq_group, SequenceStatus.RUNNING, False, budget)
- budget.subtract_num_batched_tokens(vseq_group.request_id,
- num_running_tokens)
+ num_running_tokens_uncached, _ = (
+ self._get_num_new_uncached_and_cached_tokens(
+ vseq_group, SequenceStatus.RUNNING, False, budget))
+ budget.subtract_num_batched_tokens(
+ vseq_group.request_id, num_running_tokens_uncached)
num_running_seqs = vseq_group.get_max_num_running_seqs()
budget.subtract_num_seqs(vseq_group.request_id,
num_running_seqs)
#Preempt out the victim sequence group
- self._preempt(vseq_group, blocks_to_swap_out,
- PreemptionMode.RECOMPUTE)
+ self._preempt(vseq_group, blocks_to_swap_out)
waiting_queue.appendleft(vseq_group)
force_preemption_count += 1
#Put the sequence back into the waiting queue
@@ -891,18 +997,39 @@ def _schedule_prefills(
assert len(waiting_seqs) == 1, (
"Waiting sequence group should have only one prompt "
"sequence.")
- num_new_tokens, negative_num_new_tokens = self._get_num_new_tokens(seq_group, # --- FLAGSCALE MODIFICATION ---
- SequenceStatus.WAITING,
- enable_chunking, budget)
+ num_new_tokens_uncached, num_new_tokens_cached = (
+ self._get_num_new_uncached_and_cached_tokens(
+ seq_group, SequenceStatus.WAITING, enable_chunking,
+ budget))
+
+ # --- FLAGSCALE MODIFICATION BEG ---
+ with _switch_seq_group(seq_group, seq_group.negative_seqs):
+ num_new_tokens_uncached_negative, num_new_tokens_cached_negative = (
+ self._get_num_new_uncached_and_cached_tokens(
+ seq_group, SequenceStatus.WAITING, enable_chunking,
+ budget))
+ num_new_tokens_uncached, num_new_tokens_uncached_negative = _update_num_new_tokens(
+ budget, num_new_tokens_uncached, num_new_tokens_uncached_negative)
+
+ num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached
+ num_new_tokens_negative = num_new_tokens_uncached_negative + num_new_tokens_cached_negative
+ # --- FLAGSCALE MODIFICATION END ---
+
if not enable_chunking:
num_prompt_tokens = waiting_seqs[0].get_len()
assert num_new_tokens == num_prompt_tokens
+ # --- FLAGSCALE MODIFICATION BEG ---
+ if num_new_tokens_negative > 0:
+ assert self.scheduler_config.is_multi_step is False
+ assert self.cache_config.enable_prefix_caching is False
+ # --- FLAGSCALE MODIFICATION END ---
+
prompt_limit = self._get_prompt_limit(seq_group)
- if num_new_tokens + negative_num_new_tokens > prompt_limit: # --- FLAGSCALE MODIFICATION ---
+ if num_new_tokens + num_new_tokens_negative > prompt_limit: # --- FLAGSCALE MODIFICATION ---
logger.warning(
"Input prompt (%d tokens) (%d negative tokens) is too long"
- " and exceeds limit of %d", num_new_tokens, negative_num_new_tokens, prompt_limit) # --- FLAGSCALE MODIFICATION ---
+ " and exceeds limit of %d", num_new_tokens, num_new_tokens_negative, prompt_limit) # --- FLAGSCALE MODIFICATION ---
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
@@ -921,9 +1048,9 @@ def _schedule_prefills(
break
elif can_allocate == AllocStatus.NEVER:
logger.warning(
- "Input prompt (%d tokens) (%d negative tokens) is too long"
- " and exceeds the capacity of block_manager",
- num_new_tokens, negative_num_new_tokens) # --- FLAGSCALE MODIFICATION ---
+ "Input prompt (%d tokens) (%d negative tokens) + lookahead slots (%d) is "
+ "too long and exceeds the capacity of block_manager",
+ num_new_tokens, num_new_tokens_negative, num_lookahead_slots) # --- FLAGSCALE MODIFICATION ---
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
@@ -944,10 +1071,19 @@ def _schedule_prefills(
waiting_queue.popleft()
continue
+ if (budget.num_batched_tokens >=
+ self.scheduler_config.max_num_batched_tokens):
+ # We've reached the budget limit - since there might be
+ # continuous prefills in the running queue, we should break
+ # to avoid scheduling any new prefills.
+ break
+
num_new_seqs = seq_group.get_max_num_running_seqs()
- if (num_new_tokens + negative_num_new_tokens == 0 # --- FLAGSCALE MODIFICATION ---
- or not budget.can_schedule(num_new_tokens=num_new_tokens + negative_num_new_tokens, # --- FLAGSCALE MODIFICATION ---
- num_new_seqs=num_new_seqs)):
+ if num_new_tokens_uncached + num_new_tokens_uncached_negative == 0 \
+ or not budget.can_schedule(
+ num_new_tokens=num_new_tokens_uncached + num_new_tokens_uncached_negative,
+ num_new_seqs=num_new_seqs,
+ ): # --- FLAGSCALE MODIFICATION ---
break
# Can schedule this request.
@@ -976,8 +1112,12 @@ def _schedule_prefills(
seq_groups.append(
ScheduledSequenceGroup(seq_group=seq_group,
token_chunk_size=num_new_tokens,
- negative_token_chunk_size=negative_num_new_tokens)) # --- FLAGSCALE MODIFICATION ---
- budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens + negative_num_new_tokens) # --- FLAGSCALE MODIFICATION ---
+ negative_token_chunk_size=num_new_tokens_negative)) # --- FLAGSCALE MODIFICATION ---
+ budget.add_num_batched_tokens(
+ seq_group.request_id,
+ num_batched_tokens=num_new_tokens_uncached + num_new_tokens_uncached_negative, # --- FLAGSCALE MODIFICATION ---
+ num_cached_tokens=num_new_tokens_cached + num_new_tokens_cached_negative, # --- FLAGSCALE MODIFICATION ---
+ )
budget.add_num_seqs(seq_group.request_id, num_new_seqs)
# Queue requests that couldn't be scheduled.
@@ -1085,7 +1225,8 @@ def _schedule_default(self) -> SchedulerOutputs:
return SchedulerOutputs(
scheduled_seq_groups=scheduled_seq_groups,
num_prefill_groups=num_prefill_groups,
- num_batched_tokens=budget.num_batched_tokens,
+ num_batched_tokens=budget.num_batched_tokens +
+ budget.num_cached_tokens,
blocks_to_swap_in=swapped_in.blocks_to_swap_in,
blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
@@ -1129,7 +1270,6 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs:
running_scheduled.swapped_out) == 0:
swapped_in = self._schedule_swapped(budget, curr_loras)
- # Schedule new prefills.
prefills = self._schedule_prefills(budget,
curr_loras,
enable_chunking=True)
@@ -1157,23 +1297,35 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs:
# Update swapped requests.
self.swapped.extend(running_scheduled.swapped_out)
+ # Put prefills first due to Attention backend ordering assumption.
+ scheduled_seq_groups = (prefills.seq_groups +
+ running_scheduled.prefill_seq_groups +
+ swapped_in.prefill_seq_groups +
+ running_scheduled.decode_seq_groups +
+ swapped_in.decode_seq_groups)
+ num_prefill_groups = (len(prefills.seq_groups) +
+ len(swapped_in.prefill_seq_groups) +
+ len(running_scheduled.prefill_seq_groups))
+ # If all prompts, then we set num_lookahead_slots to 0
+ # this allows us to go through the `no_spec` path in
+ # `spec_decode_worker.py`
+ all_prefills = (len(scheduled_seq_groups) == num_prefill_groups)
+ num_lookahead_slots = (0 if
+ (all_prefills
+ and not self.scheduler_config.is_multi_step)
+ else running_scheduled.num_lookahead_slots)
return SchedulerOutputs(
- scheduled_seq_groups=(prefills.seq_groups +
- running_scheduled.prefill_seq_groups +
- swapped_in.prefill_seq_groups +
- running_scheduled.decode_seq_groups +
- swapped_in.decode_seq_groups),
- num_prefill_groups=(len(prefills.seq_groups) +
- len(swapped_in.prefill_seq_groups) +
- len(running_scheduled.prefill_seq_groups)),
- num_batched_tokens=budget.num_batched_tokens,
+ scheduled_seq_groups=scheduled_seq_groups,
+ num_prefill_groups=num_prefill_groups,
+ num_batched_tokens=budget.num_batched_tokens +
+ budget.num_cached_tokens,
blocks_to_swap_in=swapped_in.blocks_to_swap_in,
blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
blocks_to_copy=running_scheduled.blocks_to_copy +
swapped_in.blocks_to_copy,
ignored_seq_groups=prefills.ignored_seq_groups +
swapped_in.infeasible_seq_groups,
- num_lookahead_slots=running_scheduled.num_lookahead_slots,
+ num_lookahead_slots=num_lookahead_slots,
running_queue_size=len(self.running),
preempted=(len(running_scheduled.preempted) +
len(running_scheduled.swapped_out)),
@@ -1344,12 +1496,15 @@ def schedule(
negative_seq_data=negative_seq_data, # --- FLAGSCALE MODIFICATION ---
negative_block_tables=negative_block_tables, # --- FLAGSCALE MODIFICATION ---
state=seq_group.state,
+ token_type_ids=seq_group.token_type_ids,
# `multi_modal_data` will only be present for the 1st comm
# between engine and worker.
# the subsequent comms can still use delta, but
# `multi_modal_data` will be None.
multi_modal_data=seq_group.multi_modal_data
if scheduler_outputs.num_prefill_groups > 0 else None,
+ multi_modal_placeholders=seq_group.multi_modal_placeholders
+ if scheduler_outputs.num_prefill_groups > 0 else None,
mm_processor_kwargs=seq_group.mm_processor_kwargs,
prompt_adapter_request=seq_group.prompt_adapter_request,
)
@@ -1493,12 +1648,8 @@ def _append_slots(self,
assert not seq_group.has_negative_seqs() # --- FLAGSCALE MODIFICATION ---
blocks_to_copy.extend(cows)
- def _preempt(
- self,
- seq_group: SequenceGroup,
- blocks_to_swap_out: List[Tuple[int, int]],
- preemption_mode: Optional[PreemptionMode] = None,
- ) -> PreemptionMode:
+ def _preempt(self, seq_group: SequenceGroup,
+ blocks_to_swap_out: List[Tuple[int, int]]) -> PreemptionMode:
# If preemption mode is not specified, we determine the mode as follows:
# We use recomputation by default since it incurs lower overhead than
# swapping. However, when the sequence group has multiple sequences
@@ -1531,6 +1682,11 @@ def _preempt(
preemption_mode, self.num_cumulative_preemption + 1)
self.num_cumulative_preemption += 1
+ # --- FLAGSCALE MODIFICATION BEG ---
+ if seq_group.has_negative_seqs() and preemption_mode != PreemptionMode.RECOMPUTE:
+ raise NotImplementedError("Swap in/out does not support in CFG.")
+ # --- FLAGSCALE MODIFICATION END ---
+
if preemption_mode == PreemptionMode.RECOMPUTE:
self._preempt_by_recompute(seq_group)
elif preemption_mode == PreemptionMode.SWAP:
@@ -1634,89 +1790,178 @@ def _get_num_lookahead_slots(self, is_prefill: bool,
return self.scheduler_config.num_lookahead_slots
- def _get_num_new_tokens(self, seq_group: SequenceGroup,
- status: SequenceStatus, enable_chunking: bool,
- budget: SchedulingBudget) -> int:
- """Get the next new tokens to compute for a given sequence group
- that's in a given `status`.
+ def _get_num_new_uncached_and_cached_tokens(
+ self,
+ seq_group: SequenceGroup,
+ status: SequenceStatus,
+ enable_chunking: bool,
+ budget: SchedulingBudget,
+ ) -> Tuple[int, int]:
+ """
+ Returns the number of new uncached and cached tokens to schedule for a
+ given sequence group that's in a given `status`.
The API could chunk the number of tokens to compute based on `budget`
if `enable_chunking` is True. If a sequence group has multiple
sequences (e.g., running beam search), it means it is in decoding
phase, so chunking doesn't happen.
- Returns 0 if the new token cannot be computed due to token budget.
+ Returns (0, 0) if the new token cannot be computed due to token budget.
+
+ The cached tokens's blocks are already computed, and the attention
+ backend will reuse the cached blocks rather than recomputing them. So
+ the scheduler could schedule these cached tokens "for free".
+
+ Args:
+ seq_group: The sequence group to get the number of new tokens to
+ schedule.
+ status: The status of the sequences to get the number of new tokens
+ to schedule.
+ enable_chunking: Whether to chunk the number of tokens to compute.
+ budget: The budget to chunk the number of tokens to compute.
+
+
+ Returns:
+ A tuple of two ints. The first int is the number of new uncached
+ tokens to schedule. The second int is the number of cached tokens.
+ If no more new tokens can be scheduled, returns (0, 0).
"""
- # --- FLAGSCALE MODIFICATION BEG ---
- num_new_tokens = 0
- negative_num_new_tokens = 0
+ num_cached_new_tokens = 0
+ num_uncached_new_tokens = 0
+
seqs = seq_group.get_seqs(status=status)
+ # Compute the number of new uncached and cached tokens for
+ # each sequence.
for seq in seqs:
- num_new_tokens += seq.get_num_new_tokens()
- if seq_group.has_negative_seqs():
- negative_seq = seq_group.negative_seqs_dict[seq.seq_id]
- negative_num_new_tokens += negative_seq.get_num_new_tokens()
- assert num_new_tokens + negative_num_new_tokens > 0
-
- if seq_group.has_negative_seqs():
- if enable_chunking and len(seqs) == 1:
- assert not self.cache_config.enable_prefix_caching, "Prefix caching is not supported with negative sequences."
- remaining_token_budget = budget.remaining_token_budget()
- if num_new_tokens + negative_num_new_tokens < remaining_token_budget:
- return num_new_tokens, negative_num_new_tokens
- else:
- rt_num_new_tokens = min(num_new_tokens, remaining_token_budget // 2)
- rt_negative_num_new_tokens = min(negative_num_new_tokens, remaining_token_budget // 2)
+ if not seq.is_prefill():
+ # Decode sequences should always just have 1 uncached token
+ # TODO(rickyx): Actually is this still correct for multi-step?
+ num_uncached_new_tokens += 1
+ continue
- if rt_num_new_tokens == num_new_tokens and rt_negative_num_new_tokens != negative_num_new_tokens:
- return 0, 0
- elif rt_num_new_tokens != num_new_tokens and rt_negative_num_new_tokens == negative_num_new_tokens:
- return 0, 0
- else:
- return rt_num_new_tokens, rt_negative_num_new_tokens
- return num_new_tokens, negative_num_new_tokens
- # --- FLAGSCALE MODIFICATION END ---
+ num_computed_tokens_seq = seq.get_num_computed_tokens()
+ all_num_new_tokens_seq = seq.get_len() - num_computed_tokens_seq
+ if not self.cache_config.enable_prefix_caching:
+ # If prefix caching is not enabled, all new tokens are uncached.
+ num_uncached_new_tokens += all_num_new_tokens_seq
+ continue
+
+ # NOTE: the cache token might be currently in a block that's in an
+ # evictor meaning that it's not yet allocated. However, we don't
+ # exclude such tokens in the cache count because it will be
+ # guaranteed to be allocated later if the sequence can be allocated.
+ num_cached_tokens_seq = self.block_manager.get_num_cached_tokens(
+ seq)
+
+ # Sanity check.
+ if num_cached_tokens_seq < num_computed_tokens_seq:
+ # This should only happen with chunked prefill, and
+ # the seq is still in prefill. The `num_cached_tokens_seq`
+ # is the value we calculated on scheduling the first prefill.
+ # For subsequent continuous prefill steps, we cached the
+ # number of cache tokens for the sequence so the cached token
+ # count could be less than the number of computed tokens.
+ # See comments on `ComputedBlocksTracker` for more details.
+ assert (
+ seq.is_prefill() and seq.status == SequenceStatus.RUNNING
+ and self.scheduler_config.chunked_prefill_enabled
+ ), ("Number of cached tokens should not be less than the "
+ "number of computed tokens for a sequence that's still "
+ f"in prefill. But there are {num_cached_tokens_seq} cached "
+ f"tokens and {num_computed_tokens_seq} computed tokens "
+ f"for sequence {seq.seq_id}.")
+
+ num_cached_new_tokens_seq = max(
+ 0, num_cached_tokens_seq - num_computed_tokens_seq)
+ num_uncached_new_tokens_seq = (all_num_new_tokens_seq -
+ num_cached_new_tokens_seq)
+
+ num_uncached_new_tokens += num_uncached_new_tokens_seq
+ num_cached_new_tokens += num_cached_new_tokens_seq
+
+ if num_uncached_new_tokens == 0 and num_cached_new_tokens > 0:
+ # For a fully cached hit sequence, we actually need to recompute the
+ # last token. So we need at least 1 uncached token to schedule.
+ # See ModelRunner._compute_for_prefix_cache_hit for more details.
+ num_uncached_new_tokens = 1
+ num_cached_new_tokens -= 1
- # Chunk if a running request cannot fit in the given budget.
- # If number of seq > 1, it means it is doing beam search
- # in a decode phase. Do not chunk.
if enable_chunking and len(seqs) == 1:
- remaining_token_budget = budget.remaining_token_budget()
- if self.scheduler_config.is_multi_step:
- # The current multi-step + chunked prefill capability does
- # not actually support chunking prompts.
- #
- # Therefore, `num_new_tokens` is computed in the same fashion
- # for both multi-step+chunked-prefill &
- # multi-step+chunked-prefill+APC
- #
- # Prompts with more tokens than the current remaining budget
- # are postponed to future scheduler steps
- if num_new_tokens > self._get_prompt_limit(seq_group):
- # If the seq_group is in prompt-stage, pass the
- # num_new_tokens as-is so the caller can ignore
- # the sequence.
- pass
- else:
- num_new_tokens = 0 \
- if num_new_tokens > remaining_token_budget \
- else num_new_tokens
- elif self.cache_config.enable_prefix_caching:
- # When prefix caching is enabled, we always allocate
- # the number of new tokens that is dividable by the block
- # size to avoid partial block matching.
- block_size = self.cache_config.block_size
- remainder = budget.token_budget % block_size
- if remainder != 0:
- raise ValueError("When enabling chunked prefill and "
- "prefix caching, max_num_batched_tokens "
- "(chunk size) must be dividable by "
- "block size, but got chunk_size "
- f"({budget.token_budget}) % block_size "
- f"({block_size}) = {remainder}")
- if remaining_token_budget < num_new_tokens:
- num_new_tokens = (remaining_token_budget //
- block_size) * block_size
- else:
- num_new_tokens = min(num_new_tokens, remaining_token_budget)
- return num_new_tokens, negative_num_new_tokens # --- FLAGSCALE MODIFICATION ---
+ # Chunk if a running request cannot fit in the given budget.
+ # If number of seq > 1, it means it is doing beam search
+ # in a decode phase. Do not chunk.
+ num_uncached_new_tokens = self._chunk_new_tokens_to_schedule(
+ self.scheduler_config,
+ self.cache_config,
+ budget,
+ self._get_prompt_limit(seq_group),
+ num_uncached_new_tokens,
+ )
+
+ return num_uncached_new_tokens, num_cached_new_tokens
+
+ @staticmethod
+ def _chunk_new_tokens_to_schedule(
+ scheduler_config: SchedulerConfig,
+ cache_config: CacheConfig,
+ budget: SchedulingBudget,
+ prompt_limit: int,
+ num_new_tokens: int,
+ ) -> int:
+ """
+ Chunks the number of new tokens to schedule based on the budget when
+ chunked prefill is enabled.
+
+ Args:
+ scheduler_config: The scheduler config.
+ cache_config: The cache config.
+ budget: The budget to chunk the number of tokens to compute.
+ prompt_limit: The maximum number of tokens allowed in a prompt.
+ num_new_tokens: The number of new tokens to schedule.
+
+ Returns:
+ The number of new tokens to schedule after chunking.
+ """
+ remaining_token_budget = budget.remaining_token_budget()
+ if scheduler_config.is_multi_step:
+ # The current multi-step + chunked prefill capability does
+ # not actually support chunking prompts.
+ #
+ # Therefore, `num_new_tokens` is computed in the same fashion
+ # for both multi-step+chunked-prefill &
+ # multi-step+chunked-prefill+APC
+ #
+ # Prompts with more tokens than the current remaining budget
+ # are postponed to future scheduler steps
+ if num_new_tokens > prompt_limit:
+ # If the seq_group is in prompt-stage, pass the
+ # num_new_tokens as-is so the caller can ignore
+ # the sequence.
+ return num_new_tokens
+
+ return (0 if num_new_tokens > remaining_token_budget else
+ num_new_tokens)
+
+ if cache_config.enable_prefix_caching:
+ # Adjust the remaining token budget to be divisible by the block
+ # size when prefix caching is enabled.
+
+ # When prefix caching is enabled, we always allocate
+ # the number of new tokens that is dividable by the block
+ # size to avoid partial block matching.
+ block_size = cache_config.block_size
+ remainder = budget.token_budget % block_size
+ if remainder != 0:
+ raise ValueError("When enabling chunked prefill and "
+ "prefix caching, max_num_batched_tokens "
+ "(chunk size) must be dividable by "
+ "block size, but got chunk_size "
+ f"({budget.token_budget}) % block_size "
+ f"({block_size}) = {remainder}")
+ # Round down to block size.
+ remaining_token_budget = (remaining_token_budget // block_size *
+ block_size)
+
+ num_new_tokens = min(num_new_tokens, remaining_token_budget)
+
+ return num_new_tokens
diff --git a/flagscale/inference/core/sequence.py b/flagscale/inference/core/sequence.py
index 996a3a33b..614987178 100644
--- a/flagscale/inference/core/sequence.py
+++ b/flagscale/inference/core/sequence.py
@@ -6,24 +6,21 @@
from array import array
from collections import defaultdict
from dataclasses import dataclass, field
-from functools import cached_property, reduce
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
+from functools import reduce
+from typing import Any, Callable, DefaultDict, Dict, List, Mapping, Optional
from typing import Sequence as GenericSequence
-from typing import Set, Tuple, Union, cast
+from typing import Set, Tuple, Union
import msgspec
import torch
-from vllm.inputs.parse import is_encoder_decoder_inputs
+from vllm.inputs import SingletonInputs, SingletonInputsAdapter
from vllm.lora.request import LoRARequest
+from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
-if TYPE_CHECKING:
- from vllm.inputs import SingletonInputs
- from vllm.multimodal.base import MultiModalDataDict
-
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
VLLM_INVALID_TOKEN_ID = -1
@@ -167,6 +164,8 @@ class SequenceData(msgspec.Struct,
...] = msgspec.field(default_factory=tuple)
# The number of tokens that are computed (that run against the model).
_num_computed_tokens: int = 0
+ # The number of tokens with prefix cache hit.
+ _num_cached_tokens: int = 0
_stage: SequenceStage = SequenceStage.PREFILL
_cached_all_token_ids: List[int] = msgspec.field(default_factory=list)
@@ -257,7 +256,8 @@ def output_token_ids(self) -> Tuple[int, ...]:
return tuple(self._output_token_ids)
@output_token_ids.setter
- def output_token_ids(self, new_output_token_ids: List[int]) -> None:
+ def output_token_ids(self,
+ new_output_token_ids: GenericSequence[int]) -> None:
self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
new_output_token_ids)
self._update_cached_all_tokens()
@@ -322,6 +322,14 @@ def update_num_computed_tokens(self, num_new_computed_tokens: int):
if self.get_num_uncomputed_tokens() == 0:
self._stage = SequenceStage.DECODE
+ def get_num_cached_tokens(self) -> int:
+ """Return the number of tokens with prefix cache hit."""
+ return self._num_cached_tokens
+
+ def update_num_cached_tokens(self, num_cached_tokens: int):
+ """Update the number of tokens with prefix cache hit."""
+ self._num_cached_tokens = num_cached_tokens
+
def reset_state_for_recompute(self) -> None:
"""Reset the number of computed tokens from this sequence. It is
supposed to be called when a sequence needs to be started from
@@ -379,14 +387,9 @@ def __repr__(self) -> str:
class Sequence:
"""Stores the data, status, and block information of a sequence.
- The sequence is constructed from the :code:`SingletonInputs` instance
- passed in through the :code:`inputs` constructor argument.
-
- For encoder/decoder models, SingletonInputs encapsulates both a
- decoder and encoder prompt, creating an ambiguity about which
- prompt to construct the sequence from. The `from_decoder_prompt`
- constructor argument signals whether to construct the Sequence
- from the SingletonInputs decoder prompt, or encoder prompt.
+ The sequence is constructed from the :data:`DecoderOnlyInputs`
+ (for decoder-only) or :data:`EncoderDecoderInputs` (for encoder-decoder)
+ instance passed in through the :code:`inputs` constructor argument.
Args:
seq_id: The ID of the sequence.
@@ -396,59 +399,23 @@ class Sequence:
eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
lora_request: LoRA request.
prompt_adapter_request: Prompt Adapter request.
- from_decoder_prompt: Construct Sequence from SingletonInputs decoder
- prompt (True) or encoder prompt (False.) Must be
- True for decoder-only model.
-
"""
def __init__(
self,
seq_id: int,
- inputs: "SingletonInputs",
+ inputs: SingletonInputs,
block_size: int,
eos_token_id: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
- from_decoder_prompt: bool = True,
- from_negative_prompt: bool = False, # --- FLAGSCALE MODIFICATION ---
) -> None:
self.seq_id = seq_id
- self.inputs = inputs
+ self.inputs = SingletonInputsAdapter(inputs)
self.block_size = block_size
self.eos_token_id = eos_token_id
self.lora_request = lora_request
self.prompt_adapter_request = prompt_adapter_request
- self.from_decoder_prompt = from_decoder_prompt
- # --- FLAGSCALE MODIFICATION BEG ---
- self.from_negative_prompt = from_negative_prompt
- # --- FLAGSCALE MODIFICATION END ---
-
- # For decoder-only models, a Sequence is constructed
- # from an DecoderOnlyInputs instance (the `inputs` arg.)
- #
- # For encoder/decoder models the same `inputs`
- # instance could be utilized to construct either an
- # encoder sequence or a decoder sequence, because
- # `DecoderOnlyInputs` has both decoder- and encoder-oriented
- # member variables (i.e. it encapsulates both an encoder
- # and a decoder prompt.) The decision of which type of sequence
- # to generate is determined by the `from_decoder_prompt` argument.
- #
- # When constructing a encoder sequence
- # (`from_decoder_prompt` False) it matters that
- # the `DecoderOnlyInputs` instance stored in `inputs` is valid
- # in the sense that its encoder-related member variables are
- # populated; below, an exception is raised if this is
- # not the case.
- #
- # When constructing a decoder sequence (`from_decoder_prompt` True)
- # it does not matter whether `inputs` has its encoder-related
- # member variables populated.
- if not (from_decoder_prompt or is_encoder_decoder_inputs(inputs)):
- raise ValueError("Cannot extract encoder input prompt from "
- f"invalid input {inputs}; did you forget the "
- "encoder input prompt fields?")
self.data = SequenceData.from_seqs(self.prompt_token_ids)
self.output_logprobs: SampleLogprobs = []
@@ -471,55 +438,33 @@ def __init__(
def n_blocks(self) -> int:
return (self.get_len() + self.block_size - 1) // self.block_size
- @cached_property
+ @property
def prompt(self) -> Optional[str]:
- # Select decoder or encoder input prompt str, as appropriate
- prompt_key: str = ("prompt"
- if self.from_decoder_prompt else "encoder_prompt")
-
- # --- FLAGSCALE MODIFICATION BEG ---
- if self.from_negative_prompt:
- assert self.from_decoder_prompt is True, "negative prompt is only supported for decoder"
- prompt_key: str = "negative_prompt"
- # --- FLAGSCALE MODIFICATION END ---
+ return self.inputs.prompt
- return cast(Optional[str], self.inputs.get(prompt_key))
-
- @cached_property
+ @property
def prompt_token_ids(self) -> List[int]:
- # Select decoder or encoder input prompt token ids, as appropriate
- prompt_token_ids_key: str = ("prompt_token_ids"
- if self.from_decoder_prompt else
- "encoder_prompt_token_ids")
+ return self.inputs.prompt_token_ids
- # --- FLAGSCALE MODIFICATION BEG ---
- if self.from_negative_prompt:
- assert self.from_decoder_prompt is True, "negative prompt is only supported for decoder"
- prompt_token_ids_key: str = "negative_prompt_token_ids"
- # --- FLAGSCALE MODIFICATION END ---
+ @property
+ def prompt_embeds(self) -> Optional[torch.Tensor]:
+ return self.inputs.prompt_embeds
- # Cache computed prompt token ids
- return cast(List[int], self.inputs.get(prompt_token_ids_key))
+ @property
+ def token_type_ids(self) -> List[int]:
+ return self.inputs.token_type_ids
@property
def multi_modal_data(self) -> "MultiModalDataDict":
- inputs = self.inputs
+ return self.inputs.multi_modal_data
- if (inputs.get("multi_modal_data")
- and inputs.get("encoder_multi_modal_data")):
- raise ValueError(
- "Multi-modal data in both encoder and decoder is not supported."
- )
-
- return cast(
- "MultiModalDataDict",
- (inputs.get("multi_modal_data")
- or inputs.get("encoder_multi_modal_data") or {}),
- )
+ @property
+ def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
+ return self.inputs.multi_modal_placeholders
@property
def mm_processor_kwargs(self) -> Dict[str, Any]:
- return self.inputs.get("mm_processor_kwargs") or {}
+ return self.inputs.mm_processor_kwargs
@property
def lora_int_id(self) -> int:
@@ -639,6 +584,9 @@ def get_num_new_tokens(self) -> int:
return 1
return self.data.get_num_uncomputed_tokens()
+ def get_num_computed_tokens(self) -> int:
+ return self.data.get_num_computed_tokens()
+
def is_prefill(self) -> bool:
return self.data.stage == SequenceStage.PREFILL
@@ -723,6 +671,7 @@ def __init__(
# --- FLAGSCALE MODIFICATION BEG ---
self.negative_seqs = negative_seqs
+ self.negative_seqs_dict = {}
if negative_seqs:
self.negative_seqs_dict = {seq.seq_id: seq for seq in negative_seqs}
assert self.is_single_seq is True
@@ -754,9 +703,17 @@ def encoder_prompt_token_ids(self) -> Optional[List[int]]:
if self.encoder_seq is not None else None)
@property
- def multi_modal_data(self) -> "MultiModalDataDict":
+ def token_type_ids(self) -> Optional[List[int]]:
+ return self.first_seq.token_type_ids
+
+ @property
+ def multi_modal_data(self) -> MultiModalDataDict:
return self.first_seq.multi_modal_data
+ @property
+ def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
+ return self.first_seq.multi_modal_placeholders
+
@property
def mm_processor_kwargs(self) -> Dict[str, Any]:
return self.first_seq.mm_processor_kwargs
@@ -969,7 +926,7 @@ class SequenceGroupMetadata(
multi_modal_data: Multi modal data.
mm_processor_kwargs: Multimodal input processor / mapper overrides.
encoder_seq_data: Optional sequence data for encoder prompt
- (SequenceGroup.encoder_seq). Should be None
+ (SequenceGroup.encoder_seq). Should be None
unless you are working with an encoder/decoder
model.
cross_block_table: Optional cross-attention block table associated
@@ -993,7 +950,9 @@ class SequenceGroupMetadata(
default_factory=lambda: SequenceGroupState())
# "MultiModalDataDict" types. We have to use Any due to msgspec
# doesn't allow to have union of 2 different dicts.
+ token_type_ids: Optional[List[int]] = None
multi_modal_data: Optional[Any] = None
+ multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None
encoder_seq_data: Optional[SequenceData] = None
cross_block_table: Optional[List[int]] = None
@@ -1227,7 +1186,7 @@ def get_all_seq_ids_and_request_ids(
sequence ids.
"""
seq_ids: List[int] = []
- request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set)
+ request_id_seq_ids_mapping: DefaultDict[str, Set[int]] = defaultdict(set)
for sg in seq_group_metadata_list:
for seq_id in sg.seq_data:
seq_ids.append(seq_id)
diff --git a/flagscale/runner/runner_inference.py b/flagscale/runner/runner_inference.py
index 376696a22..a3c64550c 100644
--- a/flagscale/runner/runner_inference.py
+++ b/flagscale/runner/runner_inference.py
@@ -81,6 +81,7 @@ def _generate_run_script_inference(
root_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
+ vllm_dir = os.path.join(root_dir, "vllm")
cmds_config = config.experiment.get("cmds", None)
if cmds_config:
before_start = cmds_config.get("before_start", "")
@@ -94,7 +95,7 @@ def _generate_run_script_inference(
f.write(f"\n")
f.write(f"cd {root_dir}\n")
f.write(f"\n")
- f.write(f"export PYTHONPATH={root_dir}\n")
+ f.write(f"export PYTHONPATH={vllm_dir}:{root_dir}\n")
f.write(f"\n")
f.write(f'cmd="{cmd}"\n')
f.write(f"\n")
diff --git a/flagscale/runner/runner_serve.py b/flagscale/runner/runner_serve.py
new file mode 100644
index 000000000..1955e8d95
--- /dev/null
+++ b/flagscale/runner/runner_serve.py
@@ -0,0 +1,306 @@
+import os
+import shlex
+
+import hydra
+from hydra.core.hydra_config import HydraConfig
+from omegaconf import DictConfig, OmegaConf
+
+from flagscale.runner.runner_base import RunnerBase
+from flagscale.runner.runner_utils import (
+ get_free_port,
+ get_nnodes,
+ get_nproc_per_node,
+ logger,
+ parse_hostfile,
+ run_local_command,
+ run_scp_command,
+ run_ssh_command,
+)
+
+
+def _get_args_vllm(config: DictConfig):
+ # see the following link for more details
+ # https://github.com/facebookresearch/hydra/discussions/2750
+ OmegaConf.set_struct(config, False)
+
+ hydra_config = HydraConfig.get()
+ output_dir = hydra_config.runtime.output_dir
+ output_subdir = hydra_config.output_subdir
+ config_path = os.path.join(output_dir, f"{output_subdir}/config.yaml")
+ config_path = hydra.utils.to_absolute_path(config_path)
+
+ args = []
+ args.append(f"--config-path={config_path}")
+
+ return args
+
+
+def _update_config_serve(config: DictConfig):
+ exp_dir = os.path.abspath(config.experiment.exp_dir)
+ if not os.path.isdir(exp_dir):
+ os.makedirs(exp_dir)
+ assert os.path.isdir(exp_dir), f"Directory {exp_dir} does not exist."
+
+ OmegaConf.set_struct(config, False)
+
+ if config.get("logging", None) is None:
+ config.serve.logging = DictConfig({})
+
+ log_dir = os.path.join(exp_dir, f"serve_logs")
+ scripts_dir = os.path.join(log_dir, "scripts")
+ pids_dir = os.path.join(log_dir, "pids")
+
+ config.serve.logging.log_dir = log_dir
+ config.serve.logging.scripts_dir = scripts_dir
+ config.serve.logging.pids_dir = pids_dir
+
+ OmegaConf.set_struct(config, True)
+
+
+def _generate_run_script_serve(
+ config, host, node_rank, cmd, background=True, with_test=False
+):
+ logging_config = config.serve.logging
+
+ no_shared_fs = config.experiment.runner.get("no_shared_fs", False)
+ if no_shared_fs:
+ host_output_file = os.path.join(logging_config.log_dir, f"host.output")
+ else:
+ host_output_file = os.path.join(
+ logging_config.log_dir, f"host_{node_rank}_{host}.output"
+ )
+ host_run_script_file = os.path.join(
+ logging_config.scripts_dir, f"host_{node_rank}_{host}_run.sh"
+ )
+ host_pid_file = os.path.join(
+ logging_config.pids_dir, f"host_{node_rank}_{host}.pid"
+ )
+
+ os.makedirs(logging_config.scripts_dir, exist_ok=True)
+
+ root_dir = os.path.dirname(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+ )
+ cmds_config = config.experiment.get("cmds", None)
+ if cmds_config:
+ before_start = cmds_config.get("before_start", "")
+ else:
+ before_start = ""
+ cmd += f" --log-dir={logging_config.log_dir}"
+ with open(host_run_script_file, "w") as f:
+ f.write("#!/bin/bash\n\n")
+ f.write("set -x\n")
+ f.write(f"{before_start}\n")
+ f.write(f"mkdir -p {logging_config.log_dir}\n")
+ f.write(f"mkdir -p {logging_config.pids_dir}\n")
+ f.write(f"\n")
+ f.write(f"cd {root_dir}\n")
+ f.write(f"\n")
+ f.write(f"export PYTHONPATH={root_dir}\n")
+ f.write(f"\n")
+ f.write(f'cmd="{cmd}"\n')
+ f.write(f"\n")
+ if with_test:
+ f.write(f'bash -c "$cmd; sync" \n')
+ else:
+ # TODO: need a option to control whether to append or overwrite the output file
+ # Now, it always appends to the output file
+ if background:
+ f.write(
+ f'nohup bash -c "$cmd; sync" >> {host_output_file} 2>&1 & echo $! > {host_pid_file}\n'
+ )
+ else:
+ f.write(f'bash -c "$cmd; sync" >> {host_output_file} 2>&1\n')
+ f.write("\n")
+ f.flush()
+ os.fsync(f.fileno())
+ os.chmod(host_run_script_file, 0o755)
+
+ return host_run_script_file
+
+
+def _generate_stop_script(config, host, node_rank):
+ logging_config = config.serve.logging
+
+ host_stop_script_file = os.path.join(
+ logging_config.scripts_dir, f"host_{node_rank}_{host}_stop.sh"
+ )
+
+ host_pid_file = os.path.join(
+ logging_config.pids_dir, f"host_{node_rank}_{host}.pid"
+ )
+
+ os.makedirs(logging_config.scripts_dir, exist_ok=True)
+
+ cmds_config = config.experiment.get("cmds", None)
+ if cmds_config:
+ after_stop = cmds_config.get("after_stop", "")
+ else:
+ after_stop = ""
+ with open(host_stop_script_file, "w") as f:
+ f.write("#!/bin/bash\n\n")
+ f.write("pkill -f 'python'\n")
+ f.write(f"{after_stop}\n")
+ f.flush()
+ os.fsync(f.fileno())
+ os.chmod(host_stop_script_file, 0o755)
+
+ return host_stop_script_file
+
+
+class SSHServeRunner(RunnerBase):
+ def __init__(self, config: DictConfig):
+ super().__init__(config)
+ self.task_type = getattr(self.config.experiment.task, "type", None)
+ assert self.task_type == "serve", f"Unsupported task type: {self.task_type}"
+ self.command_line_mode = getattr(
+ self.config.serve.deploy, "command-line-mode", None
+ )
+ self._prepare()
+
+ def _prepare(self):
+ _update_config_serve(self.config)
+ self.user_args = _get_args_vllm(self.config)
+ self.user_envs = self.config.experiment.get("envs", {})
+ if self.command_line_mode:
+ self.user_script = "flagscale/serve/run_vllm.py"
+ else:
+ self.user_script = self.config.experiment.task.entrypoint
+ self.resources = parse_hostfile(
+ self.config.experiment.runner.get("hostfile", None)
+ )
+ logger.info("\n************** configuration **************")
+ logger.info(f"\n{OmegaConf.to_yaml(self.config)}")
+
+ def _run_each(
+ self,
+ host,
+ master_addr,
+ master_port,
+ nnodes,
+ node_rank,
+ nproc_per_node,
+ with_test=False,
+ dryrun=False,
+ ):
+ export_cmd = []
+ for k, v in self.user_envs.items():
+ export_cmd += [f"{k}={v}"]
+
+ cmd = shlex.join(export_cmd + ["python"] + [self.user_script] + self.user_args)
+
+ logging_config = self.config.serve.logging
+ host_run_script_file = _generate_run_script_serve(
+ self.config, host, node_rank, cmd, background=True, with_test=with_test
+ )
+
+ if host != "localhost":
+ ssh_port = self.config.experiment.runner.get("ssh_port", 22)
+ # Step 1: make sure the scripts_dir exists on the remote host
+ run_ssh_command(
+ host, f"mkdir -p {logging_config.scripts_dir}", ssh_port, dryrun
+ )
+
+ # Step 2: copy the host_run_script_file to the remote host
+ no_shared_fs = self.config.experiment.runner.get("no_shared_fs", False)
+ if no_shared_fs:
+ run_scp_command(
+ host,
+ host_run_script_file,
+ logging_config.scripts_dir,
+ ssh_port,
+ dryrun,
+ )
+
+ # Step 3: run the host_run_script_file on the remote host
+ run_ssh_command(host, f"bash {host_run_script_file}", ssh_port, dryrun)
+ else:
+ run_local_command(f"bash {host_run_script_file}", dryrun)
+
+ def run(self, with_test=False, dryrun=False):
+ num_visible_devices = None
+ visible_devices = self.user_envs.get("CUDA_VISIBLE_DEVICES", None)
+ if visible_devices is not None and isinstance(visible_devices, str):
+ visible_devices = visible_devices.split(",")
+ num_visible_devices = len(visible_devices)
+
+ runner_config = self.config.experiment.runner
+
+ # If hostfile is provided, use the resources from the hostfile
+ if self.resources is not None:
+ nnodes_from_hostfile = len(self.resources.keys())
+ nnodes_from_args = runner_config.get("nnodes", None)
+ nnodes = get_nnodes(nnodes_from_hostfile, nnodes_from_args)
+ available_ip = list(self.resources.keys())[0]
+ available_port = get_free_port()
+ for node_rank, (host, resource_info) in enumerate(self.resources.items()):
+ if node_rank >= nnodes:
+ break
+ nproc_from_hostfile = resource_info["slots"]
+ nproc_from_args = runner_config.get("nproc_per_node", None)
+ nproc_per_node = get_nproc_per_node(
+ nproc_from_hostfile, nproc_from_args, num_visible_devices
+ )
+ master_addr = runner_config.get("master_addr", available_ip)
+ master_port = runner_config.get("master_port", available_port)
+ self._run_each(
+ host,
+ master_addr,
+ master_port,
+ nnodes,
+ node_rank,
+ nproc_per_node,
+ with_test=with_test,
+ dryrun=dryrun,
+ )
+ else:
+ # If hostfile is not provided, run the job on localhost
+ nproc_from_args = runner_config.get("nproc_per_node", None)
+ nproc_per_node = get_nproc_per_node(
+ None, nproc_from_args, num_visible_devices
+ )
+ available_addr = runner_config.get("master_addr", "localhost")
+ available_port = runner_config.get("master_port", get_free_port())
+ self._run_each(
+ "localhost",
+ available_addr,
+ available_port,
+ 1,
+ 0,
+ nproc_per_node,
+ with_test=with_test,
+ dryrun=dryrun,
+ )
+
+ def _stop_each(self, host, node_rank):
+ host_stop_script_file = _generate_stop_script(self.config, host, node_rank)
+ logging_config = self.config.serve.logging
+
+ if host != "localhost":
+ ssh_port = self.config.experiment.runner.get("ssh_port", 22)
+ # Step 1: make sure the scripts_dir exists on the remote host
+ run_ssh_command(host, f"mkdir -p {logging_config.scripts_dir}", ssh_port)
+ # Step 2: copy the host_run_script_file to the remote host
+ no_shared_fs = self.config.experiment.runner.get("no_shared_fs", False)
+ if no_shared_fs:
+ run_scp_command(
+ host, host_stop_script_file, logging_config.scripts_dir, ssh_port
+ )
+ # Step 3: run the host_run_script_file on the remote host
+ run_ssh_command(host, f"bash {host_stop_script_file}", ssh_port)
+ else:
+ run_local_command(f"bash {host_stop_script_file}")
+
+ def stop(self):
+ if self.resources is None:
+ self._stop_each("localhost", 0)
+ return
+
+ nnodes = get_nnodes(
+ len(self.resources), self.config.experiment.runner.get("nnodes", None)
+ )
+
+ for node_rank, (host, _) in enumerate(self.resources.items()):
+ if node_rank >= nnodes:
+ break
+ self._stop_each(host, node_rank)
diff --git a/flagscale/runner/runner_train.py b/flagscale/runner/runner_train.py
index 8efead5a4..745da2c52 100644
--- a/flagscale/runner/runner_train.py
+++ b/flagscale/runner/runner_train.py
@@ -7,6 +7,7 @@
from flagscale.runner.runner_base import JobStatus, RunnerBase
from flagscale.runner.runner_utils import (
+ add_decive_extra_config,
flatten_dict_to_args,
get_free_port,
get_host_name_or_ip,
@@ -290,6 +291,7 @@ def _prepare(self):
self.user_args = _get_args_megatron(self.config)
self.rdzv_id = datetime.now().strftime("%Y%m%d_%H%M%S.%f")
self.user_envs = self.config.experiment.get("envs", {})
+ self.cur_envs = None # current node envs
self.user_script = self.config.experiment.task.entrypoint
self.resources = parse_hostfile(
self.config.experiment.runner.get("hostfile", None)
@@ -305,11 +307,13 @@ def _run_each(
nnodes,
node_rank,
nproc_per_node,
+ device_type=None,
with_test=False,
dryrun=False,
):
export_cmd = []
- for k, v in self.user_envs.items():
+
+ for k, v in self.cur_envs.items():
export_cmd += [f"{k}={v}"]
runner_cmd = _get_runner_cmd_train(
@@ -321,6 +325,13 @@ def _run_each(
nproc_per_node,
self.config,
)
+ # update hetero-current-device-type according to the device_type in hostfile
+ if device_type is not None:
+ if "--hetero-current-device-type" in self.user_args:
+ idx = self.user_args.index("--hetero-current-device-type")
+ self.user_args[idx + 1] = device_type
+ else:
+ self.user_args += ["--hetero-current-device-type", device_type]
cmd = shlex.join(export_cmd + runner_cmd + [self.user_script] + self.user_args)
@@ -355,11 +366,6 @@ def _run_each(
def run(self, with_test=False, dryrun=False, monitor=False, interval=10):
num_visible_devices = None
- visible_devices = self.user_envs.get("CUDA_VISIBLE_DEVICES", None)
- if visible_devices is not None and isinstance(visible_devices, str):
- visible_devices = visible_devices.split(",")
- num_visible_devices = len(visible_devices)
-
runner_config = self.config.experiment.runner
# If hostfile is provided, use the resources from the hostfile
@@ -372,6 +378,13 @@ def run(self, with_test=False, dryrun=False, monitor=False, interval=10):
for node_rank, (host, resource_info) in enumerate(self.resources.items()):
if node_rank >= nnodes:
break
+ self.cur_envs = add_decive_extra_config(
+ self.user_envs, resource_info["type"]
+ )
+ visible_devices = self.cur_envs.get("CUDA_VISIBLE_DEVICES", None)
+ if visible_devices is not None and isinstance(visible_devices, str):
+ visible_devices = visible_devices.split(",")
+ num_visible_devices = len(visible_devices)
nproc_from_hostfile = resource_info["slots"]
nproc_from_args = runner_config.get("nproc_per_node", None)
nproc_per_node = get_nproc_per_node(
@@ -386,11 +399,17 @@ def run(self, with_test=False, dryrun=False, monitor=False, interval=10):
nnodes,
node_rank,
nproc_per_node,
+ device_type=resource_info["type"],
with_test=with_test,
dryrun=dryrun,
)
else:
# If hostfile is not provided, run the job on localhost
+ self.cur_envs = self.user_envs
+ visible_devices = self.cur_envs.get("CUDA_VISIBLE_DEVICES", None)
+ if visible_devices is not None and isinstance(visible_devices, str):
+ visible_devices = visible_devices.split(",")
+ num_visible_devices = len(visible_devices)
nproc_from_args = runner_config.get("nproc_per_node", None)
nproc_per_node = get_nproc_per_node(
None, nproc_from_args, num_visible_devices
diff --git a/flagscale/runner/runner_utils.py b/flagscale/runner/runner_utils.py
index 37a27f105..8e030648c 100644
--- a/flagscale/runner/runner_utils.py
+++ b/flagscale/runner/runner_utils.py
@@ -3,6 +3,9 @@
import re
import socket
import subprocess
+import sys
+
+from omegaconf import DictConfig, OmegaConf
from flagscale.logger import logger
@@ -45,6 +48,10 @@ def parse_hostfile(hostfile_path):
else:
log_and_raise_error(f"Invalid entry in hostfile: {line}.")
+ assert all(info["type"] == None for _, info in resources.items()) or all(
+ info["type"] != None for _, info in resources.items()
+ ), "All hosts must have the a machine type or no machine type specified."
+
if len(resources) == 0:
log_and_raise_error(
"Hostfile is empty or not formatted correctly. Please check the hostfile."
@@ -84,12 +91,24 @@ def run_local_command(cmd, dryrun=False, query=False):
return
if query:
result = subprocess.run(
- cmd, shell=True, check=True, capture_output=True, text=True
+ cmd,
+ shell=True,
+ check=True,
+ capture_output=True,
+ text=True,
+ encoding="utf-8",
+ errors="replace",
)
return result
else:
result = subprocess.run(
- cmd, shell=True, check=True, capture_output=True, text=True
+ cmd,
+ shell=True,
+ check=True,
+ capture_output=True,
+ text=True,
+ encoding="utf-8",
+ errors="replace",
)
if result.returncode != 0:
print(f"Command {cmd} failed with return code {result.returncode}.")
@@ -107,13 +126,22 @@ def run_ssh_command(host, cmd, port=None, dryrun=False, query=False):
logger.info(f"Running the ssh command: {ssh_cmd}")
if dryrun:
return
+ result = subprocess.run(
+ ssh_cmd,
+ shell=True,
+ check=True,
+ capture_output=True,
+ text=True,
+ encoding="utf-8",
+ errors="replace",
+ )
+ if result.returncode != 0:
+ print(f"SSH command {ssh_cmd} failed with return code {result.returncode}.")
+ print(f"Output: {result.stdout}")
+ print(f"Error: {result.stderr}")
+ sys.exit(result.returncode)
if query:
- result = subprocess.run(
- ssh_cmd, shell=True, check=True, text=True, stdout=subprocess.PIPE
- )
return result
- else:
- subprocess.run(ssh_cmd, shell=True, check=True)
def run_scp_command(host, src, dst, port=None, dryrun=False):
@@ -124,7 +152,20 @@ def run_scp_command(host, src, dst, port=None, dryrun=False):
logger.info(f"Run the scp command: {scp_cmd}")
if dryrun:
return
- subprocess.run(scp_cmd, shell=True, check=True)
+ result = subprocess.run(
+ scp_cmd,
+ shell=True,
+ check=True,
+ capture_output=True,
+ text=True,
+ encoding="utf-8",
+ errors="replace",
+ )
+ if result.returncode != 0:
+ print(f"SCP command {scp_cmd} failed with return code {result.returncode}.")
+ print(f"Output: {result.stdout}")
+ print(f"Error: {result.stderr}")
+ sys.exit(result.returncode)
def flatten_dict_to_args(config_dict, ignore_keys=[]):
@@ -188,3 +229,26 @@ def get_nproc_per_node(
return num_visible_devices
else:
return 1
+
+
+def add_decive_extra_config(config, device_type):
+ if device_type is None:
+ logger.warning(
+ f"type in hostfile is not specified. All the nodes use the same arguments inlucding evnironment variables."
+ )
+ return OmegaConf.to_container(config, resolve=True)
+ cur_node_config = {}
+ temp_dict = {}
+ if isinstance(config, DictConfig):
+ temp_dict = OmegaConf.to_container(config, resolve=True)
+ else:
+ temp_dict = config
+ for key, value in temp_dict.items():
+ if isinstance(value, dict):
+ if key == device_type:
+ cur_node_config.update(value)
+ else:
+ continue
+ else:
+ cur_node_config[key] = value
+ return cur_node_config
diff --git a/flagscale/serve/README.md b/flagscale/serve/README.md
new file mode 100644
index 000000000..8a312f1b5
--- /dev/null
+++ b/flagscale/serve/README.md
@@ -0,0 +1,93 @@
+## Introduce
+
+We introduce support for deploying large models with FlagScale, leveraging the Ray framework for efficient orchestration and scalability. Currently, this implementation supports the Qwen model, enabling users to easily deploy and manage large-scale machine learning services.
+
+Future Key features include:
+
+- Easy distributed Serve on base of eamless integration with Ray.
+- Optimized resource management for large model inference.
+- Simplified deployment process for the LLM and Multimodal models.
+
+This enhancement will significantly improve the usability of FlagScale for large model deployment scenarios.
+
+## Setup
+
+[Install vLLM](../../README.md#setup)
+
+## Prepare Model
+
+[Prepare Qwen data](https://www.modelscope.cn/models/Qwen/Qwen2.5-7B-Instruct/summary)
+
+```shell
+pip install modelscope
+modelscope download --model Qwen/Qwen2.5-7B-Instruct --local_dir /models/
+```
+
+## Serve run
+
+```shell
+cd FlagScale
+python run.py --config-path ./examples/qwen/conf --config-name config_qwen2.5_7b action=run
+```
+
+## Serve call
+
+```shell
+curl http://127.0.0.1:4567/v1/chat/completions -H "Content-Type: application/json" -d '{
+ "model": "/models/Qwen2.5-7B-Instruct",
+ "messages": [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "Introduce Bruce Lee in details."}
+ ]
+ }'
+```
+
+## Serve stop
+
+```shell
+cd FlagScale
+python run.py --config-path ./examples/qwen/conf --config-name config_qwen2.5_7b action=stop
+```
+
+## logs
+
+Since serve is the distributed mode, the logs are stored separately. \
+The default logs of are loacated in `/tmp/ray/session_latest/logs`.\
+The log of each work is named as `worker-[worker_id]-[job_id]-[pid].[out|err]`.
+
+## Config Template
+
+Flagscale.serve will support multiple scenarios. For better performance and usage, Flagscale.serve will optimize for specific scenarios, and these optimizations can be applied through different configurations.
+
+### Command Line Mode with vLLM
+
+If origin model is excuted in command line mode with vLLM, we can use Flagscale.serve to deploy it easily.
+
+```shell
+vllm serve /models/Qwen2.5-7B-Instruct --tensor-parallel-size=1 --gpu-memory-utilization=0.9 --max-model-len=32768 --max-num-seqs=256 --port=4567 --trust-remote-code --enable-chunked-prefill
+```
+
+All the args remain the same as vLLM. Note that action args without value, like trust-remote-code and enable-chunked-prefill, are located in **action-args** block in config file.
+
+```YAML
+model_args:
+ vllm_model:
+ model-tag: /models/Qwen2.5-7B-Instruct
+ tensor-parallel-size: 1
+ gpu-memory-utilization: 0.9
+ max-model-len: 32768
+ max-num-seqs: 256
+ port: 4567
+ action-args:
+ - trust-remote-code
+ - enable-chunked-prefill
+
+deploy:
+ command-line-mode: true
+ models:
+ vllm_model:
+ num_gpus: 1
+```
+
+### How to config serve parameters
+***deploy*** block is used to specify the parameters of serve. The ***models*** block is used to specify the parameters of each model decorated by "serve.remote".
diff --git a/flagscale/serve/__init__.py b/flagscale/serve/__init__.py
new file mode 100644
index 000000000..f9aecd94b
--- /dev/null
+++ b/flagscale/serve/__init__.py
@@ -0,0 +1,6 @@
+from .utils import init, run, stop, prepare, remote, task_config
+
+
+__all__ = ["init", "run", "stop", "prepare", "remote", "task_config"]
+
+prepare()
diff --git a/flagscale/serve/run_vllm.py b/flagscale/serve/run_vllm.py
new file mode 100644
index 000000000..171239cbc
--- /dev/null
+++ b/flagscale/serve/run_vllm.py
@@ -0,0 +1,73 @@
+import os
+import sys
+import datetime
+import inspect
+import subprocess
+import argparse
+import logging as logger
+from omegaconf import OmegaConf
+import ray
+from flagscale import serve
+
+
+timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
+
+
+@serve.remote(name="vllm_model")
+def vllm_model(args):
+
+ vllm_args = args["serve"]["model_args"]["vllm_model"]
+
+ command = ["vllm", "serve"]
+ command.append(vllm_args["model-tag"])
+ for item in vllm_args:
+ if item not in {"model-tag", "action-args"}:
+ command.append(f"--{item}={vllm_args[item]}")
+ for arg in vllm_args["action-args"]:
+ command.append(f"--{arg}")
+
+ # Start the subprocess
+ logger.info(f"[Serve]: Starting vllm serve with command: {' '.join(command)}")
+ runtime_context = ray.get_runtime_context()
+ worker_id = runtime_context.get_worker_id()
+ job_id = runtime_context.get_job_id()
+ logger.info(
+ f"[Serve]: Current Job ID: {job_id} , \n[Serve]: ******** Worker ID: {worker_id} ********\n\n"
+ )
+ link_dir = os.path.join(
+ args.log_dir, f"session_latest_{timestamp}", f"worker-{worker_id}-"
+ )
+ logger.info(
+ f"\n\n[Serve]: ********************** {inspect.currentframe().f_code.co_name} Worker log path\
+ ********************** \n[Serve]: {link_dir} \n\n"
+ )
+
+ process = subprocess.Popen(command, stdout=sys.stdout, stderr=sys.stderr)
+ pid = os.getpid()
+ logger.info(f"[Serve]: Current vLLM PID: {pid} ")
+
+ stdout, stderr = process.communicate()
+ logger.info(f"[Serve]: Standard Output: {stdout}")
+ logger.info(f"[Serve]: Standard Error: {stderr}")
+
+ return process.returncode
+
+
+def main():
+ # Note: Custom log dir here may cause "OSError: AF_UNIX path length cannot exceed 107 bytes:"
+ ray.init(
+ log_to_driver=True,
+ logging_config=ray.LoggingConfig(encoding="TEXT", log_level="INFO"),
+ )
+ link_dir = os.path.join(serve.task_config.log_dir, f"session_latest_{timestamp}")
+ tar_dir = ray._private.worker.global_worker.node._logs_dir
+ os.symlink(tar_dir, link_dir)
+
+ result = vllm_model.remote(serve.task_config)
+ return_code = ray.get(result)
+
+ logger.info(f"[Serve]: vLLM serve exited with return code: {return_code}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/flagscale/serve/utils.py b/flagscale/serve/utils.py
new file mode 100644
index 000000000..d4061e57b
--- /dev/null
+++ b/flagscale/serve/utils.py
@@ -0,0 +1,67 @@
+from omegaconf import OmegaConf
+import argparse
+import ray
+
+
+task_config = OmegaConf.create()
+
+
+class TaskManager:
+ def __init__(self):
+ pass
+
+
+def init():
+ ray.init(address="auto")
+
+
+def run():
+ ray.run()
+
+
+def stop():
+ ray.shutdown()
+
+
+def remote(*args, **kwargs):
+ """Transform a function into a Ray task"""
+ _load()
+ def _merge_kwargs(func_name, **kwargs):
+ new_kwargs = kwargs.copy()
+ models = task_config.serve.deploy.models
+
+ if func_name in models:
+ new_kwargs.update(models[func_name])
+ if "model_name" not in kwargs:
+ new_kwargs.pop("model_name", None)
+
+ return new_kwargs
+
+ new_kwargs = _merge_kwargs(kwargs["name"], **kwargs)
+
+ return ray.remote(*args, **new_kwargs)
+
+
+def _load() -> None:
+ """Load configuration for cluster init"""
+ parser = argparse.ArgumentParser(description="Start vllm serve with Ray")
+
+ parser.add_argument(
+ "--config-path", type=str, required=True, help="Path to the model"
+ )
+ parser.add_argument("--log-dir", type=str, default="outputs", help="Path to the model")
+ args = parser.parse_args()
+
+ config = OmegaConf.load(args.config_path)
+
+ global task_config
+ task_config.update(config)
+ task_config.update({"log_dir": args.log_dir})
+
+ return
+
+
+def prepare() -> None:
+ # Load config
+ # _load()
+ return
diff --git a/flagscale/train/hetero/p2p_communication.py b/flagscale/train/hetero/p2p_communication.py
index 43af7ff18..590d59c69 100644
--- a/flagscale/train/hetero/p2p_communication.py
+++ b/flagscale/train/hetero/p2p_communication.py
@@ -65,7 +65,8 @@ def warm_up_comm_group_hetero(config: ModelParallelConfig):
for pp_group in pp_groups:
group_ranks = torch.distributed.get_process_group_ranks(pp_group)
- if rank == group_ranks[0]:
+ pipeline_rank = get_pipeline_model_parallel_rank()
+ if pipeline_rank == 0:
_communicate(
tensor_send_next=to_send_tensor,
tensor_send_prev=None,
@@ -75,7 +76,7 @@ def warm_up_comm_group_hetero(config: ModelParallelConfig):
config=config,
group=pp_group,
)
- elif rank == group_ranks[-1]:
+ elif pipeline_rank == len(group_ranks) - 1:
_communicate(
tensor_send_next=None,
tensor_send_prev=None,
diff --git a/flagscale/train/hetero/parallel_context.py b/flagscale/train/hetero/parallel_context.py
index 788bd35e6..6303be272 100644
--- a/flagscale/train/hetero/parallel_context.py
+++ b/flagscale/train/hetero/parallel_context.py
@@ -130,6 +130,7 @@ def __init__(
):
assert torch.distributed.is_initialized()
self._args = args
+ self._distributed_backend = args.distributed_backend
self._rank = torch.distributed.get_rank()
self._world_size = tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size * data_parallel_size
self._offset = offset
@@ -205,7 +206,7 @@ def build_process_group(
ranks = self._rank_mapper.to_physical_ranks(logical_ranks)
group = torch.distributed.new_group(
ranks,
- backend="nccl",
+ backend=self._distributed_backend,
timeout=self._timeout,
pg_options=pg_options,
)
@@ -996,7 +997,11 @@ def get_pipeline_model_parallel_rank(self, group=None):
return rank
if group is None:
group = self.get_pipeline_model_parallel_group()[0]
- return torch.distributed.get_rank(group=group)
+ if group not in self._process_group_to_ranks: # local pipeline group
+ return torch.distributed.get_rank(group=group)
+ else:
+ ranks = self._process_group_to_ranks[group]
+ return ranks.index(self._rank)
def get_pipeline_model_parallel_split_rank(self):
"""Return pipeline model parallel split rank."""
@@ -1153,11 +1158,14 @@ def get_pipeline_model_parallel_next_rank(self, group=None):
if group is None:
group = self.get_pipeline_model_parallel_group()[0]
ranks = self._process_group_to_ranks.get(group, None)
- if ranks is None:
+ if ranks is None: # local pipeline group
current_process_mesh = self._process_meshes[self._current_process_mesh_index]
ranks = current_process_mesh.get_process_group_ranks(token="pp", independent_ep=False, check_initialized=True)
- assert ranks is not None, "Pipeline parallel group is not initialized"
- rank_in_pipeline = self.get_pipeline_model_parallel_rank(group)
+ assert ranks is not None, "Pipeline parallel group is not initialized"
+ rank_in_pipeline = torch.distributed.get_rank(group=group)
+ world_size = self.get_pipeline_model_parallel_world_size(group)
+ return ranks[(rank_in_pipeline + 1) % world_size]
+ rank_in_pipeline = ranks.index(self._rank)
world_size = self.get_pipeline_model_parallel_world_size(group)
return ranks[(rank_in_pipeline + 1) % world_size]
@@ -1166,11 +1174,14 @@ def get_pipeline_model_parallel_prev_rank(self, group=None):
if group is None:
group = self.get_pipeline_model_parallel_group()[0]
ranks = self._process_group_to_ranks.get(group, None)
- if ranks is None:
+ if ranks is None: # local pipeline group
current_process_mesh = self._process_meshes[self._current_process_mesh_index]
ranks = current_process_mesh.get_process_group_ranks(token="pp", independent_ep=False, check_initialized=True)
- assert ranks is not None, "Pipeline parallel group is not initialized"
- rank_in_pipeline = self.get_pipeline_model_parallel_rank(group)
+ assert ranks is not None, "Pipeline parallel group is not initialized"
+ rank_in_pipeline = torch.distributed.get_rank(group=group)
+ world_size = self.get_pipeline_model_parallel_world_size(group)
+ return ranks[(rank_in_pipeline - 1) % world_size]
+ rank_in_pipeline = ranks.index(self._rank)
world_size = self.get_pipeline_model_parallel_world_size(group)
return ranks[(rank_in_pipeline - 1) % world_size]
diff --git a/flagscale/train/train.py b/flagscale/train/train.py
index 7e5dc158e..f100b911b 100644
--- a/flagscale/train/train.py
+++ b/flagscale/train/train.py
@@ -1588,6 +1588,7 @@ def get_e2e_base_metrics():
torch.distributed.get_rank() in args.profile_ranks:
if args.use_pytorch_profiler:
prof.stop()
+ print(prof.key_averages(group_by_stack_n=10 ,group_by_input_shape=True).table(sort_by="cpu_time_total"))
else:
torch.cuda.cudart().cudaProfilerStop()
diff --git a/flagscale/train/train_aquila_sft.py b/flagscale/train/train_aquila_sft.py
new file mode 100644
index 000000000..3085ce573
--- /dev/null
+++ b/flagscale/train/train_aquila_sft.py
@@ -0,0 +1,335 @@
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+"""Pretrain GPT."""
+
+import os
+import sys
+from flagscale.utils import CustomModuleFinder
+sys.path.append(os.path.dirname(
+ os.path.dirname(os.path.abspath(__file__))))
+sys.meta_path.insert(0, CustomModuleFinder())
+
+import torch
+from functools import partial
+from contextlib import nullcontext
+import inspect
+
+from typing import Union
+from megatron.training import get_args
+from megatron.training import print_rank_0
+from megatron.training import get_timers
+from megatron.training import get_tokenizer
+from megatron.core import mpu
+from megatron.core.enums import ModelType
+from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
+from megatron.core.datasets.utils import get_blend_from_list
+from megatron.core.datasets.gpt_dataset import GPTDatasetConfig
+from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset
+import megatron.legacy.model
+from megatron.core.models.gpt import GPTModel
+
+from megatron.core.utils import StragglerDetector
+from megatron.core.transformer.spec_utils import import_module
+from megatron.training.utils import (
+ get_batch_on_this_ulysses_sp_rank,
+ get_batch_on_this_cp_rank,
+ get_batch_on_this_tp_rank,
+)
+from megatron.training.arguments import core_transformer_config_from_args
+from megatron.training.yaml_arguments import core_transformer_config_from_yaml
+from megatron.core.models.gpt.gpt_layer_specs import (
+ get_gpt_layer_local_spec,
+ get_gpt_layer_with_transformer_engine_spec,
+)
+from flagscale.datasets.sft_dataset import SFTDatasetConfig, SFTDataset
+from flagscale.train.extra_valid import extra_valid_dataset_provider
+from flagscale.train.train import pretrain
+from flagscale.train.global_vars import get_parallel_context
+
+
+stimer = StragglerDetector()
+
+def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.legacy.model.GPTModel]:
+ """Builds the model.
+
+ If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model.
+
+ Args:
+ pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.
+ post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True.
+
+
+ Returns:
+ Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model
+ """
+ args = get_args()
+ use_te = args.transformer_impl == "transformer_engine"
+
+ print_rank_0('building GPT model ...')
+ # Experimental loading arguments from yaml
+ config = None
+ if args.yaml_cfg is not None:
+ config = core_transformer_config_from_yaml(args, "language_model")
+ else:
+ para_ctx = get_parallel_context()
+ if para_ctx is not None:
+ config = para_ctx.get_transformer_config()
+
+ if config is None:
+ config = core_transformer_config_from_args(args)
+
+ if args.use_legacy_models:
+ model = megatron.legacy.model.GPTModel(
+ config,
+ num_tokentypes=0,
+ parallel_output=True,
+ pre_process=pre_process,
+ post_process=post_process,
+ )
+ else: # using core models
+ if args.spec is not None:
+ transformer_layer_spec = import_module(args.spec)
+ else:
+ if use_te:
+ transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm, args.multi_latent_attention, args.fp8)
+ else:
+ transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm, args.multi_latent_attention)
+
+ build_model_context = nullcontext
+ build_model_context_args = {}
+ if args.fp8_param_gather:
+ try:
+ from transformer_engine.pytorch import fp8_model_init
+
+ build_model_context = fp8_model_init
+ build_model_context_args["enabled"] = True
+
+ # Check if fp8_model_init supports preserve_high_precision_init_val
+ if "preserve_high_precision_init_val" in inspect.signature(fp8_model_init).parameters:
+ build_model_context_args["preserve_high_precision_init_val"] = True
+ except:
+ raise RuntimeError("--fp8-param-gather requires `fp8_model_init` from TransformerEngine, but not found.")
+
+ with build_model_context(**build_model_context_args):
+ model = GPTModel(
+ config=config,
+ transformer_layer_spec=transformer_layer_spec,
+ vocab_size=args.padded_vocab_size,
+ max_sequence_length=args.max_position_embeddings,
+ pre_process=pre_process,
+ post_process=post_process,
+ fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
+ parallel_output=True,
+ share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
+ position_embedding_type=args.position_embedding_type,
+ rotary_percent=args.rotary_percent,
+ rotary_base=args.rotary_base,
+ rope_scaling=args.use_rope_scaling
+ )
+
+ return model
+
+
+def get_batch(data_iterator):
+ """Generate a batch."""
+
+ # TODO: this is pretty hacky, find a better way
+ if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
+ return None, None, None, None, None
+
+ # get batches based on the TP rank you are on
+ batch = get_batch_on_this_tp_rank(data_iterator)
+
+ # slice batch along sequence dimension for context parallelism
+ batch = get_batch_on_this_cp_rank(batch)
+
+ # slice batch along sequence dimension for ulysses sequence parallelism
+ batch = get_batch_on_this_ulysses_sp_rank(batch)
+
+ return batch.values()
+
+
+def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
+ """Loss function.
+
+ Args:
+ loss_mask (torch.Tensor): Used to mask out some portions of the loss
+ output_tensor (torch.Tensor): The tensor with the losses
+
+ Returns:
+ the loss scalar for this micro-batch
+ the number of non-padded tokens in this microbatch
+ a dict containing reporting metrics on the loss and number of tokens across
+ the data parallel ranks
+ """
+
+ args = get_args()
+
+ losses = output_tensor.float()
+ loss_mask = loss_mask.view(-1).float()
+ total_tokens = loss_mask.sum()
+ loss = torch.cat([torch.sum(torch.masked_select(losses.view(-1) , loss_mask==1)).view(1), total_tokens.view(1)])
+
+ if args.ulysses_sp_parallel_size > 1:
+ torch.distributed.all_reduce(loss, group=mpu.get_ulysses_sp_parallel_group())
+
+ if args.context_parallel_size > 1:
+ torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group())
+
+ # Check individual rank losses are not NaN prior to DP all-reduce.
+ if args.check_for_nan_in_loss_and_grad:
+ global_rank = torch.distributed.get_rank()
+ assert not loss[0].isnan(), (
+ f'Rank {global_rank}: found NaN in local forward loss calculation. '
+ f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}'
+ )
+
+ # Reduce loss for logging.
+ reporting_loss = loss.clone().detach()
+ torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group())
+
+ local_num_tokens = loss[1].clone().detach().to(torch.int)
+ return (
+ loss[0] * args.context_parallel_size * args.ulysses_sp_parallel_size,
+ local_num_tokens,
+ {'lm loss': (reporting_loss[0], reporting_loss[1])},
+ )
+
+def forward_step(data_iterator, model: GPTModel):
+ """Forward training step.
+
+ Args:
+ data_iterator : Input data iterator
+ model (GPTModel): The GPT Model
+ """
+ args = get_args()
+ timers = get_timers()
+
+ # Get the batch.
+ timers('batch-generator', log_level=2).start()
+ global stimer
+ with stimer(bdata=True):
+ tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
+ data_iterator)
+ timers('batch-generator').stop()
+
+ with stimer:
+ output_tensor = model(tokens, position_ids, attention_mask,
+ labels=labels)
+
+ return output_tensor, partial(loss_func, loss_mask)
+
+
+def is_dataset_built_on_rank():
+ return (
+ mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()
+ ) and mpu.get_tensor_model_parallel_rank() == 0
+
+
+def core_gpt_dataset_config_from_args(args):
+ tokenizer = get_tokenizer()
+
+ return GPTDatasetConfig(
+ random_seed=args.seed,
+ sequence_length=args.seq_length,
+ blend=get_blend_from_list(args.data_path),
+ blend_per_split=[
+ get_blend_from_list(args.train_data_path),
+ get_blend_from_list(args.valid_data_path),
+ get_blend_from_list(args.test_data_path)
+ ],
+ renormalize_blend_weights=args.renormalize_blend_weights,
+ split=args.split,
+ num_dataset_builder_threads=args.num_dataset_builder_threads,
+ path_to_cache=args.data_cache_path,
+ mmap_bin_files=args.mmap_bin_files,
+ tokenizer=tokenizer,
+ reset_position_ids=args.reset_position_ids,
+ reset_attention_mask=args.reset_attention_mask,
+ eod_mask_loss=args.eod_mask_loss,
+ create_attention_mask=args.create_attention_mask_in_dataloader,
+ s3_cache_path = args.s3_cache_path,
+ )
+
+
+def core_sft_dataset_config_from_args(args):
+ tokenizer = get_tokenizer()
+
+ return SFTDatasetConfig(
+ random_seed=args.seed,
+ sequence_length=args.seq_length,
+ blend=get_blend_from_list(args.data_path),
+ blend_per_split=[
+ get_blend_from_list(args.train_data_path),
+ get_blend_from_list(args.valid_data_path),
+ get_blend_from_list(args.test_data_path)
+ ],
+ renormalize_blend_weights=args.renormalize_blend_weights,
+ split=args.split,
+ num_dataset_builder_threads=args.num_dataset_builder_threads,
+ path_to_cache=args.data_cache_path,
+ mmap_bin_files=args.mmap_bin_files,
+ tokenizer=tokenizer,
+ reset_position_ids=args.reset_position_ids,
+ reset_attention_mask=args.reset_attention_mask,
+ eod_mask_loss=args.eod_mask_loss,
+ create_attention_mask=args.create_attention_mask_in_dataloader,
+ apply_sft_dataset_separated_loss_mask_if_existed=args.apply_sft_dataset_separated_loss_mask_if_existed,
+ )
+
+
+def train_valid_test_datasets_provider(train_val_test_num_samples):
+ """Build the train test and validation datasets.
+
+ Args:
+ train_val_test_num_samples : A list containing the number of samples in train test and validation.
+ """
+ args = get_args()
+
+ config = None
+ para_ctx = get_parallel_context()
+ if para_ctx is not None:
+ config = para_ctx.get_dataset_config()
+
+ if config is None:
+ if args.apply_sft_dataset_separated_loss_mask_if_existed:
+ config = core_sft_dataset_config_from_args(args)
+ else:
+ config = core_gpt_dataset_config_from_args(args)
+
+ if args.mock_data:
+ dataset_type = MockGPTDataset
+ elif args.apply_sft_dataset_separated_loss_mask_if_existed:
+ dataset_type = SFTDataset
+ else:
+ dataset_type = GPTDataset
+
+ print_rank_0("> building train, validation, and test datasets for GPT ...")
+
+ train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(
+ dataset_type,
+ train_val_test_num_samples,
+ is_dataset_built_on_rank,
+ config
+ ).build()
+
+ print_rank_0("> finished creating GPT datasets ...")
+
+ return train_ds, valid_ds, test_ds
+
+
+if __name__ == "__main__":
+
+ # Temporary for transition to core datasets
+ train_valid_test_datasets_provider.is_distributed = True
+
+ extra_valid_dataset_provider.is_distributed = True
+
+ pretrain(
+ train_valid_test_datasets_provider,
+ model_provider,
+ ModelType.encoder_or_decoder,
+ forward_step,
+ args_defaults={'tokenizer_type': 'GPT2BPETokenizer'},
+ get_batch_fn=get_batch,
+ extra_valid_dataset_provider=extra_valid_dataset_provider
+ )
diff --git a/hardware/kunlunxin_R300p/237377e9/237377e9.patch b/hardware/kunlunxin_R300p/237377e9/237377e9.patch
new file mode 100644
index 000000000..f404971ec
--- /dev/null
+++ b/hardware/kunlunxin_R300p/237377e9/237377e9.patch
@@ -0,0 +1,117 @@
+From 115b26cc46200236cccfe072cf0049b39853b168 Mon Sep 17 00:00:00 2001
+From: brianlcy123
+Date: Sun, 24 Nov 2024 19:12:03 +0800
+Subject: [PATCH] [kunlunxin] add patch for mixtral
+
+---
+ .../megatron/core/dist_checkpointing/strategies/base.py | 4 ++--
+ .../megatron/core/distributed/param_and_grad_buffer.py | 7 ++++++-
+ megatron/megatron/core/transformer/moe/moe_utils.py | 6 +++---
+ megatron/megatron/core/transformer/moe/token_dispatcher.py | 4 ++--
+ megatron/megatron/training/checkpointing.py | 3 ++-
+ 5 files changed, 15 insertions(+), 9 deletions(-)
+
+diff --git a/megatron/megatron/core/dist_checkpointing/strategies/base.py b/megatron/megatron/core/dist_checkpointing/strategies/base.py
+index cc1c83b9..125779a0 100644
+--- a/megatron/megatron/core/dist_checkpointing/strategies/base.py
++++ b/megatron/megatron/core/dist_checkpointing/strategies/base.py
+@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
+ from collections import defaultdict
+ from enum import Enum
+ from pathlib import Path
+-from typing import Any, DefaultDict
++from typing import Any, DefaultDict, Dict, Tuple
+
+ from ..mapping import CheckpointingException, ShardedStateDict, StateDict
+ from .async_utils import AsyncCallsQueue, AsyncRequest
+@@ -20,7 +20,7 @@ class StrategyAction(Enum):
+
+
+ _import_trigger = None
+-default_strategies: DefaultDict[str, dict[tuple, Any]] = defaultdict(dict)
++default_strategies: DefaultDict[str, Dict[Tuple, Any]] = defaultdict(dict)
+
+ async_calls = AsyncCallsQueue()
+
+diff --git a/megatron/megatron/core/distributed/param_and_grad_buffer.py b/megatron/megatron/core/distributed/param_and_grad_buffer.py
+index 77ecd7be..c2761c6e 100644
+--- a/megatron/megatron/core/distributed/param_and_grad_buffer.py
++++ b/megatron/megatron/core/distributed/param_and_grad_buffer.py
+@@ -248,6 +248,11 @@ class ParamAndGradBuffer:
+ def _pad(number_to_be_padded: int, divisor: int) -> int:
+ return int(math.ceil(number_to_be_padded / divisor) * divisor)
+
++ import math
++
++ def _lcm(a, b):
++ return abs(a * b) // math.gcd(a, b)
++
+ def _pad_end_of_bucket_if_needed(bucket_end_index: int) -> int:
+ """
+ Pads end index of bucket if using distributed optimizer (to ensure uniform sharding).
+@@ -257,7 +262,7 @@ class ParamAndGradBuffer:
+ # This also helps cuBLAS pick more efficient algorithms for GEMMs.
+ # We now ensure that all buckets start at a memory address that is 256-byte
+ # aligned (128 values since params and grads use >= 16-bit precision).
+- return _pad(bucket_end_index, math.lcm(self.data_parallel_world_size, 128))
++ return _pad(bucket_end_index, _lcm(self.data_parallel_world_size, 128))
+ return bucket_end_index
+
+ def _pad_start_of_param_if_needed(param_start_index: int) -> int:
+diff --git a/megatron/megatron/core/transformer/moe/moe_utils.py b/megatron/megatron/core/transformer/moe/moe_utils.py
+index ee4bb690..a3c1fd69 100644
+--- a/megatron/megatron/core/transformer/moe/moe_utils.py
++++ b/megatron/megatron/core/transformer/moe/moe_utils.py
+@@ -366,8 +366,8 @@ def topk_softmax_with_capacity(
+
+ if capacity_factor is None:
+ # TopK without capacity
+- tokens_per_expert = torch.bincount(top_indices.view(-1), minlength=num_experts)
+- return probs, top_indices, tokens_per_expert
++ tokens_per_expert = torch.bincount(top_indices.cpu().view(-1), minlength=num_experts)
++ return probs, top_indices, tokens_per_expert.cuda()
+ else:
+ # TopK with capacity
+ expert_capacity = get_capacity(
+@@ -380,7 +380,7 @@ def topk_softmax_with_capacity(
+ # Maskout exceeded tokens
+ if drop_policy == "probs":
+ capacity_probs, capacity_indices = torch.topk(
+- topk_masked_gates, k=expert_capacity, dim=0, sorted=False
++ topk_masked_gates, k=expert_capacity, dim=0, sorted=True #mod by zh
+ )
+ capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1)
+ elif drop_policy == "position":
+diff --git a/megatron/megatron/core/transformer/moe/token_dispatcher.py b/megatron/megatron/core/transformer/moe/token_dispatcher.py
+index 84f3d450..6a0b4a28 100644
+--- a/megatron/megatron/core/transformer/moe/token_dispatcher.py
++++ b/megatron/megatron/core/transformer/moe/token_dispatcher.py
+@@ -179,10 +179,10 @@ class MoEAllGatherTokenDispatcher(MoETokenDispatcher):
+
+ with torch.no_grad():
+ tokens_per_expert = torch.bincount(
+- local_indices.view(-1), minlength=self.config.num_moe_experts
++ local_indices.cpu().view(-1), minlength=self.config.num_moe_experts
+ )
+ if self.num_local_experts < self.config.num_moe_experts:
+- tokens_per_expert = tokens_per_expert[
++ tokens_per_expert = tokens_per_expert.cuda()[
+ self.local_expert_indices[0] : self.local_expert_indices[-1] + 1
+ ]
+ tokens_per_expert = tokens_per_expert.cpu().to(torch.long)
+diff --git a/megatron/megatron/training/checkpointing.py b/megatron/megatron/training/checkpointing.py
+index 6e58b317..6c650c4e 100644
+--- a/megatron/megatron/training/checkpointing.py
++++ b/megatron/megatron/training/checkpointing.py
+@@ -1057,7 +1057,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
+ restore_modelopt_state(model, state_dict)
+
+ # Model.
+- strict = False if args.retro_add_retriever else strict
++ # strict = False if args.retro_add_retriever else strict
++ strict = False
+ if len(model) == 1:
+ model[0].load_state_dict(state_dict['model'], strict=strict)
+ else:
+--
+2.25.1
diff --git a/hardware/kunlunxin_R300p/a44556c0/a44556c0.patch b/hardware/kunlunxin_R300p/a44556c0/a44556c0.patch
new file mode 100644
index 000000000..77c365928
--- /dev/null
+++ b/hardware/kunlunxin_R300p/a44556c0/a44556c0.patch
@@ -0,0 +1,107 @@
+From 9289f099424ba4d0dec83fb5715d4d2561f4c4d8 Mon Sep 17 00:00:00 2001
+From: brianlcy123
+Date: Thu, 21 Nov 2024 15:46:54 +0800
+Subject: [PATCH] [kunlunxin] add updated llama3 70b patch
+
+---
+ examples/llama/conf/config.yaml | 40 ++++++++++++++++++++-
+ megatron/megatron/training/arguments.py | 18 +++++-----
+ megatron/megatron/training/checkpointing.py | 5 +--
+ 3 files changed, 51 insertions(+), 12 deletions(-)
+
+diff --git a/examples/llama/conf/config.yaml b/examples/llama/conf/config.yaml
+index 592c45bf..27fb83ae 100644
+--- a/examples/llama/conf/config.yaml
++++ b/examples/llama/conf/config.yaml
+@@ -46,6 +46,44 @@ experiment:
+ CUDA_DEVICE_MAX_CONNECTIONS: 1
+ NVTE_APPLY_QK_LAYER_SCALING: 0
+ NVTE_ALLOW_NONDETERMINISTIC_ALGO: 0
++ ALLGATHER_ASYNC: false
++ ALLREDUCE_ASYNC: false
++ ALLREDUCE_FUSION: 0
++ BKCL_CCIX_BUFFER_GM: 1
++ BKCL_CCIX_RING: 1
++ BKCL_ENABLE_XDR: 1
++ BKCL_FLAT_RING: 1
++ BKCL_KL3_TURBO_MODE: 1
++ BKCL_RDMA_FORCE_TREE: 1
++ BKCL_RDMA_NICS: ens11np0,ens11np0,ens13np0,ens13np0,ens15np0,ens15np0,ens17np0,ens17np0
++ BKCL_RDMA_PROXY_DISABLE: 1
++ BKCL_RING_BUFFER_GM: 1
++ BKCL_TIMEOUT: 360000
++ BKCL_TRANS_UNSUPPORTED_DATATYPE: 8
++ BKCL_TREE_THRESHOLD: 1
++ BKCL_XLINK_C2C: 1
++ BKCL_XLINK_D2D: 0
++ BKCL_XLINK_ETH: 0
++ CUDART_DUMMY_REGISTER: 1
++ FAST_SWIGLU_ENABLE: 1
++ USE_FAST_BF16_FC: true
++ USE_L3: 1
++ XDNN_USE_FAST_SWISH: true
++ XPU_ZEBU_MODE: 1
++ XPU_FORCE_USERMODE_LAUNCH: 1
++ DIST_MULTI_STREAM: true
++ XMLIR_DIST_SINGLETON_STREAM: true
++ XMLIR_FA_GEMM_TYPE: float
++ XBLAS_FC_HBM_VERSION: 40
++ XMLIR_PARALLEL_SAVE_MEMORY: false
++ XMLIR_DISABLE_CUDA_ALLOCATOR: true
++ XMLIR_XDNN_PYTORCH_CHECK_ENABLE_FALLBACK_BOOL: 0
++ XMLIR_ENABLE_FALLBACK_TO_CPU_BOOL: False
++ XMLIR_DUMP_FALLBACK_OP_LIST_BOOL: true
++ XMLIR_BATCH_PARALLEL: true
++ DIST_MULTI_STREAM: true
++ CUDA_DEVICE_MAX_CONNECTIONS: 8
++ XMLIR_DIST_ASYNC_ISEND_IRECV: 1
+ action: run
+
+ hydra:
+diff --git a/megatron/megatron/training/arguments.py b/megatron/megatron/training/arguments.py
+index e20f178b..7e79da2a 100644
+--- a/megatron/megatron/training/arguments.py
++++ b/megatron/megatron/training/arguments.py
+@@ -652,15 +652,15 @@ def validate_args(args, defaults={}):
+ if args.sequence_parallel:
+ args.async_tensor_model_parallel_allreduce = False
+
+- if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
+- if args.sequence_parallel:
+- raise RuntimeError(
+- "Using sequence parallelism requires setting the environment variable "
+- "CUDA_DEVICE_MAX_CONNECTIONS to 1")
+- if args.async_tensor_model_parallel_allreduce:
+- raise RuntimeError(
+- "Using async gradient all reduce requires setting the environment "
+- "variable CUDA_DEVICE_MAX_CONNECTIONS to 1")
++ # if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
++ # if args.sequence_parallel:
++ # raise RuntimeError(
++ # "Using sequence parallelism requires setting the environment variable "
++ # "CUDA_DEVICE_MAX_CONNECTIONS to 1")
++ # if args.async_tensor_model_parallel_allreduce:
++ # raise RuntimeError(
++ # "Using async gradient all reduce requires setting the environment "
++ # "variable CUDA_DEVICE_MAX_CONNECTIONS to 1")
+
+ # Disable bias gelu fusion if we are disabling bias altogether
+ if not args.add_bias_linear:
+diff --git a/megatron/megatron/training/checkpointing.py b/megatron/megatron/training/checkpointing.py
+index 01425f36..80fa0254 100644
+--- a/megatron/megatron/training/checkpointing.py
++++ b/megatron/megatron/training/checkpointing.py
+@@ -530,8 +530,9 @@ def save_dataloader_state(train_iterator, iteration, dataloader_save_path):
+
+ torch.distributed.barrier(group=mpu.get_data_parallel_group())
+
+- if mpu.get_data_parallel_rank() == 0:
+- ensure_directory_exists(data_state_save_path)
++ # if mpu.get_data_parallel_rank() == 0:
++ # ensure_directory_exists(data_state_save_path)
++ ensure_directory_exists(data_state_save_path)
+
+ torch.distributed.barrier(group=mpu.get_data_parallel_group())
+
+--
+2.25.1
diff --git a/megatron/megatron/core/dist_checkpointing/serialization.py b/megatron/megatron/core/dist_checkpointing/serialization.py
index 5493c96bb..cf6f51e44 100644
--- a/megatron/megatron/core/dist_checkpointing/serialization.py
+++ b/megatron/megatron/core/dist_checkpointing/serialization.py
@@ -8,6 +8,7 @@
loading the sharded tensors.
"""
+import os
import logging
from pathlib import Path
from typing import Dict, Optional, Set, Tuple, Union
@@ -339,10 +340,17 @@ def save(
f'Checkpoint destination directory does not exist: {checkpoint_dir}'
)
- if next(checkpoint_dir.iterdir(), None) is not None:
- raise CheckpointingException(
- f'Checkpoint destination directory ({checkpoint_dir}) is not empty'
- )
+ # Skip this if the env var exists, otherwise default to False
+ single_file_per_tensor_ckpt = os.getenv('FS_SFPT_CKPT_SAVE', 'False').lower() in (
+ 'true',
+ '1',
+ 't',
+ )
+ if not single_file_per_tensor_ckpt:
+ if next(checkpoint_dir.iterdir(), None) is not None:
+ raise CheckpointingException(
+ f'Checkpoint destination directory ({checkpoint_dir}) is not empty'
+ )
if common_strategy is not None:
raise NotImplementedError('The only supported common strategy is torch')
diff --git a/megatron/megatron/core/dist_checkpointing/strategies/filesystem_async.py b/megatron/megatron/core/dist_checkpointing/strategies/filesystem_async.py
index 9d0be4d6e..7d942e366 100644
--- a/megatron/megatron/core/dist_checkpointing/strategies/filesystem_async.py
+++ b/megatron/megatron/core/dist_checkpointing/strategies/filesystem_async.py
@@ -5,19 +5,21 @@
import logging
import os
import queue
+import pickle
from contextlib import contextmanager
from itertools import chain
from pathlib import Path
from time import time
-from typing import Callable, Dict, List, Optional, Tuple, Union
+from typing import Callable, Dict, List, Optional, Tuple, Union, cast
import psutil
import torch
from torch import multiprocessing as mp
from torch.distributed.checkpoint import FileSystemWriter
-from torch.distributed.checkpoint.filesystem import DEFAULT_SUFFIX, _StoragePrefix, _write_item
+from torch.distributed.checkpoint.filesystem import DEFAULT_SUFFIX, _StoragePrefix, _write_item, _metadata_fn
from torch.distributed.checkpoint.planner import SavePlan, SavePlanner, WriteItem, WriteItemType
from torch.distributed.checkpoint.storage import WriteResult
+from torch.distributed.checkpoint.metadata import Metadata
from torch.futures import Future
logger = logging.getLogger(__name__)
@@ -26,6 +28,40 @@
_results_queue = None
+_GLOBAL_PREVIOUS_METADATA = None
+
+_GLOBAL_PREVIOUS_COUNT = 0
+
+
+def get_previous_metadata():
+ """
+ Get the metadata from the previous save.
+ """
+ return _GLOBAL_PREVIOUS_METADATA
+
+
+def set_previous_metadata(metadata):
+ """
+ Set the metadata from the previous save.
+ """
+ global _GLOBAL_PREVIOUS_METADATA
+ _GLOBAL_PREVIOUS_METADATA = metadata
+
+
+def get_previous_count():
+ """
+ Get the count from the previous save.
+ """
+ return _GLOBAL_PREVIOUS_COUNT
+
+
+def set_previous_count(count):
+ """
+ Set the count from the previous save.
+ """
+ global _GLOBAL_PREVIOUS_COUNT
+ _GLOBAL_PREVIOUS_COUNT = count
+
def _get_write_results_queue():
global _results_queue
@@ -80,6 +116,13 @@ def __init__(self, *args, **kwargs):
self.write_buckets: Optional[List[WriteBucket]] = None
self.results_queue: Optional[mp.Queue] = None
+ # Get the value from the environment variable if it exists, otherwise default to False
+ self.single_file_per_tensor_ckpt = os.getenv('FS_SFPT_CKPT_SAVE', 'False').lower() in (
+ 'true',
+ '1',
+ 't',
+ )
+
def prepare_write_data(self, plan: SavePlan, planner: SavePlanner) -> None:
"""
First stage of async saving. Copy data to CPU and plan the local saving.
@@ -99,12 +142,17 @@ def prepare_write_data(self, plan: SavePlan, planner: SavePlanner) -> None:
start = time()
# move tensors from GPU to CPU before starting async writing
# We do D2H synchronously for now
- file_count = 0
+ if not self.single_file_per_tensor_ckpt:
+ file_count = 0
+ else:
+ file_count = get_previous_count()
def gen_file():
nonlocal file_count
file_name = f"{storage_plan.prefix}{file_count}{DEFAULT_SUFFIX}"
file_count += 1
+ if self.single_file_per_tensor_ckpt:
+ set_previous_count(file_count)
return file_name
# Prepare bytes / tensor data in each bucket, which will be assigned to each writer process
@@ -314,6 +362,48 @@ def retrieve_write_results(self) -> List[WriteResult]:
)
return list(chain.from_iterable(write_results.values()))
+ def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
+ # Modify based on the original implementation from torch.distributed.checkpoint.filesystem.FileSystemWriter
+ # https://github.com/pytorch/pytorch/blob/625c24a7f98a645b6f8758a01d7163a842582ce0/torch/distributed/checkpoint/filesystem.py#L574
+
+ if not self.single_file_per_tensor_ckpt:
+ storage_md = {}
+ else:
+ if get_previous_count() == 1:
+ storage_md = {}
+ else:
+ # Get the metadata from the previous save
+ prev_metadata = get_previous_metadata()
+ prev_metadata.state_dict_metadata.update(metadata.state_dict_metadata)
+ metadata = prev_metadata
+ storage_md = metadata.storage_data
+
+ for wr_list in results:
+ storage_md.update({wr.index: wr.storage_data for wr in wr_list})
+ metadata.storage_data = storage_md
+
+ if not self.single_file_per_tensor_ckpt or get_previous_count() == 1:
+ metadata.storage_meta = self.storage_meta()
+
+ tmp_path = cast(Path, self.fs.concat_path(self.path, f"{_metadata_fn}.tmp"))
+ with self.fs.create_stream(tmp_path, "wb") as metadata_file:
+ pickle.dump(metadata, metadata_file)
+ if self.sync_files:
+ try:
+ os.fsync(metadata_file.fileno())
+ except AttributeError:
+ os.sync()
+
+ # delete in-case other checkpoints were present.
+ if self.fs.exists(self.metadata_path):
+ self.fs.rm_file(self.metadata_path)
+
+ self.fs.rename(tmp_path, self.metadata_path)
+
+ # Store the metadata for the next save
+ if self.single_file_per_tensor_ckpt:
+ set_previous_metadata(metadata)
+
def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[WriteItem]]:
"""
@@ -349,7 +439,6 @@ def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[Writ
idx = min(enumerate(bucket_sizes), key=lambda x: x[1])[0]
buckets[idx].append(item)
bucket_sizes[idx] += _item_size(item)
-
return buckets
diff --git a/megatron/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/megatron/core/dist_checkpointing/strategies/torch.py
index 077d94eb7..8a41ad669 100644
--- a/megatron/megatron/core/dist_checkpointing/strategies/torch.py
+++ b/megatron/megatron/core/dist_checkpointing/strategies/torch.py
@@ -2,6 +2,7 @@
""" Strategies using PyTorch distributed.checkpoint as an underlying format. """
import io
+import os
from collections import ChainMap, defaultdict
from dataclasses import dataclass
from itertools import product
@@ -404,7 +405,6 @@ def _replace_sharded_keys_with_state_dict_keys(
assert len(tensors) == len(rename_mapping[k])
for ten, recovered_k in zip(tensors, rename_mapping[k]):
recovered_sd[recovered_k] = ten
-
return unflatten_state_dict(recovered_sd, flat_mapping)
@@ -734,6 +734,13 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> St
Returns: loaded state dict
"""
+ # Get the value from the environment variable if it exists, otherwise default to True
+ single_file_per_tensor_ckpt = os.getenv('FS_SFPT_CKPT_LOAD', 'False').lower() in (
+ 'true',
+ '1',
+ 't',
+ )
+
# Apply N-D tensors resharding
sharded_state_dict, formulation_restore_data = apply_nd_flattened_tensors_reformulation(
sharded_state_dict, get_reformulation_metadata(sharded_state_dict, checkpoint_dir)
@@ -752,13 +759,24 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> St
)
pyt_state_dict = mcore_to_pyt_state_dict(sharded_state_dict, True)
# Load PyT Distributed format
- checkpoint.load_state_dict(
- pyt_state_dict,
- FileSystemReader(checkpoint_dir),
- planner=MCoreLoadPlanner(
- shapes_validation_sharded_tensors=flexible_shape_sharded_tensors
- ),
- )
+ if not single_file_per_tensor_ckpt:
+ checkpoint.load_state_dict(
+ pyt_state_dict,
+ FileSystemReader(checkpoint_dir),
+ planner=MCoreLoadPlanner(
+ shapes_validation_sharded_tensors=flexible_shape_sharded_tensors
+ ),
+ )
+ else:
+ checkpoint.load_state_dict(
+ pyt_state_dict,
+ FileSystemReader(checkpoint_dir),
+ planner=MCoreLoadPlanner(
+ shapes_validation_sharded_tensors=flexible_shape_sharded_tensors,
+ allow_partial_load=True,
+ ),
+ )
+
pyt_state_dict = cast(
Dict[str, Union[TorchShardedTensor, List[io.BytesIO]]], pyt_state_dict
)
@@ -767,6 +785,13 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> St
k: v if not isinstance(v, TorchShardedTensor) else _unwrap_pyt_sharded_tensor(v)
for k, v in pyt_state_dict.items()
}
+
+ if single_file_per_tensor_ckpt:
+ mcore_state_dict = {
+ k: [None] if (not isinstance(v, list) and "_extra_state" in k) else v
+ for k, v in mcore_state_dict.items()
+ }
+
mcore_state_dict = _replace_sharded_keys_with_state_dict_keys(
mcore_state_dict, flat_mapping, rename_mapping
)
diff --git a/megatron/megatron/core/transformer/transformer_block.py b/megatron/megatron/core/transformer/transformer_block.py
index aabd8558b..ca978e6f1 100755
--- a/megatron/megatron/core/transformer/transformer_block.py
+++ b/megatron/megatron/core/transformer/transformer_block.py
@@ -1,5 +1,5 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
-
+import os
from contextlib import nullcontext
from dataclasses import dataclass
from typing import List, Optional, Union
@@ -316,66 +316,6 @@ def checkpoint_handler(forward_func):
rotary_pos_emb,
)
- if self.config.recompute_method_per_stage_micro_batch != None:
- if self.config.virtual_pipeline_model_parallel_size != None:
- if (
- self.config.recompute_method_per_stage_micro_batch[
- parallel_state.get_virtual_pipeline_model_parallel_rank()
- * self.config.pipeline_model_parallel_size
- + parallel_state.get_pipeline_model_parallel_rank()
- ][self.current_microbatch]
- == 0
- ):
- self.config.recompute_method = 'uniform'
- elif (
- self.config.recompute_method_per_stage_micor_batch[
- parallel_state.get_virtual_pipeline_model_parallel_rank()
- * self.config.pipeline_model_parallel_size
- + parallel_state.get_pipeline_model_parallel_rank()
- ][self.current_microbatch]
- == 1
- ):
- self.config.recompute_method = 'block'
- else:
- raise ValueError("the item of recompute_method_per_stage_micor_batch must be '0' or '1' ")
- else:
- if (
- self.config.recompute_method_per_stage_micro_batch[
- parallel_state.get_pipeline_model_parallel_rank()
- ][self.current_microbatch]
- == 0
- ):
- self.config.recompute_method = 'uniform'
- elif (
- self.config.recompute_method_per_stage_micro_batch[
- parallel_state.get_pipeline_model_parallel_rank()
- ][self.current_microbatch]
- == 1
- ):
- self.config.recompute_method = 'block'
-
- if self.config.recompute_num_layers_per_stage_micro_batch != None:
- if self.config.virtual_pipeline_model_parallel_size != None:
- self.config.recompute_num_layers = self.config.recompute_num_layers_per_stage_micro_batch[
- parallel_state.get_virtual_pipeline_model_parallel_rank()
- * self.config.pipeline_model_parallel_size
- + parallel_state.get_pipeline_model_parallel_rank()
- ][self.current_microbatch]
- else:
- self.config.recompute_num_layers = self.config.recompute_num_layers_per_stage_micro_batch[
- parallel_state.get_pipeline_model_parallel_rank()
- ][self.current_microbatch]
-
- if (
- self.config.recompute_granularity_per_stage_micro_batch != None
- and self.config.recompute_granularity_per_stage_micro_batch[
- parallel_state.get_pipeline_model_parallel_rank()
- ][self.current_microbatch]
- == 0
- ):
- self.recompute_granularity = None
- self.recompute_method = None
-
if self.config.recompute_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
@@ -538,6 +478,71 @@ def forward(
else:
fp8_context = nullcontext()
+ if self.config.recompute_method_per_stage_micro_batch != None:
+ if self.config.virtual_pipeline_model_parallel_size != None:
+ if (
+ self.config.recompute_method_per_stage_micro_batch[
+ parallel_state.get_virtual_pipeline_model_parallel_rank()
+ * self.config.pipeline_model_parallel_size
+ + parallel_state.get_pipeline_model_parallel_rank()
+ ][self.current_microbatch]
+ == 0
+ ):
+ self.config.recompute_method = 'uniform'
+ elif (
+ self.config.recompute_method_per_stage_micor_batch[
+ parallel_state.get_virtual_pipeline_model_parallel_rank()
+ * self.config.pipeline_model_parallel_size
+ + parallel_state.get_pipeline_model_parallel_rank()
+ ][self.current_microbatch]
+ == 1
+ ):
+ self.config.recompute_method = 'block'
+ else:
+ raise ValueError("the item of recompute_method_per_stage_micor_batch must be '0' or '1' ")
+ else:
+ if (
+ self.config.recompute_method_per_stage_micro_batch[
+ parallel_state.get_pipeline_model_parallel_rank()
+ ][self.current_microbatch]
+ == 0
+ ):
+ self.config.recompute_method = 'uniform'
+ elif (
+ self.config.recompute_method_per_stage_micro_batch[
+ parallel_state.get_pipeline_model_parallel_rank()
+ ][self.current_microbatch]
+ == 1
+ ):
+ self.config.recompute_method = 'block'
+ else:
+ raise ValueError("the item of recompute_method_per_stage_micor_batch must be '0' or '1' ")
+
+ if self.config.recompute_num_layers_per_stage_micro_batch != None:
+ if self.config.virtual_pipeline_model_parallel_size != None:
+ self.config.recompute_num_layers = self.config.recompute_num_layers_per_stage_micro_batch[
+ parallel_state.get_virtual_pipeline_model_parallel_rank()
+ * self.config.pipeline_model_parallel_size
+ + parallel_state.get_pipeline_model_parallel_rank()
+ ][self.current_microbatch]
+ else:
+ self.config.recompute_num_layers = self.config.recompute_num_layers_per_stage_micro_batch[
+ parallel_state.get_pipeline_model_parallel_rank()
+ ][self.current_microbatch]
+ if self.config.recompute_num_layers == 0:
+ self.config.recompute_method = None
+ self.config.recompute_granularity = None
+
+ if (
+ self.config.recompute_granularity_per_stage_micro_batch != None
+ and self.config.recompute_granularity_per_stage_micro_batch[
+ parallel_state.get_pipeline_model_parallel_rank()
+ ][self.current_microbatch]
+ == 0
+ ):
+ self.config.recompute_granularity = None
+ self.config.recompute_method = None
+
with rng_context and fp8_context:
# Forward pass.
if self.config.recompute_granularity == 'full' and self.training:
@@ -625,6 +630,16 @@ def sharded_state_dict(
non_homogeneous_layers = metadata is not None and metadata.get(
'non_homogeneous_layers', False
)
+
+ # TODO: @aoyulong - This is a temporary solution to support single-file-per-tensor ckpt
+ non_homogeneous_layers_env = os.getenv('FS_NON_HOMOGENEOUS_LAYERS', 'False').lower() in (
+ 'true',
+ '1',
+ 't',
+ )
+ if non_homogeneous_layers_env:
+ non_homogeneous_layers = True
+
sharded_state_dict = {}
layer_prefix = f'{prefix}layers.'
diff --git a/run.py b/run.py
index fba8ebb07..472f072ac 100644
--- a/run.py
+++ b/run.py
@@ -3,6 +3,7 @@
from flagscale.runner.runner_train import SSHTrainRunner, CloudTrainRunner
from flagscale.runner.runner_inference import SSHInferenceRunner
+from flagscale.runner.runner_serve import SSHServeRunner
@hydra.main(version_base=None, config_name="config")
@@ -47,6 +48,16 @@ def main(config: DictConfig) -> None:
runner.stop()
else:
raise ValueError(f"Unknown action {config.action}")
+ elif task_type == "serve":
+ runner = SSHServeRunner(config)
+ if config.action == "run":
+ runner.run()
+ elif config.action == "test":
+ runner.run(with_test=True)
+ elif config.action == "stop":
+ runner.stop()
+ else:
+ raise ValueError(f"Unknown action {config.action}")
else:
raise ValueError(f"Unknown task type {task_type}")
diff --git a/tests/functional_tests/test_cases/hetero_train/aquila/conf/tp2dp1pp1_tp2dp2pp1_tp1dp2pp1.yaml b/tests/functional_tests/test_cases/hetero_train/aquila/conf/tp2dp1pp1_tp2dp2pp1_tp1dp2pp1.yaml
index b73534e02..b2967ee26 100644
--- a/tests/functional_tests/test_cases/hetero_train/aquila/conf/tp2dp1pp1_tp2dp2pp1_tp1dp2pp1.yaml
+++ b/tests/functional_tests/test_cases/hetero_train/aquila/conf/tp2dp1pp1_tp2dp2pp1_tp1dp2pp1.yaml
@@ -10,9 +10,9 @@ experiment:
backend: megatron
entrypoint: flagscale/train/train_aquila.py
runner:
- backend: torchrun
+ backend: torchrun
+ ssh_port: null
shell_cmds: null
- ssh_port: null
envs:
HYDRA_FULL_ERROR: 1
CUDA_VISIBLE_DEVICES: "0,1,2,3,4,5,6,7"
diff --git a/tests/functional_tests/test_cases/hetero_train/aquila/conf/tp2pp1_tp4pp1_tp2pp1.yaml b/tests/functional_tests/test_cases/hetero_train/aquila/conf/tp2pp1_tp4pp1_tp2pp1.yaml
index 4ac191c2f..1f7a16d7e 100644
--- a/tests/functional_tests/test_cases/hetero_train/aquila/conf/tp2pp1_tp4pp1_tp2pp1.yaml
+++ b/tests/functional_tests/test_cases/hetero_train/aquila/conf/tp2pp1_tp4pp1_tp2pp1.yaml
@@ -10,9 +10,9 @@ experiment:
backend: megatron
entrypoint: flagscale/train/train_aquila.py
runner:
- backend: torchrun
+ backend: torchrun
+ ssh_port: null
shell_cmds: null
- ssh_port: null
envs:
HYDRA_FULL_ERROR: 1
CUDA_VISIBLE_DEVICES: "0,1,2,3,4,5,6,7"
diff --git a/tests/functional_tests/test_cases/train/aquila/conf/tp2_pp2.yaml b/tests/functional_tests/test_cases/train/aquila/conf/tp2_pp2.yaml
index 78b108a27..03930cdf8 100644
--- a/tests/functional_tests/test_cases/train/aquila/conf/tp2_pp2.yaml
+++ b/tests/functional_tests/test_cases/train/aquila/conf/tp2_pp2.yaml
@@ -10,9 +10,9 @@ experiment:
backend: megatron
entrypoint: flagscale/train/train_aquila.py
runner:
- backend: torchrun
+ backend: torchrun
+ ssh_port: null
shell_cmds: null
- ssh_port: null
envs:
HYDRA_FULL_ERROR: 1
CUDA_VISIBLE_DEVICES: "0,1,2,3,4,5,6,7"
diff --git a/tests/functional_tests/test_cases/train/aquila/conf/tp4_pp2.yaml b/tests/functional_tests/test_cases/train/aquila/conf/tp4_pp2.yaml
index 81f4e1e60..b8354557b 100644
--- a/tests/functional_tests/test_cases/train/aquila/conf/tp4_pp2.yaml
+++ b/tests/functional_tests/test_cases/train/aquila/conf/tp4_pp2.yaml
@@ -10,9 +10,9 @@ experiment:
backend: megatron
entrypoint: flagscale/train/train_aquila.py
runner:
- backend: torchrun
+ backend: torchrun
+ ssh_port: null
shell_cmds: null
- ssh_port: null
envs:
HYDRA_FULL_ERROR: 1
CUDA_VISIBLE_DEVICES: "0,1,2,3,4,5,6,7"
diff --git a/tests/functional_tests/test_cases/train/mixtral/conf/tp2_pp1_ep2.yaml b/tests/functional_tests/test_cases/train/mixtral/conf/tp2_pp1_ep2.yaml
index c48c4ef16..dce9ff56a 100644
--- a/tests/functional_tests/test_cases/train/mixtral/conf/tp2_pp1_ep2.yaml
+++ b/tests/functional_tests/test_cases/train/mixtral/conf/tp2_pp1_ep2.yaml
@@ -10,9 +10,9 @@ experiment:
backend: megatron
entrypoint: flagscale/train/train_mixtral.py
runner:
- backend: torchrun
+ backend: torchrun
+ ssh_port: null
shell_cmds: null
- ssh_port: null
envs:
HYDRA_FULL_ERROR: 1
CUDA_VISIBLE_DEVICES: "0,1,2,3,4,5,6,7"
diff --git a/tests/functional_tests/test_cases/train/mixtral/conf/tp4_pp1_ep2.yaml b/tests/functional_tests/test_cases/train/mixtral/conf/tp4_pp1_ep2.yaml
index 95387cb5d..ef744d84f 100644
--- a/tests/functional_tests/test_cases/train/mixtral/conf/tp4_pp1_ep2.yaml
+++ b/tests/functional_tests/test_cases/train/mixtral/conf/tp4_pp1_ep2.yaml
@@ -10,9 +10,9 @@ experiment:
backend: megatron
entrypoint: flagscale/train/train_mixtral.py
runner:
- backend: torchrun
+ backend: torchrun
+ ssh_port: null
shell_cmds: null
- ssh_port: null
envs:
HYDRA_FULL_ERROR: 1
CUDA_VISIBLE_DEVICES: "0,1,2,3,4,5,6,7"
diff --git a/tests/scripts/format_tests/test_format.sh b/tests/scripts/format_tests/test_format.sh
index 5ad2bbc3e..f01c67d25 100755
--- a/tests/scripts/format_tests/test_format.sh
+++ b/tests/scripts/format_tests/test_format.sh
@@ -13,7 +13,9 @@ flagscale/logger.py \
flagscale/patches_utils.py \
flagscale/datasets/sft_dataset.py \
flagscale/inference/inference_*.py \
-flagscale/inference/arguments.py"
+flagscale/inference/arguments.py \
+tools/checkpoint/sfpt_ckpt/*.py \
+"
# Function to run a command and continue even if it fails
run_command() {
diff --git a/tools/checkpoint/convert.py b/tools/checkpoint/convert.py
index e5a0f2b72..dde8204b2 100644
--- a/tools/checkpoint/convert.py
+++ b/tools/checkpoint/convert.py
@@ -32,7 +32,7 @@ def main():
allow_abbrev=False, conflict_handler='resolve')
# convert args
parser.add_argument('--model-type', type=str, default=[], nargs="+", required=True,
- choices=['mistral', 'mixtral', 'llama'],
+ choices=['aquila3_dense', 'aquila3_moe', 'mistral', 'mixtral', 'llama'],
help='Type of the model.')
parser.add_argument('--loader', type=str, default='mcore', choices=['mcore', 'transformers'],
help='Module name to load checkpoint, should be on python path')
@@ -44,6 +44,7 @@ def main():
help='Directory to save model checkpoint to')
parser.add_argument('--max-queue-size', type=int, default=50,
help='Maximum number of tensors in the queue')
+ extend_cases = [['mistral', 'mixtral'], ['aquila3_dense', 'aquila3_moe']]
known_args, _ = parser.parse_known_args()
loader = load_plugin('loader', known_args.loader)
@@ -85,4 +86,4 @@ def main():
if __name__ == '__main__':
- main()
\ No newline at end of file
+ main()
diff --git a/tools/checkpoint/run.sh b/tools/checkpoint/run.sh
index b433f7713..cfe7a0696 100755
--- a/tools/checkpoint/run.sh
+++ b/tools/checkpoint/run.sh
@@ -54,3 +54,30 @@ python convert.py \
--target-params-dtype bf16 \
--true-vocab-size 128256 \
--megatron-path
+
+python convert.py \
+ --model-type aquila3_dense \
+ --loader transformers \
+ --saver mcore \
+ --load-dir $loaddir \
+ --save-dir $outputs \
+ --target-tensor-parallel-size 1 \
+ --target-pipeline-parallel-size 1 \
+ --target-expert-parallel-size 1 \
+ --target-params-dtype bf16 \
+ --true-vocab-size 151665 \
+ --megatron-path
+
+# megatron to huggingface
+python convert.py \
+ --model-type llama \
+ --loader mcore \
+ --saver transformers \
+ --load-dir $loaddir \
+ --save-dir $outputs \
+ --target-tensor-parallel-size 1 \
+ --target-pipeline-parallel-size 1 \
+ --target-expert-parallel-size 1 \
+ --target-params-dtype bf16 \
+ --true-vocab-size 128256 \
+ --megatron-path
diff --git a/tools/checkpoint/sfpt_ckpt/README.md b/tools/checkpoint/sfpt_ckpt/README.md
new file mode 100644
index 000000000..88530118b
--- /dev/null
+++ b/tools/checkpoint/sfpt_ckpt/README.md
@@ -0,0 +1,73 @@
+# README
+
+This directory contains scripts for converting checkpoints between DCP (Distributed Checkpoint) and SFPT (Single File Per Tensor) formats.
+
+## Scripts
+
+- `dcp_to_sfpt.py` - Converts a DCP checkpoint to SFPT format.
+- `sfpt_to_dcp.py` - Converts an SFPT checkpoint to DCP format.
+
+## Usage
+
+**Convert DCP to SFPT:**
+1. Get the DCP checkpoint non-homogeneous layers from the training run.
+ * Add the environment variable to experiment-level configuration file:
+ ```yaml
+ envs:
+ FS_NON_HOMOGENEOUS_LAYERS: True
+ ```
+
+ * Add the following to the task-level configuration file:
+ ```yaml
+ use_dist_ckpt: True
+ ckpt_format: torch_dist
+ ckpt_fully_parallel_save: True
+ ckpt_fully_parallel_load: True
+ ```
+
+2. Set the `PYTHONPATH` environment variable:
+
+ ```bash
+ # FlagScale_ROOT is the root directory of the FlagScale repository
+ export PYTHONPATH=$FlagScale_ROOT/megatron:$FlagScale_ROOT
+ ```
+
+3. Run the conversion script:
+ ```bash
+ torchrun --nnodes 1 --node_rank 0 --nproc_per_node 1 \
+ --master_addr localhost --master_port 1234 \
+ dcp_to_sfpt.py --input_dir /path/to/dcp_checkpoint --output_dir /path/to/output_sfpt_checkpoint
+ ```
+
+**Convert SFPT to DCP:**
+
+1. Set the `PYTHONPATH` environment variable:
+ ```bash
+ # FlagScale_ROOT is the root directory of the FlagScale repository
+ export PYTHONPATH=$FlagScale_ROOT/megatron:$FlagScale_ROOT
+ ```
+
+2. Run the conversion script:
+ ```bash
+ FS_SFPT_CKPT_SAVE=1 torchrun --nnodes 1 --node_rank 0 --nproc_per_node 1 \
+ --master_addr localhost --master_port 1234 \
+ sfpt_to_dcp.py --input_dir /path/to/sfpt_checkpoint --output_dir /path/to/output_dcp_checkpoint
+ ```
+
+3. Use the DCP checkpoint for further fine-tuning.
+ * Add the environment variables to experiment-level configuration file:
+ ```yaml
+ envs:
+ FS_NON_HOMOGENEOUS_LAYERS: True
+ FS_SFPT_CKPT_LOAD: True
+ ```
+
+ * Add the following to the task-level configuration file:
+ ```yaml
+ use_dist_ckpt: True
+ ckpt_format: torch_dist
+ ckpt_fully_parallel_save: True
+ ckpt_fully_parallel_load: True
+ finetune: True
+ load: /path/to/output_dcp_checkpoint
+ ```
\ No newline at end of file
diff --git a/tools/checkpoint/sfpt_ckpt/dcp_to_sfpt.py b/tools/checkpoint/sfpt_ckpt/dcp_to_sfpt.py
new file mode 100644
index 000000000..c29831994
--- /dev/null
+++ b/tools/checkpoint/sfpt_ckpt/dcp_to_sfpt.py
@@ -0,0 +1,111 @@
+import argparse
+import os
+from datetime import timedelta
+
+import torch
+from torch.distributed.checkpoint import (
+ BytesStorageMetadata,
+ FileSystemReader,
+ Metadata,
+ TensorStorageMetadata,
+)
+from torch.distributed.checkpoint.metadata import Metadata
+
+from megatron.core.dist_checkpointing import ShardedTensor, load
+from megatron.core.dist_checkpointing.mapping import ShardedObject
+
+
+def build_tensor_shared_state_dict(key, metadata: Metadata = None):
+ # Based on load_tensors_metadata from FlagScale/megatron/megatron/core/dist_checkpointing/strategies/torch.py
+ mcore_data = getattr(metadata, "mcore_data", {})
+ sharded_state_dict = {}
+ tp = metadata.state_dict_metadata[key]
+
+ nd_orig_global_shape = mcore_data.get(key, {}).get(
+ "nd_reformulated_orig_global_shape"
+ )
+ if nd_orig_global_shape is None:
+ # Regular tensor
+ sharded_state_dict[key] = ShardedTensor.from_rank_offsets(
+ key, torch.empty(tp.size, **tp.properties.__dict__, device="cpu")
+ )
+ else:
+ # N-D flattened tensor
+ unflat_ten = torch.empty(
+ nd_orig_global_shape, **tp.properties.__dict__, device="cpu"
+ )
+ flat_ten = unflat_ten.flatten()
+ sharded_state_dict[key] = ShardedTensor.from_rank_offsets_flat(
+ key,
+ flat_ten,
+ unflat_ten.shape,
+ flattened_range=slice(0, unflat_ten.numel()), # whole slice
+ )
+
+ return sharded_state_dict
+
+
+def build_sharded_state_dict(metadata_key, metadata):
+ # Based on load_sharded_metadata from FlagScale/megatron/megatron/core/dist_checkpointing/strategies/torch.py
+ storage_metadata = metadata.state_dict_metadata[metadata_key]
+ if isinstance(storage_metadata, BytesStorageMetadata):
+ sharded_state_dict = {}
+ sh_obj = ShardedObject.empty_from_unique_key(metadata_key)
+ sharded_state_dict[sh_obj.unique_key] = sh_obj
+ return sharded_state_dict
+ elif isinstance(storage_metadata, TensorStorageMetadata):
+ sharded_state_dict = build_tensor_shared_state_dict(metadata_key, metadata)
+ return sharded_state_dict
+
+
+def convert_dist_ckpt_to_sfpt_ckpt(input_dir, output_dir):
+ # Distributed checkpoint loading requires the distributed environment to be initialized
+ rank = int(os.getenv("RANK", "0"))
+ world_size = int(os.getenv("WORLD_SIZE", "1"))
+ print(f"Rank: {rank}, World size: {world_size}")
+ torch.distributed.init_process_group(
+ backend="gloo", world_size=world_size, rank=rank
+ )
+
+ fs_reader = FileSystemReader(input_dir)
+ metadata = fs_reader.read_metadata()
+ state_dict_metadata = metadata.state_dict_metadata
+ for metadata_key, storage_metadata in state_dict_metadata.items():
+ # Skip optimizer state_dict
+ if "optimizer" not in metadata_key and isinstance(
+ storage_metadata, TensorStorageMetadata
+ ):
+ print(f"Processing {metadata_key}")
+ sharded_state_dict = build_sharded_state_dict(metadata_key, metadata)
+ loaded_state_dict = load(sharded_state_dict, input_dir)
+ sharded_tensor = loaded_state_dict[metadata_key]
+ unshared_tensor = sharded_tensor.data
+ path = os.path.join(output_dir, metadata_key)
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+ with open(f"{path}.pt", "wb") as f:
+ torch.save({metadata_key: unshared_tensor}, f)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="Convert distributed checkpoint to single-file-per-tensor checkpoint."
+ )
+ parser.add_argument(
+ "--input_dir",
+ type=str,
+ required=True,
+ help="Input directory containing the distributed checkpoint.",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ required=True,
+ help="Output directory to save the single-file-per-tensor checkpoint.",
+ )
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ convert_dist_ckpt_to_sfpt_ckpt(args.input_dir, args.output_dir)
diff --git a/tools/checkpoint/sfpt_ckpt/sfpt_to_dcp.py b/tools/checkpoint/sfpt_ckpt/sfpt_to_dcp.py
new file mode 100644
index 000000000..5c65bc8aa
--- /dev/null
+++ b/tools/checkpoint/sfpt_ckpt/sfpt_to_dcp.py
@@ -0,0 +1,82 @@
+import argparse
+import os
+from argparse import Namespace
+from pathlib import Path
+
+import torch
+
+from megatron.core.dist_checkpointing import ShardedTensor, save
+from megatron.core.dist_checkpointing.serialization import (
+ get_default_save_common_strategy,
+)
+
+
+def convert_sfpt_ckpt_to_dist_ckpt(input_dir, output_dir):
+ # Distributed checkpoint loading requires the distributed environment to be initialized
+ rank = int(os.getenv("RANK", "0"))
+ world_size = int(os.getenv("WORLD_SIZE", "1"))
+ print(f"Rank: {rank}, World size: {world_size}")
+ torch.distributed.init_process_group(
+ backend="gloo", world_size=world_size, rank=rank
+ )
+
+ input_ckpt_dir = os.path.join(input_dir)
+ if not os.path.isdir(input_ckpt_dir):
+ raise ValueError(f"Checkpoint directory {input_ckpt_dir} does not exist")
+
+ ckpt_output_dir = os.path.join(output_dir, "iter_0000000")
+ if not os.path.exists(ckpt_output_dir):
+ os.makedirs(ckpt_output_dir)
+
+ for root, dirs, files in os.walk(input_ckpt_dir):
+ for file in files:
+ file_path = os.path.join(root, file)
+ print(f"Processing file: {file_path}")
+ state_dict = torch.load(file_path)
+ assert len(state_dict) == 1
+ key = list(state_dict.keys())[0]
+ tensor = state_dict[key]
+ sharded_state_dict = {}
+ sharded_state_dict[key] = ShardedTensor.from_rank_offsets(
+ key,
+ tensor,
+ )
+ save(sharded_state_dict, ckpt_output_dir)
+
+ # Fake the minimal args for the checkpoint loading processing
+ state_dict = {}
+ args = Namespace(
+ tensor_model_parallel_size=1,
+ pipeline_model_parallel_size=1,
+ )
+ state_dict["args"] = args
+ common_strategy = get_default_save_common_strategy()
+ common_strategy.save_common(state_dict, Path(ckpt_output_dir))
+
+ # add the latest_checkpointed_iteration file
+ with open(os.path.join(output_dir, "latest_checkpointed_iteration.txt"), "w") as f:
+ f.write("0")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="Convert single-file-per-tensor checkpoint to distributed checkpoint."
+ )
+ parser.add_argument(
+ "--input_dir",
+ type=str,
+ required=True,
+ help="Input directory containing the single-file-per-tensor checkpoint.",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ required=True,
+ help="Output directory to save the distributed checkpoint.",
+ )
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ convert_sfpt_ckpt_to_dist_ckpt(args.input_dir, args.output_dir)
diff --git a/vllm/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh b/vllm/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh
index b2e910e1b..a67fc89d5 100644
--- a/vllm/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh
+++ b/vllm/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh
@@ -41,6 +41,6 @@ while getopts "m:b:l:f:" OPT; do
done
lm_eval --model hf \
- --model_args pretrained=$MODEL,parallelize=True \
- --tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \
- --batch_size $BATCH_SIZE
+ --model_args "pretrained=$MODEL,parallelize=True" \
+ --tasks gsm8k --num_fewshot "$FEWSHOT" --limit "$LIMIT" \
+ --batch_size "$BATCH_SIZE"
diff --git a/vllm/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh b/vllm/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh
index 4d32b49a4..65be3c5d9 100644
--- a/vllm/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh
+++ b/vllm/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh
@@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do
done
lm_eval --model vllm \
- --model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,distributed_executor_backend="ray",trust_remote_code=true,max_model_len=4096 \
- --tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \
- --batch_size $BATCH_SIZE
+ --model_args "pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,distributed_executor_backend=ray,trust_remote_code=true,max_model_len=4096" \
+ --tasks gsm8k --num_fewshot "$FEWSHOT" --limit "$LIMIT" \
+ --batch_size "$BATCH_SIZE"
diff --git a/vllm/.buildkite/lm-eval-harness/run-tests.sh b/vllm/.buildkite/lm-eval-harness/run-tests.sh
index b4fdde6da..26f33b744 100644
--- a/vllm/.buildkite/lm-eval-harness/run-tests.sh
+++ b/vllm/.buildkite/lm-eval-harness/run-tests.sh
@@ -30,7 +30,7 @@ while getopts "c:t:" OPT; do
done
# Parse list of configs.
-IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < $CONFIG
+IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < "$CONFIG"
for MODEL_CONFIG in "${MODEL_CONFIGS[@]}"
do
diff --git a/vllm/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml b/vllm/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml
index eec2a51e2..64ba1b32f 100644
--- a/vllm/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml
+++ b/vllm/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml
@@ -9,8 +9,11 @@ steps:
- image: badouralix/curl-jq
command:
- sh .buildkite/nightly-benchmarks/scripts/wait-for-image.sh
+
- wait
+
- label: "A100"
+ # skip: "use this flag to conditionally skip the benchmark step, useful for PR testing"
agents:
queue: A100
plugins:
@@ -18,7 +21,7 @@ steps:
podSpec:
priorityClassName: perf-benchmark
containers:
- - image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
+ - image: public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:$BUILDKITE_COMMIT
command:
- bash .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh
resources:
@@ -41,20 +44,48 @@ steps:
- name: devshm
emptyDir:
medium: Memory
- # - label: "H100"
- # agents:
- # queue: H100
- # plugins:
- # - docker#v5.11.0:
- # image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- # command:
- # - bash
- # - .buildkite/nightly-benchmarks/run-benchmarks-suite.sh
- # mount-buildkite-agent: true
- # propagate-environment: true
- # ipc: host
- # gpus: all
- # environment:
- # - VLLM_USAGE_SOURCE
- # - HF_TOKEN
+ - label: "H200"
+ # skip: "use this flag to conditionally skip the benchmark step, useful for PR testing"
+ agents:
+ queue: H200
+ plugins:
+ - docker#v5.12.0:
+ image: public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:$BUILDKITE_COMMIT
+ command:
+ - bash
+ - .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh
+ mount-buildkite-agent: true
+ propagate-environment: true
+ ipc: host
+ gpus: 4,5,6,7
+ volumes:
+ - /data/benchmark-hf-cache:/root/.cache/huggingface
+ environment:
+ - VLLM_USAGE_SOURCE
+ - HF_TOKEN
+
+ - block: "Run H100 Benchmark"
+ key: block-h100
+ depends_on: ~
+
+ - label: "H100"
+ # skip: "use this flag to conditionally skip the benchmark step, useful for PR testing"
+ agents:
+ queue: H100
+ depends_on: block-h100
+ plugins:
+ - docker#v5.12.0:
+ image: public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:$BUILDKITE_COMMIT
+ command:
+ - bash
+ - .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh
+ mount-buildkite-agent: true
+ propagate-environment: true
+ ipc: host
+ gpus: all # see CUDA_VISIBLE_DEVICES for actual GPUs used
+ volumes:
+ - /data/benchmark-hf-cache:/root/.cache/huggingface
+ environment:
+ - VLLM_USAGE_SOURCE
+ - HF_TOKEN
diff --git a/vllm/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py b/vllm/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py
index f90e46428..9d3646e2f 100644
--- a/vllm/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py
+++ b/vllm/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py
@@ -56,7 +56,7 @@
def read_markdown(file):
if os.path.exists(file):
- with open(file, "r") as f:
+ with open(file) as f:
return f.read() + "\n"
else:
return f"{file} not found.\n"
@@ -75,14 +75,14 @@ def results_to_json(latency, throughput, serving):
# collect results
for test_file in results_folder.glob("*.json"):
- with open(test_file, "r") as f:
+ with open(test_file) as f:
raw_result = json.loads(f.read())
if "serving" in str(test_file):
# this result is generated via `benchmark_serving.py`
# attach the benchmarking command to raw_result
- with open(test_file.with_suffix(".commands"), "r") as f:
+ with open(test_file.with_suffix(".commands")) as f:
command = json.loads(f.read())
raw_result.update(command)
@@ -97,7 +97,7 @@ def results_to_json(latency, throughput, serving):
# this result is generated via `benchmark_latency.py`
# attach the benchmarking command to raw_result
- with open(test_file.with_suffix(".commands"), "r") as f:
+ with open(test_file.with_suffix(".commands")) as f:
command = json.loads(f.read())
raw_result.update(command)
@@ -119,7 +119,7 @@ def results_to_json(latency, throughput, serving):
# this result is generated via `benchmark_throughput.py`
# attach the benchmarking command to raw_result
- with open(test_file.with_suffix(".commands"), "r") as f:
+ with open(test_file.with_suffix(".commands")) as f:
command = json.loads(f.read())
raw_result.update(command)
@@ -157,6 +157,18 @@ def results_to_json(latency, throughput, serving):
throughput_results,
serving_results)
+ for df in [latency_results, serving_results, throughput_results]:
+ if df.empty:
+ continue
+
+ # Sort all dataframes by their respective "Test name" columns
+ df.sort_values(by="Test name", inplace=True)
+
+ # The GPUs sometimes come in format of "GPUTYPE\nGPUTYPE\n...",
+ # we want to turn it into "8xGPUTYPE"
+ df["GPU"] = df["GPU"].apply(
+ lambda x: f"{len(x.split('\n'))}x{x.split('\n')[0]}")
+
# get markdown tables
latency_md_table = tabulate(latency_results,
headers='keys',
diff --git a/vllm/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py b/vllm/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py
index 6059588fe..052060c57 100644
--- a/vllm/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py
+++ b/vllm/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py
@@ -72,7 +72,7 @@ def main(args):
# collect results
for test_file in results_folder.glob("*_nightly_results.json"):
- with open(test_file, "r") as f:
+ with open(test_file) as f:
results = results + json.loads(f.read())
# generate markdown table
@@ -80,7 +80,7 @@ def main(args):
md_table = tabulate(df, headers='keys', tablefmt='pipe', showindex=False)
- with open(args.description, "r") as f:
+ with open(args.description) as f:
description = f.read()
description = description.format(
diff --git a/vllm/.buildkite/nightly-benchmarks/scripts/launch-server.sh b/vllm/.buildkite/nightly-benchmarks/scripts/launch-server.sh
index e9d7d6a8d..fb5063db8 100644
--- a/vllm/.buildkite/nightly-benchmarks/scripts/launch-server.sh
+++ b/vllm/.buildkite/nightly-benchmarks/scripts/launch-server.sh
@@ -50,31 +50,30 @@ launch_trt_server() {
git clone https://github.com/triton-inference-server/tensorrtllm_backend.git
git lfs install
cd tensorrtllm_backend
- git checkout $trt_llm_version
- tensorrtllm_backend_dir=$(pwd)
+ git checkout "$trt_llm_version"
git submodule update --init --recursive
# build trtllm engine
cd /tensorrtllm_backend
- cd ./tensorrt_llm/examples/${model_type}
+ cd "./tensorrt_llm/examples/${model_type}"
python3 convert_checkpoint.py \
- --model_dir ${model_path} \
- --dtype ${model_dtype} \
- --tp_size ${model_tp_size} \
- --output_dir ${trt_model_path}
+ --model_dir "${model_path}" \
+ --dtype "${model_dtype}" \
+ --tp_size "${model_tp_size}" \
+ --output_dir "${trt_model_path}"
trtllm-build \
- --checkpoint_dir ${trt_model_path} \
+ --checkpoint_dir "${trt_model_path}" \
--use_fused_mlp \
--reduce_fusion disable \
--workers 8 \
- --gpt_attention_plugin ${model_dtype} \
- --gemm_plugin ${model_dtype} \
- --tp_size ${model_tp_size} \
- --max_batch_size ${max_batch_size} \
- --max_input_len ${max_input_len} \
- --max_seq_len ${max_seq_len} \
- --max_num_tokens ${max_num_tokens} \
- --output_dir ${trt_engine_path}
+ --gpt_attention_plugin "${model_dtype}" \
+ --gemm_plugin "${model_dtype}" \
+ --tp_size "${model_tp_size}" \
+ --max_batch_size "${max_batch_size}" \
+ --max_input_len "${max_input_len}" \
+ --max_seq_len "${max_seq_len}" \
+ --max_num_tokens "${max_num_tokens}" \
+ --output_dir "${trt_engine_path}"
# handle triton protobuf files and launch triton server
cd /tensorrtllm_backend
@@ -82,15 +81,15 @@ launch_trt_server() {
cp -r all_models/inflight_batcher_llm/* triton_model_repo/
cd triton_model_repo
rm -rf ./tensorrt_llm/1/*
- cp -r ${trt_engine_path}/* ./tensorrt_llm/1
+ cp -r "${trt_engine_path}"/* ./tensorrt_llm/1
python3 ../tools/fill_template.py -i tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,engine_dir:/tensorrtllm_backend/triton_model_repo/tensorrt_llm/1,decoupled_mode:true,batching_strategy:inflight_fused_batching,batch_scheduler_policy:guaranteed_no_evict,exclude_input_in_output:true,triton_max_batch_size:2048,max_queue_delay_microseconds:0,max_beam_width:1,max_queue_size:2048,enable_kv_cache_reuse:false
- python3 ../tools/fill_template.py -i preprocessing/config.pbtxt triton_max_batch_size:2048,tokenizer_dir:$model_path,preprocessing_instance_count:5
- python3 ../tools/fill_template.py -i postprocessing/config.pbtxt triton_max_batch_size:2048,tokenizer_dir:$model_path,postprocessing_instance_count:5,skip_special_tokens:false
- python3 ../tools/fill_template.py -i ensemble/config.pbtxt triton_max_batch_size:$max_batch_size
- python3 ../tools/fill_template.py -i tensorrt_llm_bls/config.pbtxt triton_max_batch_size:$max_batch_size,decoupled_mode:true,accumulate_tokens:"False",bls_instance_count:1
+ python3 ../tools/fill_template.py -i preprocessing/config.pbtxt "triton_max_batch_size:2048,tokenizer_dir:$model_path,preprocessing_instance_count:5"
+ python3 ../tools/fill_template.py -i postprocessing/config.pbtxt "triton_max_batch_size:2048,tokenizer_dir:$model_path,postprocessing_instance_count:5,skip_special_tokens:false"
+ python3 ../tools/fill_template.py -i ensemble/config.pbtxt triton_max_batch_size:"$max_batch_size"
+ python3 ../tools/fill_template.py -i tensorrt_llm_bls/config.pbtxt "triton_max_batch_size:$max_batch_size,decoupled_mode:true,accumulate_tokens:False,bls_instance_count:1"
cd /tensorrtllm_backend
python3 scripts/launch_triton_server.py \
- --world_size=${model_tp_size} \
+ --world_size="${model_tp_size}" \
--model_repo=/tensorrtllm_backend/triton_model_repo &
}
@@ -98,10 +97,7 @@ launch_trt_server() {
launch_tgi_server() {
model=$(echo "$common_params" | jq -r '.model')
tp=$(echo "$common_params" | jq -r '.tp')
- dataset_name=$(echo "$common_params" | jq -r '.dataset_name')
- dataset_path=$(echo "$common_params" | jq -r '.dataset_path')
port=$(echo "$common_params" | jq -r '.port')
- num_prompts=$(echo "$common_params" | jq -r '.num_prompts')
server_args=$(json2args "$server_params")
if echo "$common_params" | jq -e 'has("fp8")' >/dev/null; then
@@ -129,10 +125,7 @@ launch_tgi_server() {
launch_lmdeploy_server() {
model=$(echo "$common_params" | jq -r '.model')
tp=$(echo "$common_params" | jq -r '.tp')
- dataset_name=$(echo "$common_params" | jq -r '.dataset_name')
- dataset_path=$(echo "$common_params" | jq -r '.dataset_path')
port=$(echo "$common_params" | jq -r '.port')
- num_prompts=$(echo "$common_params" | jq -r '.num_prompts')
server_args=$(json2args "$server_params")
server_command="lmdeploy serve api_server $model \
@@ -149,10 +142,7 @@ launch_sglang_server() {
model=$(echo "$common_params" | jq -r '.model')
tp=$(echo "$common_params" | jq -r '.tp')
- dataset_name=$(echo "$common_params" | jq -r '.dataset_name')
- dataset_path=$(echo "$common_params" | jq -r '.dataset_path')
port=$(echo "$common_params" | jq -r '.port')
- num_prompts=$(echo "$common_params" | jq -r '.num_prompts')
server_args=$(json2args "$server_params")
if echo "$common_params" | jq -e 'has("fp8")' >/dev/null; then
@@ -185,10 +175,7 @@ launch_vllm_server() {
model=$(echo "$common_params" | jq -r '.model')
tp=$(echo "$common_params" | jq -r '.tp')
- dataset_name=$(echo "$common_params" | jq -r '.dataset_name')
- dataset_path=$(echo "$common_params" | jq -r '.dataset_path')
port=$(echo "$common_params" | jq -r '.port')
- num_prompts=$(echo "$common_params" | jq -r '.num_prompts')
server_args=$(json2args "$server_params")
if echo "$common_params" | jq -e 'has("fp8")' >/dev/null; then
@@ -217,19 +204,19 @@ launch_vllm_server() {
main() {
- if [[ $CURRENT_LLM_SERVING_ENGINE == "trt" ]]; then
+ if [[ "$CURRENT_LLM_SERVING_ENGINE" == "trt" ]]; then
launch_trt_server
fi
- if [[ $CURRENT_LLM_SERVING_ENGINE == "tgi" ]]; then
+ if [[ "$CURRENT_LLM_SERVING_ENGINE" == "tgi" ]]; then
launch_tgi_server
fi
- if [[ $CURRENT_LLM_SERVING_ENGINE == "lmdeploy" ]]; then
+ if [[ "$CURRENT_LLM_SERVING_ENGINE" == "lmdeploy" ]]; then
launch_lmdeploy_server
fi
- if [[ $CURRENT_LLM_SERVING_ENGINE == "sglang" ]]; then
+ if [[ "$CURRENT_LLM_SERVING_ENGINE" == "sglang" ]]; then
launch_sglang_server
fi
diff --git a/vllm/.buildkite/nightly-benchmarks/scripts/nightly-annotate.sh b/vllm/.buildkite/nightly-benchmarks/scripts/nightly-annotate.sh
index c6a1bbdeb..686f70dbe 100644
--- a/vllm/.buildkite/nightly-benchmarks/scripts/nightly-annotate.sh
+++ b/vllm/.buildkite/nightly-benchmarks/scripts/nightly-annotate.sh
@@ -16,10 +16,10 @@ main() {
fi
# initial annotation
- description="$VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/nightly-descriptions.md"
+ #description="$VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/nightly-descriptions.md"
# download results
- cd $VLLM_SOURCE_CODE_LOC/benchmarks
+ cd "$VLLM_SOURCE_CODE_LOC/benchmarks"
mkdir -p results/
/workspace/buildkite-agent artifact download 'results/*nightly_results.json' results/
ls
@@ -30,15 +30,15 @@ main() {
/workspace/buildkite-agent artifact upload "results.zip"
# upload benchmarking scripts
- cd $VLLM_SOURCE_CODE_LOC/
+ cd "$VLLM_SOURCE_CODE_LOC/"
zip -r nightly-benchmarks.zip .buildkite/ benchmarks/
/workspace/buildkite-agent artifact upload "nightly-benchmarks.zip"
- cd $VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/
+ cd "$VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/"
# upload benchmarking pipeline
/workspace/buildkite-agent artifact upload "nightly-pipeline.yaml"
- cd $VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/
+ cd "$VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/"
/workspace/buildkite-agent annotate --style "success" --context "nightly-benchmarks-results" --append < nightly-annotation.md
@@ -75,4 +75,4 @@ main() {
# /workspace/buildkite-agent annotate --style "success" --context "nightly-benchmarks-results" --append < nightly_results.md
}
-main "$@"
\ No newline at end of file
+main "$@"
diff --git a/vllm/.buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh b/vllm/.buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh
index dd8c15e07..3f38cf513 100644
--- a/vllm/.buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh
+++ b/vllm/.buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh
@@ -12,7 +12,7 @@ check_gpus() {
echo "Need at least 1 GPU to run benchmarking."
exit 1
fi
- declare -g gpu_type=$(echo $(nvidia-smi --query-gpu=name --format=csv,noheader) | awk '{print $2}')
+ declare -g gpu_type="$(nvidia-smi --query-gpu=name --format=csv,noheader | awk '{print $2}')"
echo "GPU type is $gpu_type"
}
@@ -102,7 +102,7 @@ kill_gpu_processes() {
pkill -f text-generation
pkill -f lmdeploy
- while [ $(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits | head -n 1) -ge 1000 ]; do
+ while [ "$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits | head -n 1)" -ge 1000 ]; do
sleep 1
done
}
@@ -119,8 +119,8 @@ wait_for_server() {
ensure_installed() {
# Ensure that the given command is installed by apt-get
local cmd=$1
- if ! which $cmd >/dev/null; then
- apt-get update && apt-get install -y $cmd
+ if ! which "$cmd" >/dev/null; then
+ apt-get update && apt-get install -y "$cmd"
fi
}
@@ -173,13 +173,11 @@ run_serving_tests() {
echo "Reuse previous server for test case $test_name"
else
kill_gpu_processes
- bash $VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/scripts/launch-server.sh \
+ bash "$VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/scripts/launch-server.sh" \
"$server_params" "$common_params"
fi
- wait_for_server
-
- if [ $? -eq 0 ]; then
+ if wait_for_server; then
echo ""
echo "$CURRENT_LLM_SERVING_ENGINE server is up and running."
else
@@ -190,13 +188,13 @@ run_serving_tests() {
# prepare tokenizer
# this is required for lmdeploy.
- cd $VLLM_SOURCE_CODE_LOC/benchmarks
+ cd "$VLLM_SOURCE_CODE_LOC/benchmarks"
rm -rf /tokenizer_cache
mkdir /tokenizer_cache
python3 ../.buildkite/nightly-benchmarks/scripts/download-tokenizer.py \
--model "$model" \
--cachedir /tokenizer_cache
- cd $VLLM_SOURCE_CODE_LOC/benchmarks
+ cd "$VLLM_SOURCE_CODE_LOC/benchmarks"
# change model name for lmdeploy (it will not follow standard hf name)
@@ -307,11 +305,11 @@ run_serving_tests() {
prepare_dataset() {
# download sharegpt dataset
- cd $VLLM_SOURCE_CODE_LOC/benchmarks
+ cd "$VLLM_SOURCE_CODE_LOC/benchmarks"
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
# duplicate sonnet by 4x, to allow benchmarking with input length 2048
- cd $VLLM_SOURCE_CODE_LOC/benchmarks
+ cd "$VLLM_SOURCE_CODE_LOC/benchmarks"
echo "" > sonnet_4x.txt
for _ in {1..4}
do
@@ -339,17 +337,17 @@ main() {
prepare_dataset
- cd $VLLM_SOURCE_CODE_LOC/benchmarks
+ cd "$VLLM_SOURCE_CODE_LOC/benchmarks"
declare -g RESULTS_FOLDER=results/
mkdir -p $RESULTS_FOLDER
- BENCHMARK_ROOT=$VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/
+ BENCHMARK_ROOT="$VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/"
# run the test
- run_serving_tests $BENCHMARK_ROOT/tests/nightly-tests.json
+ run_serving_tests "$BENCHMARK_ROOT/tests/nightly-tests.json"
# upload benchmark results to buildkite
python3 -m pip install tabulate pandas
- python3 $BENCHMARK_ROOT/scripts/summary-nightly-results.py
+ python3 "$BENCHMARK_ROOT/scripts/summary-nightly-results.py"
upload_to_buildkite
}
diff --git a/vllm/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh b/vllm/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh
index a0b9a409b..0d16a8378 100644
--- a/vllm/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh
+++ b/vllm/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh
@@ -6,6 +6,7 @@
# Do not set -e, as the mixtral 8x22B model tends to crash occasionally
# and we still want to see other benchmarking results even when mixtral crashes.
+set -x
set -o pipefail
check_gpus() {
@@ -17,7 +18,7 @@ check_gpus() {
echo "Need at least 1 GPU to run benchmarking."
exit 1
fi
- declare -g gpu_type=$(echo $(nvidia-smi --query-gpu=name --format=csv,noheader) | awk '{print $2}')
+ declare -g gpu_type=$(nvidia-smi --query-gpu=name --format=csv,noheader | awk '{print $2}')
echo "GPU type is $gpu_type"
}
@@ -85,15 +86,11 @@ kill_gpu_processes() {
ps -aux
lsof -t -i:8000 | xargs -r kill -9
- pkill -f pt_main_thread
- # this line doesn't work now
- # ps aux | grep python | grep openai | awk '{print $2}' | xargs -r kill -9
- pkill -f python3
- pkill -f /usr/bin/python3
+ pgrep python3 | xargs -r kill -9
# wait until GPU memory usage smaller than 1GB
- while [ $(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits | head -n 1) -ge 1000 ]; do
+ while [ "$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits | head -n 1)" -ge 1000 ]; do
sleep 1
done
@@ -117,7 +114,7 @@ upload_to_buildkite() {
fi
# Use the determined command to annotate and upload artifacts
- $BUILDKITE_AGENT_COMMAND annotate --style "info" --context "$BUILDKITE_LABEL-benchmark-results" <$RESULTS_FOLDER/benchmark_results.md
+ $BUILDKITE_AGENT_COMMAND annotate --style "info" --context "$BUILDKITE_LABEL-benchmark-results" < "$RESULTS_FOLDER/benchmark_results.md"
$BUILDKITE_AGENT_COMMAND artifact upload "$RESULTS_FOLDER/*"
}
@@ -150,7 +147,7 @@ run_latency_tests() {
# check if there is enough GPU to run the test
tp=$(echo "$latency_params" | jq -r '.tensor_parallel_size')
if [[ $gpu_count -lt $tp ]]; then
- echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $testname."
+ echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
continue
fi
@@ -206,9 +203,9 @@ run_throughput_tests() {
throughput_args=$(json2args "$throughput_params")
# check if there is enough GPU to run the test
- tp=$(echo $throughput_params | jq -r '.tensor_parallel_size')
+ tp=$(echo "$throughput_params" | jq -r '.tensor_parallel_size')
if [[ $gpu_count -lt $tp ]]; then
- echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $testname."
+ echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
continue
fi
@@ -270,7 +267,7 @@ run_serving_tests() {
# check if there is enough GPU to run the test
tp=$(echo "$server_params" | jq -r '.tensor_parallel_size')
if [[ $gpu_count -lt $tp ]]; then
- echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $testname."
+ echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
continue
fi
@@ -278,7 +275,7 @@ run_serving_tests() {
server_model=$(echo "$server_params" | jq -r '.model')
client_model=$(echo "$client_params" | jq -r '.model')
if [[ $server_model != "$client_model" ]]; then
- echo "Server model and client model must be the same. Skip testcase $testname."
+ echo "Server model and client model must be the same. Skip testcase $test_name."
continue
fi
@@ -289,12 +286,11 @@ run_serving_tests() {
# run the server
echo "Running test case $test_name"
echo "Server command: $server_command"
- eval "$server_command" &
+ bash -c "$server_command" &
server_pid=$!
# wait until the server is alive
- wait_for_server
- if [ $? -eq 0 ]; then
+ if wait_for_server; then
echo ""
echo "vllm server is up and running."
else
@@ -323,7 +319,7 @@ run_serving_tests() {
echo "Running test case $test_name with qps $qps"
echo "Client command: $client_command"
- eval "$client_command"
+ bash -c "$client_command"
# record the benchmarking commands
jq_output=$(jq -n \
diff --git a/vllm/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py b/vllm/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py
index 4e4d4cd4c..92d6fad73 100644
--- a/vllm/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py
+++ b/vllm/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py
@@ -36,11 +36,11 @@
# collect results
for test_file in results_folder.glob("*.json"):
- with open(test_file, "r") as f:
+ with open(test_file) as f:
raw_result = json.loads(f.read())
# attach the benchmarking command to raw_result
- with open(test_file.with_suffix(".commands"), "r") as f:
+ with open(test_file.with_suffix(".commands")) as f:
command = json.loads(f.read())
raw_result.update(command)
diff --git a/vllm/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh b/vllm/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh
index f16862907..aa0f7ade8 100644
--- a/vllm/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh
+++ b/vllm/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh
@@ -1,12 +1,12 @@
#!/bin/sh
-TOKEN=$(curl -s -L "https://public.ecr.aws/token?service=public.ecr.aws&scope=repository:q9t5s3a7/vllm-ci-test-repo:pull" | jq -r .token)
-URL="https://public.ecr.aws/v2/q9t5s3a7/vllm-ci-test-repo/manifests/$BUILDKITE_COMMIT"
+TOKEN=$(curl -s -L "https://public.ecr.aws/token?service=public.ecr.aws&scope=repository:q9t5s3a7/vllm-ci-postmerge-repo:pull" | jq -r .token)
+URL="https://public.ecr.aws/v2/q9t5s3a7/vllm-ci-postmerge-repo/manifests/$BUILDKITE_COMMIT"
TIMEOUT_SECONDS=10
retries=0
while [ $retries -lt 1000 ]; do
- if [ $(curl -s --max-time $TIMEOUT_SECONDS -L -H "Authorization: Bearer $TOKEN" -o /dev/null -w "%{http_code}" $URL) -eq 200 ]; then
+ if [ "$(curl -s --max-time "$TIMEOUT_SECONDS" -L -H "Authorization: Bearer $TOKEN" -o /dev/null -w "%{http_code}" "$URL")" -eq 200 ]; then
exit 0
fi
@@ -16,4 +16,4 @@ while [ $retries -lt 1000 ]; do
sleep 5
done
-exit 1
\ No newline at end of file
+exit 1
diff --git a/vllm/.buildkite/release-pipeline.yaml b/vllm/.buildkite/release-pipeline.yaml
index 3b7fa0f2d..93e118fb3 100644
--- a/vllm/.buildkite/release-pipeline.yaml
+++ b/vllm/.buildkite/release-pipeline.yaml
@@ -1,33 +1,41 @@
steps:
- label: "Build wheel - CUDA 12.1"
agents:
- queue: cpu_queue
+ queue: cpu_queue_postmerge
commands:
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.1.0 --tag vllm-ci:build-image --target build --progress plain ."
- "mkdir artifacts"
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
- # rename the files to change linux -> manylinux1
- - "for f in artifacts/dist/*.whl; do mv -- \"$$f\" \"$${f/linux/manylinux1}\"; done"
- - "mv artifacts/dist/$(ls artifacts/dist) artifacts/dist/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
- - "aws s3 cp artifacts/dist/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl s3://vllm-wheels/$BUILDKITE_COMMIT/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
- - "aws s3 cp artifacts/dist/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl s3://vllm-wheels/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
+ - "bash .buildkite/upload-wheels.sh"
env:
DOCKER_BUILDKIT: "1"
- - block: "Build CUDA 11.8 wheel"
- key: block-build-cu118-wheel
-
+ # Note(simon): We can always build CUDA 11.8 wheel to ensure the build is working.
+ # However, this block can be uncommented to save some compute hours.
+ # - block: "Build CUDA 11.8 wheel"
+ # key: block-build-cu118-wheel
+
- label: "Build wheel - CUDA 11.8"
- depends_on: block-build-cu118-wheel
+ # depends_on: block-build-cu118-wheel
agents:
- queue: cpu_queue
+ queue: cpu_queue_postmerge
commands:
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=11.8.0 --tag vllm-ci:build-image --target build --progress plain ."
- "mkdir artifacts"
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
- # rename the files to change linux -> manylinux1
- - "for f in artifacts/dist/*.whl; do mv -- \"$$f\" \"$${f/linux/manylinux1}\"; done"
- - "aws s3 cp --recursive artifacts/dist s3://vllm-wheels/$BUILDKITE_COMMIT/"
- - "aws s3 cp --recursive artifacts/dist s3://vllm-wheels/nightly/"
+ - "bash .buildkite/upload-wheels.sh"
env:
DOCKER_BUILDKIT: "1"
+
+ - block: "Build release image"
+ depends_on: ~
+ key: block-release-image-build
+
+ - label: "Build release image"
+ depends_on: block-release-image-build
+ agents:
+ queue: cpu_queue_postmerge
+ commands:
+ - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
+ - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.1.0 --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain ."
+ - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT"
diff --git a/vllm/.buildkite/run-amd-test.sh b/vllm/.buildkite/run-amd-test.sh
index df201cdc7..3515ccd65 100755
--- a/vllm/.buildkite/run-amd-test.sh
+++ b/vllm/.buildkite/run-amd-test.sh
@@ -1,3 +1,5 @@
+#!/bin/bash
+
# This script runs test inside the corresponding ROCm docker container.
set -o pipefail
@@ -31,8 +33,8 @@ cleanup_docker() {
echo "Disk usage is above $threshold%. Cleaning up Docker images and volumes..."
# Remove dangling images (those that are not tagged and not used by any container)
docker image prune -f
- # Remove unused volumes
- docker volume prune -f
+ # Remove unused volumes / force the system prune for old images as well.
+ docker volume prune -f && docker system prune --force --filter "until=72h" --all
echo "Docker images and volumes cleanup completed."
else
echo "Disk usage is below $threshold%. No cleanup needed."
@@ -57,17 +59,17 @@ done
echo "--- Pulling container"
image_name="rocm/vllm-ci:${BUILDKITE_COMMIT}"
container_name="rocm_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)"
-docker pull ${image_name}
+docker pull "${image_name}"
remove_docker_container() {
- docker rm -f ${container_name} || docker image rm -f ${image_name} || true
+ docker rm -f "${container_name}" || docker image rm -f "${image_name}" || true
}
trap remove_docker_container EXIT
echo "--- Running container"
HF_CACHE="$(realpath ~)/huggingface"
-mkdir -p ${HF_CACHE}
+mkdir -p "${HF_CACHE}"
HF_MOUNT="/root/.cache/huggingface"
commands=$@
@@ -83,7 +85,6 @@ if [[ $commands == *" kernels "* ]]; then
--ignore=kernels/test_encoder_decoder_attn.py \
--ignore=kernels/test_flash_attn.py \
--ignore=kernels/test_flashinfer.py \
- --ignore=kernels/test_gguf.py \
--ignore=kernels/test_int8_quant.py \
--ignore=kernels/test_machete_gemm.py \
--ignore=kernels/test_mamba_ssm.py \
@@ -107,35 +108,36 @@ fi
PARALLEL_JOB_COUNT=8
# check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs.
if [[ $commands == *"--shard-id="* ]]; then
+ # assign job count as the number of shards used
+ commands=${commands//"--num-shards= "/"--num-shards=${PARALLEL_JOB_COUNT} "}
for GPU in $(seq 0 $(($PARALLEL_JOB_COUNT-1))); do
- #replace shard arguments
- commands=${commands//"--shard-id= "/"--shard-id=${GPU} "}
- commands=${commands//"--num-shards= "/"--num-shards=${PARALLEL_JOB_COUNT} "}
- echo "Shard ${GPU} commands:$commands"
+ # assign shard-id for each shard
+ commands_gpu=${commands//"--shard-id= "/"--shard-id=${GPU} "}
+ echo "Shard ${GPU} commands:$commands_gpu"
docker run \
--device /dev/kfd --device /dev/dri \
--network host \
--shm-size=16gb \
--rm \
- -e HIP_VISIBLE_DEVICES=${GPU} \
+ -e HIP_VISIBLE_DEVICES="${GPU}" \
-e HF_TOKEN \
- -v ${HF_CACHE}:${HF_MOUNT} \
- -e HF_HOME=${HF_MOUNT} \
- --name ${container_name}_${GPU} \
- ${image_name} \
- /bin/bash -c "${commands}" \
+ -v "${HF_CACHE}:${HF_MOUNT}" \
+ -e "HF_HOME=${HF_MOUNT}" \
+ --name "${container_name}_${GPU}" \
+ "${image_name}" \
+ /bin/bash -c "${commands_gpu}" \
|& while read -r line; do echo ">>Shard $GPU: $line"; done &
PIDS+=($!)
done
#wait for all processes to finish and collect exit codes
- for pid in ${PIDS[@]}; do
- wait ${pid}
+ for pid in "${PIDS[@]}"; do
+ wait "${pid}"
STATUS+=($?)
done
- for st in ${STATUS[@]}; do
+ for st in "${STATUS[@]}"; do
if [[ ${st} -ne 0 ]]; then
echo "One of the processes failed with $st"
- exit ${st}
+ exit "${st}"
fi
done
else
@@ -146,9 +148,9 @@ else
--rm \
-e HIP_VISIBLE_DEVICES=0 \
-e HF_TOKEN \
- -v ${HF_CACHE}:${HF_MOUNT} \
- -e HF_HOME=${HF_MOUNT} \
- --name ${container_name} \
- ${image_name} \
+ -v "${HF_CACHE}:${HF_MOUNT}" \
+ -e "HF_HOME=${HF_MOUNT}" \
+ --name "${container_name}" \
+ "${image_name}" \
/bin/bash -c "${commands}"
fi
diff --git a/vllm/.buildkite/run-benchmarks.sh b/vllm/.buildkite/run-benchmarks.sh
index cbf6dda67..1641c1faa 100644
--- a/vllm/.buildkite/run-benchmarks.sh
+++ b/vllm/.buildkite/run-benchmarks.sh
@@ -1,3 +1,5 @@
+#!/bin/bash
+
# This script is run by buildkite to run the benchmarks and upload the results to buildkite
set -ex
diff --git a/vllm/.buildkite/run-cpu-test-ppc64le.sh b/vllm/.buildkite/run-cpu-test-ppc64le.sh
index fd60f5b6a..bc06838d8 100755
--- a/vllm/.buildkite/run-cpu-test-ppc64le.sh
+++ b/vllm/.buildkite/run-cpu-test-ppc64le.sh
@@ -1,39 +1,14 @@
+#!/bin/bash
+
# This script build the CPU docker image and run the offline inference inside the container.
# It serves a sanity check for compilation and basic model usage.
set -ex
-# Try building the docker image
-docker build -t cpu-test -f Dockerfile.ppc64le .
-
# Setup cleanup
-remove_docker_container() { docker rm -f cpu-test || true; }
+remove_docker_container() { docker rm -f cpu-test || true; docker system prune -f; }
trap remove_docker_container EXIT
remove_docker_container
-# Run the image, setting --shm-size=4g for tensor parallel.
-source /etc/environment
-#docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test cpu-test
-docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN=$HF_TOKEN --name cpu-test cpu-test
-
-# Run basic model test
-docker exec cpu-test bash -c "
- pip install pytest matplotlib einops transformers_stream_generator
- pytest -v -s tests/models -m \"not vlm\" \
- --ignore=tests/models/test_embedding.py \
- --ignore=tests/models/test_oot_registration.py \
- --ignore=tests/models/test_registry.py \
- --ignore=tests/models/test_jamba.py \
- --ignore=tests/models/test_mamba.py \
- --ignore=tests/models/test_danube3_4b.py" # Mamba kernels and Danube3-4B on CPU is not supported
+# Try building the docker image
+docker build -t cpu-test -f Dockerfile.ppc64le .
-# online inference
-docker exec cpu-test bash -c "
- python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m &
- timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
- python3 benchmarks/benchmark_serving.py \
- --backend vllm \
- --dataset-name random \
- --model facebook/opt-125m \
- --num-prompts 20 \
- --endpoint /v1/completions \
- --tokenizer facebook/opt-125m"
diff --git a/vllm/.buildkite/run-cpu-test.sh b/vllm/.buildkite/run-cpu-test.sh
index c331a9c49..4f1729d46 100644
--- a/vllm/.buildkite/run-cpu-test.sh
+++ b/vllm/.buildkite/run-cpu-test.sh
@@ -1,57 +1,85 @@
+#!/bin/bash
+
# This script build the CPU docker image and run the offline inference inside the container.
# It serves a sanity check for compilation and basic model usage.
set -ex
+# allow to bind to different cores
+CORE_RANGE=${CORE_RANGE:-48-95}
+NUMA_NODE=${NUMA_NODE:-1}
+
# Try building the docker image
-numactl -C 48-95 -N 1 docker build -t cpu-test -f Dockerfile.cpu .
-numactl -C 48-95 -N 1 docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" -t cpu-test-avx2 -f Dockerfile.cpu .
+numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build -t cpu-test -f Dockerfile.cpu .
+numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" -t cpu-test-avx2 -f Dockerfile.cpu .
# Setup cleanup
-remove_docker_container() { docker rm -f cpu-test cpu-test-avx2 || true; }
+remove_docker_container() { docker rm -f cpu-test-"$NUMA_NODE" cpu-test-avx2-"$NUMA_NODE" || true; }
trap remove_docker_container EXIT
remove_docker_container
# Run the image, setting --shm-size=4g for tensor parallel.
-docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus=48-95 \
- --cpuset-mems=1 --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test cpu-test
-docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus=48-95 \
- --cpuset-mems=1 --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-avx2 cpu-test-avx2
-
-# offline inference
-docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py"
-
-# Run basic model test
-docker exec cpu-test bash -c "
- pip install pytest matplotlib einops transformers_stream_generator datamodel_code_generator
- pytest -v -s tests/models/encoder_decoder/language
- pytest -v -s tests/models/decoder_only/language \
- --ignore=tests/models/test_fp8.py \
- --ignore=tests/models/decoder_only/language/test_jamba.py \
- --ignore=tests/models/decoder_only/language/test_mamba.py \
- --ignore=tests/models/decoder_only/language/test_granitemoe.py \
- --ignore=tests/models/decoder_only/language/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported
-
-# Run compressed-tensor test
-docker exec cpu-test bash -c "
- pytest -s -v \
- tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \
- tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynamic_per_token"
-
-# Run AWQ test
-docker exec cpu-test bash -c "
- pytest -s -v \
- tests/quantization/test_ipex_quant.py"
-
-# online inference
-docker exec cpu-test bash -c "
- export VLLM_CPU_KVCACHE_SPACE=10
- export VLLM_CPU_OMP_THREADS_BIND=48-92
- python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m &
- timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
- python3 benchmarks/benchmark_serving.py \
- --backend vllm \
- --dataset-name random \
- --model facebook/opt-125m \
- --num-prompts 20 \
- --endpoint /v1/completions \
- --tokenizer facebook/opt-125m"
+docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus="$CORE_RANGE" \
+ --cpuset-mems="$NUMA_NODE" --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test
+docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus="$CORE_RANGE" \
+ --cpuset-mems="$NUMA_NODE" --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-avx2-"$NUMA_NODE" cpu-test-avx2
+
+function cpu_tests() {
+ set -e
+ export NUMA_NODE=$2
+
+ # offline inference
+ docker exec cpu-test-avx2-"$NUMA_NODE" bash -c "
+ set -e
+ python3 examples/offline_inference.py"
+
+ # Run basic model test
+ docker exec cpu-test-"$NUMA_NODE" bash -c "
+ set -e
+ pip install pytest pytest-asyncio \
+ decord einops librosa peft Pillow sentence-transformers soundfile \
+ transformers_stream_generator matplotlib datamodel_code_generator
+ pip install torchvision --index-url https://download.pytorch.org/whl/cpu
+ pytest -v -s tests/models/decoder_only/language -m cpu_model
+ pytest -v -s tests/models/embedding/language -m cpu_model
+ pytest -v -s tests/models/encoder_decoder/language -m cpu_model
+ pytest -v -s tests/models/decoder_only/audio_language -m cpu_model
+ pytest -v -s tests/models/decoder_only/vision_language -m cpu_model"
+
+ # Run compressed-tensor test
+ docker exec cpu-test-"$NUMA_NODE" bash -c "
+ set -e
+ pytest -s -v \
+ tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \
+ tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynamic_per_token"
+
+ # Run AWQ test
+ docker exec cpu-test-"$NUMA_NODE" bash -c "
+ set -e
+ pytest -s -v \
+ tests/quantization/test_ipex_quant.py"
+
+ # Run chunked-prefill and prefix-cache test
+ docker exec cpu-test-"$NUMA_NODE" bash -c "
+ set -e
+ pytest -s -v -k cpu_model \
+ tests/basic_correctness/test_chunked_prefill.py"
+
+ # online inference
+ docker exec cpu-test-"$NUMA_NODE" bash -c "
+ set -e
+ export VLLM_CPU_KVCACHE_SPACE=10
+ export VLLM_CPU_OMP_THREADS_BIND=$1
+ python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m --dtype half &
+ timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
+ python3 benchmarks/benchmark_serving.py \
+ --backend vllm \
+ --dataset-name random \
+ --model facebook/opt-125m \
+ --num-prompts 20 \
+ --endpoint /v1/completions \
+ --tokenizer facebook/opt-125m"
+}
+
+# All of CPU tests are expected to be finished less than 25 mins.
+export -f cpu_tests
+timeout 30m bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
diff --git a/vllm/.buildkite/run-hpu-test.sh b/vllm/.buildkite/run-hpu-test.sh
new file mode 100644
index 000000000..fa4f74fca
--- /dev/null
+++ b/vllm/.buildkite/run-hpu-test.sh
@@ -0,0 +1,16 @@
+#!/bin/bash
+
+# This script build the CPU docker image and run the offline inference inside the container.
+# It serves a sanity check for compilation and basic model usage.
+set -ex
+
+# Try building the docker image
+docker build -t hpu-test-env -f Dockerfile.hpu .
+
+# Setup cleanup
+remove_docker_container() { docker rm -f hpu-test || true; }
+trap remove_docker_container EXIT
+remove_docker_container
+
+# Run the image and launch offline inference
+docker run --runtime=habana --name=hpu-test --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference.py
\ No newline at end of file
diff --git a/vllm/.buildkite/run-multi-node-test.sh b/vllm/.buildkite/run-multi-node-test.sh
index 7ac4dcc4c..530bf90a8 100755
--- a/vllm/.buildkite/run-multi-node-test.sh
+++ b/vllm/.buildkite/run-multi-node-test.sh
@@ -14,7 +14,7 @@ DOCKER_IMAGE=$4
shift 4
COMMANDS=("$@")
-if [ ${#COMMANDS[@]} -ne $NUM_NODES ]; then
+if [ ${#COMMANDS[@]} -ne "$NUM_NODES" ]; then
echo "The number of commands must be equal to the number of nodes."
echo "Number of nodes: $NUM_NODES"
echo "Number of commands: ${#COMMANDS[@]}"
@@ -23,7 +23,7 @@ fi
echo "List of commands"
for command in "${COMMANDS[@]}"; do
- echo $command
+ echo "$command"
done
start_network() {
@@ -36,7 +36,7 @@ start_nodes() {
for node_gpu in $(seq 0 $(($NUM_GPUS - 1))); do
DEVICE_NUM=$(($node * $NUM_GPUS + $node_gpu))
GPU_DEVICES+=$(($DEVICE_NUM))
- if [ $node_gpu -lt $(($NUM_GPUS - 1)) ]; then
+ if [ "$node_gpu" -lt $(($NUM_GPUS - 1)) ]; then
GPU_DEVICES+=','
fi
done
@@ -49,17 +49,20 @@ start_nodes() {
# 3. map the huggingface cache directory to the container
# 3. assign ip addresses to the containers (head node: 192.168.10.10, worker nodes:
# starting from 192.168.10.11)
- docker run -d --gpus "$GPU_DEVICES" --shm-size=10.24gb -e HF_TOKEN -v ~/.cache/huggingface:/root/.cache/huggingface --name node$node --network docker-net --ip 192.168.10.$((10 + $node)) --rm $DOCKER_IMAGE /bin/bash -c "tail -f /dev/null"
+ docker run -d --gpus "$GPU_DEVICES" --shm-size=10.24gb -e HF_TOKEN \
+ -v ~/.cache/huggingface:/root/.cache/huggingface --name "node$node" \
+ --network docker-net --ip 192.168.10.$((10 + $node)) --rm "$DOCKER_IMAGE" \
+ /bin/bash -c "tail -f /dev/null"
# organize containers into a ray cluster
- if [ $node -eq 0 ]; then
+ if [ "$node" -eq 0 ]; then
# start the ray head node
- docker exec -d node$node /bin/bash -c "ray start --head --port=6379 --block"
+ docker exec -d "node$node" /bin/bash -c "ray start --head --port=6379 --block"
# wait for the head node to be ready
sleep 10
else
# start the ray worker nodes, and connect them to the head node
- docker exec -d node$node /bin/bash -c "ray start --address=192.168.10.10:6379 --block"
+ docker exec -d "node$node" /bin/bash -c "ray start --address=192.168.10.10:6379 --block"
fi
done
@@ -79,22 +82,22 @@ run_nodes() {
for node_gpu in $(seq 0 $(($NUM_GPUS - 1))); do
DEVICE_NUM=$(($node * $NUM_GPUS + $node_gpu))
GPU_DEVICES+=$(($DEVICE_NUM))
- if [ $node_gpu -lt $(($NUM_GPUS - 1)) ]; then
+ if [ "$node_gpu" -lt $(($NUM_GPUS - 1)) ]; then
GPU_DEVICES+=','
fi
done
GPU_DEVICES+='"'
echo "Running node$node with GPU devices: $GPU_DEVICES"
- if [ $node -ne 0 ]; then
- docker exec -d node$node /bin/bash -c "cd $WORKING_DIR ; ${COMMANDS[$node]}"
+ if [ "$node" -ne 0 ]; then
+ docker exec -d "node$node" /bin/bash -c "cd $WORKING_DIR ; ${COMMANDS[$node]}"
else
- docker exec node$node /bin/bash -c "cd $WORKING_DIR ; ${COMMANDS[$node]}"
+ docker exec "node$node" /bin/bash -c "cd $WORKING_DIR ; ${COMMANDS[$node]}"
fi
done
}
cleanup() {
for node in $(seq 0 $(($NUM_NODES-1))); do
- docker stop node$node
+ docker stop "node$node"
done
docker network rm docker-net
}
diff --git a/vllm/.buildkite/run-neuron-test.sh b/vllm/.buildkite/run-neuron-test.sh
index 252c0f7fe..9259391aa 100644
--- a/vllm/.buildkite/run-neuron-test.sh
+++ b/vllm/.buildkite/run-neuron-test.sh
@@ -1,3 +1,5 @@
+#!/bin/bash
+
# This script build the Neuron docker image and run the API server inside the container.
# It serves a sanity check for compilation and basic model usage.
set -e
@@ -12,10 +14,10 @@ if [ -f /tmp/neuron-docker-build-timestamp ]; then
current_time=$(date +%s)
if [ $((current_time - last_build)) -gt 86400 ]; then
docker system prune -f
- echo $current_time > /tmp/neuron-docker-build-timestamp
+ echo "$current_time" > /tmp/neuron-docker-build-timestamp
fi
else
- echo $(date +%s) > /tmp/neuron-docker-build-timestamp
+ date "+%s" > /tmp/neuron-docker-build-timestamp
fi
docker build -t neuron -f Dockerfile.neuron .
@@ -34,7 +36,7 @@ wait_for_server_to_start() {
timeout=300
counter=0
- while [ "$(curl -s -o /dev/null -w ''%{http_code}'' localhost:8000/health)" != "200" ]; do
+ while [ "$(curl -s -o /dev/null -w '%{http_code}' localhost:8000/health)" != "200" ]; do
sleep 1
counter=$((counter + 1))
if [ $counter -ge $timeout ]; then
diff --git a/vllm/.buildkite/run-openvino-test.sh b/vllm/.buildkite/run-openvino-test.sh
index 70e56596c..6b12f424f 100755
--- a/vllm/.buildkite/run-openvino-test.sh
+++ b/vllm/.buildkite/run-openvino-test.sh
@@ -1,3 +1,5 @@
+#!/bin/bash
+
# This script build the OpenVINO docker image and run the offline inference inside the container.
# It serves a sanity check for compilation and basic model usage.
set -ex
@@ -11,4 +13,4 @@ trap remove_docker_container EXIT
remove_docker_container
# Run the image and launch offline inference
-docker run --network host --env VLLM_OPENVINO_KVCACHE_SPACE=1 --name openvino-test openvino-test python3 /workspace/vllm/examples/offline_inference.py
+docker run --network host --env VLLM_OPENVINO_KVCACHE_SPACE=1 --name openvino-test openvino-test python3 /workspace/examples/offline_inference.py
diff --git a/vllm/.buildkite/run-tpu-test.sh b/vllm/.buildkite/run-tpu-test.sh
index 6989c94d4..770dad6ff 100644
--- a/vllm/.buildkite/run-tpu-test.sh
+++ b/vllm/.buildkite/run-tpu-test.sh
@@ -1,3 +1,5 @@
+#!/bin/bash
+
set -e
# Build the docker image.
@@ -12,4 +14,4 @@ remove_docker_container
# For HF_TOKEN.
source /etc/environment
# Run a simple end-to-end example.
-docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py"
+docker run --privileged --net host --shm-size=16G -it -e "HF_TOKEN=$HF_TOKEN" --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && python3 -m pip install lm_eval[api]==0.4.4 && pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py"
diff --git a/vllm/.buildkite/run-xpu-test.sh b/vllm/.buildkite/run-xpu-test.sh
index 6ffa66d5e..e0a12afbe 100644
--- a/vllm/.buildkite/run-xpu-test.sh
+++ b/vllm/.buildkite/run-xpu-test.sh
@@ -1,3 +1,5 @@
+#!/bin/bash
+
# This script build the CPU docker image and run the offline inference inside the container.
# It serves a sanity check for compilation and basic model usage.
set -ex
@@ -10,5 +12,8 @@ remove_docker_container() { docker rm -f xpu-test || true; }
trap remove_docker_container EXIT
remove_docker_container
-# Run the image and launch offline inference
-docker run --network host --name xpu-test --device /dev/dri -v /dev/dri/by-path:/dev/dri/by-path --entrypoint="" xpu-test python3 examples/offline_inference.py
+# Run the image and test offline inference/tensor parallel
+docker run --name xpu-test --device /dev/dri -v /dev/dri/by-path:/dev/dri/by-path --entrypoint="" xpu-test sh -c '
+ python3 examples/offline_inference.py
+ python3 examples/offline_inference_cli.py -tp 2
+'
diff --git a/vllm/.buildkite/test-pipeline.yaml b/vllm/.buildkite/test-pipeline.yaml
index 8c98aa36a..bf0de3f69 100644
--- a/vllm/.buildkite/test-pipeline.yaml
+++ b/vllm/.buildkite/test-pipeline.yaml
@@ -9,7 +9,7 @@
# label(str): the name of the test. emoji allowed.
# fast_check(bool): whether to run this on each commit on fastcheck pipeline.
# fast_check_only(bool): run this test on fastcheck pipeline only
-# optional(bool): never run this test by default (i.e. need to unblock manually)
+# optional(bool): never run this test by default (i.e. need to unblock manually) unless it's scheduled nightly run.
# command(str): the single command to run for tests. incompatible with commands.
# commands(list): the list of commands to run for test. incompatbile with command.
# mirror_hardwares(list): the list of hardwares to run the test on as well. currently only supports [amd]
@@ -50,7 +50,9 @@ steps:
- tests/multimodal
- tests/test_utils
- tests/worker
+ - tests/standalone_tests/lazy_torch_compile.py
commands:
+ - python3 standalone_tests/lazy_torch_compile.py
- pytest -v -s mq_llm_engine # MQLLMEngine
- pytest -v -s async_engine # AsyncLLMEngine
- NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py
@@ -59,6 +61,13 @@ steps:
- pytest -v -s test_utils.py # Utils
- pytest -v -s worker # Worker
+- label: Python-only Installation Test
+ source_file_dependencies:
+ - tests/standalone_tests/python_only_compile.sh
+ - setup.py
+ commands:
+ - bash standalone_tests/python_only_compile.sh
+
- label: Basic Correctness Test # 30min
#mirror_hardwares: [amd]
fast_check: true
@@ -119,6 +128,7 @@ steps:
- tests/spec_decode/e2e/test_integration_dist_tp4
- tests/compile
commands:
+ - pytest -v -s distributed/test_utils.py
- pytest -v -s compile/test_basic_correctness.py
- pytest -v -s distributed/test_pynccl.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
@@ -163,6 +173,14 @@ steps:
# OOM in the CI unless we run this separately
- pytest -v -s tokenization
+- label: V1 Test
+ #mirror_hardwares: [amd]
+ source_file_dependencies:
+ - vllm/
+ - tests/v1
+ commands:
+ - VLLM_USE_V1=1 pytest -v -s v1
+
- label: Examples Test # 15min
working_dir: "/vllm-workspace/examples"
#mirror_hardwares: [amd]
@@ -219,7 +237,7 @@ steps:
source_file_dependencies:
- vllm/lora
- tests/lora
- command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py
+ command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore lora/test_long_context.py lora/test_chatglm3_tp.py lora/test_llama_tp.py
parallelism: 4
- label: "PyTorch Fullgraph Smoke Test" # 9min
@@ -229,6 +247,9 @@ steps:
- tests/compile
commands:
- pytest -v -s compile/test_basic_correctness.py
+ # these tests need to be separated, cannot combine
+ - pytest -v -s compile/piecewise/test_simple.py
+ - pytest -v -s compile/piecewise/test_toy_llama.py
- label: "PyTorch Fullgraph Test" # 18min
source_file_dependencies:
@@ -264,7 +285,6 @@ steps:
source_file_dependencies:
- benchmarks/
commands:
- - pip install aiohttp
- bash run-benchmarks.sh
- label: Quantization Test # 33min
@@ -301,55 +321,70 @@ steps:
##### models test #####
-- label: Basic Models Test # 3min
+- label: Basic Models Test # 30min
source_file_dependencies:
- vllm/
- tests/models
commands:
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s models/test_oot_registration.py # it needs a clean process
- - pytest -v -s models/*.py --ignore=models/test_oot_registration.py
+ - pytest -v -s models/test_registry.py
+ - pytest -v -s models/test_initialization.py
-- label: Decoder-only Language Models Test (Standard) # 35min
+- label: Language Models Test (Standard) # 42min
#mirror_hardwares: [amd]
source_file_dependencies:
- vllm/
- tests/models/decoder_only/language
+ - tests/models/embedding/language
+ - tests/models/encoder_decoder/language
commands:
- - pytest -v -s models/decoder_only/language/test_models.py
- - pytest -v -s models/decoder_only/language/test_big_models.py
+ - pytest -v -s models/decoder_only/language -m 'core_model or quant_model'
+ - pytest -v -s models/embedding/language -m core_model
-- label: Decoder-only Language Models Test (Extended) # 1h20min
- nightly: true
+- label: Language Models Test (Extended) # 50min
+ optional: true
source_file_dependencies:
- vllm/
- tests/models/decoder_only/language
+ - tests/models/embedding/language
+ - tests/models/encoder_decoder/language
commands:
- - pytest -v -s models/decoder_only/language --ignore=models/decoder_only/language/test_models.py --ignore=models/decoder_only/language/test_big_models.py
+ - pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
+ - pytest -v -s models/embedding/language -m 'not core_model'
-- label: Decoder-only Multi-Modal Models Test # 1h31min
+- label: Multi-Modal Models Test (Standard) # 26min
#mirror_hardwares: [amd]
source_file_dependencies:
- vllm/
- tests/models/decoder_only/audio_language
- tests/models/decoder_only/vision_language
+ - tests/models/embedding/vision_language
+ - tests/models/encoder_decoder/vision_language
commands:
- - pytest -v -s models/decoder_only/audio_language
- - pytest -v -s models/decoder_only/vision_language
+ - pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model'
+ - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model'
+ - pytest -v -s models/embedding/vision_language -m core_model
+ - pytest -v -s models/encoder_decoder/language -m core_model
+ - pytest -v -s models/encoder_decoder/vision_language -m core_model
-- label: Other Models Test # 6min
- #mirror_hardwares: [amd]
+- label: Multi-Modal Models Test (Extended) # 1h15m
+ optional: true
source_file_dependencies:
- vllm/
- - tests/models/embedding/language
+ - tests/models/decoder_only/audio_language
+ - tests/models/decoder_only/vision_language
- tests/models/embedding/vision_language
- - tests/models/encoder_decoder/language
- tests/models/encoder_decoder/vision_language
commands:
- - pytest -v -s models/embedding/language
- - pytest -v -s models/embedding/vision_language
- - pytest -v -s models/encoder_decoder/language
- - pytest -v -s models/encoder_decoder/vision_language
+ - pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model'
+ # HACK - run phi3v tests separately to sidestep this transformers bug
+ # https://github.com/huggingface/transformers/issues/34307
+ - pytest -v -s models/decoder_only/vision_language/test_phi3v.py
+ - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model'
+ - pytest -v -s models/embedding/vision_language -m 'not core_model'
+ - pytest -v -s models/encoder_decoder/language -m 'not core_model'
+ - pytest -v -s models/encoder_decoder/vision_language -m 'not core_model'
# This test is used only in PR development phase to test individual models and should never run on main
- label: Custom Models Test
@@ -402,6 +437,9 @@ steps:
- vllm/model_executor/models/
- tests/distributed/
- vllm/compilation
+ - vllm/worker/worker_base.py
+ - vllm/worker/worker.py
+ - vllm/worker/model_runner.py
commands:
- pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py
@@ -410,12 +448,12 @@ steps:
# Avoid importing model tests that cause CUDA reinitialization error
- pytest models/encoder_decoder/language/test_bart.py -v -s -m distributed_2_gpus
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m distributed_2_gpus
- - pytest models/decoder_only/vision_language/test_broadcast.py -v -s -m distributed_2_gpus
+ - pytest models/decoder_only/vision_language/test_models.py -v -s -m distributed_2_gpus
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s distributed/test_distributed_oot.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
+ - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/disagg_test.py
- label: Multi-step Tests (4 GPUs) # 36min
working_dir: "/vllm-workspace/tests"
@@ -448,18 +486,22 @@ steps:
- pytest -v -s distributed/test_pp_cudagraph.py
- pytest -v -s distributed/test_pipeline_parallel.py
-- label: LoRA Long Context (Distributed) # 11min
- # This test runs llama 13B, so it is required to run on 4 GPUs.
+- label: LoRA TP Test (Distributed)
num_gpus: 4
- soft_fail: true
source_file_dependencies:
- vllm/lora
- - tests/lora/test_long_context
+ - tests/lora
commands:
# FIXIT: find out which code initialize cuda before running the test
# before the fix, we need to use spawn to test it
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
+ # This test runs llama 13B, so it is required to run on 4 GPUs.
- pytest -v -s -x lora/test_long_context.py
+ # There is some Tensor Parallelism related processing logic in LoRA that
+ # requires multi-GPU testing for validation.
+ - pytest -v -s -x lora/test_chatglm3_tp.py
+ - pytest -v -s -x lora/test_llama_tp.py
+
- label: Weight Loading Multiple GPU Test # 33min
working_dir: "/vllm-workspace/tests"
@@ -487,6 +529,7 @@ steps:
- label: Distributed Tests (A100) # optional
gpu: a100
+ optional: true
num_gpus: 4
source_file_dependencies:
- vllm/
@@ -494,11 +537,13 @@ steps:
# NOTE: don't test llama model here, it seems hf implementation is buggy
# see https://github.com/vllm-project/vllm/pull/5689 for details
- pytest -v -s distributed/test_custom_all_reduce.py
+ - torchrun --nproc_per_node=2 distributed/test_ca_buffer_sharing.py
- TARGET_TEST_SUITE=A100 pytest basic_correctness/ -v -s -m distributed_2_gpus
- pytest -v -s -x lora/test_mixtral.py
- label: LM Eval Large Models # optional
gpu: a100
+ optional: true
num_gpus: 4
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
source_file_dependencies:
diff --git a/vllm/.buildkite/upload-wheels.sh b/vllm/.buildkite/upload-wheels.sh
new file mode 100644
index 000000000..7345dd4e6
--- /dev/null
+++ b/vllm/.buildkite/upload-wheels.sh
@@ -0,0 +1,43 @@
+#!/usr/bin/env bash
+
+set -ex
+
+# Assume wheels are in artifacts/dist/*.whl
+wheel_files=(artifacts/dist/*.whl)
+
+# Check that exactly one wheel is found
+if [[ ${#wheel_files[@]} -ne 1 ]]; then
+ echo "Error: Expected exactly one wheel file in artifacts/dist/, but found ${#wheel_files[@]}"
+ exit 1
+fi
+
+# Get the single wheel file
+wheel="${wheel_files[0]}"
+
+# Rename 'linux' to 'manylinux1' in the wheel filename
+new_wheel="${wheel/linux/manylinux1}"
+mv -- "$wheel" "$new_wheel"
+wheel="$new_wheel"
+
+# Extract the version from the wheel
+version=$(unzip -p "$wheel" '**/METADATA' | grep '^Version: ' | cut -d' ' -f2)
+echo "Version: $version"
+
+# If the version contains "dev", rename it to v1.0.0.dev for consistency
+if [[ $version == *dev* ]]; then
+ suffix="${version##*.}"
+ if [[ $suffix == cu* ]]; then
+ new_version="1.0.0.dev+${suffix}"
+ else
+ new_version="1.0.0.dev"
+ fi
+ new_wheel="${wheel/$version/$new_version}"
+ mv -- "$wheel" "$new_wheel"
+ wheel="$new_wheel"
+ version="$new_version"
+fi
+
+# Upload the wheel to S3
+aws s3 cp "$wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/"
+aws s3 cp "$wheel" "s3://vllm-wheels/nightly/"
+aws s3 cp "$wheel" "s3://vllm-wheels/$version/"
\ No newline at end of file
diff --git a/vllm/.github/CODEOWNERS b/vllm/.github/CODEOWNERS
index cd721971d..3cb91fc0f 100644
--- a/vllm/.github/CODEOWNERS
+++ b/vllm/.github/CODEOWNERS
@@ -3,13 +3,16 @@
# This lists cover the "core" components of vLLM that require careful review
/vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
-/vllm/core @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
-/vllm/engine/llm_engine.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
-/vllm/executor/executor_base.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
-/vllm/worker/worker_base.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
-/vllm/worker/worker.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
-/vllm/model_executor/layers/sampler.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
-CMakeLists.txt @tlrmchlsmth @WoosukKwon
+/vllm/core @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
+/vllm/engine/llm_engine.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
+/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
+/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
+/vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
+/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
+CMakeLists.txt @tlrmchlsmth
+
+# vLLM V1
+/vllm/v1 @WoosukKwon @robertgshaw2-neuralmagic @njhill @ywang96 @comaniac @alexm-neuralmagic
# Test ownership
/tests/async_engine @njhill @robertgshaw2-neuralmagic @simon-mo
diff --git a/vllm/.github/FUNDING.yml b/vllm/.github/FUNDING.yml
index 71f4e5201..d1f6105a4 100644
--- a/vllm/.github/FUNDING.yml
+++ b/vllm/.github/FUNDING.yml
@@ -1,2 +1,2 @@
github: [vllm-project]
-open_collective: [vllm]
+open_collective: vllm
diff --git a/vllm/.github/PULL_REQUEST_TEMPLATE.md b/vllm/.github/PULL_REQUEST_TEMPLATE.md
index be0afc630..51a73c857 100644
--- a/vllm/.github/PULL_REQUEST_TEMPLATE.md
+++ b/vllm/.github/PULL_REQUEST_TEMPLATE.md
@@ -2,73 +2,4 @@ FILL IN THE PR DESCRIPTION HERE
FIX #xxxx (*link existing issues this PR will resolve*)
-**BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE**
-
----
-
-
-
- PR Checklist (Click to Expand)
-
-Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.
-
-PR Title and Classification
-Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:
-
- [Bugfix]
for bug fixes.
- [CI/Build]
for build or continuous integration improvements.
- [Doc]
for documentation fixes and improvements.
- [Model]
for adding a new model or improving an existing model. Model name should appear in the title.
- [Frontend]
For changes on the vLLM frontend (e.g., OpenAI API server, LLM
class, etc.)
- [Kernel]
for changes affecting CUDA kernels or other compute kernels.
- [Core]
for changes in the core vLLM logic (e.g., LLMEngine
, AsyncLLMEngine
, Scheduler
, etc.)
- [Hardware][Vendor]
for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]
).
- [Misc]
for PRs that do not fit the above categories. Please use this sparingly.
-
-Note: If the PR spans more than one category, please include all relevant prefixes.
-
-Code Quality
-
-The PR need to meet the following code quality standards:
-
-
- - We adhere to Google Python style guide and Google C++ style guide.
- - Pass all linter checks. Please use
format.sh
to format your code.
- - The code need to be well-documented to ensure future contributors can easily understand the code.
- - Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
- - Please add documentation to
docs/source/
if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.
-
-
-Adding or changing kernels
-Each custom kernel needs a schema and one or more implementations to be registered with PyTorch.
-
- - Make sure custom ops are registered following PyTorch guidelines: Custom C++ and CUDA Operators and The Custom Operators Manual
- - Custom operations that return
Tensors
require meta-functions. Meta-functions should be implemented and registered in python so that dynamic dims can be handled automatically. See above documents for a description of meta-functions.
- - Use
torch.libary.opcheck()
to test the function registration and meta-function for any registered ops. See tests/kernels
for examples.
- - When changing the C++ signature of an existing op, the schema must be updated to reflect the changes.
- - If a new custom type is needed, see the following document: Custom Class Support in PT2.
-
-
-Notes for Large Changes
-Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required
and might not go through the PR.
-
-What to Expect for the Reviews
-
-The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:
-
-
- - After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
- - After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
- - After the review, the reviewer will put an
action-required
label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
- - Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.
-
-
-
-Thank You
-
- Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!
-
-
-
-
-
+**BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html **
diff --git a/vllm/.github/dependabot.yml b/vllm/.github/dependabot.yml
index 6fddca0d6..683b70cd8 100644
--- a/vllm/.github/dependabot.yml
+++ b/vllm/.github/dependabot.yml
@@ -5,3 +5,27 @@ updates:
directory: "/"
schedule:
interval: "weekly"
+ - package-ecosystem: "pip"
+ directory: "/"
+ schedule:
+ interval: "weekly"
+ labels: ["dependencies"]
+ open-pull-requests-limit: 5
+ reviewers: ["khluu", "simon-mo"]
+ allow:
+ - dependency-type: "all"
+ ignore:
+ - dependency-name: "*"
+ update-types: ["version-update:semver-patch"]
+ - dependency-name: "torch"
+ - dependency-name: "torchvision"
+ - dependency-name: "xformers"
+ - dependency-name: "lm-format-enforcer"
+ - dependency-name: "gguf"
+ - dependency-name: "compressed-tensors"
+ - dependency-name: "ray[adag]"
+ - dependency-name: "lm-eval"
+ groups:
+ minor-update:
+ applies-to: version-updates
+ update-types: ["minor"]
diff --git a/vllm/.github/mergify.yml b/vllm/.github/mergify.yml
index 2a3dee7c6..ca4bd7ee2 100644
--- a/vllm/.github/mergify.yml
+++ b/vllm/.github/mergify.yml
@@ -13,13 +13,14 @@ pull_request_rules:
- name: label-ci-build
description: Automatically apply ci/build label
conditions:
- - files~=^\.github/
- - files~=\.buildkite/
- - files~=^cmake/
- - files=CMakeLists.txt
- - files~=^Dockerfile
- - files~=^requirements.*\.txt
- - files=setup.py
+ - or:
+ - files~=^\.github/
+ - files~=\.buildkite/
+ - files~=^cmake/
+ - files=CMakeLists.txt
+ - files~=^Dockerfile
+ - files~=^requirements.*\.txt
+ - files=setup.py
actions:
label:
add:
@@ -45,7 +46,9 @@ pull_request_rules:
comment:
message: |
This pull request has merge conflicts that must be resolved before it can be
- merged. @{{author}} please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
+ merged. Please rebase the PR, @{{author}}.
+
+ https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
- name: remove 'needs-rebase' label when conflict is resolved
conditions:
diff --git a/vllm/.github/scripts/cleanup_pr_body.sh b/vllm/.github/scripts/cleanup_pr_body.sh
new file mode 100755
index 000000000..3246c6f9b
--- /dev/null
+++ b/vllm/.github/scripts/cleanup_pr_body.sh
@@ -0,0 +1,50 @@
+#!/bin/bash
+
+set -eu
+
+# ensure 1 argument is passed
+if [ "$#" -ne 1 ]; then
+ echo "Usage: $0 "
+ exit 1
+fi
+
+PR_NUMBER=$1
+OLD=/tmp/orig_pr_body.txt
+NEW=/tmp/new_pr_body.txt
+
+gh pr view --json body --template "{{.body}}" "${PR_NUMBER}" > "${OLD}"
+cp "${OLD}" "${NEW}"
+
+# Remove "FIX #xxxx (*link existing issues this PR will resolve*)"
+sed -i '/FIX #xxxx.*$/d' "${NEW}"
+
+# Remove "FILL IN THE PR DESCRIPTION HERE"
+sed -i '/FILL IN THE PR DESCRIPTION HERE/d' "${NEW}"
+
+# Remove all lines after and including "**BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE**"
+sed -i '/\*\*BEFORE SUBMITTING, PLEASE READ.*\*\*/,$d' "${NEW}"
+
+# Remove HTML section that includes text of "PR Checklist (Click to Expand)"
+python3 - <.*?.*?PR Checklist \(Click to Expand\).*?
.*?
', re.DOTALL)
+content = re.sub(pattern, '', content)
+
+with open("${NEW}", "w") as file:
+ file.write(content)
+EOF
+
+# Run this only if ${NEW} is different than ${OLD}
+if ! cmp -s "${OLD}" "${NEW}"; then
+ gh pr edit --body-file "${NEW}" "${PR_NUMBER}"
+ echo
+ echo "Updated PR body:"
+ echo
+ cat "${NEW}"
+else
+ echo "No changes needed"
+fi
diff --git a/vllm/.github/workflows/actionlint.yml b/vllm/.github/workflows/actionlint.yml
index b80749aaa..0226cf0ca 100644
--- a/vllm/.github/workflows/actionlint.yml
+++ b/vllm/.github/workflows/actionlint.yml
@@ -6,12 +6,14 @@ on:
paths:
- '.github/workflows/*.ya?ml'
- '.github/workflows/actionlint.*'
+ - '.github/workflows/matchers/actionlint.json'
pull_request:
branches:
- "main"
paths:
- '.github/workflows/*.ya?ml'
- '.github/workflows/actionlint.*'
+ - '.github/workflows/matchers/actionlint.json'
env:
LC_ALL: en_US.UTF-8
@@ -28,7 +30,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: "Checkout"
- uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
fetch-depth: 0
diff --git a/vllm/.github/workflows/clang-format.yml b/vllm/.github/workflows/clang-format.yml
index 68d60d736..68149d2dc 100644
--- a/vllm/.github/workflows/clang-format.yml
+++ b/vllm/.github/workflows/clang-format.yml
@@ -6,9 +6,21 @@ on:
push:
branches:
- main
+ paths:
+ - '**/*.h'
+ - '**/*.cpp'
+ - '**/*.cu'
+ - '**/*.cuh'
+ - '.github/workflows/clang-format.yml'
pull_request:
branches:
- main
+ paths:
+ - '**/*.h'
+ - '**/*.cpp'
+ - '**/*.cu'
+ - '**/*.cuh'
+ - '.github/workflows/clang-format.yml'
jobs:
clang-format:
@@ -17,9 +29,9 @@ jobs:
matrix:
python-version: ["3.11"]
steps:
- - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1
+ - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
diff --git a/vllm/.github/workflows/cleanup_pr_body.yml b/vllm/.github/workflows/cleanup_pr_body.yml
new file mode 100644
index 000000000..0085a1cc2
--- /dev/null
+++ b/vllm/.github/workflows/cleanup_pr_body.yml
@@ -0,0 +1,26 @@
+name: Cleanup PR Body
+
+on:
+ pull_request_target:
+ types: [opened, reopened, edited]
+
+permissions:
+ pull-requests: write
+
+jobs:
+ update-description:
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+
+ - name: Set up Python
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
+ with:
+ python-version: '3.12'
+
+ - name: Update PR description
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ run: .github/scripts/cleanup_pr_body.sh "${{ github.event.number }}"
diff --git a/vllm/.github/workflows/codespell.yml b/vllm/.github/workflows/codespell.yml
new file mode 100644
index 000000000..68887adaa
--- /dev/null
+++ b/vllm/.github/workflows/codespell.yml
@@ -0,0 +1,45 @@
+name: codespell
+
+on:
+ # Trigger the workflow on push or pull request,
+ # but only for the main branch
+ push:
+ branches:
+ - main
+ paths:
+ - "**/*.py"
+ - "**/*.md"
+ - "**/*.rst"
+ - pyproject.toml
+ - requirements-lint.txt
+ - .github/workflows/codespell.yml
+ pull_request:
+ branches:
+ - main
+ paths:
+ - "**/*.py"
+ - "**/*.md"
+ - "**/*.rst"
+ - pyproject.toml
+ - requirements-lint.txt
+ - .github/workflows/codespell.yml
+
+jobs:
+ codespell:
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python-version: ["3.12"]
+ steps:
+ - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install -r requirements-lint.txt
+ - name: Spelling check with codespell
+ run: |
+ codespell --toml pyproject.toml
diff --git a/vllm/.github/workflows/mypy.yaml b/vllm/.github/workflows/mypy.yaml
index 5f1e5f8ee..73eeacf1f 100644
--- a/vllm/.github/workflows/mypy.yaml
+++ b/vllm/.github/workflows/mypy.yaml
@@ -6,20 +6,35 @@ on:
push:
branches:
- main
+ paths:
+ - '**/*.py'
+ - '.github/workflows/mypy.yaml'
+ - 'tools/mypy.sh'
+ - 'pyproject.toml'
pull_request:
branches:
- main
+ # This workflow is only relevant when one of the following files changes.
+ # However, we have github configured to expect and require this workflow
+ # to run and pass before github with auto-merge a pull request. Until github
+ # allows more flexible auto-merge policy, we can just run this on every PR.
+ # It doesn't take that long to run, anyway.
+ #paths:
+ # - '**/*.py'
+ # - '.github/workflows/mypy.yaml'
+ # - 'tools/mypy.sh'
+ # - 'pyproject.toml'
jobs:
mypy:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
+ python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1
+ - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
@@ -33,4 +48,4 @@ jobs:
- name: Mypy
run: |
echo "::add-matcher::.github/workflows/matchers/mypy.json"
- tools/mypy.sh 1
+ tools/mypy.sh 1 ${{ matrix.python-version }}
diff --git a/vllm/.github/workflows/png-lint.yml b/vllm/.github/workflows/png-lint.yml
new file mode 100644
index 000000000..4932af943
--- /dev/null
+++ b/vllm/.github/workflows/png-lint.yml
@@ -0,0 +1,37 @@
+name: Lint PNG exports from excalidraw
+on:
+ push:
+ branches:
+ - "main"
+ paths:
+ - '*.excalidraw.png'
+ - '.github/workflows/png-lint.yml'
+ pull_request:
+ branches:
+ - "main"
+ paths:
+ - '*.excalidraw.png'
+ - '.github/workflows/png-lint.yml'
+
+env:
+ LC_ALL: en_US.UTF-8
+
+defaults:
+ run:
+ shell: bash
+
+permissions:
+ contents: read
+
+jobs:
+ actionlint:
+ runs-on: ubuntu-latest
+ steps:
+ - name: "Checkout"
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+ with:
+ fetch-depth: 0
+
+ - name: "Run png-lint.sh to check excalidraw exported images"
+ run: |
+ tools/png-lint.sh
diff --git a/vllm/.github/workflows/publish.yml b/vllm/.github/workflows/publish.yml
index f959a1cac..c1051d10a 100644
--- a/vllm/.github/workflows/publish.yml
+++ b/vllm/.github/workflows/publish.yml
@@ -21,7 +21,7 @@ jobs:
upload_url: ${{ steps.create_release.outputs.upload_url }}
steps:
- name: Checkout
- uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Extract branch info
shell: bash
@@ -48,13 +48,13 @@ jobs:
fail-fast: false
matrix:
os: ['ubuntu-20.04']
- python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
+ python-version: ['3.9', '3.10', '3.11', '3.12']
pytorch-version: ['2.4.0'] # Must be the most recent version that meets requirements-cuda.txt.
cuda-version: ['11.8', '12.1']
steps:
- name: Checkout
- uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Setup ccache
uses: hendrikmuhs/ccache-action@ed74d11c0b343532753ecead8a951bb09bb34bc9 # v1.2.14
@@ -68,7 +68,7 @@ jobs:
bash -x .github/workflows/scripts/env.sh
- name: Set up Python
- uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
diff --git a/vllm/.github/workflows/ruff.yml b/vllm/.github/workflows/ruff.yml
index 9cc8a9e91..7266cc378 100644
--- a/vllm/.github/workflows/ruff.yml
+++ b/vllm/.github/workflows/ruff.yml
@@ -6,33 +6,47 @@ on:
push:
branches:
- main
+ paths:
+ - "**/*.py"
+ - pyproject.toml
+ - requirements-lint.txt
+ - .github/workflows/matchers/ruff.json
+ - .github/workflows/ruff.yml
pull_request:
branches:
- main
+ # This workflow is only relevant when one of the following files changes.
+ # However, we have github configured to expect and require this workflow
+ # to run and pass before github with auto-merge a pull request. Until github
+ # allows more flexible auto-merge policy, we can just run this on every PR.
+ # It doesn't take that long to run, anyway.
+ #paths:
+ # - "**/*.py"
+ # - pyproject.toml
+ # - requirements-lint.txt
+ # - .github/workflows/matchers/ruff.json
+ # - .github/workflows/ruff.yml
jobs:
ruff:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
+ python-version: ["3.12"]
steps:
- - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1
- - name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
- with:
- python-version: ${{ matrix.python-version }}
- - name: Install dependencies
- run: |
- python -m pip install --upgrade pip
- pip install -r requirements-lint.txt
- - name: Analysing the code with ruff
- run: |
- echo "::add-matcher::.github/workflows/matchers/ruff.json"
- ruff check --output-format github .
- - name: Spelling check with codespell
- run: |
- codespell --toml pyproject.toml
- - name: Run isort
- run: |
- isort . --check-only
+ - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install -r requirements-lint.txt
+ - name: Analysing the code with ruff
+ run: |
+ echo "::add-matcher::.github/workflows/matchers/ruff.json"
+ ruff check --output-format github .
+ - name: Run isort
+ run: |
+ isort . --check-only
diff --git a/vllm/.github/workflows/scripts/cuda-install.sh b/vllm/.github/workflows/scripts/cuda-install.sh
index 312c6e82f..3d0b7a1fe 100644
--- a/vllm/.github/workflows/scripts/cuda-install.sh
+++ b/vllm/.github/workflows/scripts/cuda-install.sh
@@ -1,16 +1,16 @@
#!/bin/bash
# Replace '.' with '-' ex: 11.8 -> 11-8
-cuda_version=$(echo $1 | tr "." "-")
+cuda_version=$(echo "$1" | tr "." "-")
# Removes '-' and '.' ex: ubuntu-20.04 -> ubuntu2004
-OS=$(echo $2 | tr -d ".\-")
+OS=$(echo "$2" | tr -d ".\-")
# Installs CUDA
-wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-keyring_1.1-1_all.deb
+wget -nv "https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-keyring_1.1-1_all.deb"
sudo dpkg -i cuda-keyring_1.1-1_all.deb
rm cuda-keyring_1.1-1_all.deb
sudo apt -qq update
-sudo apt -y install cuda-${cuda_version} cuda-nvcc-${cuda_version} cuda-libraries-dev-${cuda_version}
+sudo apt -y install "cuda-${cuda_version}" "cuda-nvcc-${cuda_version}" "cuda-libraries-dev-${cuda_version}"
sudo apt clean
# Test nvcc
diff --git a/vllm/.github/workflows/scripts/pytorch-install.sh b/vllm/.github/workflows/scripts/pytorch-install.sh
index dfc1851d7..e3cda7dad 100644
--- a/vllm/.github/workflows/scripts/pytorch-install.sh
+++ b/vllm/.github/workflows/scripts/pytorch-install.sh
@@ -6,7 +6,7 @@ cuda_version=$3
# Install torch
$python_executable -m pip install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses setuptools && conda clean -ya
-$python_executable -m pip install torch==${pytorch_version}+cu${cuda_version//./} --extra-index-url https://download.pytorch.org/whl/cu${cuda_version//./}
+$python_executable -m pip install torch=="${pytorch_version}+cu${cuda_version//./}" --extra-index-url "https://download.pytorch.org/whl/cu${cuda_version//./}"
# Print version information
$python_executable --version
diff --git a/vllm/.github/workflows/shellcheck.yml b/vllm/.github/workflows/shellcheck.yml
new file mode 100644
index 000000000..4b1587e37
--- /dev/null
+++ b/vllm/.github/workflows/shellcheck.yml
@@ -0,0 +1,37 @@
+name: Lint shell scripts
+on:
+ push:
+ branches:
+ - "main"
+ paths:
+ - '**/*.sh'
+ - '.github/workflows/shellcheck.yml'
+ pull_request:
+ branches:
+ - "main"
+ paths:
+ - '**/*.sh'
+ - '.github/workflows/shellcheck.yml'
+
+env:
+ LC_ALL: en_US.UTF-8
+
+defaults:
+ run:
+ shell: bash
+
+permissions:
+ contents: read
+
+jobs:
+ shellcheck:
+ runs-on: ubuntu-latest
+ steps:
+ - name: "Checkout"
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+ with:
+ fetch-depth: 0
+
+ - name: "Check shell scripts"
+ run: |
+ tools/shellcheck.sh
diff --git a/vllm/.github/workflows/sphinx-lint.yml b/vllm/.github/workflows/sphinx-lint.yml
new file mode 100644
index 000000000..e0bb24276
--- /dev/null
+++ b/vllm/.github/workflows/sphinx-lint.yml
@@ -0,0 +1,32 @@
+name: Lint documentation
+
+on:
+ push:
+ branches:
+ - main
+ paths:
+ - "docs/**"
+ pull_request:
+ branches:
+ - main
+ paths:
+ - "docs/**"
+
+jobs:
+ sphinx-lint:
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python-version: ["3.12"]
+ steps:
+ - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install -r requirements-lint.txt
+ - name: Linting docs
+ run: tools/sphinx-lint.sh
diff --git a/vllm/.github/workflows/yapf.yml b/vllm/.github/workflows/yapf.yml
index 9f06b35c1..ff441f944 100644
--- a/vllm/.github/workflows/yapf.yml
+++ b/vllm/.github/workflows/yapf.yml
@@ -6,26 +6,33 @@ on:
push:
branches:
- main
+ paths:
+ - "**/*.py"
+ - .github/workflows/yapf.yml
pull_request:
branches:
- main
+ paths:
+ - "**/*.py"
+ - .github/workflows/yapf.yml
+
jobs:
yapf:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
+ python-version: ["3.12"]
steps:
- - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1
- - name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
- with:
- python-version: ${{ matrix.python-version }}
- - name: Install dependencies
- run: |
- python -m pip install --upgrade pip
- pip install yapf==0.32.0
- pip install toml==0.10.2
- - name: Running yapf
- run: |
- yapf --diff --recursive .
+ - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install yapf==0.32.0
+ pip install toml==0.10.2
+ - name: Running yapf
+ run: |
+ yapf --diff --recursive .
diff --git a/vllm/.gitignore b/vllm/.gitignore
index 1ea6e3419..ceef6a5fb 100644
--- a/vllm/.gitignore
+++ b/vllm/.gitignore
@@ -202,3 +202,4 @@ benchmarks/*.json
# Linting
actionlint
+shellcheck*/
diff --git a/vllm/.readthedocs.yaml b/vllm/.readthedocs.yaml
index 42cbf18a0..284196bc2 100644
--- a/vllm/.readthedocs.yaml
+++ b/vllm/.readthedocs.yaml
@@ -6,17 +6,16 @@ version: 2
build:
os: ubuntu-22.04
tools:
- python: "3.8"
+ python: "3.12"
sphinx:
- configuration: docs/source/conf.py
- fail_on_warning: true
+ configuration: docs/source/conf.py
+ fail_on_warning: true
# If using Sphinx, optionally build your docs in additional formats such as PDF
formats: []
# Optionally declare the Python requirements required to build your docs
python:
- install:
- - requirements: docs/requirements-docs.txt
-
+ install:
+ - requirements: docs/requirements-docs.txt
diff --git a/vllm/.shellcheckrc b/vllm/.shellcheckrc
new file mode 100644
index 000000000..f3b6eedf8
--- /dev/null
+++ b/vllm/.shellcheckrc
@@ -0,0 +1,9 @@
+# rules currently disabled:
+#
+# SC1091 (info): Not following: was not specified as input (see shellcheck -x)
+# SC2004 (style): $/${} is unnecessary on arithmetic variables.
+# SC2129 (style): Consider using { cmd1; cmd2; } >> file instead of individual redirects.
+# SC2155 (warning): Declare and assign separately to avoid masking return values.
+# SC2164 (warning): Use 'cd ... || exit' or 'cd ... || return' in case cd fails.
+#
+disable=SC1091,SC2004,SC2129,SC2155,SC2164
diff --git a/vllm/CMakeLists.txt b/vllm/CMakeLists.txt
index 1a6a311e9..c78cdc77a 100644
--- a/vllm/CMakeLists.txt
+++ b/vllm/CMakeLists.txt
@@ -31,13 +31,13 @@ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
# Supported python versions. These versions will be searched in order, the
# first match will be selected. These should be kept in sync with setup.py.
#
-set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11" "3.12")
+set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12")
# Supported NVIDIA architectures.
-set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0")
+set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0")
# Supported AMD GPU architectures.
-set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100")
+set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101")
#
# Supported/expected torch versions for CUDA/ROCm.
@@ -49,8 +49,8 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11
# requirements.txt files and should be kept consistent. The ROCm torch
# versions are derived from Dockerfile.rocm
#
-set(TORCH_SUPPORTED_VERSION_CUDA "2.5.0")
-set(TORCH_SUPPORTED_VERSION_ROCM "2.5.0")
+set(TORCH_SUPPORTED_VERSION_CUDA "2.5.1")
+set(TORCH_SUPPORTED_VERSION_ROCM "2.5.1")
#
# Try to find python package with an executable that exactly matches
@@ -128,9 +128,9 @@ endif()
if(VLLM_GPU_LANG STREQUAL "CUDA")
#
- # For cuda we want to be able to control which architectures we compile for on
+ # For cuda we want to be able to control which architectures we compile for on
# a per-file basis in order to cut down on compile time. So here we extract
- # the set of architectures we want to compile for and remove the from the
+ # the set of architectures we want to compile for and remove the from the
# CMAKE_CUDA_FLAGS so that they are not applied globally.
#
clear_cuda_arches(CUDA_ARCH_FLAGS)
@@ -138,7 +138,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "CUDA target architectures: ${CUDA_ARCHS}")
# Filter the target architectures by the supported supported archs
# since for some files we will build for all CUDA_ARCHS.
- cuda_archs_loose_intersection(CUDA_ARCHS
+ cuda_archs_loose_intersection(CUDA_ARCHS
"${CUDA_SUPPORTED_ARCHS}" "${CUDA_ARCHS}")
message(STATUS "CUDA supported target architectures: ${CUDA_ARCHS}")
else()
@@ -187,13 +187,16 @@ message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}")
set(VLLM_EXT_SRC
"csrc/cache_kernels.cu"
- "csrc/attention/attention_kernels.cu"
+ "csrc/attention/paged_attention_v1.cu"
+ "csrc/attention/paged_attention_v2.cu"
"csrc/pos_encoding_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
+ "csrc/layernorm_quant_kernels.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
+ "csrc/quantization/gguf/gguf_kernel.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/torch_bindings.cpp")
@@ -204,7 +207,19 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
set(CUTLASS_REVISION "v3.5.1" CACHE STRING "CUTLASS revision to use")
- FetchContent_Declare(
+ # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
+ if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
+ set(VLLM_CUTLASS_SRC_DIR $ENV{VLLM_CUTLASS_SRC_DIR})
+ endif()
+
+ if(VLLM_CUTLASS_SRC_DIR)
+ if(NOT IS_ABSOLUTE VLLM_CUTLASS_SRC_DIR)
+ get_filename_component(VLLM_CUTLASS_SRC_DIR "${VLLM_CUTLASS_SRC_DIR}" ABSOLUTE)
+ endif()
+ message(STATUS "The VLLM_CUTLASS_SRC_DIR is set, using ${VLLM_CUTLASS_SRC_DIR} for compilation")
+ FetchContent_Declare(cutlass SOURCE_DIR ${VLLM_CUTLASS_SRC_DIR})
+ else()
+ FetchContent_Declare(
cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
GIT_TAG v3.5.1
@@ -214,7 +229,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags.
# So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE
GIT_SHALLOW TRUE
- )
+ )
+ endif()
FetchContent_MakeAvailable(cutlass)
list(APPEND VLLM_EXT_SRC
@@ -222,7 +238,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
"csrc/quantization/aqlm/gemm_kernels.cu"
"csrc/quantization/awq/gemm_kernels.cu"
- "csrc/quantization/gguf/gguf_kernel.cu"
"csrc/custom_all_reduce.cu"
"csrc/permute_cols.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu")
@@ -234,9 +249,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# Only build Marlin kernels if we are building for at least some compatible archs.
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
# are not supported by Machete yet.
- cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.9;9.0" ${CUDA_ARCHS})
+ cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0" ${CUDA_ARCHS})
if (MARLIN_ARCHS)
- set(MARLIN_SRCS
+ set(MARLIN_SRCS
"csrc/quantization/fp8/fp8_marlin.cu"
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
@@ -277,7 +292,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"in CUDA target architectures")
endif()
- # clear SCALED_MM_3X_ARCHS so the scaled_mm_c2x kernels know we didn't
+ # clear SCALED_MM_3X_ARCHS so the scaled_mm_c2x kernels know we didn't
# build any 3x kernels
set(SCALED_MM_3X_ARCHS)
endif()
@@ -286,7 +301,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
# kernels for the remaining archs that are not already built for 3x.
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
- "7.5;8.0;8.6;8.9;9.0" "${CUDA_ARCHS}")
+ "7.5;8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}")
# subtract out the archs that are already built for 3x
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
if (SCALED_MM_2X_ARCHS)
@@ -316,10 +331,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
cuda_archs_loose_intersection(MACHETE_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND MACHETE_ARCHS)
#
- # For the Machete kernels we automatically generate sources for various
+ # For the Machete kernels we automatically generate sources for various
# preselected input type pairs and schedules.
# Generate sources:
- set(MACHETE_GEN_SCRIPT
+ set(MACHETE_GEN_SCRIPT
${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/machete/generate.py)
file(MD5 ${MACHETE_GEN_SCRIPT} MACHETE_GEN_SCRIPT_HASH)
@@ -329,8 +344,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
if (NOT DEFINED CACHE{MACHETE_GEN_SCRIPT_HASH}
OR NOT $CACHE{MACHETE_GEN_SCRIPT_HASH} STREQUAL ${MACHETE_GEN_SCRIPT_HASH})
execute_process(
- COMMAND ${CMAKE_COMMAND} -E env
- PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH
+ COMMAND ${CMAKE_COMMAND} -E env
+ PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH
${Python_EXECUTABLE} ${MACHETE_GEN_SCRIPT}
RESULT_VARIABLE machete_generation_result
OUTPUT_VARIABLE machete_generation_output
@@ -340,11 +355,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
if (NOT machete_generation_result EQUAL 0)
message(FATAL_ERROR "Machete generation failed."
- " Result: \"${machete_generation_result}\""
+ " Result: \"${machete_generation_result}\""
"\nCheck the log for details: "
"${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log")
else()
- set(MACHETE_GEN_SCRIPT_HASH ${MACHETE_GEN_SCRIPT_HASH}
+ set(MACHETE_GEN_SCRIPT_HASH ${MACHETE_GEN_SCRIPT_HASH}
CACHE STRING "Last run machete generate script hash" FORCE)
message(STATUS "Machete generation completed successfully.")
endif()
@@ -366,7 +381,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Building Machete kernels for archs: ${MACHETE_ARCHS}")
else()
- if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0
+ if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0
AND MACHETE_ARCHS)
message(STATUS "Not building Machete kernels as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
@@ -392,8 +407,8 @@ define_gpu_extension_target(
USE_SABI 3
WITH_SOABI)
-# If CUTLASS is compiled on NVCC >= 12.5, it by default uses
-# cudaGetDriverEntryPointByVersion as a wrapper to avoid directly calling the
+# If CUTLASS is compiled on NVCC >= 12.5, it by default uses
+# cudaGetDriverEntryPointByVersion as a wrapper to avoid directly calling the
# driver API. This causes problems when linking with earlier versions of CUDA.
# Setting this variable sidesteps the issue by calling the driver directly.
target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
@@ -412,7 +427,7 @@ set_gencode_flags_for_srcs(
CUDA_ARCHS "${CUDA_ARCHS}")
if(VLLM_GPU_LANG STREQUAL "CUDA")
- cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.9;9.0" "${CUDA_ARCHS}")
+ cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}")
if (MARLIN_MOE_ARCHS)
set(MARLIN_MOE_SRC
"csrc/moe/marlin_kernels/marlin_moe_kernel.h"
@@ -471,9 +486,9 @@ if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda")
return()
endif ()
-# vLLM flash attention requires VLLM_GPU_ARCHES to contain the set of target
-# arches in the CMake syntax (75-real, 89-virtual, etc), since we clear the
-# arches in the CUDA case (and instead set the gencodes on a per file basis)
+# vLLM flash attention requires VLLM_GPU_ARCHES to contain the set of target
+# arches in the CMake syntax (75-real, 89-virtual, etc), since we clear the
+# arches in the CUDA case (and instead set the gencodes on a per file basis)
# we need to manually set VLLM_GPU_ARCHES here.
if(VLLM_GPU_LANG STREQUAL "CUDA")
foreach(_ARCH ${CUDA_ARCHS})
@@ -507,7 +522,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
- GIT_TAG 5259c586c403a4e4d8bf69973c159b40cc346fb9
+ GIT_TAG 04325b6798bcc326c86fb35af62d05a9c8c8eceb
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
diff --git a/vllm/CONTRIBUTING.md b/vllm/CONTRIBUTING.md
index 5f79356bd..6d46a6dca 100644
--- a/vllm/CONTRIBUTING.md
+++ b/vllm/CONTRIBUTING.md
@@ -1,50 +1,3 @@
# Contributing to vLLM
-Thank you for your interest in contributing to vLLM! Our community is open to everyone and welcomes all kinds of contributions, no matter how small or large. There are several ways you can contribute to the project:
-
-- Identify and report any issues or bugs.
-- Request or add support for a new model.
-- Suggest or implement new features.
-- Improve documentation or contribute a how-to guide.
-
-We also believe in the power of community support; thus, answering queries, offering PR reviews, and assisting others are also highly regarded and beneficial contributions.
-
-Finally, one of the most impactful ways to support us is by raising awareness about vLLM. Talk about it in your blog posts and highlight how it's driving your incredible projects. Express your support on social media if you're using vLLM, or simply offer your appreciation by starring our repository!
-
-
-## Developing
-
-Depending on the kind of development you'd like to do (e.g. Python, CUDA), you can choose to build vLLM with or without compilation. Check out the [building from source](https://docs.vllm.ai/en/latest/getting_started/installation.html#build-from-source) documentation for details.
-
-
-## Testing
-
-```bash
-pip install -r requirements-dev.txt
-
-# linting and formatting
-bash format.sh
-# Static type checking
-mypy
-# Unit tests
-pytest tests/
-```
-**Note:** Currently, the repository does not pass the ``mypy`` tests.
-
-## Contribution Guidelines
-
-### Issues
-
-If you encounter a bug or have a feature request, please [search existing issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue) first to see if it has already been reported. If not, please [file a new issue](https://github.com/vllm-project/vllm/issues/new/choose), providing as much relevant information as possible.
-
-> [!IMPORTANT]
-> If you discover a security vulnerability, please follow the instructions [here](/SECURITY.md#reporting-a-vulnerability).
-
-### Pull Requests & Code Reviews
-
-Please check the PR checklist in the [PR template](.github/PULL_REQUEST_TEMPLATE.md) for detailed guide for contribution.
-
-### Thank You
-
-Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM.
-All of your contributions help make vLLM a great tool and community for everyone!
+You may find information about contributing to vLLM on [docs.vllm.ai](https://docs.vllm.ai/en/latest/contributing/overview.html).
diff --git a/vllm/DCO b/vllm/DCO
new file mode 100644
index 000000000..49b8cb054
--- /dev/null
+++ b/vllm/DCO
@@ -0,0 +1,34 @@
+Developer Certificate of Origin
+Version 1.1
+
+Copyright (C) 2004, 2006 The Linux Foundation and its contributors.
+
+Everyone is permitted to copy and distribute verbatim copies of this
+license document, but changing it is not allowed.
+
+
+Developer's Certificate of Origin 1.1
+
+By making a contribution to this project, I certify that:
+
+(a) The contribution was created in whole or in part by me and I
+ have the right to submit it under the open source license
+ indicated in the file; or
+
+(b) The contribution is based upon previous work that, to the best
+ of my knowledge, is covered under an appropriate open source
+ license and I have the right under that license to submit that
+ work with modifications, whether created in whole or in part
+ by me, under the same open source license (unless I am
+ permitted to submit under a different license), as indicated
+ in the file; or
+
+(c) The contribution was provided directly to me by some other
+ person who certified (a), (b) or (c) and I have not modified
+ it.
+
+(d) I understand and agree that this project and the contribution
+ are public and that a record of the contribution (including all
+ personal information I submit with it, including my sign-off) is
+ maintained indefinitely and may be redistributed consistent with
+ this project or the open source license(s) involved.
diff --git a/vllm/Dockerfile b/vllm/Dockerfile
index 0a562253c..682f046d4 100644
--- a/vllm/Dockerfile
+++ b/vllm/Dockerfile
@@ -191,6 +191,18 @@ ADD . /vllm-workspace/
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -r requirements-dev.txt
+# install development dependencies (for testing)
+RUN --mount=type=cache,target=/root/.cache/pip \
+ python3 -m pip install -e tests/vllm_test_utils
+
+# enable fast downloads from hf (for testing)
+RUN --mount=type=cache,target=/root/.cache/pip \
+ python3 -m pip install hf_transfer
+ENV HF_HUB_ENABLE_HF_TRANSFER 1
+
+# Copy in the v1 package for testing (it isn't distributed yet)
+COPY vllm/v1 /usr/local/lib/python3.12/dist-packages/vllm/v1
+
# doc requires source code
# we hide them inside `test_docs/` , so that this source code
# will not be imported by other tests
@@ -206,7 +218,7 @@ FROM vllm-base AS vllm-openai
# install additional dependencies for openai api server
RUN --mount=type=cache,target=/root/.cache/pip \
- pip install accelerate hf_transfer 'modelscope!=1.15.0' bitsandbytes>=0.44.0 timm==0.9.10
+ pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.44.0' timm==0.9.10
ENV VLLM_USAGE_SOURCE production-docker-image
diff --git a/vllm/Dockerfile.arm b/vllm/Dockerfile.arm
new file mode 100644
index 000000000..093ee2209
--- /dev/null
+++ b/vllm/Dockerfile.arm
@@ -0,0 +1,62 @@
+# This vLLM Dockerfile is used to construct an image that can build and run vLLM on ARM CPU platform.
+
+FROM ubuntu:22.04 AS cpu-test-arm
+
+ENV CCACHE_DIR=/root/.cache/ccache
+
+ENV CMAKE_CXX_COMPILER_LAUNCHER=ccache
+
+RUN --mount=type=cache,target=/var/cache/apt \
+ apt-get update -y \
+ && apt-get install -y curl ccache git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 libnuma-dev \
+ && apt-get install -y ffmpeg libsm6 libxext6 libgl1 \
+ && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
+
+# tcmalloc provides better memory allocation efficiency, e.g., holding memory in caches to speed up access of commonly-used objects.
+RUN --mount=type=cache,target=/root/.cache/pip \
+ pip install py-cpuinfo # Use this to gather CPU info and optimize based on ARM Neoverse cores
+
+# Set LD_PRELOAD for tcmalloc on ARM
+ENV LD_PRELOAD="/usr/lib/aarch64-linux-gnu/libtcmalloc_minimal.so.4"
+
+RUN echo 'ulimit -c 0' >> ~/.bashrc
+
+WORKDIR /workspace
+
+ARG PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu"
+ENV PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}
+RUN --mount=type=cache,target=/root/.cache/pip \
+ --mount=type=bind,src=requirements-build.txt,target=requirements-build.txt \
+ pip install --upgrade pip && \
+ pip install -r requirements-build.txt
+
+FROM cpu-test-arm AS build
+
+WORKDIR /workspace/vllm
+
+RUN --mount=type=cache,target=/root/.cache/pip \
+ --mount=type=bind,src=requirements-common.txt,target=requirements-common.txt \
+ --mount=type=bind,src=requirements-cpu.txt,target=requirements-cpu.txt \
+ pip install -v -r requirements-cpu.txt
+
+COPY . .
+ARG GIT_REPO_CHECK=0
+RUN --mount=type=bind,source=.git,target=.git \
+ if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi
+
+# Disabling AVX512 specific optimizations for ARM
+ARG VLLM_CPU_DISABLE_AVX512="true"
+ENV VLLM_CPU_DISABLE_AVX512=${VLLM_CPU_DISABLE_AVX512}
+
+RUN --mount=type=cache,target=/root/.cache/pip \
+ --mount=type=cache,target=/root/.cache/ccache \
+ --mount=type=bind,source=.git,target=.git \
+ VLLM_TARGET_DEVICE=cpu python3 setup.py bdist_wheel && \
+ pip install dist/*.whl && \
+ rm -rf dist
+
+WORKDIR /workspace/
+
+RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks
+
+ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
\ No newline at end of file
diff --git a/vllm/Dockerfile.cpu b/vllm/Dockerfile.cpu
index f1a21d6bd..ebe226cf6 100644
--- a/vllm/Dockerfile.cpu
+++ b/vllm/Dockerfile.cpu
@@ -16,13 +16,13 @@ RUN --mount=type=cache,target=/var/cache/apt \
# intel-openmp provides additional performance improvement vs. openmp
# tcmalloc provides better memory allocation efficiency, e.g, holding memory in caches to speed up access of commonly-used objects.
RUN --mount=type=cache,target=/root/.cache/pip \
- pip install intel-openmp
+ pip install intel-openmp==2025.0.1
ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/lib/libiomp5.so"
RUN echo 'ulimit -c 0' >> ~/.bashrc
-RUN pip install intel_extension_for_pytorch==2.4.0
+RUN pip install intel_extension_for_pytorch==2.5.0
WORKDIR /workspace
@@ -62,4 +62,8 @@ WORKDIR /workspace/
RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks
+# install development dependencies (for testing)
+RUN --mount=type=cache,target=/root/.cache/pip \
+ pip install -e tests/vllm_test_utils
+
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
diff --git a/vllm/Dockerfile.hpu b/vllm/Dockerfile.hpu
new file mode 100644
index 000000000..87e0c1a6a
--- /dev/null
+++ b/vllm/Dockerfile.hpu
@@ -0,0 +1,21 @@
+FROM vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest
+
+COPY ./ /workspace/vllm
+
+WORKDIR /workspace/vllm
+
+RUN pip install -v -r requirements-hpu.txt
+
+ENV no_proxy=localhost,127.0.0.1
+ENV PT_HPU_ENABLE_LAZY_COLLECTIVES=true
+
+RUN VLLM_TARGET_DEVICE=hpu python3 setup.py install
+
+# install development dependencies (for testing)
+RUN python3 -m pip install -e tests/vllm_test_utils
+
+WORKDIR /workspace/
+
+RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks
+
+ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
diff --git a/vllm/Dockerfile.neuron b/vllm/Dockerfile.neuron
index 3d9d8e7da..76dbd4c04 100644
--- a/vllm/Dockerfile.neuron
+++ b/vllm/Dockerfile.neuron
@@ -31,11 +31,14 @@ RUN --mount=type=bind,source=.git,target=.git \
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi
RUN python3 -m pip install -U \
- cmake>=3.26 ninja packaging setuptools-scm>=8 wheel jinja2 \
+ 'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
-r requirements-neuron.txt
ENV VLLM_TARGET_DEVICE neuron
RUN --mount=type=bind,source=.git,target=.git \
- pip install --no-build-isolation -v -e . \
+ pip install --no-build-isolation -v -e .
+
+# install development dependencies (for testing)
+RUN python3 -m pip install -e tests/vllm_test_utils
CMD ["/bin/bash"]
diff --git a/vllm/Dockerfile.openvino b/vllm/Dockerfile.openvino
index a05ff452c..8bd188ffd 100644
--- a/vllm/Dockerfile.openvino
+++ b/vllm/Dockerfile.openvino
@@ -22,4 +22,7 @@ RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" VLLM_TARGET_DEVIC
COPY examples/ /workspace/examples
COPY benchmarks/ /workspace/benchmarks
+# install development dependencies (for testing)
+RUN python3 -m pip install -e tests/vllm_test_utils
+
CMD ["/bin/bash"]
diff --git a/vllm/Dockerfile.ppc64le b/vllm/Dockerfile.ppc64le
index cd5fcf481..971248577 100644
--- a/vllm/Dockerfile.ppc64le
+++ b/vllm/Dockerfile.ppc64le
@@ -21,7 +21,7 @@ RUN --mount=type=bind,source=.git,target=.git \
# These packages will be in rocketce eventually
RUN --mount=type=cache,target=/root/.cache/pip \
pip install -v --prefer-binary --extra-index-url https://repo.fury.io/mgiessing \
- cmake>=3.26 ninja packaging setuptools-scm>=8 wheel jinja2 \
+ 'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
torch==2.3.1 \
-r requirements-cpu.txt \
xformers uvloop==0.20.0
@@ -29,6 +29,9 @@ RUN --mount=type=cache,target=/root/.cache/pip \
RUN --mount=type=bind,source=.git,target=.git \
VLLM_TARGET_DEVICE=cpu python3 setup.py install
+# install development dependencies (for testing)
+RUN python3 -m pip install -e tests/vllm_test_utils
+
WORKDIR /workspace/
RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks
diff --git a/vllm/Dockerfile.rocm b/vllm/Dockerfile.rocm
index d35889f05..e733994f8 100644
--- a/vllm/Dockerfile.rocm
+++ b/vllm/Dockerfile.rocm
@@ -51,9 +51,9 @@ RUN --mount=type=cache,target=/root/.cache/pip \
*"rocm-6.2"*) \
python3 -m pip uninstall -y torch torchvision \
&& python3 -m pip install --pre \
- torch==2.6.0.dev20240918 \
- setuptools-scm>=8 \
- torchvision==0.20.0.dev20240918 \
+ torch==2.6.0.dev20241113+rocm6.2 \
+ 'setuptools-scm>=8' \
+ torchvision==0.20.0.dev20241113+rocm6.2 \
--extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2;; \
*) ;; esac
@@ -121,6 +121,8 @@ ARG GIT_REPO_CHECK=0
RUN --mount=type=bind,source=.git,target=.git \
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi
+RUN python3 -m pip install --upgrade pip
+
# Package upgrades for useful functionality or to avoid dependency issues
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install --upgrade numba scipy huggingface-hub[cli] pytest-shard
@@ -166,4 +168,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \
if ls libs/*.whl; then \
python3 -m pip install libs/*.whl; fi
+# install development dependencies (for testing)
+RUN python3 -m pip install -e tests/vllm_test_utils
+
CMD ["/bin/bash"]
diff --git a/vllm/Dockerfile.tpu b/vllm/Dockerfile.tpu
index bdfab3f61..b617932a8 100644
--- a/vllm/Dockerfile.tpu
+++ b/vllm/Dockerfile.tpu
@@ -1,4 +1,4 @@
-ARG NIGHTLY_DATE="20240828"
+ARG NIGHTLY_DATE="20241017"
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
FROM $BASE_IMAGE
@@ -9,12 +9,6 @@ RUN apt-get update && apt-get install -y \
git \
ffmpeg libsm6 libxext6 libgl1
-# Install the TPU and Pallas dependencies.
-RUN --mount=type=cache,target=/root/.cache/pip \
- python3 -m pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
-RUN --mount=type=cache,target=/root/.cache/pip \
- python3 -m pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
-
# Build vLLM.
COPY . .
ARG GIT_REPO_CHECK=0
@@ -25,8 +19,10 @@ ENV VLLM_TARGET_DEVICE="tpu"
RUN --mount=type=cache,target=/root/.cache/pip \
--mount=type=bind,source=.git,target=.git \
python3 -m pip install \
- cmake>=3.26 ninja packaging setuptools-scm>=8 wheel jinja2 \
-r requirements-tpu.txt
RUN python3 setup.py develop
+# install development dependencies (for testing)
+RUN python3 -m pip install -e tests/vllm_test_utils
+
CMD ["/bin/bash"]
diff --git a/vllm/Dockerfile.xpu b/vllm/Dockerfile.xpu
index 0ecb46df6..a374f20d7 100644
--- a/vllm/Dockerfile.xpu
+++ b/vllm/Dockerfile.xpu
@@ -30,9 +30,19 @@ COPY requirements-common.txt /workspace/vllm/requirements-common.txt
RUN --mount=type=cache,target=/root/.cache/pip \
pip install --no-cache-dir \
- --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ \
-r requirements-xpu.txt
+RUN git clone https://github.com/intel/pti-gpu && \
+ cd pti-gpu/sdk && \
+ git checkout 6c491f07a777ed872c2654ca9942f1d0dde0a082 && \
+ mkdir build && \
+ cd build && \
+ cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=../cmake/toolchains/icpx_toolchain.cmake -DBUILD_TESTING=OFF .. && \
+ make -j && \
+ cmake --install . --config Release --prefix "/usr/local"
+
+ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/"
+
COPY . .
ARG GIT_REPO_CHECK
RUN --mount=type=bind,source=.git,target=.git \
@@ -54,5 +64,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \
ENV VLLM_USAGE_SOURCE production-docker-image \
TRITON_XPU_PROFILE 1
-
+# install development dependencies (for testing)
+RUN python3 -m pip install -e tests/vllm_test_utils
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
diff --git a/vllm/README.md b/vllm/README.md
index 0836d8723..cfeb24cbb 100644
--- a/vllm/README.md
+++ b/vllm/README.md
@@ -13,10 +13,12 @@ Easy, fast, and cheap LLM serving for everyone
| Documentation | Blog | Paper | Discord | Twitter/X | Developer Slack |
+---
*Latest News* 🔥
-- [2024/10] We have just created a developer slack ([slack.vllm.ai](https://slack.vllm.ai)) focusing on coordinating contributions and discussing features. Please feel free to join us there!
-- [2024/10] Ray Summit 2024 held a special track for vLLM! Please find the opening talk slides from the vLLM team [here](https://docs.google.com/presentation/d/1B_KQxpHBTRa_mDF-tR6i8rWdOU5QoTZNcEg2MKZxEHM/edit?usp=sharing). Learn more from the [talks](https://raysummit.anyscale.com/flow/anyscale/raysummit2024/landing/page/sessioncatalog?tab.day=20241001&search.sessiontracks=1719251906298001uzJ2) from other vLLM contributors and users!
+- [2024/11] We hosted [the seventh vLLM meetup](https://lu.ma/h0qvrajz) with Snowflake! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing), and Snowflake team [here](https://docs.google.com/presentation/d/1qF3RkDAbOULwz9WK5TOltt2fE9t6uIc_hVNLFAaQX6A/edit?usp=sharing).
+- [2024/10] We have just created a developer slack ([slack.vllm.ai](https://slack.vllm.ai)) focusing on coordinating contributions and discussing features. Please feel free to join us there!
+- [2024/10] Ray Summit 2024 held a special track for vLLM! Please find the opening talk slides from the vLLM team [here](https://docs.google.com/presentation/d/1B_KQxpHBTRa_mDF-tR6i8rWdOU5QoTZNcEg2MKZxEHM/edit?usp=sharing). Learn more from the [talks](https://www.youtube.com/playlist?list=PLzTswPQNepXl6AQwifuwUImLPFRVpksjR) from other vLLM contributors and users!
- [2024/09] We hosted [the sixth vLLM meetup](https://lu.ma/87q3nvnh) with NVIDIA! Please find the meetup slides [here](https://docs.google.com/presentation/d/1wrLGwytQfaOTd5wCGSPNhoaW3nq0E-9wqyP7ny93xRs/edit?usp=sharing).
- [2024/07] We hosted [the fifth vLLM meetup](https://lu.ma/lp0gyjqr) with AWS! Please find the meetup slides [here](https://docs.google.com/presentation/d/1RgUD8aCfcHocghoP3zmXzck9vX3RCI9yfUAB2Bbcl4Y/edit?usp=sharing).
- [2024/07] In partnership with Meta, vLLM officially supports Llama 3.1 with FP8 quantization and pipeline parallelism! Please check out our blog post [here](https://blog.vllm.ai/2024/07/23/llama31.html).
@@ -42,7 +44,7 @@ vLLM is fast with:
- Speculative decoding
- Chunked prefill
-**Performance benchmark**: We include a performance benchmark at the end of [our blog post](https://blog.vllm.ai/2024/09/05/perf-update.html). It compares the performance of vLLM against other LLM serving engines ([TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [SGLang](https://github.com/sgl-project/sglang) and [LMDeploy](https://github.com/InternLM/lmdeploy)). The implementation is under [nightly-benchmarks folder](.buildkite/nightly-benchmarks/) and you can [reproduce](https://github.com/vllm-project/vllm/issues/8176) this benchmark using our one-click runnable script.
+**Performance benchmark**: We include a performance benchmark at the end of [our blog post](https://blog.vllm.ai/2024/09/05/perf-update.html). It compares the performance of vLLM against other LLM serving engines ([TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [SGLang](https://github.com/sgl-project/sglang) and [LMDeploy](https://github.com/InternLM/lmdeploy)). The implementation is under [nightly-benchmarks folder](.buildkite/nightly-benchmarks/) and you can [reproduce](https://github.com/vllm-project/vllm/issues/8176) this benchmark using our one-click runnable script.
vLLM is flexible and easy to use with:
@@ -98,6 +100,7 @@ vLLM is a community project. Our compute resources for development and testing a
- Dropbox
- Google Cloud
- Lambda Lab
+- Nebius
- NVIDIA
- Replicate
- Roblox
diff --git a/vllm/benchmarks/README.md b/vllm/benchmarks/README.md
index 192d6c402..2aa4a2850 100644
--- a/vllm/benchmarks/README.md
+++ b/vllm/benchmarks/README.md
@@ -6,3 +6,14 @@ You can download the dataset by running:
```bash
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
```
+
+## Downloading the ShareGPT4V dataset
+
+The json file refers to several image datasets (coco, llava, etc.). The benchmark scripts
+will ignore a datapoint if the referred image is missing.
+```bash
+wget https://huggingface.co/datasets/Lin-Chen/ShareGPT4V/resolve/main/sharegpt4v_instruct_gpt4-vision_cap100k.json
+mkdir coco -p
+wget http://images.cocodataset.org/zips/train2017.zip -O coco/train2017.zip
+unzip coco/train2017.zip -d coco/
+```
diff --git a/vllm/benchmarks/backend_request_func.py b/vllm/benchmarks/backend_request_func.py
index 4813fde27..b67849038 100644
--- a/vllm/benchmarks/backend_request_func.py
+++ b/vllm/benchmarks/backend_request_func.py
@@ -24,6 +24,7 @@ class RequestFuncInput:
model: str
best_of: int = 1
logprobs: Optional[int] = None
+ extra_body: Optional[dict] = None
multi_modal_content: Optional[dict] = None
ignore_eos: bool = False
@@ -36,6 +37,7 @@ class RequestFuncOutput:
ttft: float = 0.0 # Time to first token
itl: List[float] = field(
default_factory=list) # List of inter-token latencies
+ tpot: float = 0.0 # avg next-token latencies
prompt_len: int = 0
error: str = ""
@@ -54,6 +56,7 @@ async def async_request_tgi(
"do_sample": True,
"temperature": 0.01, # TGI does not accept 0.0 temperature.
"top_p": 0.99, # TGI does not accept 1.0 top_p.
+ "truncate": request_func_input.prompt_len,
# TGI does not accept ignore_eos flag.
}
payload = {
@@ -79,7 +82,7 @@ async def async_request_tgi(
# any data, we should skip it.
if chunk_bytes.startswith(":"):
continue
- chunk = remove_prefix(chunk_bytes, "data:")
+ chunk = chunk_bytes.removeprefix("data:")
data = json.loads(chunk)
timestamp = time.perf_counter()
@@ -144,8 +147,8 @@ async def async_request_trt_llm(
if not chunk_bytes:
continue
- chunk = remove_prefix(chunk_bytes.decode("utf-8"),
- "data:")
+ chunk = chunk_bytes.decode("utf-8").removeprefix(
+ "data:")
data = json.loads(chunk)
output.generated_text += data["text_output"]
@@ -241,6 +244,8 @@ async def async_request_openai_completions(
"stream": True,
"ignore_eos": request_func_input.ignore_eos,
}
+ if request_func_input.extra_body:
+ payload.update(request_func_input.extra_body)
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
}
@@ -256,13 +261,14 @@ async def async_request_openai_completions(
async with session.post(url=api_url, json=payload,
headers=headers) as response:
if response.status == 200:
+ first_chunk_received = False
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
- chunk = remove_prefix(chunk_bytes.decode("utf-8"),
- "data: ")
+ chunk = chunk_bytes.decode("utf-8").removeprefix(
+ "data: ")
if chunk == "[DONE]":
latency = time.perf_counter() - st
else:
@@ -274,7 +280,8 @@ async def async_request_openai_completions(
if data["choices"][0]["text"]:
timestamp = time.perf_counter()
# First token
- if ttft == 0.0:
+ if not first_chunk_received:
+ first_chunk_received = True
ttft = time.perf_counter() - st
output.ttft = ttft
@@ -285,9 +292,14 @@ async def async_request_openai_completions(
most_recent_timestamp = timestamp
generated_text += data["choices"][0]["text"]
-
+ if first_chunk_received:
+ output.success = True
+ else:
+ output.success = False
+ output.error = (
+ "Never received a valid chunk to calculate TTFT."
+ "This response will be marked as failed!")
output.generated_text = generated_text
- output.success = True
output.latency = latency
else:
output.error = response.reason or ""
@@ -324,10 +336,12 @@ async def async_request_openai_chat_completions(
},
],
"temperature": 0.0,
- "max_tokens": request_func_input.output_len,
+ "max_completion_tokens": request_func_input.output_len,
"stream": True,
"ignore_eos": request_func_input.ignore_eos,
}
+ if request_func_input.extra_body:
+ payload.update(request_func_input.extra_body)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
@@ -349,8 +363,8 @@ async def async_request_openai_chat_completions(
if not chunk_bytes:
continue
- chunk = remove_prefix(chunk_bytes.decode("utf-8"),
- "data: ")
+ chunk = chunk_bytes.decode("utf-8").removeprefix(
+ "data: ")
if chunk == "[DONE]":
latency = time.perf_counter() - st
else:
@@ -389,14 +403,6 @@ async def async_request_openai_chat_completions(
return output
-# Since vllm must support Python 3.8, we can't use str.removeprefix(prefix)
-# introduced in Python 3.9
-def remove_prefix(text: str, prefix: str) -> str:
- if text.startswith(prefix):
- return text[len(prefix):]
- return text
-
-
def get_model(pretrained_model_name_or_path: str) -> str:
if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
from modelscope import snapshot_download
diff --git a/vllm/benchmarks/benchmark_guided.py b/vllm/benchmarks/benchmark_guided.py
new file mode 100644
index 000000000..1a0e62598
--- /dev/null
+++ b/vllm/benchmarks/benchmark_guided.py
@@ -0,0 +1,494 @@
+"""Benchmark guided decoding throughput."""
+import argparse
+import dataclasses
+import json
+import os
+import random
+import time
+from typing import List
+
+import datasets
+import pandas as pd
+import uvloop
+from transformers import AutoTokenizer, PreTrainedTokenizerBase
+
+from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
+from vllm.entrypoints.openai.api_server import (
+ build_async_engine_client_from_engine_args)
+from vllm.sampling_params import GuidedDecodingParams
+from vllm.utils import FlexibleArgumentParser, merge_async_iterators
+
+
+@dataclasses.dataclass
+class SampleRequest:
+ """A class representing a single inference request for benchmarking.
+
+ Attributes:
+ prompt: The input text prompt for the model.
+ multi_modal_data: Optional dictionary containing multi-modal data (e.g.
+ images).
+ prompt_len: The length of the prompt in tokens.
+ expected_output_len: The expected length of the output in tokens.
+ """
+ prompt: str
+ prompt_len: int
+ expected_output_len: int
+ schema: dict
+ structure_type: str = 'json'
+ completion: str = None
+
+
+def run_vllm(requests: List[SampleRequest],
+ engine_args: EngineArgs,
+ n: int,
+ guided_decoding_rate: float = 1.0,
+ warmup: bool = False) -> float:
+ from vllm import LLM, SamplingParams
+ llm = LLM(**vars(engine_args))
+
+ # Add the requests to the engine.
+ prompts: List[str] = []
+ sampling_params: List[SamplingParams] = []
+ # create a list containing random selected true or false
+ guided_decoding_req_idx = random.sample(
+ range(len(requests)), int(len(requests) * guided_decoding_rate))
+
+ if warmup:
+ print(">>>>> Running warmup prompt, for the first 5")
+ # We setup the first 5 requests to warmup FSM
+ # if using xgrammar dataset, we will skip warmup
+ warmup_requests = requests[:5]
+ for i, request in enumerate(warmup_requests):
+ prompts.append(request.prompt)
+ sampling_params.append(
+ SamplingParams(
+ n=n,
+ temperature=1.0,
+ top_p=1.0,
+ ignore_eos=True,
+ max_tokens=request.expected_output_len,
+ guided_decoding=GuidedDecodingParams(json=request.schema)
+ if guided_decoding_rate > 0 else None,
+ ))
+ llm.generate(prompts, sampling_params, use_tqdm=False)
+
+ print(">>>>> Benchmark started...")
+ prompts = []
+ sampling_params = []
+ for i, request in enumerate(requests):
+ prompts.append(request.prompt)
+ sampling_params.append(
+ SamplingParams(
+ n=n,
+ temperature=1.0,
+ top_p=1.0,
+ ignore_eos=True,
+ max_tokens=request.expected_output_len,
+ guided_decoding=GuidedDecodingParams(
+ **{request.structure_type: request.schema})
+ if i in guided_decoding_req_idx else None,
+ ))
+
+ start = time.perf_counter()
+ outputs = llm.generate(prompts, sampling_params, use_tqdm=False)
+ ret = []
+ for output, request in zip(outputs, requests):
+ generated_text = output.outputs[0].text
+ ret.append({
+ "generated": generated_text,
+ "expected": request.completion
+ })
+ end = time.perf_counter()
+ return end - start, ret
+
+
+async def run_vllm_async(
+ requests: List[SampleRequest],
+ engine_args: AsyncEngineArgs,
+ n: int,
+ guided_decoding_rate: float = 1.0,
+ warmup: bool = False,
+ disable_frontend_multiprocessing: bool = False) -> float:
+ from vllm import SamplingParams
+
+ async with build_async_engine_client_from_engine_args(
+ engine_args, disable_frontend_multiprocessing) as llm:
+
+ # Add the requests to the engine.
+ prompts: List[str] = []
+ sampling_params: List[SamplingParams] = []
+ guided_decoding_req_idx = random.sample(
+ range(len(requests)), int(len(requests) * guided_decoding_rate))
+
+ if warmup:
+ print(">>>>>> Running warmup prompt, for the first 5")
+ # We setup the first 5 requests to warmup FSM
+ # if using xgrammar dataset, we will skip warmup
+ warmup_requests = requests[:5]
+ for i, request in enumerate(warmup_requests):
+ prompts.append(request.prompt)
+ sampling_params.append(
+ SamplingParams(
+ n=n,
+ temperature=1.0,
+ top_p=1.0,
+ ignore_eos=True,
+ max_tokens=request.expected_output_len,
+ guided_decoding=GuidedDecodingParams(
+ json=request.schema)
+ if guided_decoding_rate > 0 else None,
+ ))
+ generators = []
+ for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
+ generator = llm.generate(prompt, sp, request_id=f"test{i}")
+ generators.append(generator)
+ all_gens = merge_async_iterators(*generators)
+ async for i, res in all_gens:
+ pass
+
+ print(">>>>> Benchmark started...")
+ prompts = []
+ sampling_params = []
+ for i, request in enumerate(requests):
+ prompts.append(request.prompt)
+ sampling_params.append(
+ SamplingParams(
+ n=n,
+ temperature=1.0,
+ top_p=1.0,
+ ignore_eos=True,
+ max_tokens=request.expected_output_len,
+ guided_decoding=GuidedDecodingParams(json=request.schema)
+ if i in guided_decoding_req_idx else None,
+ ))
+
+ generators = []
+ start_time = []
+ latencies = []
+ start = time.perf_counter()
+ for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
+ generator = llm.generate(prompt, sp, request_id=f"test{i}")
+ generators.append(generator)
+ start_time.append(time.perf_counter())
+ latencies.append([])
+ all_gens = merge_async_iterators(*generators)
+ generated_texts = [''] * len(requests)
+ async for i, res in all_gens:
+ generated_texts[i] = res.outputs[0].text
+ lat = time.perf_counter() - start_time[i]
+ latencies[i].append(lat)
+ ret = [{
+ 'generated': gt,
+ 'expected': req.completion
+ } for gt, req in zip(generated_texts, requests)]
+ end = time.perf_counter()
+ first_latency = pd.Series([lat[0] * 1000 for lat in latencies])
+ next_latency = pd.Series([(lat[-1] - lat[0]) / len(lat[1:]) * 1000
+ for lat in latencies])
+ return end - start, ret, (first_latency, next_latency)
+
+
+def sample_requests(tokenizer: PreTrainedTokenizerBase,
+ args: argparse.Namespace) -> List[SampleRequest]:
+ if args.dataset == 'json':
+ if args.json_schema_path is None:
+ dir_path = os.path.dirname(os.path.realpath(__file__))
+ args.json_schema_path = os.path.join(dir_path,
+ "structured_schemas",
+ "structured_schema_1.json")
+ with open(args.json_schema_path) as f:
+ schema = json.load(f)
+ prompt = f"Generate an example of a user profile given the following schema: {json.dumps(schema)}" # noqa: E501
+ input_len = len(tokenizer(prompt).input_ids)
+ print(f"Input length of the prompt: {input_len} tokens")
+ requests = [
+ SampleRequest(prompt=prompt,
+ prompt_len=input_len,
+ expected_output_len=args.output_len,
+ schema=schema,
+ structure_type=args.structure_type)
+ for _ in range(args.num_prompts)
+ ]
+
+ elif args.dataset == "grammar":
+ schema = """
+ ?start: select_statement
+
+ ?select_statement: "SELECT " column_list " FROM " table_name
+
+ ?column_list: column_name ("," column_name)*
+
+ ?table_name: identifier
+
+ ?column_name: identifier
+
+ ?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/
+ """
+ prompt = "Generate an SQL query to show the 'username' \
+ and 'email' from the 'users' table."
+
+ input_len = len(tokenizer(prompt).input_ids)
+ print(f"Input length of the prompt: {input_len} tokens")
+ requests = [
+ SampleRequest(prompt=prompt,
+ prompt_len=input_len,
+ expected_output_len=args.output_len,
+ schema=schema,
+ structure_type=args.structure_type)
+ for _ in range(args.num_prompts)
+ ]
+
+ elif args.dataset == "regex":
+ regex = r"\w+@\w+\.com\n"
+ args.regex = regex
+ prompt = "Generate an email address for Alan Turing, \
+ who works in Enigma. End in .com and new line. \
+ Example result: alan.turing@enigma.com\n"
+
+ input_len = len(tokenizer(prompt).input_ids)
+ print(f"Input length of the prompt: {input_len} tokens")
+ requests = [
+ SampleRequest(prompt=prompt,
+ prompt_len=input_len,
+ expected_output_len=args.output_len,
+ schema=regex,
+ structure_type=args.structure_type)
+ for _ in range(args.num_prompts)
+ ]
+
+ elif args.dataset == "choice":
+ choice = ["Positive", "Negative"]
+ args.choice = choice
+ prompt = "Classify this sentiment: vLLM is wonderful!"
+ input_len = len(tokenizer(prompt).input_ids)
+ print(f"Input length of the prompt: {input_len} tokens")
+ requests = [
+ SampleRequest(prompt=prompt,
+ prompt_len=input_len,
+ expected_output_len=args.output_len,
+ schema=choice,
+ structure_type=args.structure_type)
+ for _ in range(args.num_prompts)
+ ]
+
+ elif args.dataset == "xgrammar_bench":
+ args.warmup = False
+ requests: List[SampleRequest] = []
+ dataset = datasets.load_dataset("NousResearch/json-mode-eval",
+ split="train")
+ print(f"dataset has {len(dataset)} entries")
+ len_dataset = len(dataset)
+ for data_point_idx in range(args.num_prompts):
+ idx = data_point_idx
+ while idx >= len_dataset:
+ idx -= len_dataset
+ schema = dataset["schema"][idx]
+ prompt = tokenizer.apply_chat_template(dataset["prompt"][idx],
+ tokenize=False)
+ input_len = len(tokenizer(prompt).input_ids)
+ completion = dataset["completion"][idx]
+
+ requests.append(
+ SampleRequest(prompt=prompt,
+ prompt_len=input_len,
+ expected_output_len=args.output_len,
+ schema=schema,
+ completion=completion))
+
+ return requests
+
+
+def evaluate(ret, args):
+
+ def _eval_correctness_json(expected, actual):
+ # extract json string from string using regex
+ import re
+ actual = actual.replace('\n', '').replace(' ', '').strip()
+ try:
+ actual = re.search(r'\{.*\}', actual).group()
+ actual = json.loads(actual)
+ except Exception:
+ return False
+
+ return True
+
+ def _eval_correctness_choice(expected, actual):
+ return actual in args.choice
+
+ def _eval_correctness_regex(expected, actual):
+ import re
+ return re.match(args.regex, actual) is not None
+
+ def _eval_correctness(expected, actual):
+ if args.structure_type == 'json':
+ return _eval_correctness_json(expected, actual)
+ elif args.structure_type == 'regex':
+ return _eval_correctness_regex(expected, actual)
+ elif args.structure_type == 'choice':
+ return _eval_correctness_choice(expected, actual)
+ else:
+ return None
+
+ scores = []
+ for res in ret:
+ score = _eval_correctness(res['expected'], res['generated'])
+ res['correctness'] = score
+ scores.append(score)
+
+ not_none_scores = [score for score in scores if score is not None]
+
+ return (sum(not_none_scores) / len(not_none_scores) *
+ 100) if len(not_none_scores) > 0 else None
+
+
+def main(args: argparse.Namespace):
+ print(args)
+ random.seed(args.seed)
+
+ # async engine is working for 'regex', 'choice' and 'grammar'
+ if args.dataset == 'grammar':
+ args.structure_type = 'grammar'
+ args.async_engine = False
+ elif args.dataset == 'regex':
+ args.structure_type = 'regex'
+ args.async_engine = False
+ elif args.dataset == 'choice':
+ args.structure_type = 'choice'
+ args.async_engine = False
+ else:
+ args.structure_type = 'json'
+
+ if args.no_guided_decoding:
+ args.guided_decoding_ratio = 0
+ if args.save_results:
+ result_file_name = f'{args.guided_decoding_ratio}guided'
+ result_file_name += f"_{args.model.split('/')[-1]}"
+ result_file_name += f"_{args.dataset}"
+ result_file_name += f"_{args.num_prompts}"
+ result_file_name += f"_out{args.output_len}"
+ result_file_name += f"_async{args.async_engine}"
+ result_file_name += f"_warmup{args.warmup}"
+ result_file_name += f"_chunkedprefill{args.enable_chunked_prefill}"
+ result_file_name += ".txt"
+ else:
+ result_file_name = None
+
+ # Synthesize a prompt with the given input length.
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.tokenizer, trust_remote_code=args.trust_remote_code)
+ requests = sample_requests(tokenizer, args)
+
+ if args.async_engine:
+ engine_args = AsyncEngineArgs.from_cli_args(args)
+ elapsed_time, ret, (first_latency, next_latency) = uvloop.run(
+ run_vllm_async(requests, engine_args, args.n,
+ args.guided_decoding_ratio, args.warmup,
+ args.disable_frontend_multiprocessing))
+ else:
+ engine_args = EngineArgs.from_cli_args(args)
+ elapsed_time, ret = run_vllm(requests, engine_args, args.n,
+ args.guided_decoding_ratio, args.warmup)
+ first_latency, next_latency = None, None
+
+ score = evaluate(ret, args)
+ total_num_tokens = sum(request.prompt_len + request.expected_output_len
+ for request in requests)
+ total_output_tokens = sum(request.expected_output_len
+ for request in requests)
+ if first_latency is not None:
+ latency_breakdown = "\nFirst token latency(msecs):\n"
+ latency_breakdown += f"{first_latency.describe()}"
+ latency_breakdown += "\nNext token latency(msecs):\n"
+ latency_breakdown += f"{next_latency.describe()}"
+ print(
+ f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
+ f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
+ f"{total_output_tokens / elapsed_time:.2f} output tokens/s",
+ f"Correct rate is {score} %",
+ f"{latency_breakdown if first_latency is not None else ''}")
+
+ # Output JSON results if specified
+ if args.output_json or result_file_name:
+ results = {
+ "elapsed_time": elapsed_time,
+ "num_requests": len(requests),
+ "total_num_tokens": total_num_tokens,
+ "total_output_tokens": total_output_tokens,
+ "requests_per_second": len(requests) / elapsed_time,
+ "tokens_per_second": f"{total_num_tokens / elapsed_time:.2f}",
+ "output_tokens_per_second":
+ f"{total_output_tokens / elapsed_time:.2f}",
+ "correct_rate(%)": score
+ }
+ results = {"outputs": ret, **results}
+ if first_latency is not None:
+ results["first_token_latency(msecs)"] = first_latency.describe(
+ ).to_dict()
+ results["next_token_latency(msecs)"] = next_latency.describe(
+ ).to_dict()
+ if args.output_json:
+ with open(args.output_json, "w") as f:
+ json.dump(results, f, indent=4)
+ elif result_file_name:
+ with open(result_file_name, "w") as f:
+ json.dump(results, f, indent=4)
+
+
+if __name__ == "__main__":
+ parser = FlexibleArgumentParser(description="Benchmark guided decoding.")
+ parser = AsyncEngineArgs.add_cli_args(parser)
+
+ parser.add_argument("--output-len",
+ type=int,
+ default=512,
+ help="Output length for each request. Overrides the "
+ "output length from the dataset.")
+ parser.add_argument(
+ "--dataset",
+ default='json',
+ choices=['json', 'grammar', 'regex', 'choice', 'xgrammar_bench'])
+ parser.add_argument("--json_schema_path",
+ type=str,
+ default=None,
+ help="Path to json schema.")
+ parser.add_argument("--n",
+ type=int,
+ default=1,
+ help="Number of generated sequences per prompt.")
+ parser.add_argument("--num-prompts",
+ type=int,
+ default=10,
+ help="Number of prompts to process.")
+ parser.add_argument(
+ '--output-json',
+ type=str,
+ default=None,
+ help='Path to save the throughput results in JSON format.')
+ parser.add_argument("--async-engine",
+ action='store_true',
+ default=False,
+ help="Use vLLM async engine rather than LLM class.")
+ parser.add_argument("--no-guided-decoding",
+ action='store_true',
+ default=False,
+ help="Whether to disable JSON decoding or not.")
+ parser.add_argument("--guided-decoding-ratio",
+ type=float,
+ default=1.0,
+ help="Ratio of Guided Decoding requests")
+ parser.add_argument("--disable-frontend-multiprocessing",
+ action='store_true',
+ default=False,
+ help="Disable decoupled async engine frontend.")
+ parser.add_argument("--warmup",
+ action="store_true",
+ default=False,
+ help="Run warmup prompts before benchmark.")
+ parser.add_argument("--save-results",
+ action="store_true",
+ default=False,
+ help="save output results.")
+ args = parser.parse_args()
+ if args.tokenizer is None:
+ args.tokenizer = args.model
+ main(args)
diff --git a/vllm/benchmarks/benchmark_prefix_caching.py b/vllm/benchmarks/benchmark_prefix_caching.py
index 1aac02999..5e9381f71 100644
--- a/vllm/benchmarks/benchmark_prefix_caching.py
+++ b/vllm/benchmarks/benchmark_prefix_caching.py
@@ -54,13 +54,30 @@ def test_prefix(llm=None, sampling_params=None, prompts=None):
print(f"cost time {end_time - start_time}")
-def sample_requests(
+@dataclasses.dataclass
+class Request:
+ prompt: str
+ prompt_len: int
+ output_len: int
+
+
+def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> str:
+ vocab = tokenizer.get_vocab()
+ # Remove the special tokens.
+ vocab = {
+ k: v
+ for k, v in vocab.items() if k not in tokenizer.all_special_ids
+ }
+ return random.choices(list(vocab.values()), k=length)
+
+
+def sample_requests_from_dataset(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
input_length_range: Tuple[int, int],
fixed_output_len: Optional[int],
-) -> List[Tuple[str, int, int]]:
+) -> List[Request]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
@@ -77,31 +94,55 @@ def sample_requests(
random.shuffle(dataset)
min_len, max_len = input_length_range
+ assert min_len >= 0 and max_len >= min_len, "input_length_range too small"
# Filter out sequences that are too long or too short
- filtered_dataset: List[Tuple[str, int, int]] = []
+ filtered_requests: List[Request] = []
+
for i in range(len(dataset)):
- if len(filtered_dataset) == num_requests:
+ if len(filtered_requests) == num_requests:
break
# Tokenize the prompts and completions.
- prompt = dataset[i][0]
- prompt_token_ids = tokenizer(prompt).input_ids
+ prompt_token_ids = tokenizer(dataset[i][0]).input_ids
+ prompt = tokenizer.decode(prompt_token_ids)
completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
- output_len = len(completion_token_ids
- ) if fixed_output_len is None else fixed_output_len
- if prompt_len < 4 or output_len < 4:
- # Prune too short sequences.
- continue
+ output_len = (len(completion_token_ids)
+ if fixed_output_len is None else fixed_output_len)
if min_len <= prompt_len <= max_len:
- filtered_dataset.append((prompt, prompt_len, output_len))
+ filtered_requests.append(Request(prompt, prompt_len, output_len))
+
+ return filtered_requests
- return filtered_dataset
+
+def sample_requests_from_random(
+ num_requests: int,
+ tokenizer: PreTrainedTokenizerBase,
+ input_length_range: Tuple[int, int],
+ fixed_output_len: Optional[int],
+ prefix_len: int,
+) -> List[Request]:
+
+ requests = []
+ prefix_token_ids = sample_tokens(tokenizer, prefix_len)
+ min_len, max_len = input_length_range
+
+ for i in range(num_requests):
+ unique_part_token_ids = sample_tokens(
+ tokenizer,
+ random.randint(min_len - prefix_len, max_len - prefix_len))
+ prompt_token_ids = prefix_token_ids + unique_part_token_ids
+ prompt = tokenizer.decode(prompt_token_ids)
+ prompt_len = len(prompt_token_ids)
+ assert (min_len <= prompt_len <= max_len
+ ), f"prompt_len {prompt_len} out of range {min_len}:{max_len}"
+ requests.append(Request(prompt, prompt_len, fixed_output_len))
+ return requests
-def repeat_and_sort_requests(requests: List[Tuple[str, int, int]],
+def repeat_and_sort_requests(requests: List[Request],
repeat_count: int,
sort: bool = False) -> List[str]:
repeated_requests = requests * repeat_count
@@ -109,7 +150,7 @@ def repeat_and_sort_requests(requests: List[Tuple[str, int, int]],
repeated_requests.sort(key=lambda x: x[1])
else:
random.shuffle(repeated_requests)
- return [req[0] for req in repeated_requests]
+ return [req.prompt for req in repeated_requests]
def main(args):
@@ -117,9 +158,12 @@ def main(args):
input_length_range = tuple(map(int, args.input_length_range.split(':')))
random.seed(args.seed)
if args.dataset_path is not None:
- print(f"Start to sample {args.num_prompts} prompts"
- "from {args.dataset_path}")
- filtered_datasets = sample_requests(
+ if args.prefix_len > 0:
+ raise ValueError("prefix-len is not supported when "
+ "dataset-path is provided.")
+ print(f"Start to sample {args.num_prompts} prompts "
+ f"from {args.dataset_path}")
+ filtered_requests = sample_requests_from_dataset(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
tokenizer=tokenizer,
@@ -127,9 +171,22 @@ def main(args):
fixed_output_len=args.output_len,
)
else:
- prompt_len = len(tokenizer(PROMPT).input_ids)
- filtered_datasets = [(PROMPT, prompt_len, args.output_len)
- ] * args.num_prompts
+ print(f"Start to sample {args.num_prompts} prompts from random")
+ filtered_requests = sample_requests_from_random(
+ num_requests=args.num_prompts,
+ tokenizer=tokenizer,
+ input_length_range=input_length_range,
+ fixed_output_len=args.output_len,
+ prefix_len=args.prefix_len,
+ )
+
+ # Print some helpful stats of the requests.
+ print(f"Sampled {len(filtered_requests)} requests.")
+ prompt_lens = [req.prompt_len for req in filtered_requests]
+ print(f"Average input length: {sum(prompt_lens) / len(prompt_lens)}")
+ print(f"P50 input length: {sorted(prompt_lens)[len(prompt_lens) // 2]}")
+ print(f"Min Prompt Length: {min(prompt_lens)}")
+ print(f"Max Prompt Length: {max(prompt_lens)}")
engine_args = EngineArgs.from_cli_args(args)
@@ -137,18 +194,11 @@ def main(args):
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
- print("Testing filtered datasets")
- prompts = repeat_and_sort_requests(filtered_datasets,
+ print("Testing filtered requests")
+ prompts = repeat_and_sort_requests(filtered_requests,
repeat_count=args.repeat_count,
sort=args.sort)
- print("------warm up------")
- test_prefix(
- llm=llm,
- prompts=prompts,
- sampling_params=sampling_params,
- )
-
print("------start generating------")
test_prefix(
llm=llm,
@@ -168,20 +218,29 @@ def main(args):
parser.add_argument('--output-len', type=int, default=10)
parser.add_argument('--num-prompts',
type=int,
- default=1,
+ required=True,
help="Number of the prompts sampled from dataset")
parser.add_argument('--repeat-count',
type=int,
- default=100,
+ default=1,
help='Number of times to repeat each prompt')
parser.add_argument('--sort',
action='store_true',
help='Sort prompts by input length')
parser.add_argument('--input-length-range',
type=str,
- default='128:256',
+ required=True,
help='Range of input lengths for sampling prompts,'
'specified as "min:max" (e.g., "128:256").')
+ parser.add_argument(
+ "--prefix-len",
+ type=int,
+ default=0,
+ help="Specifies the length of a common prefix to be "
+ "added to the input prompt. The input-length-range will "
+ "subtract this length when filtering prompts. Only used "
+ "when dataset-path is not provided.",
+ )
parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args()
diff --git a/vllm/benchmarks/benchmark_serving.py b/vllm/benchmarks/benchmark_serving.py
index 0d205014b..325669214 100644
--- a/vllm/benchmarks/benchmark_serving.py
+++ b/vllm/benchmarks/benchmark_serving.py
@@ -199,6 +199,56 @@ def sample_sonnet_requests(
return sampled_requests
+def sample_mmmu_pro_vision_requests(
+ dataset,
+ num_requests: int,
+ tokenizer: PreTrainedTokenizerBase,
+ fixed_output_len: Optional[int] = None,
+) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]:
+ sampled_requests: List[Tuple[str, int, int, Dict[str,
+ Collection[str]]]] = []
+ for data in dataset:
+ if len(sampled_requests) == num_requests:
+ break
+
+ # MMMU-Pro vision direct prompt
+ # Ref: https://github.com/MMMU-Benchmark/MMMU/blob/6ce42f4d8f70c1841c67867152648974415b5cac/mmmu-pro/prompts.yaml#L5
+ prompt = (
+ "Answer with the option letter from the given choices directly. "
+ "The last line of your response should be of the following "
+ "format: 'Answer: $LETTER' (without quotes) where LETTER is one of "
+ "options.")
+
+ prompt_token_ids = tokenizer(prompt).input_ids
+ if fixed_output_len is None:
+ # Default max output len is set to 128
+ print("--hf-output-len is not provided. Using default value 128.")
+ fixed_output_len = 128
+
+ prompt_len = len(prompt_token_ids)
+ output_len = fixed_output_len
+
+ assert isinstance(
+ data["image"],
+ Image), ("Input image format must be `PIL.Image.Image`, "
+ f"given {type(data['image'])}.")
+ image: Image = data["image"]
+ image = image.convert("RGB")
+ image_data = io.BytesIO()
+ image.save(image_data, format='JPEG')
+ image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8")
+ mm_content = {
+ "type": "image_url",
+ "image_url": {
+ "url": f"data:image/jpeg;base64,{image_base64}"
+ },
+ }
+
+ sampled_requests.append((prompt, prompt_len, output_len, mm_content))
+
+ return sampled_requests
+
+
def sample_hf_requests(
dataset_path: str,
dataset_subset: str,
@@ -208,6 +258,21 @@ def sample_hf_requests(
random_seed: int,
fixed_output_len: Optional[int] = None,
) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]:
+
+ # Special case for MMMU-Pro vision dataset
+ if dataset_path == 'MMMU/MMMU_Pro' and dataset_subset == 'vision':
+ assert dataset_split == "test"
+ dataset = load_dataset(dataset_path,
+ name=dataset_subset,
+ split=dataset_split,
+ streaming=True)
+ assert "image" in dataset.features, (
+ "MMMU/MMMU_Pro vision dataset must have 'image' column.")
+ filter_func = lambda x: isinstance(x["image"], Image)
+ dataset = dataset.shuffle(seed=random_seed).filter(filter_func)
+ return sample_mmmu_pro_vision_requests(dataset, num_requests,
+ tokenizer, fixed_output_len)
+
dataset = load_dataset(dataset_path,
name=dataset_subset,
split=dataset_split,
@@ -251,6 +316,19 @@ def sample_hf_requests(
"url": f"data:image/jpeg;base64,{image_base64}"
},
}
+ elif "image" in data and isinstance(data["image"], str):
+ if (data["image"].startswith("http://") or \
+ data["image"].startswith("file://")):
+ image_url = data["image"]
+ else:
+ image_url = f"file://{data['image']}"
+
+ mm_content = {
+ "type": "image_url",
+ "image_url": {
+ "url": image_url
+ },
+ }
else:
mm_content = None
@@ -297,8 +375,33 @@ def sample_random_requests(
async def get_request(
input_requests: List[Tuple[str, int, int]],
request_rate: float,
+ burstiness: float = 1.0,
) -> AsyncGenerator[Tuple[str, int, int], None]:
+ """
+ Asynchronously generates requests at a specified rate
+ with OPTIONAL burstiness.
+
+ Args:
+ input_requests:
+ A list of input requests, each represented as a tuple.
+ request_rate:
+ The rate at which requests are generated (requests/s).
+ burstiness (optional):
+ The burstiness factor of the request generation.
+ Only takes effect when request_rate is not inf.
+ Default value is 1, which follows a Poisson process.
+ Otherwise, the request intervals follow a gamma distribution.
+ A lower burstiness value (0 < burstiness < 1) results
+ in more bursty requests, while a higher burstiness value
+ (burstiness > 1) results in a more uniform arrival of requests.
+ """
input_requests = iter(input_requests)
+
+ # Calculate scale parameter theta to maintain the desired request_rate.
+ assert burstiness > 0, (
+ f"A positive burstiness factor is expected, but given {burstiness}.")
+ theta = 1.0 / (request_rate * burstiness)
+
for request in input_requests:
yield request
@@ -306,8 +409,9 @@ async def get_request(
# If the request rate is infinity, then we don't need to wait.
continue
- # Sample the request interval from the exponential distribution.
- interval = np.random.exponential(1.0 / request_rate)
+ # Sample the request interval from the gamma distribution.
+ # If burstiness is 1, it follows exponential distribution.
+ interval = np.random.gamma(shape=burstiness, scale=theta)
# The next request will be sent after the interval.
await asyncio.sleep(interval)
@@ -406,9 +510,9 @@ def calculate_metrics(
median_itl_ms=np.median(itls or 0) * 1000,
percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000)
for p in selected_percentiles],
- mean_e2el_ms=np.median(e2els or 0) * 1000,
+ mean_e2el_ms=np.mean(e2els or 0) * 1000,
std_e2el_ms=np.std(e2els or 0) * 1000,
- median_e2el_ms=np.mean(e2els or 0) * 1000,
+ median_e2el_ms=np.median(e2els or 0) * 1000,
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
for p in selected_percentiles],
)
@@ -426,6 +530,7 @@ async def benchmark(
logprobs: Optional[int],
best_of: int,
request_rate: float,
+ burstiness: float,
disable_tqdm: bool,
profile: bool,
selected_percentile_metrics: List[str],
@@ -480,7 +585,13 @@ async def benchmark(
if profile_output.success:
print("Profiler started")
+ if burstiness == 1.0:
+ distribution = "Poisson process"
+ else:
+ distribution = "Gamma distribution"
+
print(f"Traffic request rate: {request_rate}")
+ print(f"Burstiness factor: {burstiness} ({distribution})")
print(f"Maximum request concurrency: {max_concurrency}")
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
@@ -502,7 +613,7 @@ async def limited_request_func(request_func_input, pbar):
benchmark_start_time = time.perf_counter()
tasks: List[asyncio.Task] = []
- async for request in get_request(input_requests, request_rate):
+ async for request in get_request(input_requests, request_rate, burstiness):
prompt, prompt_len, output_len, mm_content = request
request_func_input = RequestFuncInput(model=model_id,
prompt=prompt,
@@ -769,6 +880,7 @@ def main(args: argparse.Namespace):
logprobs=args.logprobs,
best_of=args.best_of,
request_rate=args.request_rate,
+ burstiness=args.burstiness,
disable_tqdm=args.disable_tqdm,
profile=args.profile,
selected_percentile_metrics=args.percentile_metrics.split(","),
@@ -807,6 +919,7 @@ def main(args: argparse.Namespace):
# Traffic
result_json["request_rate"] = (
args.request_rate if args.request_rate < float("inf") else "inf")
+ result_json["burstiness"] = args.burstiness
result_json["max_concurrency"] = args.max_concurrency
# Merge with benchmark result
@@ -922,8 +1035,20 @@ def main(args: argparse.Namespace):
default=float("inf"),
help="Number of requests per second. If this is inf, "
"then all the requests are sent at time 0. "
- "Otherwise, we use Poisson process to synthesize "
- "the request arrival times.",
+ "Otherwise, we use Poisson process or gamma distribution "
+ "to synthesize the request arrival times.",
+ )
+ parser.add_argument(
+ "--burstiness",
+ type=float,
+ default=1.0,
+ help="Burstiness factor of the request generation. "
+ "Only take effect when request_rate is not inf. "
+ "Default value is 1, which follows Poisson process. "
+ "Otherwise, the request intervals follow a gamma distribution. "
+ "A lower burstiness value (0 < burstiness < 1) results in more "
+ "bursty requests. A higher burstiness value (burstiness > 1) "
+ "results in a more uniform arrival of requests.",
)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument(
diff --git a/vllm/benchmarks/benchmark_serving_guided.py b/vllm/benchmarks/benchmark_serving_guided.py
new file mode 100644
index 000000000..4435d87e1
--- /dev/null
+++ b/vllm/benchmarks/benchmark_serving_guided.py
@@ -0,0 +1,881 @@
+r"""Benchmark online serving throughput with guided decoding.
+
+On the server side, run one of the following commands:
+ (vLLM OpenAI API server)
+ vllm serve --disable-log-requests
+
+ (TGI backend)
+ ./launch_tgi_server.sh
+
+On the client side, run:
+ python benchmarks/benchmark_serving.py \
+ --backend \
+ --model \
+ --dataset json \
+ --guided-decoding-ratio 1.0 \
+ --guided-decoding-backend xgrammar \
+ --request-rate 10 \
+ --num-prompts 1000
+
+ when using tgi backend, add
+ --endpoint /generate_stream
+ to the end of the command above.
+"""
+import argparse
+import asyncio
+import dataclasses
+import json
+import os
+import random
+import time
+import warnings
+from dataclasses import dataclass
+from typing import AsyncGenerator, List, Optional, Tuple
+
+import datasets
+import numpy as np
+import pandas as pd
+from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
+ RequestFuncOutput)
+from tqdm.asyncio import tqdm
+from transformers import PreTrainedTokenizerBase
+
+try:
+ from vllm.transformers_utils.tokenizer import get_tokenizer
+except ImportError:
+ from backend_request_func import get_tokenizer
+
+try:
+ from vllm.utils import FlexibleArgumentParser
+except ImportError:
+ from argparse import ArgumentParser as FlexibleArgumentParser
+
+MILLISECONDS_TO_SECONDS_CONVERSION = 1000
+
+
+@dataclass
+class BenchmarkMetrics:
+ completed: int
+ total_input: int
+ total_output: int
+ request_throughput: float
+ request_goodput: float
+ output_throughput: float
+ total_token_throughput: float
+ mean_ttft_ms: float
+ median_ttft_ms: float
+ std_ttft_ms: float
+ percentiles_ttft_ms: List[Tuple[float, float]]
+ mean_tpot_ms: float
+ median_tpot_ms: float
+ std_tpot_ms: float
+ percentiles_tpot_ms: List[Tuple[float, float]]
+ mean_itl_ms: float
+ median_itl_ms: float
+ std_itl_ms: float
+ percentiles_itl_ms: List[Tuple[float, float]]
+ # E2EL stands for end-to-end latency per request.
+ # It is the time taken on the client side from sending
+ # a request to receiving a complete response.
+ mean_e2el_ms: float
+ median_e2el_ms: float
+ std_e2el_ms: float
+ percentiles_e2el_ms: List[Tuple[float, float]]
+
+
+@dataclasses.dataclass
+class SampleRequest:
+ """A class representing a single inference request for benchmarking.
+
+ Attributes:
+ prompt: The input text prompt for the model.
+ multi_modal_data: Optional dictionary containing multi-modal data (e.g.
+ images).
+ prompt_len: The length of the prompt in tokens.
+ expected_output_len: The expected length of the output in tokens.
+ """
+ prompt: str
+ prompt_len: int
+ expected_output_len: int
+ schema: dict
+ structure_type: str
+ completion: str = None
+
+
+def sample_requests(tokenizer: PreTrainedTokenizerBase,
+ args: argparse.Namespace) -> List[SampleRequest]:
+ if args.dataset == 'json':
+ if args.json_schema_path is None:
+ dir_path = os.path.dirname(os.path.realpath(__file__))
+ args.json_schema_path = os.path.join(dir_path,
+ "structured_schemas",
+ "structured_schema_1.json")
+ with open(args.json_schema_path) as f:
+ schema = json.load(f)
+ prompt = f"Generate an example of a user profile given the following schema: {json.dumps(schema)}" # noqa: E501
+ input_len = len(tokenizer(prompt).input_ids)
+ print(f"Input length of the prompt: {input_len} tokens")
+ requests = [
+ SampleRequest(prompt=prompt,
+ prompt_len=input_len,
+ expected_output_len=args.output_len,
+ schema=schema,
+ structure_type=args.structure_type)
+ for _ in range(args.num_prompts)
+ ]
+
+ elif args.dataset == "grammar":
+ schema = """
+ ?start: select_statement
+
+ ?select_statement: "SELECT " column_list " FROM " table_name
+
+ ?column_list: column_name ("," column_name)*
+
+ ?table_name: identifier
+
+ ?column_name: identifier
+
+ ?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/
+ """
+ prompt = "Generate an SQL query to show the 'username' \
+ and 'email' from the 'users' table."
+
+ input_len = len(tokenizer(prompt).input_ids)
+ print(f"Input length of the prompt: {input_len} tokens")
+ requests = [
+ SampleRequest(prompt=prompt,
+ prompt_len=input_len,
+ expected_output_len=args.output_len,
+ schema=schema,
+ structure_type=args.structure_type)
+ for _ in range(args.num_prompts)
+ ]
+
+ elif args.dataset == "regex":
+ regex = r"\w+@\w+\.com\n"
+ args.regex = regex
+ prompt = "Generate an email address for Alan Turing, \
+ who works in Enigma. End in .com and new line. \
+ Example result: alan.turing@enigma.com\n"
+
+ input_len = len(tokenizer(prompt).input_ids)
+ print(f"Input length of the prompt: {input_len} tokens")
+ requests = [
+ SampleRequest(prompt=prompt,
+ prompt_len=input_len,
+ expected_output_len=args.output_len,
+ schema=regex,
+ structure_type=args.structure_type)
+ for _ in range(args.num_prompts)
+ ]
+
+ elif args.dataset == "choice":
+ choice = ["Positive", "Negative"]
+ args.choice = choice
+ prompt = "Classify this sentiment: vLLM is wonderful!"
+ input_len = len(tokenizer(prompt).input_ids)
+ print(f"Input length of the prompt: {input_len} tokens")
+ requests = [
+ SampleRequest(prompt=prompt,
+ prompt_len=input_len,
+ expected_output_len=args.output_len,
+ schema=choice,
+ structure_type=args.structure_type)
+ for _ in range(args.num_prompts)
+ ]
+
+ elif args.dataset == "xgrammar_bench":
+ requests: List[SampleRequest] = []
+ dataset = datasets.load_dataset("NousResearch/json-mode-eval",
+ split="train")
+ print(f"dataset has {len(dataset)} entries")
+ len_dataset = len(dataset)
+ for data_point_idx in range(args.num_prompts):
+ idx = data_point_idx
+ while idx >= len_dataset:
+ idx -= len_dataset
+ schema = dataset["schema"][idx]
+ prompt = tokenizer.apply_chat_template(dataset["prompt"][idx],
+ tokenize=False)
+ input_len = len(tokenizer(prompt).input_ids)
+ completion = dataset["completion"][idx]
+
+ requests.append(
+ SampleRequest(prompt=prompt,
+ prompt_len=input_len,
+ expected_output_len=args.output_len,
+ schema=schema,
+ structure_type=args.structure_type,
+ completion=completion))
+
+ return requests
+
+
+async def get_request(
+ input_requests: List[SampleRequest],
+ request_rate: float,
+ burstiness: float = 1.0,
+) -> AsyncGenerator[Tuple[int, SampleRequest], None]:
+ """
+ Asynchronously generates requests at a specified rate
+ with OPTIONAL burstiness.
+
+ Args:
+ input_requests:
+ A list of input requests, each represented as a tuple.
+ request_rate:
+ The rate at which requests are generated (requests/s).
+ burstiness (optional):
+ The burstiness factor of the request generation.
+ Only takes effect when request_rate is not inf.
+ Default value is 1, which follows a Poisson process.
+ Otherwise, the request intervals follow a gamma distribution.
+ A lower burstiness value (0 < burstiness < 1) results
+ in more bursty requests, while a higher burstiness value
+ (burstiness > 1) results in a more uniform arrival of requests.
+ """
+ input_requests = iter(input_requests)
+
+ # Calculate scale parameter theta to maintain the desired request_rate.
+ assert burstiness > 0, (
+ f"A positive burstiness factor is expected, but given {burstiness}.")
+ theta = 1.0 / (request_rate * burstiness)
+
+ for i, request in enumerate(input_requests):
+ yield i, request
+
+ if request_rate == float("inf"):
+ # If the request rate is infinity, then we don't need to wait.
+ continue
+
+ # Sample the request interval from the gamma distribution.
+ # If burstiness is 1, it follows exponential distribution.
+ interval = np.random.gamma(shape=burstiness, scale=theta)
+ # The next request will be sent after the interval.
+ await asyncio.sleep(interval)
+
+
+def calculate_metrics(
+ input_requests: List[Tuple[str, int, int]],
+ outputs: List[RequestFuncOutput],
+ dur_s: float,
+ tokenizer: PreTrainedTokenizerBase,
+ selected_percentile_metrics: List[str],
+ selected_percentiles: List[float],
+) -> Tuple[BenchmarkMetrics, List[int]]:
+ actual_output_lens: List[int] = []
+ total_input = 0
+ completed = 0
+ good_completed = 0
+ itls: List[float] = []
+ tpots: List[float] = []
+ all_tpots: List[float] = []
+ ttfts: List[float] = []
+ e2els: List[float] = []
+ for i in range(len(outputs)):
+ if outputs[i].success:
+ # We use the tokenizer to count the number of output tokens for all
+ # serving backends instead of looking at len(outputs[i].itl) since
+ # multiple output tokens may be bundled together
+ # Note : this may inflate the output token count slightly
+ output_len = len(
+ tokenizer(outputs[i].generated_text,
+ add_special_tokens=False).input_ids)
+ actual_output_lens.append(output_len)
+ total_input += input_requests[i].prompt_len
+ tpot = 0
+ if output_len > 1:
+ tpot = (outputs[i].latency - outputs[i].ttft) / (output_len -
+ 1)
+ tpots.append(tpot)
+ outputs[i].tpot = sum(tpots) / len(tpots) if len(tpots) else 0
+ # Note: if output_len <= 1, we regard tpot as 0 for goodput
+ all_tpots.append(tpot)
+ itls += outputs[i].itl
+ ttfts.append(outputs[i].ttft)
+ e2els.append(outputs[i].latency)
+ completed += 1
+ else:
+ actual_output_lens.append(0)
+
+ if completed == 0:
+ warnings.warn(
+ "All requests failed. This is likely due to a misconfiguration "
+ "on the benchmark arguments.",
+ stacklevel=2)
+ metrics = BenchmarkMetrics(
+ completed=completed,
+ total_input=total_input,
+ total_output=sum(actual_output_lens),
+ request_throughput=completed / dur_s,
+ request_goodput=good_completed / dur_s,
+ output_throughput=sum(actual_output_lens) / dur_s,
+ total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
+ mean_ttft_ms=np.mean(ttfts or 0) *
+ 1000, # ttfts is empty if streaming is not supported by backend
+ std_ttft_ms=np.std(ttfts or 0) * 1000,
+ median_ttft_ms=np.median(ttfts or 0) * 1000,
+ percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000)
+ for p in selected_percentiles],
+ mean_tpot_ms=np.mean(tpots or 0) * 1000,
+ std_tpot_ms=np.std(tpots or 0) * 1000,
+ median_tpot_ms=np.median(tpots or 0) * 1000,
+ percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000)
+ for p in selected_percentiles],
+ mean_itl_ms=np.mean(itls or 0) * 1000,
+ std_itl_ms=np.std(itls or 0) * 1000,
+ median_itl_ms=np.median(itls or 0) * 1000,
+ percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000)
+ for p in selected_percentiles],
+ mean_e2el_ms=np.mean(e2els or 0) * 1000,
+ std_e2el_ms=np.std(e2els or 0) * 1000,
+ median_e2el_ms=np.median(e2els or 0) * 1000,
+ percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
+ for p in selected_percentiles],
+ )
+
+ return metrics, actual_output_lens
+
+
+async def benchmark(
+ backend: str,
+ api_url: str,
+ base_url: str,
+ model_id: str,
+ tokenizer: PreTrainedTokenizerBase,
+ input_requests: List[SampleRequest],
+ request_rate: float,
+ burstiness: float,
+ disable_tqdm: bool,
+ profile: bool,
+ selected_percentile_metrics: List[str],
+ selected_percentiles: List[str],
+ ignore_eos: bool,
+ max_concurrency: Optional[int],
+ guided_decoding_ratio: float,
+ guided_decoding_backend: str,
+):
+ if backend in ASYNC_REQUEST_FUNCS:
+ request_func = ASYNC_REQUEST_FUNCS[backend]
+ else:
+ raise ValueError(f"Unknown backend: {backend}")
+
+ def prepare_extra_body(request) -> dict:
+ extra_body = {}
+ # Add the schema to the extra_body
+ extra_body[request.structure_type] = request.schema
+ # Add the specific guided_decoding_backend
+ extra_body["guided_decoding_backend"] = guided_decoding_backend
+ return extra_body
+
+ print("Starting initial single prompt test run...")
+ guided_decoding_req_idx = random.sample(
+ range(len(input_requests)),
+ int(len(input_requests) * guided_decoding_ratio))
+
+ test_request = input_requests[0]
+ test_input = RequestFuncInput(
+ model=model_id,
+ prompt=test_request.prompt,
+ api_url=api_url,
+ prompt_len=test_request.prompt_len,
+ output_len=test_request.expected_output_len,
+ ignore_eos=ignore_eos,
+ extra_body=prepare_extra_body(test_request),
+ )
+ test_output = await request_func(request_func_input=test_input)
+ if not test_output.success:
+ raise ValueError(
+ "Initial test run failed - Please make sure benchmark arguments "
+ f"are correctly specified. Error: {test_output.error}")
+ else:
+ print("Initial test run completed. Starting main benchmark run...")
+
+ if profile:
+ print("Starting profiler...")
+ profile_input = RequestFuncInput(
+ model=model_id,
+ prompt=test_request.prompt,
+ api_url=base_url + "/start_profile",
+ prompt_len=test_request.prompt_len,
+ output_len=test_request.expected_output_len,
+ ignore_eos=ignore_eos,
+ extra_body=prepare_extra_body(test_request),
+ )
+ profile_output = await request_func(request_func_input=profile_input)
+ if profile_output.success:
+ print("Profiler started")
+
+ if burstiness == 1.0:
+ distribution = "Poisson process"
+ else:
+ distribution = "Gamma distribution"
+
+ print(f"Traffic request rate: {request_rate}")
+ print(f"Burstiness factor: {burstiness} ({distribution})")
+ print(f"Maximum request concurrency: {max_concurrency}")
+
+ pbar = None if disable_tqdm else tqdm(total=len(input_requests))
+
+ # This can be used once the minimum Python version is 3.10 or higher,
+ # and it will simplify the code in limited_request_func.
+ # semaphore = (asyncio.Semaphore(max_concurrency)
+ # if max_concurrency else contextlib.nullcontext())
+ semaphore = (asyncio.Semaphore(max_concurrency)
+ if max_concurrency else None)
+
+ async def limited_request_func(request_func_input, pbar):
+ if semaphore is None:
+ return await request_func(request_func_input=request_func_input,
+ pbar=pbar)
+ async with semaphore:
+ return await request_func(request_func_input=request_func_input,
+ pbar=pbar)
+
+ benchmark_start_time = time.perf_counter()
+ tasks: List[asyncio.Task] = []
+ expected: List[str] = []
+ async for i, request in get_request(input_requests, request_rate,
+ burstiness):
+ extra_body = prepare_extra_body(
+ request) if i in guided_decoding_req_idx else None
+ request_func_input = RequestFuncInput(
+ model=model_id,
+ prompt=request.prompt,
+ api_url=api_url,
+ prompt_len=request.prompt_len,
+ output_len=request.expected_output_len,
+ ignore_eos=ignore_eos,
+ extra_body=extra_body,
+ )
+ expected.append(request.completion)
+ tasks.append(
+ asyncio.create_task(
+ limited_request_func(request_func_input=request_func_input,
+ pbar=pbar)))
+ outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
+
+ if profile:
+ print("Stopping profiler...")
+ profile_input = RequestFuncInput(
+ model=model_id,
+ prompt=test_request.prompt,
+ api_url=base_url + "/stop_profile",
+ prompt_len=test_request.prompt_len,
+ output_len=test_request.expected_output_len,
+ extra_body={test_request.structure_type: test_request.schema},
+ )
+ profile_output = await request_func(request_func_input=profile_input)
+ if profile_output.success:
+ print("Profiler stopped")
+
+ if pbar is not None:
+ pbar.close()
+
+ benchmark_duration = time.perf_counter() - benchmark_start_time
+
+ metrics, actual_output_lens = calculate_metrics(
+ input_requests=input_requests,
+ outputs=outputs,
+ dur_s=benchmark_duration,
+ tokenizer=tokenizer,
+ selected_percentile_metrics=selected_percentile_metrics,
+ selected_percentiles=selected_percentiles,
+ )
+
+ print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
+ print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
+ print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
+ benchmark_duration))
+ print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
+ print("{:<40} {:<10}".format("Total generated tokens:",
+ metrics.total_output))
+ print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
+ metrics.request_throughput))
+ print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
+ metrics.output_throughput))
+ print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
+ metrics.total_token_throughput))
+
+ result = {
+ "duration":
+ benchmark_duration,
+ "completed":
+ metrics.completed,
+ "total_input_tokens":
+ metrics.total_input,
+ "total_output_tokens":
+ metrics.total_output,
+ "request_throughput":
+ metrics.request_throughput,
+ "output_throughput":
+ metrics.output_throughput,
+ "total_token_throughput":
+ metrics.total_token_throughput,
+ "ttft_description":
+ pd.Series([output.ttft for output in outputs]).describe().to_dict(),
+ "tpot_description":
+ pd.Series([output.tpot for output in outputs]).describe().to_dict(),
+ "input_lens": [output.prompt_len for output in outputs],
+ "output_lens":
+ actual_output_lens,
+ "ttfts": [output.ttft for output in outputs],
+ "itls": [output.itl for output in outputs],
+ "errors": [output.error for output in outputs],
+ }
+
+ ret = [{
+ 'generated': output.generated_text,
+ 'expected': gt
+ } for output, gt in zip(outputs, expected)]
+
+ def process_one_metric(
+ # E.g., "ttft"
+ metric_attribute_name: str,
+ # E.g., "TTFT"
+ metric_name: str,
+ # E.g., "Time to First Token"
+ metric_header: str,
+ ):
+ # This function prints and adds statistics of the specified
+ # metric.
+ if metric_attribute_name not in selected_percentile_metrics:
+ return
+ print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
+ print("{:<40} {:<10.2f}".format(
+ f"Mean {metric_name} (ms):",
+ getattr(metrics, f"mean_{metric_attribute_name}_ms")))
+ print("{:<40} {:<10.2f}".format(
+ f"Median {metric_name} (ms):",
+ getattr(metrics, f"median_{metric_attribute_name}_ms")))
+ result[f"mean_{metric_attribute_name}_ms"] = getattr(
+ metrics, f"mean_{metric_attribute_name}_ms")
+ result[f"median_{metric_attribute_name}_ms"] = getattr(
+ metrics, f"median_{metric_attribute_name}_ms")
+ result[f"std_{metric_attribute_name}_ms"] = getattr(
+ metrics, f"std_{metric_attribute_name}_ms")
+ for p, value in getattr(metrics,
+ f"percentiles_{metric_attribute_name}_ms"):
+ p_word = str(int(p)) if int(p) == p else str(p)
+ print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):",
+ value))
+ result[f"p{p_word}_{metric_attribute_name}_ms"] = value
+
+ process_one_metric("ttft", "TTFT", "Time to First Token")
+ process_one_metric("tpot", "TPOT",
+ "Time per Output Token (excl. 1st token)")
+ process_one_metric("itl", "ITL", "Inter-token Latency")
+ process_one_metric("e2el", "E2EL", "End-to-end Latency")
+
+ print("=" * 50)
+
+ return result, ret
+
+
+def evaluate(ret, args):
+
+ def _eval_correctness_json(expected, actual):
+ # extract json string from string using regex
+ import re
+ actual = actual.replace('\n', '').replace(' ', '').strip()
+ try:
+ actual = re.search(r'\{.*\}', actual).group()
+ actual = json.loads(actual)
+ except Exception:
+ return False
+
+ return True
+
+ def _eval_correctness_choice(expected, actual):
+ return actual in args.choice
+
+ def _eval_correctness_regex(expected, actual):
+ import re
+ return re.match(args.regex, actual) is not None
+
+ def _eval_correctness(expected, actual):
+ if args.structure_type == 'guided_json':
+ return _eval_correctness_json(expected, actual)
+ elif args.structure_type == 'guided_regex':
+ return _eval_correctness_regex(expected, actual)
+ elif args.structure_type == 'guided_choice':
+ return _eval_correctness_choice(expected, actual)
+ else:
+ return None
+
+ scores = []
+ for res in ret:
+ score = _eval_correctness(res['expected'], res['generated'])
+ res['correctness'] = score
+ scores.append(score)
+
+ not_none_scores = [score for score in scores if score is not None]
+
+ return (sum(not_none_scores) / len(not_none_scores) *
+ 100) if len(not_none_scores) > 0 else None
+
+
+def main(args: argparse.Namespace):
+ print(args)
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+
+ backend = args.backend
+ model_id = args.model
+ tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
+
+ if args.base_url is not None:
+ api_url = f"{args.base_url}{args.endpoint}"
+ base_url = f"{args.base_url}"
+ else:
+ api_url = f"http://{args.host}:{args.port}{args.endpoint}"
+ base_url = f"http://{args.host}:{args.port}"
+
+ tokenizer = get_tokenizer(tokenizer_id,
+ trust_remote_code=args.trust_remote_code)
+
+ if args.dataset == 'grammar':
+ args.structure_type = 'guided_grammar'
+ elif args.dataset == 'regex':
+ args.structure_type = 'guided_regex'
+ elif args.dataset == 'choice':
+ args.structure_type = 'guided_choice'
+ else:
+ args.structure_type = 'guided_json'
+
+ if args.no_guided_decoding:
+ args.guided_decoding_ratio = 0
+ if args.save_results:
+ result_file_name = f'{args.guided_decoding_ratio}guided'
+ result_file_name += f"_{backend}"
+ result_file_name += f"_{args.request_rate}qps"
+ result_file_name += f"_{args.model.split('/')[-1]}"
+ result_file_name += f"_{args.dataset}"
+ result_file_name += f"_{args.num_prompts}"
+ result_file_name += f"_out{args.output_len}"
+ result_file_name += ".txt"
+ else:
+ result_file_name = None
+
+ input_requests = sample_requests(tokenizer, args)
+
+ benchmark_result, ret = asyncio.run(
+ benchmark(
+ backend=backend,
+ api_url=api_url,
+ base_url=base_url,
+ model_id=model_id,
+ tokenizer=tokenizer,
+ input_requests=input_requests,
+ request_rate=args.request_rate,
+ burstiness=args.burstiness,
+ disable_tqdm=args.disable_tqdm,
+ profile=args.profile,
+ selected_percentile_metrics=args.percentile_metrics.split(","),
+ selected_percentiles=[
+ float(p) for p in args.metric_percentiles.split(",")
+ ],
+ ignore_eos=args.ignore_eos,
+ max_concurrency=args.max_concurrency,
+ guided_decoding_ratio=args.guided_decoding_ratio,
+ guided_decoding_backend=args.guided_decoding_backend,
+ ))
+
+ # Save config and results to json
+ score = evaluate(ret, args)
+ print("correct_rate(%)", score, '\n')
+ if args.save_results:
+ results = {
+ "backend":
+ backend,
+ "model_id":
+ model_id,
+ "tokenizer_id":
+ tokenizer_id,
+ "num_prompts":
+ args.num_prompts,
+ "request_rate":
+ args.request_rate if args.request_rate < float("inf") else "inf",
+ "burstiness":
+ args.burstiness,
+ "max_concurrency":
+ args.max_concurrency,
+ "correct_rate(%)":
+ score
+ }
+ results = {"outputs": ret, **results, **benchmark_result}
+
+ # Save to file
+ if args.result_filename:
+ result_file_name = args.result_filename
+ if args.result_dir:
+ result_file_name = os.path.join(args.result_dir, result_file_name)
+ with open(result_file_name, "w", encoding='utf-8') as outfile:
+ json.dump(results, outfile, indent=4)
+
+
+if __name__ == "__main__":
+ parser = FlexibleArgumentParser(
+ description="Benchmark the online serving throughput.")
+ parser.add_argument(
+ "--backend",
+ type=str,
+ default="vllm",
+ choices=list(ASYNC_REQUEST_FUNCS.keys()),
+ )
+ parser.add_argument(
+ "--base-url",
+ type=str,
+ default=None,
+ help="Server or API base url if not using http host and port.",
+ )
+ parser.add_argument("--host", type=str, default="localhost")
+ parser.add_argument("--port", type=int, default=8000)
+ parser.add_argument(
+ "--endpoint",
+ type=str,
+ default="/v1/completions",
+ help="API endpoint.",
+ )
+ parser.add_argument(
+ "--dataset",
+ default='json',
+ choices=['json', 'grammar', 'regex', 'choice', 'xgrammar_bench'])
+ parser.add_argument("--json_schema_path",
+ type=str,
+ default=None,
+ help="Path to json schema.")
+ parser.add_argument(
+ "--max-concurrency",
+ type=int,
+ default=None,
+ help="Maximum number of concurrent requests. This can be used "
+ "to help simulate an environment where a higher level component "
+ "is enforcing a maximum number of concurrent requests. While the "
+ "--request-rate argument controls the rate at which requests are "
+ "initiated, this argument will control how many are actually allowed "
+ "to execute at a time. This means that when used in combination, the "
+ "actual request rate may be lower than specified with --request-rate, "
+ "if the server is not processing requests fast enough to keep up.")
+ parser.add_argument(
+ "--model",
+ type=str,
+ required=True,
+ help="Name of the model.",
+ )
+ parser.add_argument(
+ "--tokenizer",
+ type=str,
+ help=
+ "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
+ )
+ parser.add_argument(
+ "--num-prompts",
+ type=int,
+ default=1000,
+ help="Number of prompts to process.",
+ )
+ parser.add_argument(
+ "--output-len",
+ type=int,
+ default=128,
+ help="Number of output tokens.",
+ )
+ parser.add_argument(
+ "--request-rate",
+ type=float,
+ default=float("inf"),
+ help="Number of requests per second. If this is inf, "
+ "then all the requests are sent at time 0. "
+ "Otherwise, we use Poisson process or gamma distribution "
+ "to synthesize the request arrival times.",
+ )
+ parser.add_argument(
+ "--burstiness",
+ type=float,
+ default=1.0,
+ help="Burstiness factor of the request generation. "
+ "Only take effect when request_rate is not inf. "
+ "Default value is 1, which follows Poisson process. "
+ "Otherwise, the request intervals follow a gamma distribution. "
+ "A lower burstiness value (0 < burstiness < 1) results in more "
+ "bursty requests. A higher burstiness value (burstiness > 1) "
+ "results in a more uniform arrival of requests.",
+ )
+ parser.add_argument("--seed", type=int, default=0)
+ parser.add_argument(
+ "--trust-remote-code",
+ action="store_true",
+ help="Trust remote code from huggingface",
+ )
+ parser.add_argument(
+ "--disable-tqdm",
+ action="store_true",
+ help="Specify to disable tqdm progress bar.",
+ )
+ parser.add_argument(
+ "--save-results",
+ action="store_true",
+ help="Specify to save benchmark results to a json file",
+ )
+ parser.add_argument(
+ "--profile",
+ action="store_true",
+ help="Use Torch Profiler. The endpoint must be launched with "
+ "VLLM_TORCH_PROFILER_DIR to enable profiler.",
+ )
+ parser.add_argument(
+ "--result-dir",
+ type=str,
+ default=None,
+ help="Specify directory to save benchmark json results."
+ "If not specified, results are saved in the current directory.",
+ )
+ parser.add_argument(
+ "--result-filename",
+ type=str,
+ default=None,
+ help="Specify the filename to save benchmark json results."
+ "If not specified, results will be saved in "
+ "{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
+ " format.",
+ )
+ parser.add_argument(
+ "--ignore-eos",
+ action="store_true",
+ help="Set ignore_eos flag when sending the benchmark request."
+ "Warning: ignore_eos is not supported in deepspeed_mii and tgi.")
+ parser.add_argument(
+ "--percentile-metrics",
+ type=str,
+ default="ttft,tpot,itl",
+ help="Comma-seperated list of selected metrics to report percentils. "
+ "This argument specifies the metrics to report percentiles. "
+ "Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". "
+ "Default value is \"ttft,tpot,itl\".")
+ parser.add_argument(
+ "--metric-percentiles",
+ type=str,
+ default="99",
+ help="Comma-seperated list of percentiles for selected metrics. "
+ "To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
+ "Default value is \"99\". "
+ "Use \"--percentile-metrics\" to select metrics.",
+ )
+ parser.add_argument("--no-guided-decoding",
+ action='store_true',
+ default=False,
+ help="Whether to disable JSON decoding or not.")
+ parser.add_argument("--guided-decoding-ratio",
+ type=float,
+ default=1.0,
+ help="Ratio of Guided Decoding requests")
+ parser.add_argument("--guided-decoding-backend",
+ type=str,
+ choices=["outlines", "lm-format-enforcer", "xgrammar"],
+ default="xgrammar",
+ help="Backend to use for guided decoding")
+
+ args = parser.parse_args()
+ main(args)
diff --git a/vllm/benchmarks/benchmark_throughput.py b/vllm/benchmarks/benchmark_throughput.py
index ee41c8ea3..1e5967bd9 100644
--- a/vllm/benchmarks/benchmark_throughput.py
+++ b/vllm/benchmarks/benchmark_throughput.py
@@ -4,10 +4,11 @@
import json
import random
import time
-from typing import List, Optional, Tuple
+from typing import List, Optional
import torch
import uvloop
+from PIL import Image
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)
@@ -15,16 +16,56 @@
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)
+from vllm.inputs import TextPrompt
+from vllm.multimodal import MultiModalDataDict
from vllm.sampling_params import BeamSearchParams
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
-def sample_requests(
- dataset_path: str,
- num_requests: int,
- tokenizer: PreTrainedTokenizerBase,
- fixed_output_len: Optional[int],
-) -> List[Tuple[str, int, int]]:
+@dataclasses.dataclass
+class SampleRequest:
+ """A class representing a single inference request for benchmarking.
+
+ Attributes:
+ prompt: The input text prompt for the model.
+ multi_modal_data: Optional dictionary containing multi-modal data (e.g.
+ images).
+ prompt_len: The length of the prompt in tokens.
+ expected_output_len: The expected length of the output in tokens.
+ """
+ prompt: str
+ prompt_len: int
+ expected_output_len: int
+ multi_modal_data: Optional[MultiModalDataDict] = None
+
+
+def _get_prompt_for_image_model(question: str, *, model: str) -> str:
+ """Prepend and append special tokens around the question to form a prompt.
+
+ Args:
+ question: The input question text to wrap with special tokens
+ model: The name of the model being used, to determine which special
+ tokens to add
+
+ Returns:
+ The formatted prompt string with appropriate special tokens for the
+ model
+
+ Raises:
+ ValueError: If an unsupported model name is provided
+ """
+ model = model.lower()
+ if "pixtral" in model:
+ return f"[INST]{question}\n[IMG][/INST]"
+ raise ValueError(f"Unsupported model {model}")
+
+
+def sample_requests(tokenizer: PreTrainedTokenizerBase,
+ args: argparse.Namespace) -> List[SampleRequest]:
+ dataset_path: str = args.dataset
+ num_requests: int = args.num_prompts
+ fixed_output_len: Optional[int] = args.output_len
+ model: str = args.model
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
@@ -33,23 +74,36 @@ def sample_requests(
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
- # Only keep the first two turns of each conversation.
- dataset = [(data["conversations"][0]["value"],
- data["conversations"][1]["value"]) for data in dataset]
-
# Shuffle the dataset.
random.shuffle(dataset)
# Filter out sequences that are too long or too short
- filtered_dataset: List[Tuple[str, int, int]] = []
- for i in range(len(dataset)):
+ filtered_dataset: List[SampleRequest] = []
+ for data in dataset:
if len(filtered_dataset) == num_requests:
break
+ # Only keep the first two turns of each conversation.
+ prompt = data["conversations"][0]["value"]
+ completion = data["conversations"][1]["value"]
+
+ multi_modal_data: Optional[MultiModalDataDict] = None
+ if "image" in data:
+ multi_modal_data = multi_modal_data or {}
+ image_path = data["image"]
+ # TODO(vllm-project/vllm/issues/9778): Support multiple images.
+ assert isinstance(image_path,
+ str), "Only support single image input"
+ try:
+ multi_modal_data["image"] = Image.open(image_path).convert(
+ "RGB")
+ except FileNotFoundError:
+ # Ignore datapoint where asset is missing
+ continue
+ prompt = _get_prompt_for_image_model(question=prompt, model=model)
+
# Tokenize the prompts and completions.
- prompt = dataset[i][0]
prompt_token_ids = tokenizer(prompt).input_ids
- completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
output_len = len(completion_token_ids
@@ -60,13 +114,17 @@ def sample_requests(
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences.
continue
- filtered_dataset.append((prompt, prompt_len, output_len))
+ filtered_dataset.append(
+ SampleRequest(prompt=prompt,
+ prompt_len=prompt_len,
+ expected_output_len=output_len,
+ multi_modal_data=multi_modal_data))
return filtered_dataset
def run_vllm(
- requests: List[Tuple[str, int, int]],
+ requests: List[SampleRequest],
n: int,
engine_args: EngineArgs,
) -> float:
@@ -74,17 +132,19 @@ def run_vllm(
llm = LLM(**dataclasses.asdict(engine_args))
# Add the requests to the engine.
- prompts: List[str] = []
+ prompts: List[TextPrompt] = []
sampling_params: List[SamplingParams] = []
- for prompt, _, output_len in requests:
- prompts.append(prompt)
+ for request in requests:
+ prompts.append(
+ TextPrompt(prompt=request.prompt,
+ multi_modal_data=request.multi_modal_data))
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
- max_tokens=output_len,
+ max_tokens=request.expected_output_len,
))
use_beam_search = False
@@ -94,11 +154,11 @@ def run_vllm(
llm.generate(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter()
else:
- prompts = [prompt for prompt, _, _ in requests]
+ prompts = [request.prompt for request in requests]
# output_len should be the same for all requests.
output_len = requests[0][2]
- for prompt, input_len, _output_len in requests:
- assert _output_len == output_len
+ for request in requests:
+ assert request.expected_output_len == output_len
start = time.perf_counter()
llm.beam_search(
prompts,
@@ -112,7 +172,7 @@ def run_vllm(
async def run_vllm_async(
- requests: List[Tuple[str, int, int]],
+ requests: List[SampleRequest],
n: int,
engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False,
@@ -123,17 +183,19 @@ async def run_vllm_async(
engine_args, disable_frontend_multiprocessing) as llm:
# Add the requests to the engine.
- prompts: List[str] = []
+ prompts: List[TextPrompt] = []
sampling_params: List[SamplingParams] = []
- for prompt, _, output_len in requests:
- prompts.append(prompt)
+ for request in requests:
+ prompts.append(
+ TextPrompt(prompt=request.prompt,
+ multi_modal_data=request.multi_modal_data))
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
- max_tokens=output_len,
+ max_tokens=request.expected_output_len,
))
generators = []
@@ -149,7 +211,7 @@ async def run_vllm_async(
def run_hf(
- requests: List[Tuple[str, int, int]],
+ requests: List[SampleRequest],
model: str,
tokenizer: PreTrainedTokenizerBase,
n: int,
@@ -207,14 +269,14 @@ def run_hf(
def run_mii(
- requests: List[Tuple[str, int, int]],
+ requests: List[SampleRequest],
model: str,
tensor_parallel_size: int,
output_len: int,
) -> float:
from mii import client, serve
llm = serve(model, tensor_parallel=tensor_parallel_size)
- prompts = [prompt for prompt, _, _ in requests]
+ prompts = [request.prompt for request in requests]
start = time.perf_counter()
llm.generate(prompts, max_new_tokens=output_len)
@@ -232,23 +294,41 @@ def main(args: argparse.Namespace):
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code)
if args.dataset is None:
- # Synthesize a prompt with the given input length.
- # As tokenizer may add additional tokens like BOS, we need to try
- # different lengths to get the desired input length.
- for i in range(-10, 10):
- prompt = "hi " * (args.input_len + i)
- tokenized_prompt = tokenizer(prompt).input_ids
- if len(tokenized_prompt) == args.input_len:
- break
- else:
- raise ValueError(
- f"Failed to synthesize a prompt with {args.input_len} tokens.")
- requests = [(prompt, args.input_len, args.output_len)
- for _ in range(args.num_prompts)]
+ vocab_size = tokenizer.vocab_size
+ requests = []
+ for _ in range(args.num_prompts):
+ # Synthesize a prompt with the given input length.
+ candidate_ids = [
+ random.randint(0, vocab_size - 1)
+ for _ in range(args.input_len)
+ ]
+ # As tokenizer may add additional tokens like BOS, we need to try
+ # different lengths to get the desired input length.
+ for _ in range(5): # Max attempts to correct
+ candidate_prompt = tokenizer.decode(candidate_ids)
+ tokenized_len = len(tokenizer.encode(candidate_prompt))
+
+ if tokenized_len == args.input_len:
+ break
+
+ # Adjust length based on difference
+ diff = args.input_len - tokenized_len
+ if diff > 0:
+ candidate_ids.extend([
+ random.randint(100, vocab_size - 100)
+ for _ in range(diff)
+ ])
+ else:
+ candidate_ids = candidate_ids[:diff]
+ requests.append(
+ SampleRequest(prompt=candidate_prompt,
+ prompt_len=args.input_len,
+ expected_output_len=args.output_len))
else:
- requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
- args.output_len)
+ requests = sample_requests(tokenizer, args)
+ is_multi_modal = any(request.multi_modal_data is not None
+ for request in requests)
if args.backend == "vllm":
if args.async_engine:
elapsed_time = uvloop.run(
@@ -270,9 +350,15 @@ def main(args: argparse.Namespace):
args.output_len)
else:
raise ValueError(f"Unknown backend: {args.backend}")
- total_num_tokens = sum(prompt_len + output_len
- for _, prompt_len, output_len in requests)
- total_output_tokens = sum(output_len for _, _, output_len in requests)
+ total_num_tokens = sum(request.prompt_len + request.expected_output_len
+ for request in requests)
+ total_output_tokens = sum(request.expected_output_len
+ for request in requests)
+ if is_multi_modal:
+ print("\033[91mWARNING\033[0m: Multi-modal request detected. The "
+ "following metrics are not accurate because image tokens are not"
+ " counted. See vllm-project/vllm/issues/9778 for details.")
+ # TODO(vllm-project/vllm/issues/9778): Count molti-modal token length.
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
@@ -299,7 +385,9 @@ def main(args: argparse.Namespace):
parser.add_argument("--dataset",
type=str,
default=None,
- help="Path to the dataset.")
+ help="Path to the dataset. The dataset is expected to "
+ "be a json in form of List[Dict[..., conversations: "
+ "List[Dict[..., value: ]]]]")
parser.add_argument("--input-len",
type=int,
default=None,
diff --git a/vllm/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh b/vllm/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh
new file mode 100644
index 000000000..2924ea4a4
--- /dev/null
+++ b/vllm/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh
@@ -0,0 +1,144 @@
+#!/bin/bash
+
+# benchmark the overhead of disaggregated prefill.
+# methodology:
+# - send all request to prefill vLLM instance. It will buffer KV cache.
+# - then send all request to decode instance.
+# - The TTFT of decode instance is the overhead.
+
+set -ex
+
+kill_gpu_processes() {
+ # kill all processes on GPU.
+ pkill -f pt_main_thread
+ sleep 10
+
+ # remove vllm config file
+ rm -rf ~/.config/vllm
+
+ # Print the GPU memory usage
+ # so that we know if all GPU processes are killed.
+ gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0)
+ # The memory usage should be 0 MB.
+ echo "GPU 0 Memory Usage: $gpu_memory_usage MB"
+}
+
+wait_for_server() {
+ # wait for vllm server to start
+ # return 1 if vllm server crashes
+ local port=$1
+ timeout 1200 bash -c "
+ until curl -s localhost:${port}/v1/completions > /dev/null; do
+ sleep 1
+ done" && return 0 || return 1
+}
+
+
+benchmark() {
+
+ export VLLM_LOGGING_LEVEL=DEBUG
+ export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
+
+ # compare chunked prefill with disaggregated prefill
+
+ results_folder="./results"
+ model="meta-llama/Meta-Llama-3.1-8B-Instruct"
+ dataset_name="sonnet"
+ dataset_path="../sonnet_4x.txt"
+ num_prompts=10
+ qps=$1
+ prefix_len=50
+ input_len=2048
+ output_len=$2
+
+
+ CUDA_VISIBLE_DEVICES=0 python3 \
+ -m vllm.entrypoints.openai.api_server \
+ --model meta-llama/Meta-Llama-3.1-8B-Instruct \
+ --port 8100 \
+ --max-model-len 10000 \
+ --gpu-memory-utilization 0.6 \
+ --kv-transfer-config \
+ '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
+
+
+ CUDA_VISIBLE_DEVICES=1 python3 \
+ -m vllm.entrypoints.openai.api_server \
+ --model meta-llama/Meta-Llama-3.1-8B-Instruct \
+ --port 8200 \
+ --max-model-len 10000 \
+ --gpu-memory-utilization 0.6 \
+ --kv-transfer-config \
+ '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
+
+ wait_for_server 8100
+ wait_for_server 8200
+
+ # let the prefill instance finish prefill
+ python3 ../benchmark_serving.py \
+ --backend vllm \
+ --model $model \
+ --dataset-name $dataset_name \
+ --dataset-path $dataset_path \
+ --sonnet-input-len $input_len \
+ --sonnet-output-len "$output_len" \
+ --sonnet-prefix-len $prefix_len \
+ --num-prompts $num_prompts \
+ --port 8100 \
+ --save-result \
+ --result-dir $results_folder \
+ --result-filename disagg_prefill_2xtp4.json \
+ --request-rate "inf"
+
+
+ # send the request to decode.
+ # The TTFT of this command will be the overhead of disagg prefill impl.
+ python3 ../benchmark_serving.py \
+ --backend vllm \
+ --model $model \
+ --dataset-name $dataset_name \
+ --dataset-path $dataset_path \
+ --sonnet-input-len $input_len \
+ --sonnet-output-len "$output_len" \
+ --sonnet-prefix-len $prefix_len \
+ --num-prompts $num_prompts \
+ --port 8200 \
+ --save-result \
+ --result-dir $results_folder \
+ --result-filename disagg_prefill_2xtp4.json \
+ --request-rate "$qps"
+ kill_gpu_processes
+
+}
+
+
+main() {
+
+ (which wget && which curl) || (apt-get update && apt-get install -y wget curl)
+ (which jq) || (apt-get -y install jq)
+ (which socat) || (apt-get -y install socat)
+
+ pip install quart httpx
+
+ cd "$(dirname "$0")"
+
+ cd ..
+ # create sonnet-4x.txt
+ echo "" > sonnet_4x.txt
+ for _ in {1..4}
+ do
+ cat sonnet.txt >> sonnet_4x.txt
+ done
+ cd disagg_benchmarks
+
+ rm -rf results
+ mkdir results
+
+ default_qps=1
+ default_output_len=1
+ benchmark $default_qps $default_output_len
+
+}
+
+
+main "$@"
diff --git a/vllm/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh b/vllm/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh
new file mode 100644
index 000000000..d8d9e976d
--- /dev/null
+++ b/vllm/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh
@@ -0,0 +1,164 @@
+#!/bin/bash
+
+# Requirement: 8x H100 GPUs.
+
+
+# Model: neuralmagic/Meta-Llama-3-70B-Instruct-FP8-KV
+# Query: 2048 input tokens, 11 output tokens, QPS 4, 500 requests
+# Resource: 8x H100
+# Approaches:
+# 1. Chunked prefill: 1 vllm instance with tp=8
+# 2. Chunked prefill: 2 vllm instance with tp=4, equivalent to 1 tp=4 instance with QPS 4
+# 3. Disaggregated prefill: 1 prefilling instance and 1 decoding instance
+# Prefilling instance: max_output_token=1
+# Decoding instance: force the input tokens be the same across requests to bypass prefilling
+
+set -ex
+
+kill_gpu_processes() {
+ # kill all processes on GPU.
+ pgrep pt_main_thread | xargs -r kill -9
+ pgrep python3 | xargs -r kill -9
+ for port in 8000 8100 8200; do lsof -t -i:$port | xargs -r kill -9; done
+ sleep 1
+}
+
+wait_for_server() {
+ # wait for vllm server to start
+ # return 1 if vllm server crashes
+ local port=$1
+ timeout 1200 bash -c "
+ until curl -s localhost:${port}/v1/completions > /dev/null; do
+ sleep 1
+ done" && return 0 || return 1
+}
+
+
+launch_chunked_prefill() {
+ model="meta-llama/Meta-Llama-3.1-8B-Instruct"
+ # disagg prefill
+ CUDA_VISIBLE_DEVICES=0 python3 \
+ -m vllm.entrypoints.openai.api_server \
+ --model $model \
+ --port 8100 \
+ --max-model-len 10000 \
+ --enable-chunked-prefill \
+ --gpu-memory-utilization 0.6 &
+ CUDA_VISIBLE_DEVICES=1 python3 \
+ -m vllm.entrypoints.openai.api_server \
+ --model $model \
+ --port 8200 \
+ --max-model-len 10000 \
+ --enable-chunked-prefill \
+ --gpu-memory-utilization 0.6 &
+ wait_for_server 8100
+ wait_for_server 8200
+ python3 round_robin_proxy.py &
+ sleep 1
+}
+
+
+launch_disagg_prefill() {
+ model="meta-llama/Meta-Llama-3.1-8B-Instruct"
+ # disagg prefill
+ CUDA_VISIBLE_DEVICES=0 python3 \
+ -m vllm.entrypoints.openai.api_server \
+ --model $model \
+ --port 8100 \
+ --max-model-len 10000 \
+ --gpu-memory-utilization 0.6 \
+ --kv-transfer-config \
+ '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
+
+ CUDA_VISIBLE_DEVICES=1 python3 \
+ -m vllm.entrypoints.openai.api_server \
+ --model $model \
+ --port 8200 \
+ --max-model-len 10000 \
+ --gpu-memory-utilization 0.6 \
+ --kv-transfer-config \
+ '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
+
+ wait_for_server 8100
+ wait_for_server 8200
+ python3 disagg_prefill_proxy_server.py &
+ sleep 1
+}
+
+
+benchmark() {
+ results_folder="./results"
+ model="meta-llama/Meta-Llama-3.1-8B-Instruct"
+ dataset_name="sonnet"
+ dataset_path="../sonnet_4x.txt"
+ num_prompts=100
+ qps=$1
+ prefix_len=50
+ input_len=1024
+ output_len=$2
+ tag=$3
+
+ python3 ../benchmark_serving.py \
+ --backend vllm \
+ --model $model \
+ --dataset-name $dataset_name \
+ --dataset-path $dataset_path \
+ --sonnet-input-len $input_len \
+ --sonnet-output-len "$output_len" \
+ --sonnet-prefix-len $prefix_len \
+ --num-prompts $num_prompts \
+ --port 8000 \
+ --save-result \
+ --result-dir $results_folder \
+ --result-filename "$tag"-qps-"$qps".json \
+ --request-rate "$qps"
+
+ sleep 2
+
+}
+
+
+main() {
+
+ (which wget && which curl) || (apt-get update && apt-get install -y wget curl)
+ (which jq) || (apt-get -y install jq)
+ (which socat) || (apt-get -y install socat)
+
+ pip install quart httpx matplotlib aiohttp
+
+ cd "$(dirname "$0")"
+
+ cd ..
+ # create sonnet-4x.txt so that we can sample 2048 tokens for input
+ echo "" > sonnet_4x.txt
+ for _ in {1..4}
+ do
+ cat sonnet.txt >> sonnet_4x.txt
+ done
+ cd disagg_benchmarks
+
+ rm -rf results
+ mkdir results
+
+ default_output_len=6
+
+ export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
+
+ launch_chunked_prefill
+ for qps in 2 4 6 8; do
+ benchmark $qps $default_output_len chunked_prefill
+ done
+ kill_gpu_processes
+
+ launch_disagg_prefill
+ for qps in 2 4 6 8; do
+ benchmark $qps $default_output_len disagg_prefill
+ done
+ kill_gpu_processes
+
+ python3 visualize_benchmark_results.py
+
+}
+
+
+main "$@"
diff --git a/vllm/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py b/vllm/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py
new file mode 100644
index 000000000..4058b1c0a
--- /dev/null
+++ b/vllm/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py
@@ -0,0 +1,61 @@
+import os
+
+import aiohttp
+from quart import Quart, make_response, request
+
+AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
+
+app = Quart(__name__)
+
+
+async def forward_request(url, data):
+ async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
+ headers = {
+ "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
+ }
+ async with session.post(url=url, json=data,
+ headers=headers) as response:
+ if response.status == 200:
+ # if response.headers.get('Transfer-Encoding') == 'chunked':
+ if True:
+ async for chunk_bytes in response.content.iter_chunked(
+ 1024):
+ yield chunk_bytes
+ else:
+ content = await response.read()
+ yield content
+
+
+@app.route('/v1/completions', methods=['POST'])
+async def handle_request():
+ try:
+ original_request_data = await request.get_json()
+
+ prefill_request = original_request_data.copy()
+ # change max_tokens = 1 to let it only do prefill
+ prefill_request['max_tokens'] = 1
+
+ # finish prefill
+ async for _ in forward_request('http://localhost:8100/v1/completions',
+ prefill_request):
+ continue
+
+ # return decode
+ generator = forward_request('http://localhost:8200/v1/completions',
+ original_request_data)
+ response = await make_response(generator)
+ response.timeout = None
+
+ return response
+
+ except Exception as e:
+ import sys
+ import traceback
+ exc_info = sys.exc_info()
+ print("Error occurred in disagg prefill proxy server")
+ print(e)
+ print("".join(traceback.format_exception(*exc_info)))
+
+
+if __name__ == '__main__':
+ app.run(port=8000)
diff --git a/vllm/benchmarks/disagg_benchmarks/round_robin_proxy.py b/vllm/benchmarks/disagg_benchmarks/round_robin_proxy.py
new file mode 100644
index 000000000..6eb5f6398
--- /dev/null
+++ b/vllm/benchmarks/disagg_benchmarks/round_robin_proxy.py
@@ -0,0 +1,60 @@
+import asyncio
+import itertools
+
+import aiohttp
+from aiohttp import web
+
+
+class RoundRobinProxy:
+
+ def __init__(self, target_ports):
+ self.target_ports = target_ports
+ self.port_cycle = itertools.cycle(self.target_ports)
+
+ async def handle_request(self, request):
+ target_port = next(self.port_cycle)
+ target_url = f"http://localhost:{target_port}{request.path_qs}"
+
+ async with aiohttp.ClientSession() as session:
+ try:
+ # Forward the request
+ async with session.request(
+ method=request.method,
+ url=target_url,
+ headers=request.headers,
+ data=request.content,
+ ) as response:
+ # Start sending the response
+ resp = web.StreamResponse(status=response.status,
+ headers=response.headers)
+ await resp.prepare(request)
+
+ # Stream the response content
+ async for chunk in response.content.iter_any():
+ await resp.write(chunk)
+
+ await resp.write_eof()
+ return resp
+
+ except Exception as e:
+ return web.Response(text=f"Error: {str(e)}", status=500)
+
+
+async def main():
+ proxy = RoundRobinProxy([8100, 8200])
+ app = web.Application()
+ app.router.add_route('*', '/{path:.*}', proxy.handle_request)
+
+ runner = web.AppRunner(app)
+ await runner.setup()
+ site = web.TCPSite(runner, 'localhost', 8000)
+ await site.start()
+
+ print("Proxy server started on http://localhost:8000")
+
+ # Keep the server running
+ await asyncio.Event().wait()
+
+
+if __name__ == '__main__':
+ asyncio.run(main())
diff --git a/vllm/benchmarks/disagg_benchmarks/visualize_benchmark_results.py b/vllm/benchmarks/disagg_benchmarks/visualize_benchmark_results.py
new file mode 100644
index 000000000..e59d8bb0e
--- /dev/null
+++ b/vllm/benchmarks/disagg_benchmarks/visualize_benchmark_results.py
@@ -0,0 +1,46 @@
+import json
+
+import matplotlib.pyplot as plt
+import pandas as pd
+
+if __name__ == "__main__":
+
+ data = []
+ for name in ['disagg_prefill', 'chunked_prefill']:
+ for qps in [2, 4, 6, 8]:
+ with open(f"results/{name}-qps-{qps}.json") as f:
+ x = json.load(f)
+ x['name'] = name
+ x['qps'] = qps
+ data.append(x)
+
+ df = pd.DataFrame.from_dict(data)
+ dis_df = df[df['name'] == 'disagg_prefill']
+ chu_df = df[df['name'] == 'chunked_prefill']
+
+ plt.style.use('bmh')
+ plt.rcParams['font.size'] = 20
+
+ for key in [
+ 'mean_ttft_ms', 'median_ttft_ms', 'p99_ttft_ms', 'mean_itl_ms',
+ 'median_itl_ms', 'p99_itl_ms'
+ ]:
+
+ fig, ax = plt.subplots(figsize=(11, 7))
+ plt.plot(dis_df['qps'],
+ dis_df[key],
+ label='disagg_prefill',
+ marker='o',
+ linewidth=4)
+ plt.plot(chu_df['qps'],
+ chu_df[key],
+ label='chunked_prefill',
+ marker='o',
+ linewidth=4)
+ ax.legend()
+
+ ax.set_xlabel('QPS')
+ ax.set_ylabel(key)
+ ax.set_ylim(bottom=0)
+ fig.savefig(f'results/{key}.png')
+ plt.close(fig)
diff --git a/vllm/benchmarks/kernels/benchmark_layernorm.py b/vllm/benchmarks/kernels/benchmark_layernorm.py
index 92f6053cc..7acea6087 100644
--- a/vllm/benchmarks/kernels/benchmark_layernorm.py
+++ b/vllm/benchmarks/kernels/benchmark_layernorm.py
@@ -3,8 +3,8 @@
import torch
from vllm.model_executor.layers.layernorm import RMSNorm
-from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
- seed_everything)
+from vllm.platforms import current_platform
+from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
@torch.inference_mode()
@@ -16,7 +16,7 @@ def main(num_tokens: int,
do_profile: bool = False,
num_warmup_iters: int = 5,
num_iters: int = 100) -> None:
- seed_everything(seed)
+ current_platform.seed_everything(seed)
torch.set_default_device("cuda")
layer = RMSNorm(hidden_size).to(dtype=dtype)
diff --git a/vllm/benchmarks/kernels/benchmark_machete.py b/vllm/benchmarks/kernels/benchmark_machete.py
index b70c4b94c..46bab74ae 100644
--- a/vllm/benchmarks/kernels/benchmark_machete.py
+++ b/vllm/benchmarks/kernels/benchmark_machete.py
@@ -2,8 +2,10 @@
import copy
import itertools
import math
+import os
import pickle as pkl
import time
+from dataclasses import dataclass
from itertools import product
from typing import Callable, Iterable, List, Optional, Tuple
@@ -15,11 +17,12 @@
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
- GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales)
+ GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales,
+ marlin_zero_points)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
MarlinWorkspace)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
- gptq_pack, pack_rows, quantize_weights)
+ pack_rows, quantize_weights)
from vllm.scalar_type import ScalarType, scalar_types
from vllm.utils import FlexibleArgumentParser
@@ -27,149 +30,350 @@
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024]
DEFAULT_TP_SIZES = [1]
+NVTX_PROFILE = os.environ.get("NVTX_PROFILE", False)
+
+if NVTX_PROFILE:
+ import nvtx
+
+
+def terse_type_name(dt):
+ return {
+ torch.bfloat16: "bf16",
+ torch.float16: "fp16",
+ torch.int8: "int8",
+ torch.float8_e4m3fn: "fp8",
+ torch.bfloat16: "bf16",
+ torch.float: "float",
+ torch.int: "int",
+ }[dt]
+
+
+@dataclass
+class BenchmarkTensors:
+ w_ref: torch.Tensor
+ a: torch.Tensor
+
+ w_q: torch.Tensor
+ group_size: Optional[int]
+ wtype: ScalarType
+ w_g_s: torch.Tensor
+ w_g_zp: Optional[torch.Tensor]
+ w_ch_s: Optional[torch.Tensor]
+ w_tok_s: Optional[torch.Tensor]
+
+
+@dataclass
+class TypeConfig:
+ act_type: torch.dtype
+ weight_type: ScalarType
+ output_type: Optional[torch.dtype]
+ group_scale_type: Optional[torch.dtype]
+ group_zero_type: Optional[torch.dtype]
+ channel_scale_type: Optional[torch.dtype]
+ token_scale_type: Optional[torch.dtype]
+
+
+def rand_data(shape, dtype=torch.float16, scale=1):
+ if dtype.is_floating_point:
+ return (scale * torch.rand(shape, device="cuda") - 0.3).to(dtype)
+ else:
+ return torch.randint(-15, 15, shape, dtype=dtype, device="cuda")
+
+
+def quantize_and_pack(atype: torch.dtype,
+ w: torch.Tensor,
+ wtype: ScalarType,
+ stype: Optional[torch.dtype],
+ group_size: Optional[int],
+ zero_points: bool = False):
+ assert wtype.is_integer(), "TODO: support floating point weights"
+
+ w_ref, w_q, w_s, w_zp = quantize_weights(
+ w,
+ wtype,
+ group_size=group_size,
+ zero_points=zero_points,
+ # to match how the kernel applies zps
+ ref_zero_points_after_scales=True)
-def machete_pack_weights(w_q: torch.tensor, wtype: ScalarType) -> torch.tensor:
w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
- w_q = w_q.t().contiguous().t() # make col major
- return ops.machete_prepack_B(w_q, wtype)
+ return w_ref, w_q, w_s, w_zp
-def make_bench_tensors(
- atype: torch.dtype, wtype: ScalarType, group_size: int, m: int, n: int,
- k: int
-) -> Tuple[torch.tensor, List[Tuple[torch.tensor, torch.tensor, torch.tensor,
- torch.tensor]]]:
- assert wtype.is_integer(), "TODO: support floating point weights"
+def create_bench_tensors(shape: Tuple[int, int, int], types: TypeConfig,
+ group_size: Optional[int]) -> List[BenchmarkTensors]:
+ m, n, k = shape
# we want to make sure that weights don't fit into L2 cache between runs so
# we construct enough weights to exceed L2 cache, which is 50mb on a H100
# so we target total weight size > 2*50mb
- num_weights = math.ceil(2 * 50 * 1024**2 * 8 / (k * n * wtype.size_bits))
-
- a = torch.randn((m, k), device="cuda", dtype=atype) * 5
- weights = [
- torch.randn((k, n), device="cuda", dtype=atype)
- for _ in range(num_weights)
- ]
- quanitized_weights = [
- quantize_weights(w, wtype, group_size) for w in weights
- ]
-
- return a, quanitized_weights
+ num_weights = math.ceil(2 * 50 * 1024**2 * 8 /
+ (k * n * types.weight_type.size_bits))
+
+ a = rand_data((m, k), types.act_type, scale=5)
+
+ benchmark_tensors: List[BenchmarkTensors] = []
+ for _ in range(num_weights):
+ w = rand_data((k, n), types.act_type, scale=5)
+
+ if types.group_scale_type is not None:
+ w = w.to(types.group_scale_type)
+ if w.dtype.itemsize == 1:
+ w = w.to(torch.float16)
+
+ w_ref, w_q_packed, w_s, w_zp = quantize_and_pack(
+ a.dtype, w, types.weight_type, types.group_scale_type, group_size,
+ types.group_zero_type is not None)
+
+ if not a.dtype.is_floating_point:
+ aiinfo = torch.iinfo(a.dtype)
+ w_ref = w_ref.round().clamp(aiinfo.min, aiinfo.max)
+
+ w_ref = w_ref.to(torch.float32)
+
+ w_ch_s = None if types.channel_scale_type is None else\
+ rand_data((n,), types.channel_scale_type)
+ w_tok_s = None if types.token_scale_type is None else\
+ rand_data((m,), types.token_scale_type)
+
+ benchmark_tensors.append(
+ BenchmarkTensors(w_ref=w_ref,
+ a=a,
+ w_q=w_q_packed,
+ wtype=types.weight_type,
+ w_g_s=w_s,
+ w_g_zp=w_zp,
+ group_size=group_size,
+ w_ch_s=w_ch_s,
+ w_tok_s=w_tok_s))
+
+ return benchmark_tensors
+
+
+def torch_matmul_f16_create_bench_fn(bt: BenchmarkTensors) -> Callable:
+ a = bt.a
+ w = bt.w_ref.to(bt.a.dtype) # use float reference tensor
+ if a.dtype not in [torch.float16, torch.bfloat16]:
+ a = a.to(torch.float16)
+ w = w.to(torch.float16)
+ return lambda: torch.matmul(a, w)
+
+
+def cutlass_scaled_mm_create_bench_fn(bt: BenchmarkTensors) -> Callable:
+ if bt.w_ch_s is not None and bt.w_tok_s is not None:
+ scale_a = bt.w_tok_s.to(torch.float32)
+ scale_b = bt.w_ch_s.to(torch.float32)
+ else:
+ scale_a = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device)
+ scale_b = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device)
+ w_col_major = bt.w_ref.to(bt.a.dtype).t().contiguous().t()
+ return lambda: ops.cutlass_scaled_mm(
+ bt.a, w_col_major, scale_a, scale_b, out_dtype=torch.float16)
+
+
+def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
+ device = bt.a.device
+
+ workspace = MarlinWorkspace(bt.w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N,
+ GPTQ_MARLIN_MAX_PARALLEL)
+
+ if bt.w_g_zp is None:
+ w_zp = torch.empty(0, dtype=torch.int, device=device)
+ else:
+ w_zp = marlin_zero_points(bt.w_g_zp, bt.w_ref.shape[0],
+ bt.w_ref.shape[1], bt.wtype.size_bits)
+
+ if bt.group_size is None:
+ w_s = torch.tensor([], device="cuda", dtype=torch.half)
+ else:
+ w_s = marlin_permute_scales(bt.w_g_s, bt.w_ref.shape[0],
+ bt.w_ref.shape[1], bt.group_size)
+
+ sort_indices = torch.empty(0, dtype=torch.int, device=device)
+ g_idx = torch.empty(0, dtype=torch.int, device=device)
+ w_q = ops.gptq_marlin_repack(bt.w_q, sort_indices, bt.w_ref.shape[0],
+ bt.w_ref.shape[1], bt.wtype.size_bits)
+
+ if bt.a.dtype.is_floating_point:
+ assert bt.w_ch_s is None
+ assert bt.w_tok_s is None
+ assert bt.group_size is not None
+
+ fn = lambda: ops.gptq_marlin_gemm(a=bt.a,
+ b_q_weight=w_q,
+ b_scales=w_s,
+ b_zeros=w_zp,
+ g_idx=g_idx,
+ perm=sort_indices,
+ workspace=workspace.scratch,
+ b_q_type=bt.wtype,
+ size_m=bt.a.shape[0],
+ size_n=bt.w_ref.shape[1],
+ size_k=bt.w_ref.shape[0],
+ is_k_full=True,
+ is_zp_float=False)
+ else:
+ assert bt.a.dtype == torch.int8
+ assert bt.wtype == scalar_types.uint4b8
+
+ if bt.w_ch_s is not None:
+ s_ch = bt.w_ch_s.to(torch.float32)
+ else:
+ s_ch = torch.ones(bt.w_ref.shape[1],
+ dtype=torch.float32,
+ device=device)
+
+ if bt.w_tok_s is not None:
+ s_tok = bt.w_tok_s.to(torch.float32)
+ else:
+ s_tok = torch.ones(bt.a.shape[0],
+ dtype=torch.float32,
+ device=device)
+
+ fn = lambda: ops.marlin_qqq_gemm(a=bt.a,
+ b_q_weight=w_q,
+ s_group=w_s,
+ s_tok=s_tok,
+ s_ch=s_ch,
+ workspace=workspace.scratch,
+ size_m=bt.a.shape[0],
+ size_n=bt.w_ref.shape[1],
+ size_k=bt.w_ref.shape[0])
+
+ return fn
+
+
+def machete_create_bench_fn(bt: BenchmarkTensors,
+ out_type=torch.dtype,
+ schedule=None) -> Callable:
+ w_q = bt.w_q.t().contiguous().t() # make col major
+ w_q = ops.machete_prepack_B(w_q, bt.a.dtype, bt.wtype,
+ None if bt.w_g_s is None else bt.w_g_s.dtype)
+
+ w_g_zp = bt.w_g_zp
+ if w_g_zp is not None:
+ w_g_zp = -1 * bt.w_g_s * (w_g_zp.to(bt.w_g_s.dtype))
+
+ return lambda: ops.machete_mm(
+ a=bt.a,
+ b_q=bt.w_q,
+ b_type=bt.wtype,
+ b_group_scales=bt.w_g_s,
+ b_group_zeros=w_g_zp,
+ b_group_size=bt.group_size,
+ b_channel_scales=bt.w_ch_s,
+ a_token_scales=bt.w_tok_s,
+ out_type=out_type,
+ schedule=schedule,
+ )
# impl
-
# bench
-def bench_fn(label: str, sub_label: str, description: str,
- fn: Callable) -> TMeasurement:
- min_run_time = 1
- return TBenchmark.Timer(
- stmt="fn()",
+
+def bench_fns(label: str, sub_label: str, description: str,
+ fns: List[Callable]):
+
+ min_run_time = 1 if not NVTX_PROFILE else 0.1
+ res = TBenchmark.Timer(
+ stmt="""
+ for fn in fns:
+ fn()
+ """,
globals={
- "fn": fn
+ "fns": fns
},
label=label,
sub_label=sub_label,
description=description,
).blocked_autorange(min_run_time=min_run_time)
+ if NVTX_PROFILE:
+ with nvtx.annotate("mm-bench"), nvtx.annotate(
+ f"{label}|{sub_label}|{description}"):
+ fns[0]()
-def loop_over_weights(
- a: torch.tensor, weights: List[Tuple[torch.tensor, torch.tensor,
- torch.tensor, torch.tensor]],
- fn: Callable[[torch.tensor, torch.tensor, torch.tensor, torch.tensor],
- None]):
- for w_ref, w_q, w_s, _ in weights:
- fn(a, w_ref, w_q, w_s)
+ return res
_SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None
_SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None
-def bench(atype: torch.dtype,
- wtype: ScalarType,
+def bench(types: TypeConfig,
group_size: int,
m: int,
k: int,
n: int,
label: str,
sub_label: str,
- benchmark_marlinv1: bool = True,
- sweep_schedules: bool = True) -> Iterable[TMeasurement]:
- global _SWEEP_SCHEDULES_RESULTS
-
- a, weights = make_bench_tensors(atype, wtype, group_size, m, n, k)
- sub_label += f", L={len(weights)}"
-
- weights_machete = [(w_ref, machete_pack_weights(w_q, wtype), w_s, w_zp)
- for w_ref, w_q, w_s, w_zp in weights]
+ sweep_schedules: bool = True) -> List[TMeasurement]:
+ benchmark_tensors = create_bench_tensors((m, n, k), types, group_size)
+ sub_label += f", L={len(benchmark_tensors)}"
+
+ name_type_string = f"W{types.weight_type}"+\
+ f"-A{terse_type_name(types.act_type)}"
+ if types.group_scale_type is not None:
+ name_type_string += f"-GS{terse_type_name(types.group_scale_type)}"
+ if types.group_zero_type is not None:
+ name_type_string += f"-GZ{terse_type_name(types.group_zero_type)}"
+ if group_size is not None:
+ name_type_string += f"-G{group_size}"
+ if types.channel_scale_type is not None:
+ name_type_string += f"-CS{terse_type_name(types.channel_scale_type)}"
+ if types.token_scale_type is not None:
+ name_type_string += f"-TS{terse_type_name(types.token_scale_type)}"
timers = []
# pytorch impl
timers.append(
- bench_fn(
- label, sub_label, "torch.matmul", lambda: loop_over_weights(
- a,
- weights,
- lambda a, w_ref, w_q, w_s: torch.matmul(a, w_ref),
- )))
+ bench_fns(
+ label, sub_label, "torch.matmul (fp16)",
+ [torch_matmul_f16_create_bench_fn(bt)
+ for bt in benchmark_tensors]))
- if benchmark_marlinv1:
- w_ref = weights[0][0]
-
- w_zp_empty = torch.empty(0, dtype=torch.int, device=w_ref.device)
- sort_indices = torch.empty(0, dtype=torch.int, device=w_ref.device)
- g_idx = torch.empty(0, dtype=torch.int, device=w_ref.device)
-
- def marlinv1_pack_weights(w_q: torch.tensor) -> torch.tensor:
- w_q_gptq = gptq_pack(w_q, wtype.size_bits, *w_ref.shape)
- return ops.gptq_marlin_repack(w_q_gptq, sort_indices, *w_ref.shape,
- wtype.size_bits)
-
- def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor:
- return marlin_permute_scales(w_s, *w_ref.shape, group_size)
-
- weights_marlinv1 = [(w_ref, marlinv1_pack_weights(w_q),
- marlinv1_permute_scales(w_s), w_zp)
- for w_ref, w_q, w_s, w_zp in weights]
-
- workspace = MarlinWorkspace(w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N,
- GPTQ_MARLIN_MAX_PARALLEL)
-
- # marlinv1
+ if types.act_type == torch.int8 or types.act_type == torch.float8_e4m3fn:
+ timers.append(
+ bench_fns(
+ label, sub_label,
+ f"cutlass_scaled_mm ({terse_type_name(types.act_type)})", [
+ cutlass_scaled_mm_create_bench_fn(bt)
+ for bt in benchmark_tensors
+ ]))
+
+ if types.act_type != torch.float8_e4m3fn:
timers.append(
- bench_fn(
- label, sub_label, "marlin_orig", lambda: loop_over_weights(
- a, weights_marlinv1, lambda a, w_ref, w_q, w_s: ops.
- gptq_marlin_gemm(a,
- w_q,
- w_s,
- w_zp_empty,
- g_idx,
- sort_indices,
- workspace.scratch,
- wtype,
- size_m=a.shape[0],
- size_n=w_ref.shape[1],
- size_k=w_ref.shape[0],
- is_k_full=True))))
+ bench_fns(label, sub_label, f"marlin ({name_type_string})",
+ [marlin_create_bench_fn(bt)
+ for bt in benchmark_tensors]))
# machete
timers.append(
- bench_fn(
- label, sub_label, "machete_heuristic", lambda: loop_over_weights(
- a, weights_machete, lambda a, _, w_q, w_s: ops.machete_gemm(
- a, w_q, wtype, b_scales=w_s, b_group_size=group_size))))
+ bench_fns(label, sub_label, f"machete ({name_type_string})", [
+ machete_create_bench_fn(bt, out_type=types.output_type)
+ for bt in benchmark_tensors
+ ]))
if sweep_schedules:
+ global _SWEEP_SCHEDULES_RESULTS
+
print("Finding best schedule for machete")
best = None
best_schedule = None
- schedules = ops.machete_supported_schedules(wtype)
+ schedules = ops.machete_supported_schedules(
+ a_type=types.act_type,
+ b_type=types.weight_type,
+ group_scales_type=types.group_scale_type,
+ group_zeros_type=types.group_zero_type,
+ token_scales_type=types.token_scale_type,
+ channel_scales_type=types.channel_scale_type,
+ out_type=types.output_type)
+
+ if schedules is None or len(schedules) == 0:
+ raise ValueError("No schedules found to sweep")
+
for schedule in reversed(schedules):
schedule_M = int(schedule.split("_")[0].split("x")[1])
@@ -177,16 +381,11 @@ def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor:
if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4:
continue
- def run(a, _, w_q, w_s, schedule=schedule):
- ops.machete_gemm(a,
- w_q,
- wtype,
- w_s,
- b_group_size=group_size,
- schedule=schedule)
-
- res = bench_fn(label, sub_label, "machete_best",
- lambda: loop_over_weights(a, weights_machete, run))
+ res = bench_fns(label, sub_label, "machete_best", [
+ machete_create_bench_fn(
+ bt, out_type=types.output_type, schedule=schedule)
+ for bt in benchmark_tensors
+ ])
results_row = {
"M": m,
@@ -213,25 +412,33 @@ def run(a, _, w_q, w_s, schedule=schedule):
# runner
-def print_timers(timers: Iterable[TMeasurement]):
+def print_timers(timers: List[TMeasurement]):
compare = TBenchmark.Compare(timers)
compare.print()
-def run(dtype: torch.dtype, sweep_schedules: bool,
- MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
+def run(args, MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
+ types = TypeConfig(
+ act_type=args.act_type,
+ weight_type=scalar_types.uint4b8 if args.group_zero_type is None \
+ else scalar_types.uint4,
+ output_type=args.out_type,
+ group_scale_type=args.group_scale_type,
+ group_zero_type=args.group_zero_type,
+ channel_scale_type=args.channel_scale_type,
+ token_scale_type=args.token_scale_type,
+ )
- results = []
+ results: List[TMeasurement] = []
for m, k, n in MKNs:
- timers = bench(dtype,
- scalar_types.uint4b8,
- 128,
+ timers = bench(types,
+ args.group_size,
m,
k,
n,
- f"{dtype}-gemm",
+ f"{args.act_type}-gemm",
f"MKN=({m}x{k}x{n})",
- sweep_schedules=sweep_schedules)
+ sweep_schedules=args.sweep_schedules)
print_timers(timers)
results.extend(timers)
@@ -240,7 +447,7 @@ def run(dtype: torch.dtype, sweep_schedules: bool,
# output makers
def make_output(
- data: Iterable[TMeasurement],
+ data: List[TMeasurement],
MKNs: Iterable[Tuple[int, int, int]],
base_description: str,
timestamp=None,
@@ -262,17 +469,16 @@ def run_square_bench(args):
dim_sizes = list(
range(args.dim_start, args.dim_end + 1, args.dim_increment))
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
-
data = run(args.dtype, args.sweep_schedules, MKNs)
make_output(data, MKNs, f"square_bench-{args.dtype}")
def run_range_bench(args):
- m_start, k_start, n_start = [int(x) for x in args.dim_start.split(",")]
- m_end, k_end, n_end = [int(x) for x in args.dim_end.split(",")]
+ m_start, k_start, n_start = (int(x) for x in args.dim_start.split(","))
+ m_end, k_end, n_end = (int(x) for x in args.dim_end.split(","))
m_increment, k_increment, n_increment = \
- [int(x) for x in args.dim_increment.split(",")]
+ (int(x) for x in args.dim_increment.split(","))
Ms = list(range(m_start, m_end + 1, m_increment))
Ks = list(range(k_start, k_end + 1, k_increment))
Ns = list(range(n_start, n_end + 1, n_increment))
@@ -306,33 +512,49 @@ def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
for k, n in KNs:
MKNs.append((m, k, n))
- data = run(args.dtype, args.sweep_schedules, MKNs)
+ data = run(args, MKNs)
model_bench_data.append(data)
+ type_string = f"{args.act_type}"
+
# Print all results
for data, model_tp in zip(model_bench_data, models_tps):
model, tp_size = model_tp
- print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
+ print(f"== Results {type_string} {model}-TP{tp_size} ====")
print_timers(data)
- timestamp = int(time.time())
+ timestr = time.strftime("%Y%m%d-%H%M%S")
- all_data = []
+ all_results = []
for d in model_bench_data:
- all_data.extend(d)
+ all_results.extend(d)
+
# pickle all data
- with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
- pkl.dump(all_data, f)
+ with open(f"model_bench-{type_string}-{timestr}.pkl", "wb") as f:
+ args_dict = vars(args)
+ args_dict.pop("func")
+ pkl.dump({
+ "args": args_dict,
+ "results": all_results,
+ }, f)
if __name__ == "__main__":
def to_torch_dtype(dt):
- if dt == "bfloat16":
- return torch.bfloat16
- if dt == "float16":
- return torch.float16
- raise ValueError("unsupported dtype")
+ return {
+ "bfloat16": torch.bfloat16,
+ "float16": torch.float16,
+ "int8": torch.int8,
+ "float8_e4m3fn": torch.float8_e4m3fn,
+ "int": torch.int,
+ "float": torch.float,
+ }[dt]
+
+ class ToTorchDtype(argparse.Action):
+
+ def __call__(self, parser, namespace, values, option_string=None):
+ setattr(namespace, self.dest, to_torch_dtype(values))
parser = FlexibleArgumentParser(
description="""
@@ -352,12 +574,42 @@ def to_torch_dtype(dt):
""", # noqa: E501
formatter_class=argparse.RawTextHelpFormatter,
)
-
parser.add_argument(
- "--dtype",
- type=to_torch_dtype,
+ "--act-type",
+ action=ToTorchDtype,
required=True,
- help="Available options are ['bfloat16', 'float16']",
+ choices=['bfloat16', 'float16', 'int8', 'float8_e4m3fn'],
+ )
+ parser.add_argument(
+ "--group-scale-type",
+ action=ToTorchDtype,
+ choices=['bfloat16', 'float16'],
+ )
+ parser.add_argument(
+ "--group-zero-type",
+ type=to_torch_dtype,
+ choices=['bfloat16', 'float16'],
+ )
+ parser.add_argument(
+ "--channel-scale-type",
+ action=ToTorchDtype,
+ choices=['float'],
+ )
+ parser.add_argument(
+ "--token-scale-type",
+ action=ToTorchDtype,
+ choices=['float'],
+ )
+ parser.add_argument(
+ "--out-type",
+ action=ToTorchDtype,
+ choices=['bfloat16', 'float16'],
+ )
+ parser.add_argument(
+ "--group-size",
+ type=int,
+ help="Available options are ['None', '-1', '128'], default=128",
+ default=128,
)
parser.add_argument(
"--sweep-schedules",
diff --git a/vllm/benchmarks/kernels/benchmark_marlin.py b/vllm/benchmarks/kernels/benchmark_marlin.py
index 536c133bb..8fb44e3a3 100644
--- a/vllm/benchmarks/kernels/benchmark_marlin.py
+++ b/vllm/benchmarks/kernels/benchmark_marlin.py
@@ -131,7 +131,7 @@ def bench_run(results: List[benchmark.Measurement], model: str,
results.append(
benchmark.Timer(
stmt=
- "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False)", # noqa: E501
+ "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
@@ -141,7 +141,7 @@ def bench_run(results: List[benchmark.Measurement], model: str,
results.append(
benchmark.Timer(
stmt=
- "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True)", # noqa: E501
+ "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
diff --git a/vllm/benchmarks/kernels/benchmark_moe.py b/vllm/benchmarks/kernels/benchmark_moe.py
index 4f88e8e6e..8f538c21f 100644
--- a/vllm/benchmarks/kernels/benchmark_moe.py
+++ b/vllm/benchmarks/kernels/benchmark_moe.py
@@ -10,7 +10,8 @@
from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.fused_moe import *
-from vllm.utils import FlexibleArgumentParser, seed_everything
+from vllm.platforms import current_platform
+from vllm.utils import FlexibleArgumentParser
class BenchmarkConfig(TypedDict):
@@ -167,7 +168,7 @@ class BenchmarkWorker:
def __init__(self, seed: int) -> None:
torch.set_default_device("cuda")
- seed_everything(seed)
+ current_platform.seed_everything(seed)
self.seed = seed
def benchmark(
@@ -181,7 +182,7 @@ def benchmark(
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
) -> Tuple[Dict[str, int], float]:
- seed_everything(self.seed)
+ current_platform.seed_everything(self.seed)
dtype_str = get_config_dtype_str(dtype,
use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8)
diff --git a/vllm/benchmarks/kernels/benchmark_paged_attention.py b/vllm/benchmarks/kernels/benchmark_paged_attention.py
index 87864d038..14eef00b8 100644
--- a/vllm/benchmarks/kernels/benchmark_paged_attention.py
+++ b/vllm/benchmarks/kernels/benchmark_paged_attention.py
@@ -5,8 +5,9 @@
import torch
from vllm import _custom_ops as ops
+from vllm.platforms import current_platform
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
- create_kv_caches_with_random, seed_everything)
+ create_kv_caches_with_random)
NUM_BLOCKS = 1024
PARTITION_SIZE = 512
@@ -28,7 +29,7 @@ def main(
device: str = "cuda",
kv_cache_dtype: Optional[str] = None,
) -> None:
- seed_everything(seed)
+ current_platform.seed_everything(seed)
scale = float(1.0 / (head_size**0.5))
query = torch.empty(num_seqs,
diff --git a/vllm/benchmarks/kernels/benchmark_quant.py b/vllm/benchmarks/kernels/benchmark_quant.py
index 743a5744e..1d6248344 100644
--- a/vllm/benchmarks/kernels/benchmark_quant.py
+++ b/vllm/benchmarks/kernels/benchmark_quant.py
@@ -3,8 +3,8 @@
import torch
from vllm import _custom_ops as ops
-from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
- seed_everything)
+from vllm.platforms import current_platform
+from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
@torch.inference_mode()
@@ -17,7 +17,7 @@ def main(num_tokens: int,
do_profile: bool = False,
num_warmup_iters: int = 5,
num_iters: int = 100) -> None:
- seed_everything(seed)
+ current_platform.seed_everything(seed)
torch.set_default_device("cuda")
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
diff --git a/vllm/benchmarks/kernels/benchmark_rope.py b/vllm/benchmarks/kernels/benchmark_rope.py
index 784b1cf98..250d50516 100644
--- a/vllm/benchmarks/kernels/benchmark_rope.py
+++ b/vllm/benchmarks/kernels/benchmark_rope.py
@@ -6,7 +6,8 @@
from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding,
get_rope)
-from vllm.utils import FlexibleArgumentParser, seed_everything
+from vllm.platforms import current_platform
+from vllm.utils import FlexibleArgumentParser
def benchmark_rope_kernels_multi_lora(
@@ -22,7 +23,7 @@ def benchmark_rope_kernels_multi_lora(
max_position: int = 8192,
base: int = 10000,
) -> None:
- seed_everything(seed)
+ current_platform.seed_everything(seed)
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
diff --git a/vllm/benchmarks/kernels/graph_machete_bench.py b/vllm/benchmarks/kernels/graph_machete_bench.py
index de608fd05..7d0bd8415 100644
--- a/vllm/benchmarks/kernels/graph_machete_bench.py
+++ b/vllm/benchmarks/kernels/graph_machete_bench.py
@@ -20,10 +20,11 @@
args = parser.parse_args()
with open(args.filename, 'rb') as f:
- data: List[TMeasurement] = pickle.load(f)
+ data = pickle.load(f)
+ raw_results: List[TMeasurement] = data["results"]
results = defaultdict(lambda: list())
- for v in data:
+ for v in raw_results:
result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label)
if result is not None:
KN = result.group(1)
diff --git a/vllm/benchmarks/kernels/weight_shapes.py b/vllm/benchmarks/kernels/weight_shapes.py
index 25ec9d602..51f24f3ba 100644
--- a/vllm/benchmarks/kernels/weight_shapes.py
+++ b/vllm/benchmarks/kernels/weight_shapes.py
@@ -40,4 +40,10 @@
([8192, 57344], 1),
([28672, 8192], 0),
],
+ "meta-llama/Llama-3.1-405b-hf": [
+ ([16384, 18432], 1),
+ ([16384, 16384], 0),
+ ([16384, 106496], 1),
+ ([53248, 16384], 0),
+ ],
}
diff --git a/vllm/benchmarks/launch_tgi_server.sh b/vllm/benchmarks/launch_tgi_server.sh
index 8c5cd454f..ba7383d88 100755
--- a/vllm/benchmarks/launch_tgi_server.sh
+++ b/vllm/benchmarks/launch_tgi_server.sh
@@ -4,13 +4,13 @@ PORT=8000
MODEL=$1
TOKENS=$2
-docker run -e HF_TOKEN=$HF_TOKEN --gpus all --shm-size 1g -p $PORT:80 \
- -v $PWD/data:/data \
+docker run -e "HF_TOKEN=$HF_TOKEN" --gpus all --shm-size 1g -p $PORT:80 \
+ -v "$PWD/data:/data" \
ghcr.io/huggingface/text-generation-inference:2.2.0 \
- --model-id $MODEL \
+ --model-id "$MODEL" \
--sharded false \
--max-input-length 1024 \
--max-total-tokens 2048 \
--max-best-of 5 \
--max-concurrent-requests 5000 \
- --max-batch-total-tokens $TOKENS
+ --max-batch-total-tokens "$TOKENS"
diff --git a/vllm/benchmarks/structured_schemas/structured_schema_1.json b/vllm/benchmarks/structured_schemas/structured_schema_1.json
new file mode 100644
index 000000000..600369846
--- /dev/null
+++ b/vllm/benchmarks/structured_schemas/structured_schema_1.json
@@ -0,0 +1,113 @@
+{
+ "$schema":
+ "https://json-schema.org/draft/2020-12/schema",
+ "title":
+ "User Profile",
+ "type":
+ "object",
+ "properties": {
+ "userId": {
+ "type": "string",
+ "description": "Unique identifier for the user."
+ },
+ "personalInfo": {
+ "type": "object",
+ "properties": {
+ "firstName": {
+ "type": "string",
+ "description": "The user's first name."
+ },
+ "lastName": {
+ "type": "string",
+ "description": "The user's last name."
+ },
+ "age": {
+ "type": "integer",
+ "minimum": 0,
+ "description": "The user's age."
+ },
+ "phoneNumbers": {
+ "type":
+ "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "type": {
+ "type": "string",
+ "enum": ["home", "work", "mobile"],
+ "description": "Type of phone number."
+ },
+ "number": {
+ "type": "string",
+ "pattern": "^\\+?[1-9]\\d{1,14}$",
+ "description": "Phone number in E.164 format."
+ }
+ },
+ "required": ["type", "number"]
+ },
+ "description":
+ "List of phone numbers associated with the user."
+ }
+ },
+ "required": ["firstName", "lastName"]
+ },
+ "address": {
+ "type": "object",
+ "properties": {
+ "street": {
+ "type": "string",
+ "description": "Street address."
+ },
+ "city": {
+ "type": "string",
+ "description": "City name."
+ },
+ "state": {
+ "type": "string",
+ "description": "State or province."
+ },
+ "postalCode": {
+ "type": "string",
+ "pattern": "^\\d{5}(-\\d{4})?$",
+ "description": "Postal code."
+ },
+ "country": {
+ "type": "string",
+ "description": "Country name."
+ }
+ },
+ "required": ["street", "city", "state", "postalCode", "country"]
+ },
+ "preferences": {
+ "type": "object",
+ "properties": {
+ "newsletterSubscribed": {
+ "type":
+ "boolean",
+ "description":
+ "Indicates if the user is subscribed to the newsletter."
+ },
+ "favoriteCategories": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ },
+ "description": "List of user's favorite categories."
+ }
+ },
+ "required": ["newsletterSubscribed"]
+ },
+ "accountStatus": {
+ "type": "string",
+ "enum": ["active", "inactive", "suspended"],
+ "description": "Current status of the user's account."
+ },
+ "registrationDate": {
+ "type": "string",
+ "format": "date-time",
+ "description": "ISO 8601 formatted date-time of user registration."
+ }
+ },
+ "required":
+ ["userId", "personalInfo", "address", "accountStatus", "registrationDate"]
+}
\ No newline at end of file
diff --git a/vllm/cmake/cpu_extension.cmake b/vllm/cmake/cpu_extension.cmake
index 7237d246d..68f7ca1af 100644
--- a/vllm/cmake/cpu_extension.cmake
+++ b/vllm/cmake/cpu_extension.cmake
@@ -16,6 +16,12 @@ include_directories("${CMAKE_SOURCE_DIR}/csrc")
#
# Check the compile flags
#
+
+if (CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64")
+ list(APPEND CXX_COMPILE_FLAGS
+ "-mf16c"
+ )
+endif()
list(APPEND CXX_COMPILE_FLAGS
"-fopenmp"
"-DVLLM_CPU_EXTENSION")
@@ -52,6 +58,8 @@ find_isa(${CPUINFO} "avx2" AVX2_FOUND)
find_isa(${CPUINFO} "avx512f" AVX512_FOUND)
find_isa(${CPUINFO} "POWER10" POWER10_FOUND)
find_isa(${CPUINFO} "POWER9" POWER9_FOUND)
+find_isa(${CPUINFO} "asimd" ASIMD_FOUND) # Check for ARM NEON support
+find_isa(${CPUINFO} "bf16" ARM_BF16_FOUND) # Check for ARM BF16 support
if (AVX512_FOUND AND NOT AVX512_DISABLED)
list(APPEND CXX_COMPILE_FLAGS
@@ -71,9 +79,11 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
else()
message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.")
endif()
+
elseif (AVX2_FOUND)
list(APPEND CXX_COMPILE_FLAGS "-mavx2")
message(WARNING "vLLM CPU backend using AVX2 ISA")
+
elseif (POWER9_FOUND OR POWER10_FOUND)
message(STATUS "PowerPC detected")
# Check for PowerPC VSX support
@@ -81,8 +91,20 @@ elseif (POWER9_FOUND OR POWER10_FOUND)
"-mvsx"
"-mcpu=native"
"-mtune=native")
+
+elseif (ASIMD_FOUND)
+ message(STATUS "ARMv8 or later architecture detected")
+ if(ARM_BF16_FOUND)
+ message(STATUS "BF16 extension detected")
+ set(MARCH_FLAGS "-march=armv8.2-a+bf16+dotprod+fp16")
+ add_compile_definitions(ARM_BF16_SUPPORT)
+ else()
+ message(WARNING "BF16 functionality is not available")
+ set(MARCH_FLAGS "-march=armv8.2-a+dotprod+fp16")
+ endif()
+ list(APPEND CXX_COMPILE_FLAGS ${MARCH_FLAGS})
else()
- message(FATAL_ERROR "vLLM CPU backend requires AVX512 or AVX2 or Power9+ ISA support.")
+ message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA or ARMv8 support.")
endif()
#
@@ -92,7 +114,7 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
FetchContent_Declare(
oneDNN
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
- GIT_TAG v3.5.3
+ GIT_TAG v3.6
GIT_PROGRESS TRUE
GIT_SHALLOW TRUE
)
@@ -152,4 +174,4 @@ define_gpu_extension_target(
WITH_SOABI
)
-message(STATUS "Enabling C extension.")
+message(STATUS "Enabling C extension.")
\ No newline at end of file
diff --git a/vllm/collect_env.py b/vllm/collect_env.py
index 80403d576..254c19b19 100644
--- a/vllm/collect_env.py
+++ b/vllm/collect_env.py
@@ -1,17 +1,19 @@
# ruff: noqa
# code borrowed from https://github.com/pytorch/pytorch/blob/main/torch/utils/collect_env.py
-# Unlike the rest of the PyTorch this file must be python2 compliant.
-# This script outputs relevant system environment info
-# Run it with `python collect_env.py` or `python -m torch.utils.collect_env`
import datetime
import locale
import os
import re
import subprocess
import sys
+# Unlike the rest of the PyTorch this file must be python2 compliant.
+# This script outputs relevant system environment info
+# Run it with `python collect_env.py` or `python -m torch.utils.collect_env`
from collections import namedtuple
+from vllm.envs import environment_variables
+
try:
import torch
TORCH_AVAILABLE = True
@@ -52,6 +54,7 @@
'vllm_version', # vllm specific field
'vllm_build_flags', # vllm specific field
'gpu_topo', # vllm specific field
+ 'env_vars',
])
DEFAULT_CONDA_PATTERNS = {
@@ -512,6 +515,22 @@ def is_xnnpack_available():
else:
return "N/A"
+def get_env_vars():
+ env_vars = ''
+ secret_terms=('secret', 'token', 'api', 'access', 'password')
+ report_prefix = ("TORCH", "NCCL", "PYTORCH",
+ "CUDA", "CUBLAS", "CUDNN",
+ "OMP_", "MKL_",
+ "NVIDIA")
+ for k, v in os.environ.items():
+ if any(term in k.lower() for term in secret_terms):
+ continue
+ if k in environment_variables:
+ env_vars = env_vars + "{}={}".format(k, v) + "\n"
+ if k.startswith(report_prefix):
+ env_vars = env_vars + "{}={}".format(k, v) + "\n"
+
+ return env_vars
def get_env_info():
run_lambda = run
@@ -583,6 +602,7 @@ def get_version_or_na(cfg, prefix):
vllm_version=vllm_version,
vllm_build_flags=vllm_build_flags,
gpu_topo=gpu_topo,
+ env_vars=get_env_vars(),
)
@@ -631,6 +651,8 @@ def get_version_or_na(cfg, prefix):
{vllm_build_flags}
GPU Topology:
{gpu_topo}
+
+{env_vars}
""".strip()
diff --git a/vllm/csrc/attention/attention_kernels.cu b/vllm/csrc/attention/attention_kernels.cuh
similarity index 64%
rename from vllm/csrc/attention/attention_kernels.cu
rename to vllm/csrc/attention/attention_kernels.cuh
index bcd170411..563e1438f 100644
--- a/vllm/csrc/attention/attention_kernels.cu
+++ b/vllm/csrc/attention/attention_kernels.cuh
@@ -670,332 +670,6 @@ __global__ void paged_attention_v2_reduce_kernel(
} // namespace vllm
-#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
- VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
- ((void*)vllm::paged_attention_v1_kernel), \
- shared_mem_size); \
- vllm::paged_attention_v1_kernel \
- <<>>( \
- out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
- scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
- alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
- k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
- blocksparse_vert_stride, blocksparse_block_size, \
- blocksparse_head_sliding_step);
-
-// TODO(woosuk): Tune NUM_THREADS.
-template
-void paged_attention_v1_launcher(
- torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
- torch::Tensor& value_cache, int num_kv_heads, float scale,
- torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
- const c10::optional& alibi_slopes, float k_scale,
- float v_scale, const int tp_rank, const int blocksparse_local_blocks,
- const int blocksparse_vert_stride, const int blocksparse_block_size,
- const int blocksparse_head_sliding_step) {
- int num_seqs = query.size(0);
- int num_heads = query.size(1);
- int head_size = query.size(2);
- int max_num_blocks_per_seq = block_tables.size(1);
- int q_stride = query.stride(0);
- int kv_block_stride = key_cache.stride(0);
- int kv_head_stride = key_cache.stride(1);
-
- [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
- assert(head_size % thread_group_size == 0);
-
- // NOTE: alibi_slopes is optional.
- const float* alibi_slopes_ptr =
- alibi_slopes
- ? reinterpret_cast(alibi_slopes.value().data_ptr())
- : nullptr;
-
- T* out_ptr = reinterpret_cast(out.data_ptr());
- T* query_ptr = reinterpret_cast(query.data_ptr());
- CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr());
- CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr());
- int* block_tables_ptr = block_tables.data_ptr();
- int* seq_lens_ptr = seq_lens.data_ptr();
-
- constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
- int padded_max_seq_len =
- DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
- int logits_size = padded_max_seq_len * sizeof(float);
- int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
- // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
- // Keep that in sync with the logic here!
- int shared_mem_size = std::max(logits_size, outputs_size);
-
- dim3 grid(num_heads, num_seqs, 1);
- dim3 block(NUM_THREADS);
- const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- switch (head_size) {
- // NOTE(woosuk): To reduce the compilation time, we only compile for the
- // head sizes that we use in the model. However, we can easily extend this
- // to support any head size which is a multiple of 16.
- case 64:
- LAUNCH_PAGED_ATTENTION_V1(64);
- break;
- case 80:
- LAUNCH_PAGED_ATTENTION_V1(80);
- break;
- case 96:
- LAUNCH_PAGED_ATTENTION_V1(96);
- break;
- case 112:
- LAUNCH_PAGED_ATTENTION_V1(112);
- break;
- case 120:
- LAUNCH_PAGED_ATTENTION_V1(120);
- break;
- case 128:
- LAUNCH_PAGED_ATTENTION_V1(128);
- break;
- case 192:
- LAUNCH_PAGED_ATTENTION_V1(192);
- break;
- case 256:
- LAUNCH_PAGED_ATTENTION_V1(256);
- break;
- default:
- TORCH_CHECK(false, "Unsupported head size: ", head_size);
- break;
- }
-}
-
-#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
- paged_attention_v1_launcher( \
- out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
- seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
- blocksparse_local_blocks, blocksparse_vert_stride, \
- blocksparse_block_size, blocksparse_head_sliding_step);
-
-#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
- switch (is_block_sparse) { \
- case true: \
- CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
- break; \
- case false: \
- CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
- break; \
- }
-
-// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
-// 1, 2, 4, 64, 128, 256.
-#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
- switch (block_size) { \
- case 8: \
- CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
- break; \
- case 16: \
- CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
- break; \
- case 32: \
- CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
- break; \
- default: \
- TORCH_CHECK(false, "Unsupported block size: ", block_size); \
- break; \
- }
-
-void paged_attention_v1(
- torch::Tensor& out, // [num_seqs, num_heads, head_size]
- torch::Tensor& query, // [num_seqs, num_heads, head_size]
- torch::Tensor&
- key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
- torch::Tensor&
- value_cache, // [num_blocks, num_heads, head_size, block_size]
- int64_t num_kv_heads, // [num_heads]
- double scale,
- torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
- torch::Tensor& seq_lens, // [num_seqs]
- int64_t block_size, int64_t max_seq_len,
- const c10::optional& alibi_slopes,
- const std::string& kv_cache_dtype, double k_scale, double v_scale,
- const int64_t tp_rank, const int64_t blocksparse_local_blocks,
- const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
- const int64_t blocksparse_head_sliding_step) {
- const bool is_block_sparse = (blocksparse_vert_stride > 1);
-
- DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
- CALL_V1_LAUNCHER_BLOCK_SIZE)
-}
-
-#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
- vllm::paged_attention_v2_kernel \
- <<>>( \
- exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
- value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
- seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
- kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
- blocksparse_local_blocks, blocksparse_vert_stride, \
- blocksparse_block_size, blocksparse_head_sliding_step); \
- vllm::paged_attention_v2_reduce_kernel \
- <<>>( \
- out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
- max_num_partitions);
-
-template
-void paged_attention_v2_launcher(
- torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
- torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
- torch::Tensor& value_cache, int num_kv_heads, float scale,
- torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
- const c10::optional& alibi_slopes, float k_scale,
- float v_scale, const int tp_rank, const int blocksparse_local_blocks,
- const int blocksparse_vert_stride, const int blocksparse_block_size,
- const int blocksparse_head_sliding_step) {
- int num_seqs = query.size(0);
- int num_heads = query.size(1);
- int head_size = query.size(2);
- int max_num_blocks_per_seq = block_tables.size(1);
- int q_stride = query.stride(0);
- int kv_block_stride = key_cache.stride(0);
- int kv_head_stride = key_cache.stride(1);
-
- [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
- assert(head_size % thread_group_size == 0);
-
- // NOTE: alibi_slopes is optional.
- const float* alibi_slopes_ptr =
- alibi_slopes
- ? reinterpret_cast(alibi_slopes.value().data_ptr())
- : nullptr;
-
- T* out_ptr = reinterpret_cast(out.data_ptr());
- float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr());
- float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr());
- T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr());
- T* query_ptr = reinterpret_cast(query.data_ptr());
- CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr());
- CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr());
- int* block_tables_ptr = block_tables.data_ptr();
- int* seq_lens_ptr = seq_lens.data_ptr();
-
- constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
- int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
- int logits_size = PARTITION_SIZE * sizeof(float);
- int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
-
- // For paged attention v2 kernel.
- dim3 grid(num_heads, num_seqs, max_num_partitions);
- int shared_mem_size = std::max(logits_size, outputs_size);
- // For paged attention v2 reduce kernel.
- dim3 reduce_grid(num_heads, num_seqs);
- int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
-
- dim3 block(NUM_THREADS);
- const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- switch (head_size) {
- // NOTE(woosuk): To reduce the compilation time, we only compile for the
- // head sizes that we use in the model. However, we can easily extend this
- // to support any head size which is a multiple of 16.
- case 64:
- LAUNCH_PAGED_ATTENTION_V2(64);
- break;
- case 80:
- LAUNCH_PAGED_ATTENTION_V2(80);
- break;
- case 96:
- LAUNCH_PAGED_ATTENTION_V2(96);
- break;
- case 112:
- LAUNCH_PAGED_ATTENTION_V2(112);
- break;
- case 120:
- LAUNCH_PAGED_ATTENTION_V2(120);
- break;
- case 128:
- LAUNCH_PAGED_ATTENTION_V2(128);
- break;
- case 192:
- LAUNCH_PAGED_ATTENTION_V2(192);
- break;
- case 256:
- LAUNCH_PAGED_ATTENTION_V2(256);
- break;
- default:
- TORCH_CHECK(false, "Unsupported head size: ", head_size);
- break;
- }
-}
-
-#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
- paged_attention_v2_launcher( \
- out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
- num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
- k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
- blocksparse_vert_stride, blocksparse_block_size, \
- blocksparse_head_sliding_step);
-
-#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
- switch (is_block_sparse) { \
- case true: \
- CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
- break; \
- case false: \
- CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
- break; \
- }
-
-// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
-// 1, 2, 4, 64, 128, 256.
-#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
- switch (block_size) { \
- case 8: \
- CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
- break; \
- case 16: \
- CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
- break; \
- case 32: \
- CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
- break; \
- default: \
- TORCH_CHECK(false, "Unsupported block size: ", block_size); \
- break; \
- }
-
-void paged_attention_v2(
- torch::Tensor& out, // [num_seqs, num_heads, head_size]
- torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
- torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
- torch::Tensor&
- tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
- torch::Tensor& query, // [num_seqs, num_heads, head_size]
- torch::Tensor&
- key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
- torch::Tensor&
- value_cache, // [num_blocks, num_heads, head_size, block_size]
- int64_t num_kv_heads, // [num_heads]
- double scale,
- torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
- torch::Tensor& seq_lens, // [num_seqs]
- int64_t block_size, int64_t max_seq_len,
- const c10::optional& alibi_slopes,
- const std::string& kv_cache_dtype, double k_scale, double v_scale,
- const int64_t tp_rank, const int64_t blocksparse_local_blocks,
- const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
- const int64_t blocksparse_head_sliding_step) {
- const bool is_block_sparse = (blocksparse_vert_stride > 1);
- DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
- CALL_V2_LAUNCHER_BLOCK_SIZE)
-}
-
#undef WARP_SIZE
#undef MAX
#undef MIN
diff --git a/vllm/csrc/attention/paged_attention_v1.cu b/vllm/csrc/attention/paged_attention_v1.cu
new file mode 100644
index 000000000..741cd0c82
--- /dev/null
+++ b/vllm/csrc/attention/paged_attention_v1.cu
@@ -0,0 +1,196 @@
+/*
+ * Adapted from
+ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
+ * Copyright (c) 2023, The vLLM team.
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "attention_kernels.cuh"
+
+#ifndef USE_ROCM
+ #define WARP_SIZE 32
+#else
+ #define WARP_SIZE warpSize
+#endif
+
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
+
+#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
+ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
+ ((void*)vllm::paged_attention_v1_kernel), \
+ shared_mem_size); \
+ vllm::paged_attention_v1_kernel \
+ <<>>( \
+ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
+ scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
+ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
+ k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
+ blocksparse_vert_stride, blocksparse_block_size, \
+ blocksparse_head_sliding_step);
+
+// TODO(woosuk): Tune NUM_THREADS.
+template
+void paged_attention_v1_launcher(
+ torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
+ torch::Tensor& value_cache, int num_kv_heads, float scale,
+ torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
+ const c10::optional& alibi_slopes, float k_scale,
+ float v_scale, const int tp_rank, const int blocksparse_local_blocks,
+ const int blocksparse_vert_stride, const int blocksparse_block_size,
+ const int blocksparse_head_sliding_step) {
+ int num_seqs = query.size(0);
+ int num_heads = query.size(1);
+ int head_size = query.size(2);
+ int max_num_blocks_per_seq = block_tables.size(1);
+ int q_stride = query.stride(0);
+ int kv_block_stride = key_cache.stride(0);
+ int kv_head_stride = key_cache.stride(1);
+
+ [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
+ assert(head_size % thread_group_size == 0);
+
+ // NOTE: alibi_slopes is optional.
+ const float* alibi_slopes_ptr =
+ alibi_slopes
+ ? reinterpret_cast(alibi_slopes.value().data_ptr())
+ : nullptr;
+
+ T* out_ptr = reinterpret_cast(out.data_ptr());
+ T* query_ptr = reinterpret_cast(query.data_ptr());
+ CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr());
+ CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr());
+ int* block_tables_ptr = block_tables.data_ptr();
+ int* seq_lens_ptr = seq_lens.data_ptr();
+
+ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+ int padded_max_seq_len =
+ DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
+ int logits_size = padded_max_seq_len * sizeof(float);
+ int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
+ // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
+ // Keep that in sync with the logic here!
+ int shared_mem_size = std::max(logits_size, outputs_size);
+
+ dim3 grid(num_heads, num_seqs, 1);
+ dim3 block(NUM_THREADS);
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ switch (head_size) {
+ // NOTE(woosuk): To reduce the compilation time, we only compile for the
+ // head sizes that we use in the model. However, we can easily extend this
+ // to support any head size which is a multiple of 16.
+ case 32:
+ LAUNCH_PAGED_ATTENTION_V1(32);
+ break;
+ case 64:
+ LAUNCH_PAGED_ATTENTION_V1(64);
+ break;
+ case 80:
+ LAUNCH_PAGED_ATTENTION_V1(80);
+ break;
+ case 96:
+ LAUNCH_PAGED_ATTENTION_V1(96);
+ break;
+ case 112:
+ LAUNCH_PAGED_ATTENTION_V1(112);
+ break;
+ case 120:
+ LAUNCH_PAGED_ATTENTION_V1(120);
+ break;
+ case 128:
+ LAUNCH_PAGED_ATTENTION_V1(128);
+ break;
+ case 192:
+ LAUNCH_PAGED_ATTENTION_V1(192);
+ break;
+ case 256:
+ LAUNCH_PAGED_ATTENTION_V1(256);
+ break;
+ default:
+ TORCH_CHECK(false, "Unsupported head size: ", head_size);
+ break;
+ }
+}
+
+#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
+ paged_attention_v1_launcher( \
+ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
+ seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
+ blocksparse_local_blocks, blocksparse_vert_stride, \
+ blocksparse_block_size, blocksparse_head_sliding_step);
+
+#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
+ switch (is_block_sparse) { \
+ case true: \
+ CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
+ break; \
+ case false: \
+ CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
+ break; \
+ }
+
+// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
+// 1, 2, 4, 64, 128, 256.
+#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
+ switch (block_size) { \
+ case 8: \
+ CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
+ break; \
+ case 16: \
+ CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
+ break; \
+ case 32: \
+ CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
+ break; \
+ default: \
+ TORCH_CHECK(false, "Unsupported block size: ", block_size); \
+ break; \
+ }
+
+void paged_attention_v1(
+ torch::Tensor& out, // [num_seqs, num_heads, head_size]
+ torch::Tensor& query, // [num_seqs, num_heads, head_size]
+ torch::Tensor&
+ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
+ torch::Tensor&
+ value_cache, // [num_blocks, num_heads, head_size, block_size]
+ int64_t num_kv_heads, // [num_heads]
+ double scale,
+ torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
+ torch::Tensor& seq_lens, // [num_seqs]
+ int64_t block_size, int64_t max_seq_len,
+ const c10::optional& alibi_slopes,
+ const std::string& kv_cache_dtype, double k_scale, double v_scale,
+ const int64_t tp_rank, const int64_t blocksparse_local_blocks,
+ const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
+ const int64_t blocksparse_head_sliding_step) {
+ const bool is_block_sparse = (blocksparse_vert_stride > 1);
+
+ DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
+ CALL_V1_LAUNCHER_BLOCK_SIZE)
+}
+
+#undef WARP_SIZE
+#undef MAX
+#undef MIN
+#undef DIVIDE_ROUND_UP
\ No newline at end of file
diff --git a/vllm/csrc/attention/paged_attention_v2.cu b/vllm/csrc/attention/paged_attention_v2.cu
new file mode 100644
index 000000000..6de8d0bdd
--- /dev/null
+++ b/vllm/csrc/attention/paged_attention_v2.cu
@@ -0,0 +1,206 @@
+/*
+ * Adapted from
+ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
+ * Copyright (c) 2023, The vLLM team.
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "attention_kernels.cuh"
+
+#ifndef USE_ROCM
+ #define WARP_SIZE 32
+#else
+ #define WARP_SIZE warpSize
+#endif
+
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
+
+#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
+ vllm::paged_attention_v2_kernel \
+ <<>>( \
+ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
+ value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
+ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
+ kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
+ blocksparse_local_blocks, blocksparse_vert_stride, \
+ blocksparse_block_size, blocksparse_head_sliding_step); \
+ vllm::paged_attention_v2_reduce_kernel \
+ <<>>( \
+ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
+ max_num_partitions);
+
+template
+void paged_attention_v2_launcher(
+ torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
+ torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
+ torch::Tensor& value_cache, int num_kv_heads, float scale,
+ torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
+ const c10::optional& alibi_slopes, float k_scale,
+ float v_scale, const int tp_rank, const int blocksparse_local_blocks,
+ const int blocksparse_vert_stride, const int blocksparse_block_size,
+ const int blocksparse_head_sliding_step) {
+ int num_seqs = query.size(0);
+ int num_heads = query.size(1);
+ int head_size = query.size(2);
+ int max_num_blocks_per_seq = block_tables.size(1);
+ int q_stride = query.stride(0);
+ int kv_block_stride = key_cache.stride(0);
+ int kv_head_stride = key_cache.stride(1);
+
+ [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
+ assert(head_size % thread_group_size == 0);
+
+ // NOTE: alibi_slopes is optional.
+ const float* alibi_slopes_ptr =
+ alibi_slopes
+ ? reinterpret_cast(alibi_slopes.value().data_ptr())
+ : nullptr;
+
+ T* out_ptr = reinterpret_cast(out.data_ptr());
+ float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr());
+ float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr());
+ T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr());
+ T* query_ptr = reinterpret_cast(query.data_ptr());
+ CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr());
+ CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr());
+ int* block_tables_ptr = block_tables.data_ptr();
+ int* seq_lens_ptr = seq_lens.data_ptr();
+
+ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+ int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
+ int logits_size = PARTITION_SIZE * sizeof(float);
+ int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
+
+ // For paged attention v2 kernel.
+ dim3 grid(num_heads, num_seqs, max_num_partitions);
+ int shared_mem_size = std::max(logits_size, outputs_size);
+ // For paged attention v2 reduce kernel.
+ dim3 reduce_grid(num_heads, num_seqs);
+ int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
+
+ dim3 block(NUM_THREADS);
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ switch (head_size) {
+ // NOTE(woosuk): To reduce the compilation time, we only compile for the
+ // head sizes that we use in the model. However, we can easily extend this
+ // to support any head size which is a multiple of 16.
+ case 32:
+ LAUNCH_PAGED_ATTENTION_V2(32);
+ break;
+ case 64:
+ LAUNCH_PAGED_ATTENTION_V2(64);
+ break;
+ case 80:
+ LAUNCH_PAGED_ATTENTION_V2(80);
+ break;
+ case 96:
+ LAUNCH_PAGED_ATTENTION_V2(96);
+ break;
+ case 112:
+ LAUNCH_PAGED_ATTENTION_V2(112);
+ break;
+ case 120:
+ LAUNCH_PAGED_ATTENTION_V2(120);
+ break;
+ case 128:
+ LAUNCH_PAGED_ATTENTION_V2(128);
+ break;
+ case 192:
+ LAUNCH_PAGED_ATTENTION_V2(192);
+ break;
+ case 256:
+ LAUNCH_PAGED_ATTENTION_V2(256);
+ break;
+ default:
+ TORCH_CHECK(false, "Unsupported head size: ", head_size);
+ break;
+ }
+}
+
+#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
+ paged_attention_v2_launcher( \
+ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
+ num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
+ k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
+ blocksparse_vert_stride, blocksparse_block_size, \
+ blocksparse_head_sliding_step);
+
+#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
+ switch (is_block_sparse) { \
+ case true: \
+ CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
+ break; \
+ case false: \
+ CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
+ break; \
+ }
+
+// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
+// 1, 2, 4, 64, 128, 256.
+#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
+ switch (block_size) { \
+ case 8: \
+ CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
+ break; \
+ case 16: \
+ CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
+ break; \
+ case 32: \
+ CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
+ break; \
+ default: \
+ TORCH_CHECK(false, "Unsupported block size: ", block_size); \
+ break; \
+ }
+
+void paged_attention_v2(
+ torch::Tensor& out, // [num_seqs, num_heads, head_size]
+ torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
+ torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
+ torch::Tensor&
+ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
+ torch::Tensor& query, // [num_seqs, num_heads, head_size]
+ torch::Tensor&
+ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
+ torch::Tensor&
+ value_cache, // [num_blocks, num_heads, head_size, block_size]
+ int64_t num_kv_heads, // [num_heads]
+ double scale,
+ torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
+ torch::Tensor& seq_lens, // [num_seqs]
+ int64_t block_size, int64_t max_seq_len,
+ const c10::optional& alibi_slopes,
+ const std::string& kv_cache_dtype, double k_scale, double v_scale,
+ const int64_t tp_rank, const int64_t blocksparse_local_blocks,
+ const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
+ const int64_t blocksparse_head_sliding_step) {
+ const bool is_block_sparse = (blocksparse_vert_stride > 1);
+ DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
+ CALL_V2_LAUNCHER_BLOCK_SIZE)
+}
+
+#undef WARP_SIZE
+#undef MAX
+#undef MIN
+#undef DIVIDE_ROUND_UP
\ No newline at end of file
diff --git a/vllm/csrc/cpu/attention.cpp b/vllm/csrc/cpu/attention.cpp
index abb4e3bea..e21832ba7 100644
--- a/vllm/csrc/cpu/attention.cpp
+++ b/vllm/csrc/cpu/attention.cpp
@@ -22,6 +22,24 @@ struct KernelVecType {
using v_load_vec_type = vec_op::FP32Vec16;
};
+template <>
+struct KernelVecType {
+#ifdef __powerpc64__
+ // Power architecture-specific vector types
+ using q_load_vec_type = vec_op::FP32Vec8;
+ using k_load_vec_type = vec_op::FP32Vec16;
+ using v_load_vec_type = vec_op::FP32Vec16;
+#else
+ // Fallback for other architectures, including x86
+ using q_load_vec_type = vec_op::FP16Vec8;
+ using k_load_vec_type = vec_op::FP16Vec16;
+ using v_load_vec_type = vec_op::FP16Vec16;
+#endif
+ using q_vec_type = vec_op::FP32Vec16;
+ using k_vec_type = vec_op::FP32Vec16;
+ using qk_acc_vec_type = vec_op::FP32Vec16;
+};
+
#ifdef __AVX512BF16__
template <>
struct KernelVecType {
@@ -33,6 +51,21 @@ struct KernelVecType {
using v_load_vec_type = vec_op::BF16Vec16;
};
#else
+ #ifdef __aarch64__
+ #ifndef ARM_BF16_SUPPORT
+ // pass
+ #else
+template <>
+struct KernelVecType {
+ using q_load_vec_type = vec_op::BF16Vec8;
+ using q_vec_type = vec_op::FP32Vec16;
+ using k_load_vec_type = vec_op::BF16Vec16;
+ using k_vec_type = vec_op::FP32Vec16;
+ using qk_acc_vec_type = vec_op::FP32Vec16;
+ using v_load_vec_type = vec_op::BF16Vec16;
+};
+ #endif
+ #else
template <>
struct KernelVecType {
using q_load_vec_type = vec_op::BF16Vec8;
@@ -42,6 +75,7 @@ struct KernelVecType {
using qk_acc_vec_type = vec_op::FP32Vec16;
using v_load_vec_type = vec_op::BF16Vec16;
};
+ #endif
#endif
template
@@ -375,6 +409,9 @@ void paged_attention_v1_impl_launcher(
int* seq_lens_ptr = seq_lens.data_ptr();
switch (head_size) {
+ case 32:
+ LAUNCH_V1_ATTENTION_KERNEL(T, 32, BLOCK_SIZE);
+ break;
case 64:
LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break;
@@ -692,6 +729,9 @@ void paged_attention_v2_impl_launcher(
int* seq_lens_ptr = seq_lens.data_ptr();
switch (head_size) {
+ case 32:
+ LAUNCH_V2_ATTENTION_KERNEL(T, 32, BLOCK_SIZE);
+ break;
case 64:
LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break;
@@ -755,4 +795,4 @@ void paged_attention_v2(
CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t);
CPU_KERNEL_GUARD_OUT(paged_attention_v2_impl)
});
-}
+}
\ No newline at end of file
diff --git a/vllm/csrc/cpu/cpu_types.hpp b/vllm/csrc/cpu/cpu_types.hpp
index 0213be091..28db04797 100644
--- a/vllm/csrc/cpu/cpu_types.hpp
+++ b/vllm/csrc/cpu/cpu_types.hpp
@@ -1,4 +1,3 @@
-
#ifndef CPU_TYPES_HPP
#define CPU_TYPES_HPP
@@ -8,8 +7,11 @@
#elif defined(__POWER9_VECTOR__)
//ppc implementation
#include "cpu_types_vsx.hpp"
+#elif defined(__aarch64__)
+ //arm implementation
+ #include "cpu_types_arm.hpp"
#else
#warning "unsupported vLLM cpu implementation"
#endif
-#endif
+#endif
\ No newline at end of file
diff --git a/vllm/csrc/cpu/cpu_types_arm.hpp b/vllm/csrc/cpu/cpu_types_arm.hpp
new file mode 100644
index 000000000..73e0f8cb2
--- /dev/null
+++ b/vllm/csrc/cpu/cpu_types_arm.hpp
@@ -0,0 +1,515 @@
+#include
+#include
+#include
+
+namespace vec_op {
+
+#ifdef ARM_BF16_SUPPORT
+ #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
+ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
+ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
+ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
+#else
+ #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
+ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
+ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
+#endif
+
+#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
+ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
+
+#ifndef CPU_OP_GUARD
+#define CPU_KERNEL_GUARD_IN(NAME)
+#define CPU_KERNEL_GUARD_OUT(NAME)
+#else
+#define CPU_KERNEL_GUARD_IN(NAME) \
+ std::cout << #NAME << " invoked." << std::endl;
+#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl;
+#endif
+
+#define FORCE_INLINE __attribute__((always_inline)) inline
+
+namespace {
+ template
+ constexpr void unroll_loop_item(std::integer_sequence, F &&f) {
+ (f(std::integral_constant{}), ...);
+ };
+};
+
+template >>
+constexpr void unroll_loop(F &&f) {
+ unroll_loop_item(std::make_integer_sequence{}, std::forward(f));
+}
+
+template struct Vec {
+ constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; };
+};
+
+struct FP32Vec8;
+struct FP32Vec16;
+
+struct FP16Vec8 : public Vec {
+ constexpr static int VEC_ELEM_NUM = 8;
+
+ float16x8_t reg;
+
+ explicit FP16Vec8(const void *ptr)
+ : reg(vld1q_f16(static_cast(ptr))) {};
+
+ explicit FP16Vec8(const FP32Vec8 &);
+
+ void save(void *ptr) const {
+ vst1q_f16(static_cast<__fp16 *>(ptr), reg);
+ }
+};
+
+struct FP16Vec16 : public Vec {
+ constexpr static int VEC_ELEM_NUM = 16;
+
+ float16x8x2_t reg;
+
+ explicit FP16Vec16(const void *ptr) {
+ reg.val[0] = vld1q_f16(reinterpret_cast(ptr));
+ reg.val[1] = vld1q_f16(reinterpret_cast(ptr) + 8);
+ }
+
+ explicit FP16Vec16(const FP32Vec16& vec);
+
+ void save(void *ptr) const {
+ vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
+ vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]);
+ }
+
+ void save(void *ptr, const int elem_num) const {
+ int full_blocks = elem_num / 8;
+ int remainder = elem_num % 8;
+
+ if (full_blocks > 0) {
+ vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
+ if (full_blocks > 1) {
+ vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]);
+ }
+ }
+
+ if (remainder > 0) {
+ float16x8_t temp = reg.val[full_blocks];
+ for (int i = 0; i < remainder; ++i) {
+ reinterpret_cast<__fp16*>(ptr)[full_blocks * 8 + i] = vgetq_lane_f16(temp, i);
+ }
+ }
+ }
+};
+
+
+#ifdef ARM_BF16_SUPPORT
+struct BF16Vec8 : public Vec {
+ constexpr static int VEC_ELEM_NUM = 8;
+
+ bfloat16x8_t reg;
+
+ explicit BF16Vec8(const void *ptr)
+ : reg(*reinterpret_cast(ptr)) {};
+
+ explicit BF16Vec8(bfloat16x8_t data) : reg(data) {};
+
+ explicit BF16Vec8(const FP32Vec8 &);
+
+ explicit BF16Vec8(float32x4x2_t v) : reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1])) {};
+
+ void save(void *ptr) const { *reinterpret_cast(ptr) = reg; }
+};
+
+struct BF16Vec16 : public Vec {
+ constexpr static int VEC_ELEM_NUM = 16;
+
+ bfloat16x8x2_t reg;
+
+ explicit BF16Vec16(const void *ptr)
+ : reg(*reinterpret_cast(ptr)) {};
+
+ explicit BF16Vec16(bfloat16x8x2_t data) : reg(data) {};
+
+ explicit BF16Vec16(const FP32Vec16 &);
+
+ explicit BF16Vec16(float32x4x4_t v) : reg({
+ vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1]),
+ vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[2]), v.val[3])
+ }){};
+
+ void save(void *ptr) const { *reinterpret_cast(ptr) = reg; };
+};
+
+struct BF16Vec32 : public Vec {
+ constexpr static int VEC_ELEM_NUM = 32;
+
+ bfloat16x8x4_t reg;
+
+ explicit BF16Vec32(const void *ptr)
+ : reg(*reinterpret_cast(ptr)) {};
+
+ explicit BF16Vec32(bfloat16x8x4_t data) : reg(data) {};
+
+ explicit BF16Vec32(const BF16Vec8 &vec8_data) : reg({
+ vec8_data.reg,
+ vec8_data.reg,
+ vec8_data.reg,
+ vec8_data.reg
+ }) {};
+
+ void save(void *ptr) const { *reinterpret_cast(ptr) = reg; };
+};
+#endif
+
+struct FP32Vec4 : public Vec {
+ constexpr static int VEC_ELEM_NUM = 4;
+
+ union AliasReg {
+ float32x4_t reg;
+ float values[VEC_ELEM_NUM];
+ };
+
+ float32x4_t reg;
+
+ explicit FP32Vec4(float v) : reg(vdupq_n_f32(v)) {};
+
+ explicit FP32Vec4() : reg(vdupq_n_f32(0.0f)) {};
+
+ explicit FP32Vec4(const float *ptr) : reg(vld1q_f32(ptr)) {};
+
+ explicit FP32Vec4(float32x4_t data) : reg(data) {};
+
+ explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {};
+};
+
+struct FP32Vec8 : public Vec {
+ constexpr static int VEC_ELEM_NUM = 8;
+ union AliasReg {
+ float32x4x2_t reg;
+ float values[VEC_ELEM_NUM];
+ };
+
+ float32x4x2_t reg;
+
+ explicit FP32Vec8(float v) : reg({vmovq_n_f32(v), vmovq_n_f32(v)}) {};
+
+ explicit FP32Vec8() : reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0)}) {};
+
+ explicit FP32Vec8(const float *ptr) : reg({vld1q_f32(ptr), vld1q_f32(ptr + 4)}) {};
+
+ explicit FP32Vec8(float32x4x2_t data) : reg(data) {};
+
+ explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {};
+
+ explicit FP32Vec8(const FP16Vec8 &v) {
+ reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg));
+ reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg));
+ };
+
+ explicit FP32Vec8(float16x8_t v) : reg({vcvt_f32_f16(vget_low_f16(v)), vcvt_f32_f16(vget_high_f16(v))}) {};
+
+ #ifdef ARM_BF16_SUPPORT
+
+ explicit FP32Vec8(bfloat16x8_t v) : reg({vcvtq_low_f32_bf16(v), vcvtq_high_f32_bf16(v)}) {};
+
+ explicit FP32Vec8(const BF16Vec8 &v) : reg({vcvtq_low_f32_bf16(v.reg), vcvtq_high_f32_bf16(v.reg)}) {};
+
+ #endif
+
+ float reduce_sum() const {
+ AliasReg ar;
+ ar.reg = reg;
+ float answer = 0;
+ unroll_loop([&answer, &ar](int i) { answer += ar.values[i]; });
+
+ return answer;
+ }
+
+ FP32Vec8 exp() const {
+ AliasReg ar;
+ ar.reg = reg;
+
+ float32x2_t exp_vec0 = {expf(ar.values[0]), expf(ar.values[1])};
+ float32x2_t exp_vec1 = {expf(ar.values[2]), expf(ar.values[3])};
+ float32x2_t exp_vec2 = {expf(ar.values[4]), expf(ar.values[5])};
+ float32x2_t exp_vec3 = {expf(ar.values[6]), expf(ar.values[7])};
+
+ float32x4_t result0 = vcombine_f32(exp_vec0, exp_vec1);
+ float32x4_t result1 = vcombine_f32(exp_vec2, exp_vec3);
+
+ float32x4x2_t result;
+ result.val[0] = result0;
+ result.val[1] = result1;
+
+ return FP32Vec8(result);
+ }
+
+ FP32Vec8 tanh() const {
+ AliasReg ar;
+ ar.reg = reg;
+
+ float32x2_t tanh_vec0 = {tanhf(ar.values[0]), tanhf(ar.values[1])};
+ float32x2_t tanh_vec1 = {tanhf(ar.values[2]), tanhf(ar.values[3])};
+ float32x2_t tanh_vec2 = {tanhf(ar.values[4]), tanhf(ar.values[5])};
+ float32x2_t tanh_vec3 = {tanhf(ar.values[6]), tanhf(ar.values[7])};
+
+ float32x4_t result0 = vcombine_f32(tanh_vec0, tanh_vec1);
+ float32x4_t result1 = vcombine_f32(tanh_vec2, tanh_vec3);
+
+ float32x4x2_t result;
+ result.val[0] = result0;
+ result.val[1] = result1;
+
+ return FP32Vec8(result);
+ }
+
+ FP32Vec8 er() const {
+ AliasReg ar;
+ ar.reg = reg;
+
+ float32x2_t er_vec0 = {static_cast(erf(ar.values[0])), static_cast(erf(ar.values[1]))};
+ float32x2_t er_vec1 = {static_cast(erf(ar.values[2])), static_cast(erf(ar.values[3]))};
+ float32x2_t er_vec2 = {static_cast(erf(ar.values[4])), static_cast(erf(ar.values[5]))};
+ float32x2_t er_vec3 = {static_cast(erf(ar.values[6])), static_cast(erf(ar.values[7]))};
+
+ float32x4_t result0 = vcombine_f32(er_vec0, er_vec1);
+ float32x4_t result1 = vcombine_f32(er_vec2, er_vec3);
+
+ float32x4x2_t result;
+ result.val[0] = result0;
+ result.val[1] = result1;
+
+ return FP32Vec8(result);
+ }
+
+ FP32Vec8 operator*(const FP32Vec8 &b) const {
+ return FP32Vec8(float32x4x2_t({vmulq_f32(reg.val[0], b.reg.val[0]), vmulq_f32(reg.val[1], b.reg.val[1])}));
+ }
+
+ FP32Vec8 operator+(const FP32Vec8 &b) const {
+ return FP32Vec8(float32x4x2_t({vaddq_f32(reg.val[0], b.reg.val[0]), vaddq_f32(reg.val[1], b.reg.val[1])}));
+ }
+
+ FP32Vec8 operator-(const FP32Vec8 &b) const {
+ return FP32Vec8(float32x4x2_t({vsubq_f32(reg.val[0], b.reg.val[0]), vsubq_f32(reg.val[1], b.reg.val[1])}));
+ }
+
+ FP32Vec8 operator/(const FP32Vec8 &b) const {
+ return FP32Vec8(float32x4x2_t({vdivq_f32(reg.val[0], b.reg.val[0]), vdivq_f32(reg.val[1], b.reg.val[1])}));
+ }
+
+ void save(float *ptr) const {
+ vst1q_f32(ptr, reg.val[0]);
+ vst1q_f32(ptr + 4, reg.val[1]);
+ }
+};
+
+struct FP32Vec16 : public Vec {
+ constexpr static int VEC_ELEM_NUM = 16;
+ union AliasReg {
+ float32x4x4_t reg;
+ float values[VEC_ELEM_NUM];
+ };
+
+ float32x4x4_t reg;
+
+ explicit FP32Vec16(float v) : reg({vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v)}) {}
+
+ explicit FP32Vec16() : reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0), vmovq_n_f32(0.0), vmovq_n_f32(0.0)}) {}
+
+ explicit FP32Vec16(const float *ptr) : reg({vld1q_f32(ptr), vld1q_f32(ptr + 4), vld1q_f32(ptr + 8), vld1q_f32(ptr + 12)}) {}
+
+ explicit FP32Vec16(float32x4x4_t data) : reg(data) {}
+
+ explicit FP32Vec16(const FP32Vec8 &data) {
+ reg.val[0] = data.reg.val[0];
+ reg.val[1] = data.reg.val[1];
+ reg.val[2] = data.reg.val[0];
+ reg.val[3] = data.reg.val[1];
+ }
+
+ explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {}
+
+ explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v.reg)) {}
+
+ #ifdef ARM_BF16_SUPPORT
+ explicit FP32Vec16(bfloat16x8x2_t v) : reg({
+ vcvtq_low_f32_bf16(v.val[0]),
+ vcvtq_high_f32_bf16(v.val[0]),
+ vcvtq_low_f32_bf16(v.val[1]),
+ vcvtq_high_f32_bf16(v.val[1])
+ }) {};
+ #endif
+
+ explicit FP32Vec16(const FP32Vec4 &data) {
+ reg.val[0] = data.reg;
+ reg.val[1] = data.reg;
+ reg.val[2] = data.reg;
+ reg.val[3] = data.reg;
+ };
+
+ #ifdef ARM_BF16_SUPPORT
+ explicit FP32Vec16(const BF16Vec16 &v) : reg({
+ vcvtq_low_f32_bf16(v.reg.val[0]),
+ vcvtq_high_f32_bf16(v.reg.val[0]),
+ vcvtq_low_f32_bf16(v.reg.val[1]),
+ vcvtq_high_f32_bf16(v.reg.val[1])
+ }) {};
+
+ explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {};
+ #endif
+
+ explicit FP32Vec16(const FP16Vec16 &v) {
+ reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg.val[0]));
+ reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg.val[0]));
+ reg.val[2] = vcvt_f32_f16(vget_low_f16(v.reg.val[1]));
+ reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1]));
+ };
+
+ FP32Vec16 operator+(const FP32Vec16 &b) const {
+ return FP32Vec16(float32x4x4_t({
+ vaddq_f32(reg.val[0], b.reg.val[0]),
+ vaddq_f32(reg.val[1], b.reg.val[1]),
+ vaddq_f32(reg.val[2], b.reg.val[2]),
+ vaddq_f32(reg.val[3], b.reg.val[3])}));
+ };
+
+ FP32Vec16 operator*(const FP32Vec16 &b) const {
+ return FP32Vec16(float32x4x4_t({
+ vmulq_f32(reg.val[0], b.reg.val[0]),
+ vmulq_f32(reg.val[1], b.reg.val[1]),
+ vmulq_f32(reg.val[2], b.reg.val[2]),
+ vmulq_f32(reg.val[3], b.reg.val[3])}));
+ };
+
+ FP32Vec16 operator-(const FP32Vec16 &b) const {
+ return FP32Vec16(float32x4x4_t({
+ vsubq_f32(reg.val[0], b.reg.val[0]),
+ vsubq_f32(reg.val[1], b.reg.val[1]),
+ vsubq_f32(reg.val[2], b.reg.val[2]),
+ vsubq_f32(reg.val[3], b.reg.val[3])
+ }));
+ };
+
+ FP32Vec16 operator/(const FP32Vec16 &b) const {
+ return FP32Vec16(float32x4x4_t({
+ vdivq_f32(reg.val[0], b.reg.val[0]),
+ vdivq_f32(reg.val[1], b.reg.val[1]),
+ vdivq_f32(reg.val[2], b.reg.val[2]),
+ vdivq_f32(reg.val[3], b.reg.val[3])
+ }));
+ };
+
+ float reduce_sum() const {
+ AliasReg ar;
+ ar.reg = reg;
+ float answer = 0;
+ unroll_loop([&answer, &ar](int i) { answer += ar.values[i]; });
+
+ return answer;
+ };
+
+ template float reduce_sub_sum(int idx) {
+ static_assert(VEC_ELEM_NUM % group_size == 0);
+
+ AliasReg ar;
+ ar.reg = reg;
+ float answer = 0;
+ const int start = idx * group_size;
+ unroll_loop(
+ [&answer, &start, ar](int i) { answer += ar.values[start + i]; });
+
+ return answer;
+ };
+
+ void save(float *ptr) const {
+ vst1q_f32(ptr, reg.val[0]);
+ vst1q_f32(ptr + 4, reg.val[1]);
+ vst1q_f32(ptr + 8, reg.val[2]);
+ vst1q_f32(ptr + 12, reg.val[3]);
+ };
+};
+
+template struct VecType { using vec_type = void; };
+
+template using vec_t = typename VecType::vec_type;
+
+template <> struct VecType { using vec_type = FP32Vec8; };
+
+template <> struct VecType { using vec_type = FP16Vec8; };
+
+#ifdef ARM_BF16_SUPPORT
+template <> struct VecType { using vec_type = BF16Vec8; };
+#endif
+
+template void storeFP32(float v, T *ptr) { *ptr = v; }
+
+template <> inline void storeFP32(float v, c10::Half *ptr) {
+ *reinterpret_cast<__fp16 *>(ptr) = v;
+}
+
+inline FP16Vec16::FP16Vec16(const FP32Vec16 &v) {
+ float16x4_t low_0 = vcvt_f16_f32(v.reg.val[0]);
+ float16x4_t high_0 = vcvt_f16_f32(v.reg.val[1]);
+ float16x4_t low_1 = vcvt_f16_f32(v.reg.val[2]);
+ float16x4_t high_1 = vcvt_f16_f32(v.reg.val[3]);
+
+ reg.val[0] = vcombine_f16(low_0, high_0);
+ reg.val[1] = vcombine_f16(low_1, high_1);
+};
+
+inline FP16Vec8 :: FP16Vec8(const FP32Vec8 &v) {
+ float16x4_t lower_half = vcvt_f16_f32(v.reg.val[0]);
+ float16x4_t upper_half = vcvt_f16_f32(v.reg.val[1]);
+
+ reg = vcombine_f16(lower_half, upper_half);
+};
+
+inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
+
+ acc.reg.val[0] = vfmaq_f32(acc.reg.val[0], a.reg.val[0], b.reg.val[0]);
+ acc.reg.val[1] = vfmaq_f32(acc.reg.val[1], a.reg.val[1], b.reg.val[1]);
+ acc.reg.val[2] = vfmaq_f32(acc.reg.val[2], a.reg.val[2], b.reg.val[2]);
+ acc.reg.val[3] = vfmaq_f32(acc.reg.val[3], a.reg.val[3], b.reg.val[3]);
+};
+
+#ifdef ARM_BF16_SUPPORT
+inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) {
+
+ float32x4_t a0_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[0]));
+ float32x4_t a0_high = vcvt_f32_bf16(vget_high_bf16(a.reg.val[0]));
+ float32x4_t a1_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[1]));
+ float32x4_t a1_high = vcvt_f32_bf16(vget_high_bf16(a.reg.val[1]));
+
+ float32x4_t b0_low = vcvt_f32_bf16(vget_low_bf16(b.reg.val[0]));
+ float32x4_t b0_high = vcvt_f32_bf16(vget_high_bf16(b.reg.val[0]));
+ float32x4_t b1_low = vcvt_f32_bf16(vget_low_bf16(b.reg.val[1]));
+ float32x4_t b1_high = vcvt_f32_bf16(vget_high_bf16(b.reg.val[1]));
+
+ acc.reg.val[0] = vfmaq_f32(acc.reg.val[0], a0_low, b0_low);
+ acc.reg.val[1] = vfmaq_f32(acc.reg.val[1], a0_high, b0_high);
+ acc.reg.val[2] = vfmaq_f32(acc.reg.val[2], a1_low, b1_low);
+ acc.reg.val[3] = vfmaq_f32(acc.reg.val[3], a1_high, b1_high);
+};
+#endif
+
+#ifdef ARM_BF16_SUPPORT
+inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) : reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1])) {};
+
+inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) : reg({
+ vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1]),
+ vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[2]), v.reg.val[3])
+ }){};
+#endif
+
+inline void prefetch(const void *addr) {
+ __builtin_prefetch(addr, 0, 1);
+};
+
+#ifdef ARM_BF16_SUPPORT
+template <>
+inline void storeFP32(float v, c10::BFloat16 *ptr) {
+ *reinterpret_cast<__bf16 *>(ptr) = vcvth_bf16_f32(v);
+};
+#endif
+};
\ No newline at end of file
diff --git a/vllm/csrc/cpu/cpu_types_x86.hpp b/vllm/csrc/cpu/cpu_types_x86.hpp
index a325153b4..4bb4eb0f4 100644
--- a/vllm/csrc/cpu/cpu_types_x86.hpp
+++ b/vllm/csrc/cpu/cpu_types_x86.hpp
@@ -11,10 +11,10 @@ static_assert(false, "AVX2 must be supported for the current implementation.");
namespace vec_op {
-// FIXME: FP16 is not fully supported in Torch-CPU
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
- AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
+ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
+ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
@@ -50,37 +50,37 @@ template struct Vec {
struct FP32Vec8;
struct FP32Vec16;
-#ifdef __AVX512FP16__
struct FP16Vec8 : public Vec {
constexpr static int VEC_ELEM_NUM = 8;
- __m128h reg;
+ __m128i reg;
- explicit FP16Vec8(_Float16 v) : reg(_mm_set1_ph(v)) {}
+ explicit FP16Vec8(const void *ptr)
+ : reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {}
- explicit FP16Vec8(const void *ptr) : reg(_mm_loadu_ph(ptr)) {}
+ explicit FP16Vec8(const FP32Vec8 &);
- explicit FP16Vec8(__m128h data) : reg(data) {}
+ void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; }
+};
- FP16Vec8 operator*(const FP16Vec8 &b) const {
- return FP16Vec8(_mm_mul_ph(reg, b.reg));
- }
+struct FP16Vec16 : public Vec {
+ constexpr static int VEC_ELEM_NUM = 16;
- FP16Vec8 operator+(const FP16Vec8 &b) const {
- return FP16Vec8(_mm_add_ph(reg, b.reg));
- }
+ __m256i reg;
- FP16Vec8 operator-(const FP16Vec8 &b) const {
- return FP16Vec8(_mm_sub_ph(reg, b.reg));
- }
+ explicit FP16Vec16(const void *ptr)
+ : reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {}
- FP16Vec8 operator/(const FP16Vec8 &b) const {
- return FP16Vec8(_mm_div_ph(reg, b.reg));
- }
+ explicit FP16Vec16(const FP32Vec16 &);
- void save(void *ptr) const { _mm_storeu_ph(ptr, reg); }
+ void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; }
+
+ void save(void* ptr, const int elem_num) const {
+ constexpr uint32_t M = 0xFFFFFFFF;
+ __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
+ _mm256_mask_storeu_epi16(ptr, mask, reg);
+ }
};
-#endif
struct BF16Vec8 : public Vec {
constexpr static int VEC_ELEM_NUM = 8;
@@ -202,9 +202,7 @@ struct FP32Vec8 : public Vec {
explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {}
-#ifdef __AVX512FP16__
- explicit FP32Vec8(__m128h v) : reg(_mm256_cvtph_ps(_mm_castph_si128(v))) {}
-#endif
+ explicit FP32Vec8(const FP16Vec8 &v) : reg(_mm256_cvtph_ps(v.reg)) {}
explicit FP32Vec8(const BF16Vec8 &v)
: reg(_mm256_castsi256_ps(
@@ -323,6 +321,10 @@ struct FP32Vec16 : public Vec {
: reg(_mm512_castsi512_ps(
_mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {}
+ explicit FP32Vec16(const FP16Vec16 &v) : reg(_mm512_cvtph_ps(v.reg)) {}
+
+ explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
+
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
explicit FP32Vec16(const INT32Vec16 &v)
@@ -430,6 +432,16 @@ struct FP32Vec16 : public Vec {
explicit FP32Vec16(const FP32Vec8 &data)
: reg_low(data.reg), reg_high(data.reg) {}
+ explicit FP32Vec16(const FP16Vec16 &v) {
+ __m128i low = _mm256_extractf128_si256(v.reg, 0);
+ __m128i high = _mm256_extractf128_si256(v.reg, 1);
+
+ reg_low = _mm256_cvtph_ps(low);
+ reg_high = _mm256_cvtph_ps(high);
+ }
+
+ explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
+
explicit FP32Vec16(const BF16Vec16 &v) {
__m128i low = _mm256_extractf128_si256(v.reg, 0);
__m128i high = _mm256_extractf128_si256(v.reg, 1);
@@ -534,24 +546,34 @@ template using vec_t = typename VecType::vec_type;
template <> struct VecType { using vec_type = FP32Vec8; };
-#ifdef __AVX512FP16__
-template <> struct VecType { using vec_type = FP16Vec16; };
-#endif
+template <> struct VecType { using vec_type = FP16Vec8; };
template <> struct VecType { using vec_type = BF16Vec8; };
template void storeFP32(float v, T *ptr) { *ptr = v; }
-#ifdef __AVX512FP16__
-template <> inline void storeFP32(float v, c10::Half *ptr) {
- *reinterpret_cast<_Float16 *>(ptr) = v;
-}
-#endif
-
inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
acc = acc + a * b;
}
+template <> inline void storeFP32(float v, c10::Half *ptr) {
+ *reinterpret_cast(ptr) =
+ _cvtss_sh(v, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
+}
+
+inline FP16Vec8::FP16Vec8(const FP32Vec8 &v)
+ : reg(_mm256_cvtps_ph(v.reg,
+ _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {}
+
+#ifdef __AVX512F__
+inline FP16Vec16::FP16Vec16(const FP32Vec16 &v)
+ : reg(_mm512_cvtps_ph(v.reg,
+ _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {}
+#else
+inline FP16Vec16::FP16Vec16(const FP32Vec16 &v)
+ : reg(_mm256_insertf128_si256(_mm256_castsi128_si256(FP16Vec8(FP32Vec8(v.reg_low)).reg), FP16Vec8(FP32Vec8(v.reg_low)).reg, 1)) {}
+#endif
+
#ifdef __AVX512BF16__
template <> inline void storeFP32(float v, c10::BFloat16 *ptr) {
*reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v);
diff --git a/vllm/csrc/cpu/dnnl_helper.hpp b/vllm/csrc/cpu/dnnl_helper.hpp
index 024ad4ae4..8b5011dc0 100644
--- a/vllm/csrc/cpu/dnnl_helper.hpp
+++ b/vllm/csrc/cpu/dnnl_helper.hpp
@@ -2,6 +2,7 @@
#define DNNL_HELPER_HPP
#include
+#include
#include "oneapi/dnnl/dnnl.hpp"
@@ -32,6 +33,11 @@ struct DNNLType {
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16;
};
+template <>
+struct DNNLType {
+ static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16;
+};
+
template
constexpr inline dnnl::memory::data_type get_dnnl_type() {
return DNNLType>::type;
diff --git a/vllm/csrc/cpu/quant.cpp b/vllm/csrc/cpu/quant.cpp
index b493fd793..d9aed657a 100644
--- a/vllm/csrc/cpu/quant.cpp
+++ b/vllm/csrc/cpu/quant.cpp
@@ -23,6 +23,19 @@ struct KernelVecType {
using cvt_vec_type = vec_op::FP32Vec16;
};
+template <>
+struct KernelVecType {
+#ifdef __powerpc64__
+ // Power architecture-specific vector type
+ using load_vec_type = vec_op::FP32Vec16;
+#else
+ // Fallback for other architectures
+ using load_vec_type = vec_op::FP16Vec16;
+#endif
+ using azp_adj_load_vec_type = vec_op::INT32Vec16;
+ using cvt_vec_type = vec_op::FP32Vec16;
+};
+
#ifdef __AVX512F__
template
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
diff --git a/vllm/csrc/custom_all_reduce.cu b/vllm/csrc/custom_all_reduce.cu
index 9b82bec44..123278bfe 100644
--- a/vllm/csrc/custom_all_reduce.cu
+++ b/vllm/csrc/custom_all_reduce.cu
@@ -5,32 +5,29 @@
#include "custom_all_reduce.cuh"
-// fake pointer type, must match fptr_t type in ops.h
+// Fake pointer type, must match fptr_t type in ops.h.
+// We use this type alias to indicate when pointers are passed in as int64_t.
using fptr_t = int64_t;
static_assert(sizeof(void*) == sizeof(fptr_t));
-fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
- const std::vector& handles,
- const std::vector& offsets, int64_t rank,
+fptr_t init_custom_ar(const std::vector& fake_ipc_ptrs,
+ torch::Tensor& rank_data, int64_t rank,
bool full_nvlink) {
- int world_size = offsets.size();
+ int world_size = fake_ipc_ptrs.size();
if (world_size > 8)
throw std::invalid_argument("world size > 8 is not supported");
if (world_size % 2 != 0)
throw std::invalid_argument("Odd num gpus is not supported for now");
- if (world_size != handles.size())
- throw std::invalid_argument(
- "handles length should equal to offsets length");
if (rank < 0 || rank >= world_size)
throw std::invalid_argument("invalid rank passed in");
- cudaIpcMemHandle_t ipc_handles[8];
+ vllm::Signal* ipc_ptrs[8];
for (int i = 0; i < world_size; i++) {
- std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
+ ipc_ptrs[i] = reinterpret_cast(fake_ipc_ptrs[i]);
}
- return (fptr_t) new vllm::CustomAllreduce(
- reinterpret_cast(meta.data_ptr()), rank_data.data_ptr(),
- rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
+ return (fptr_t) new vllm::CustomAllreduce(ipc_ptrs, rank_data.data_ptr(),
+ rank_data.numel(), rank, world_size,
+ full_nvlink);
}
/**
@@ -55,26 +52,48 @@ bool _is_weak_contiguous(torch::Tensor& t) {
t.numel() * t.element_size());
}
-void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
- cudaStream_t stream) {
+/**
+ * Performs an out-of-place allreduce and stores result in out.
+ *
+ * If _reg_buffer is null, assumes inp.data_ptr() is already IPC-registered.
+ * Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
+ * copied into _reg_buffer.
+ */
+void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
+ fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
auto fa = reinterpret_cast(_fa);
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
+ auto stream = c10::cuda::getCurrentCUDAStream().stream();
+
+ TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
+ TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK(_is_weak_contiguous(out));
+ TORCH_CHECK(_is_weak_contiguous(inp));
+ auto input_size = inp.numel() * inp.element_size();
+ auto reg_buffer = reinterpret_cast(_reg_buffer);
+ if (reg_buffer) {
+ TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes);
+ AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size,
+ cudaMemcpyDeviceToDevice, stream));
+ } else {
+ reg_buffer = inp.data_ptr();
+ }
switch (out.scalar_type()) {
case at::ScalarType::Float: {
- fa->allreduce(stream, reinterpret_cast(inp.data_ptr()),
+ fa->allreduce(stream, reinterpret_cast(reg_buffer),
reinterpret_cast(out.data_ptr()),
out.numel());
break;
}
case at::ScalarType::Half: {
- fa->allreduce(stream, reinterpret_cast(inp.data_ptr()),
+ fa->allreduce(stream, reinterpret_cast(reg_buffer),
reinterpret_cast(out.data_ptr()), out.numel());
break;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16: {
fa->allreduce(
- stream, reinterpret_cast(inp.data_ptr()),
+ stream, reinterpret_cast(reg_buffer),
reinterpret_cast(out.data_ptr()), out.numel());
break;
}
@@ -85,57 +104,41 @@ void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
}
}
-void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
- const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
- auto stream = c10::cuda::getCurrentCUDAStream().stream();
- TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
- TORCH_CHECK_EQ(inp.numel(), out.numel());
- _all_reduce(_fa, inp, out, stream);
-}
-
-void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
- torch::Tensor& out) {
- const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
- auto stream = c10::cuda::getCurrentCUDAStream().stream();
-
- auto input_size = inp.numel() * inp.element_size();
- TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
- TORCH_CHECK_EQ(inp.numel(), out.numel());
- TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
- "registered buffer is too small to contain the input");
- AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(),
- input_size, cudaMemcpyDeviceToDevice, stream));
- _all_reduce(_fa, reg_buffer, out, stream);
-}
-
void dispose(fptr_t _fa) {
- auto fa = reinterpret_cast(_fa);
- delete fa;
+ delete reinterpret_cast(_fa);
}
int64_t meta_size() { return sizeof(vllm::Signal); }
-void register_buffer(fptr_t _fa, torch::Tensor& t,
- const std::vector& handles,
- const std::vector& offsets) {
+void register_buffer(fptr_t _fa, const std::vector& fake_ipc_ptrs) {
auto fa = reinterpret_cast(_fa);
- fa->register_buffer(handles, offsets, t.data_ptr());
+ TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_);
+ void* ipc_ptrs[8];
+ for (int i = 0; i < fake_ipc_ptrs.size(); i++) {
+ ipc_ptrs[i] = reinterpret_cast(fake_ipc_ptrs[i]);
+ }
+ fa->register_buffer(ipc_ptrs);
}
-std::tuple> get_graph_buffer_ipc_meta(
- fptr_t _fa) {
+// Use vector to represent byte data for python binding compatibility.
+std::tuple, std::vector>
+get_graph_buffer_ipc_meta(fptr_t _fa) {
auto fa = reinterpret_cast(_fa);
- auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta();
- auto options =
- torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
- auto handles =
- torch::empty({static_cast(handle_bytes.size())}, options);
- std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size());
- return {handles, std::move(offsets)};
+ auto [handle, offsets] = fa->get_graph_buffer_ipc_meta();
+ std::vector bytes(handle.begin(), handle.end());
+ return std::make_tuple(bytes, offsets);
}
-void register_graph_buffers(fptr_t _fa, const std::vector& handles,
+// Use vector to represent byte data for python binding compatibility.
+void register_graph_buffers(fptr_t _fa,
+ const std::vector>& handles,
const std::vector>& offsets) {
auto fa = reinterpret_cast(_fa);
- fa->register_graph_buffers(handles, offsets);
+ std::vector bytes;
+ bytes.reserve(handles.size());
+ for (int i = 0; i < handles.size(); i++) {
+ bytes.emplace_back(handles[i].begin(), handles[i].end());
+ }
+ bytes.reserve(handles.size());
+ fa->register_graph_buffers(bytes, offsets);
}
diff --git a/vllm/csrc/custom_all_reduce.cuh b/vllm/csrc/custom_all_reduce.cuh
index a2f7e4330..6be4d4f2b 100644
--- a/vllm/csrc/custom_all_reduce.cuh
+++ b/vllm/csrc/custom_all_reduce.cuh
@@ -285,46 +285,52 @@ class CustomAllreduce {
int world_size_;
bool full_nvlink_;
- // below are device pointers
RankSignals sg_;
+ // Stores an map from a pointer to its peer pointters from all ranks.
std::unordered_map buffers_;
Signal* self_sg_;
- // stores the registered device pointers from all ranks
+ // Stores rank data from all ranks. This is mainly for cuda graph purposes.
+ // For cuda graph to work, all kernel arguments must be fixed during graph
+ // capture time. However, the peer pointers are not known during graph capture
+ // time. Therefore, during capture, we increment the rank data pointer and use
+ // that as the argument to the kernel. The kernel arguments are stored in
+ // graph_unreg_buffers_. The actual peer pointers will be filled in at the
+ // memory pointed to by the pointers in graph_unreg_buffers_ when
+ // the IPC handles are exchanged between ranks.
+ //
+ // The overall process looks like this:
+ // 1. Graph capture.
+ // 2. Each rank obtains the IPC handles for each addresses used during cuda
+ // graph capture using get_graph_buffer_ipc_meta.
+ // 3. (In Python) all gather the IPC handles.
+ // 4. Obtain the peer pointers by opening the IPC handles, and store them in
+ // the rank data array at corresponding positions.
RankData *d_rank_data_base_, *d_rank_data_end_;
std::vector graph_unreg_buffers_;
// a map from IPC handles to opened IPC pointers
std::map ipc_handles_;
/**
- * meta is a pointer to device metadata and temporary buffer for allreduce.
+ * Signals are an array of ipc-enabled buffers from all ranks.
+ * For each of the buffer, the layout is as follows:
+ * | -- sizeof(Signal) -- | ------ a few MB ----- |
+ * The first section is for allreduce synchronization, and the second section
+ * is for storing the intermediate results required by some allreduce algos.
*
- * There's a total of sizeof(Signal) of prefix before the actual data,
- * so meta + 1 points to actual temporary buffer.
- *
- * note: this class does not own any device memory. Any required buffers
- * are passed in from the constructor
+ * Note: this class does not own any device memory. Any required buffers
+ * are passed in from the constructor.
*/
- CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz,
- const cudaIpcMemHandle_t* handles,
- const std::vector& offsets, int rank,
- bool full_nvlink = true)
+ CustomAllreduce(Signal** signals, void* rank_data, size_t rank_data_sz,
+ int rank, int world_size, bool full_nvlink = true)
: rank_(rank),
- world_size_(offsets.size()),
+ world_size_(world_size),
full_nvlink_(full_nvlink),
- self_sg_(meta),
+ self_sg_(signals[rank]),
d_rank_data_base_(reinterpret_cast(rank_data)),
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
for (int i = 0; i < world_size_; i++) {
- Signal* rank_sg;
- if (i != rank_) {
- char* handle = open_ipc_handle(&handles[i]);
- handle += offsets[i];
- rank_sg = (Signal*)handle;
- } else {
- rank_sg = self_sg_;
- }
- sg_.signals[i] = rank_sg;
+ sg_.signals[i] = signals[i];
}
}
@@ -341,11 +347,10 @@ class CustomAllreduce {
return it->second;
}
- std::pair, std::vector>
- get_graph_buffer_ipc_meta() {
+ std::pair> get_graph_buffer_ipc_meta() {
auto num_buffers = graph_unreg_buffers_.size();
auto handle_sz = sizeof(cudaIpcMemHandle_t);
- std::vector handles(handle_sz * num_buffers, 0);
+ std::string handles(handle_sz * num_buffers, static_cast(0));
std::vector offsets(num_buffers);
for (int i = 0; i < num_buffers; i++) {
auto ptr = graph_unreg_buffers_[i];
@@ -370,26 +375,22 @@ class CustomAllreduce {
std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
}
- void register_buffer(const std::vector& handles,
- const std::vector& offsets, void* self) {
+ /**
+ * Register already-shared IPC pointers.
+ */
+ void register_buffer(void** ptrs) {
check_rank_data_capacity();
RankData data;
for (int i = 0; i < world_size_; i++) {
- if (i != rank_) {
- char* handle = open_ipc_handle(handles[i].data());
- handle += offsets[i];
- data.ptrs[i] = handle;
- } else {
- data.ptrs[i] = self;
- }
+ data.ptrs[i] = ptrs[i];
}
auto d_data = d_rank_data_base_++;
CUDACHECK(
cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
- buffers_[self] = d_data;
+ buffers_[ptrs[rank_]] = d_data;
}
- // note: when registering graph buffers, we intentionally choose to not
+ // Note: when registering graph buffers, we intentionally choose to not
// deduplicate the addresses. That means if the allocator reuses some
// addresses, they will be registered again. This is to account for the remote
// possibility of different allocation patterns between ranks. For example,
@@ -424,11 +425,13 @@ class CustomAllreduce {
}
/**
- * This is the result after careful grid search. Using 36 blocks give the best
- * or close to the best runtime on the devices I tried: A100, A10, A30, T4,
- * V100. You'll notice that NCCL kernels also only take a small amount of SMs.
- * Not quite sure the underlying reason, but my guess is that too many SMs
- * will cause contention on NVLink bus.
+ * Performs allreduce, assuming input has already been registered.
+ *
+ * Block and grid default configs are results after careful grid search. Using
+ * 36 blocks give the best or close to the best runtime on the devices I
+ * tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also only
+ * take a small amount of SMs. Not quite sure the underlying reason, but my
+ * guess is that too many SMs will cause contention on NVLink bus.
*/
template
void allreduce(cudaStream_t stream, T* input, T* output, int size,
diff --git a/vllm/csrc/custom_all_reduce_test.cu b/vllm/csrc/custom_all_reduce_test.cu
index 376687e91..b59ea40d9 100644
--- a/vllm/csrc/custom_all_reduce_test.cu
+++ b/vllm/csrc/custom_all_reduce_test.cu
@@ -135,24 +135,26 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
void* rank_data;
size_t rank_data_sz = 16 * 1024 * 1024;
CUDACHECK(cudaMalloc(&rank_data, rank_data_sz));
- std::vector offsets(nRanks, 0);
- vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles,
- offsets, myRank);
+ vllm::Signal* ipc_ptrs[8];
+ for (int i = 0; i < nRanks; i++) {
+ if (i == myRank)
+ ipc_ptrs[i] = buffer;
+ else
+ CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptrs[i], data_handles[i],
+ cudaIpcMemLazyEnablePeerAccess));
+ }
+ vllm::CustomAllreduce fa(ipc_ptrs, rank_data, rank_data_sz, myRank, nRanks);
auto* self_data =
reinterpret_cast(reinterpret_cast(buffer) +
sizeof(vllm::Signal) + data_size * sizeof(T));
// hack buffer registration
{
- std::vector handles;
- handles.reserve(nRanks);
+ void* data[8];
for (int i = 0; i < nRanks; i++) {
- char* begin = (char*)&data_handles[i];
- char* end = (char*)&data_handles[i + 1];
- handles.emplace_back(begin, end);
+ data[i] =
+ ((char*)ipc_ptrs[i]) + sizeof(vllm::Signal) + data_size * sizeof(T);
}
- std::vector offsets(nRanks,
- sizeof(vllm::Signal) + data_size * sizeof(T));
- fa.register_buffer(handles, offsets, self_data);
+ fa.register_buffer(data);
}
double* ground_truth;
diff --git a/vllm/csrc/cutlass_extensions/cute_utils.cuh b/vllm/csrc/cutlass_extensions/cute_utils.cuh
index 1842fab8b..f61fe3ceb 100644
--- a/vllm/csrc/cutlass_extensions/cute_utils.cuh
+++ b/vllm/csrc/cutlass_extensions/cute_utils.cuh
@@ -20,9 +20,9 @@ CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) {
// is the layout f(x) = x
template
CUTE_HOST_DEVICE static constexpr bool is_identity_layout() {
- if constexpr (std::is_same_v)
+ if constexpr (std::is_same_v) {
return true;
- else {
+ } else {
constexpr auto coalesced_layout = coalesce(Layout{});
if constexpr (rank(coalesced_layout) == 1 &&
stride<0>(coalesced_layout) == 1) {
diff --git a/vllm/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp b/vllm/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp
similarity index 99%
rename from vllm/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp
rename to vllm/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp
index d407d66ab..7aa87feb4 100644
--- a/vllm/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp
+++ b/vllm/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp
@@ -52,6 +52,7 @@
// clang-format off
#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
+#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cute/tensor.hpp"
namespace cutlass::epilogue::threadblock {
diff --git a/vllm/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp b/vllm/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp
similarity index 100%
rename from vllm/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
rename to vllm/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp
diff --git a/vllm/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp b/vllm/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
new file mode 100644
index 000000000..c69e87999
--- /dev/null
+++ b/vllm/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
@@ -0,0 +1,317 @@
+#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
+
+/*
+ This file defines custom epilogues for fusing channel scales, token scales,
+ bias, and activation zero-points onto a GEMM operation using the
+ CUTLASS 2.x API, for sm80 (Ampere) NVIDIA GPUs.
+
+ Epilogues must contain a public type named EVTCompute of type Sm80EVT,
+ as well as a static prepare_args function that constructs an
+ EVTCompute::Arguments struct.
+*/
+
+namespace vllm::c2x {
+
+using namespace cute;
+
+/*
+ * This class provides the common load descriptors for the
+ * ScaledEpilogue[...] classes
+ */
+template
+struct ScaledEpilogueBase {
+ protected:
+ using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
+
+ template
+ using ColOrScalarLoad =
+ cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
+ OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>;
+
+ template
+ using RowOrScalarLoad =
+ cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
+ OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>;
+
+ template
+ using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast<
+ OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>;
+
+ template
+ using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast<
+ OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>;
+
+ template
+ using RowOrZeroLoad =
+ cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast<
+ OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>;
+
+ // This utility function constructs the arguments for the load descriptors
+ // from a tensor. It can handle both row and column, as well as row/column or
+ // scalar cases.
+ template
+ static auto args_from_tensor(torch::Tensor const& tensor) {
+ using Arguments = typename Descriptor::Arguments;
+ auto* data_ptr = static_cast(tensor.data_ptr());
+ if constexpr (std::is_same_v> ||
+ std::is_same_v>) {
+ return Arguments{data_ptr, tensor.numel() != 1};
+ } else {
+ // it would technically work but no use case as data_ptr is never nullptr
+ static_assert(!std::is_same_v>);
+ return Arguments{data_ptr};
+ }
+ }
+
+ // This overload handles the case where there might not be a tensor, in which
+ // case a nullptr is passed and a constant (0) is used.
+ template
+ static auto args_from_tensor(c10::optional const& tensor) {
+ static_assert(std::is_same_v>);
+ using Arguments = typename Descriptor::Arguments;
+ auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr;
+ return Arguments{data_ptr};
+ }
+};
+
+/*
+ This epilogue function defines a quantized GEMM operation similar to
+ torch._scaled_mm.
+
+ A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
+ per-row. B can be quantized per-tensor or per-column.
+ Any combination of per-tensor and per-row or column is supported.
+ A and B must have symmetric quantization (zero point == 0).
+
+ So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
+ scales are applied elementwise with numpy-style broadcasting.
+
+ ScaleA and ScaleB define the epilogue functions that apply the scales for
+ the A and B operands respectively. These scales may be either per-tensor or
+ per row or column.
+*/
+template
+struct ScaledEpilogue
+ : private ScaledEpilogueBase {
+ private:
+ using SUPER = ScaledEpilogueBase;
+ using Accum = typename SUPER::Accum;
+ using ScaleA = typename SUPER::template ColOrScalarLoad;
+ using ScaleB = typename SUPER::template RowOrScalarLoad;
+
+ using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
+ cutlass::multiplies, float, float,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ using EVTCompute0 =
+ cutlass::epilogue::threadblock::Sm80EVT;
+
+ using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
+ cutlass::multiplies, ElementD, float,
+ cutlass::FloatRoundStyle::round_to_nearest>;
+
+ public:
+ using EVTCompute =
+ cutlass::epilogue::threadblock::Sm80EVT;
+ using ArgumentType = typename EVTCompute::Arguments;
+
+ static ArgumentType prepare_args(torch::Tensor const& a_scales,
+ torch::Tensor const& b_scales) {
+ auto a_args = SUPER::template args_from_tensor(a_scales);
+ auto b_args = SUPER::template args_from_tensor(b_scales);
+
+ typename EVTCompute0::Arguments evt0_args{b_args};
+ return ArgumentType{a_args, evt0_args};
+ }
+};
+
+/*
+ * This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
+ * This bias can also be used in the per-tensor azp case, where the activation
+ * zero point (azp) is used to compute an azp correction term,
+ * which is folded into the bias.
+ *
+ * The bias tensor must be per-output channel.
+ * ScaleA and ScaleB can be per-tensor or per-token/per-channel.
+ */
+template
+struct ScaledEpilogueBias
+ : protected ScaledEpilogueBase {
+ protected:
+ using SUPER = ScaledEpilogueBase;
+ using Accum = typename SUPER::Accum;
+ using ScaleA = typename SUPER::template ColOrScalarLoad;
+ using ScaleB = typename SUPER::template RowOrScalarLoad;
+ using Bias = typename SUPER::template RowLoad