From 6fc14d427f7d823d2497fc0d17a5c3c4c7ed7879 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Sat, 6 Jul 2024 00:34:03 -0700 Subject: [PATCH 001/303] add idealized disagg prefill benchmark --- .../disagg_benchmarks/disagg_benchmark.sh | 129 ++++++++++++++++++ .../disagg_benchmarks/round_robin_proxy.sh | 20 +++ 2 files changed, 149 insertions(+) create mode 100644 benchmarks/disagg_benchmarks/disagg_benchmark.sh create mode 100644 benchmarks/disagg_benchmarks/round_robin_proxy.sh diff --git a/benchmarks/disagg_benchmarks/disagg_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_benchmark.sh new file mode 100644 index 0000000000000..72c732704a31e --- /dev/null +++ b/benchmarks/disagg_benchmarks/disagg_benchmark.sh @@ -0,0 +1,129 @@ +#!/bin/bash + +# Model: neuralmagic/Meta-Llama-3-70B-Instruct-FP8-KV +# Query: 2048 input tokens, 128 output tokens, QPS 8, 1000 requests +# Resource: 8x H100 +# Approaches: +# 1. Chunked prefill: 1 vllm instance with tp=8 +# 2. Chunked prefill: 2 vllm instance with tp=4 +# 3. Disaggregated prefill: 1 prefilling instance and 1 decoding instance +# Prefilling instance: max_output_token=1 +# Decoding instance: force the input tokens be the same across requests to bypass prefilling + +set -ex + +kill_gpu_processes() { + # kill all processes on GPU. + pkill pt_main_thread + sleep 10 + + # remove vllm config file + rm -rf ~/.config/vllm + + # Print the GPU memory usage + # so that we know if all GPU processes are killed. + gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0) + # The memory usage should be 0 MB. + echo "GPU 0 Memory Usage: $gpu_memory_usage MB" +} + +wait_for_server() { + # wait for vllm server to start + # return 1 if vllm server crashes + local port=$1 + timeout 1200 bash -c " + until curl localhost:${port}/v1/completions; do + sleep 1 + done" && return 0 || return 1 +} + + +main() { + + (which wget && which curl) || (apt-get update && apt-get install -y wget curl) + (which jq) || (apt-get -y install jq) + (which socat) || (apt-get -y install socat) + + cd "$(dirname "$0")" + + mkdir results + results_folder="./results" + model="neuralmagic/Meta-Llama-3-70B-Instruct-FP8" + dataset_name="sonnet" + dataset_path="../sonnet.txt" + num_prompts=500 + qps=8 + prefix_len=64 + input_len=2048 + output_len=128 + + + # chunked prefill with tp=8 + python3 -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8000 \ + --tp 8 \ + --enable-chunked-prefill & + wait_for_server 8000 + + python3 benchmark_serving.py \ + --backend vllm \ + --model $model \ + --dataset-name $dataset_name \ + --dataset-path $dataset_path \ + --sonnet-input-len $input_len \ + --sonnet-output-len $output_len \ + --sonnet-prefix-len $prefix_len \ + --num-prompts $num_prompts \ + --port 8000 \ + --save-result \ + --result-dir $results_folder \ + --result-filename chunked_prefill_tp8.json \ + --request-rate $qps + + + # chunked prefill with tp=4 + CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8100 \ + --tp 4 \ + --enable-chunked-prefill & + + CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8200 \ + --tp 4 \ + --enable-chunked-prefill & + + wait_for_server 8100 + wait_for_server 8200 + # launch round robin proxy + bash ./round_robin_proxy.sh & + + python3 benchmark_serving.py \ + --backend vllm \ + --model $model \ + --dataset-name $dataset_name \ + --dataset-path $dataset_path \ + --sonnet-input-len $input_len \ + --sonnet-output-len $output_len \ + --sonnet-prefix-len $prefix_len \ + --num-prompts $num_prompts \ + --port 8000 \ + --save-result \ + --result-dir $results_folder \ + --result-filename chunked_prefill_tp8.json \ + --request-rate $qps + + kill_gpu_processes + pkill -f round_robin_proxy.sh + + + # disaggregated prefill + + + + +} \ No newline at end of file diff --git a/benchmarks/disagg_benchmarks/round_robin_proxy.sh b/benchmarks/disagg_benchmarks/round_robin_proxy.sh new file mode 100644 index 0000000000000..e996756bc89d6 --- /dev/null +++ b/benchmarks/disagg_benchmarks/round_robin_proxy.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# Define the ports to forward to +PORTS=(8100 8200) +NUM_PORTS=${#PORTS[@]} +CURRENT=0 + +# Function to handle the round-robin logic +get_next_port() { + NEXT_PORT=${PORTS[$CURRENT]} + CURRENT=$(( (CURRENT + 1) % NUM_PORTS )) + echo $NEXT_PORT +} + +# Start the proxy +while true; do + NEXT_PORT=$(get_next_port) + echo "Forwarding to port $NEXT_PORT" + socat TCP4-LISTEN:8000,reuseaddr,fork TCP4:localhost:$NEXT_PORT +done \ No newline at end of file From 69d151487d238a25989451b60dfbc533cef51edd Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Sat, 6 Jul 2024 16:30:46 -0700 Subject: [PATCH 002/303] add main --- benchmarks/disagg_benchmarks/disagg_benchmark.sh | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/benchmarks/disagg_benchmarks/disagg_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_benchmark.sh index 72c732704a31e..9ae0071a9130e 100644 --- a/benchmarks/disagg_benchmarks/disagg_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_benchmark.sh @@ -1,5 +1,8 @@ #!/bin/bash +# Requirement: 8x H100 GPUs. + + # Model: neuralmagic/Meta-Llama-3-70B-Instruct-FP8-KV # Query: 2048 input tokens, 128 output tokens, QPS 8, 1000 requests # Resource: 8x H100 @@ -123,7 +126,7 @@ main() { # disaggregated prefill +} - -} \ No newline at end of file +main "$@" \ No newline at end of file From 2bc8e7931db2358dfa9aa05c991465fb8878f8b4 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Sat, 6 Jul 2024 16:31:50 -0700 Subject: [PATCH 003/303] fix typo --- benchmarks/disagg_benchmarks/disagg_benchmark.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmarks/disagg_benchmarks/disagg_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_benchmark.sh index 9ae0071a9130e..faace2082ee72 100644 --- a/benchmarks/disagg_benchmarks/disagg_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_benchmark.sh @@ -65,7 +65,7 @@ main() { python3 -m vllm.entrypoints.openai.api_server \ --model $model \ --port 8000 \ - --tp 8 \ + -tp 8 \ --enable-chunked-prefill & wait_for_server 8000 @@ -90,14 +90,14 @@ main() { -m vllm.entrypoints.openai.api_server \ --model $model \ --port 8100 \ - --tp 4 \ + -tp 4 \ --enable-chunked-prefill & CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ -m vllm.entrypoints.openai.api_server \ --model $model \ --port 8200 \ - --tp 4 \ + -tp 4 \ --enable-chunked-prefill & wait_for_server 8100 From 3ea715dbf6872766c424587378b29d1e9a503c84 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Sat, 6 Jul 2024 16:32:20 -0700 Subject: [PATCH 004/303] use mkdir -p to avoid error --- benchmarks/disagg_benchmarks/disagg_benchmark.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/disagg_benchmarks/disagg_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_benchmark.sh index faace2082ee72..8c0ed61d284a6 100644 --- a/benchmarks/disagg_benchmarks/disagg_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_benchmark.sh @@ -49,7 +49,7 @@ main() { cd "$(dirname "$0")" - mkdir results + mkdir -p results results_folder="./results" model="neuralmagic/Meta-Llama-3-70B-Instruct-FP8" dataset_name="sonnet" From 3656f8aa7aa0cc2e1ede9b3743263051b4382421 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Sat, 6 Jul 2024 16:38:31 -0700 Subject: [PATCH 005/303] fix bug --- benchmarks/disagg_benchmarks/disagg_benchmark.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/disagg_benchmarks/disagg_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_benchmark.sh index 8c0ed61d284a6..ae5034cc3b614 100644 --- a/benchmarks/disagg_benchmarks/disagg_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_benchmark.sh @@ -69,7 +69,7 @@ main() { --enable-chunked-prefill & wait_for_server 8000 - python3 benchmark_serving.py \ + python3 ../benchmark_serving.py \ --backend vllm \ --model $model \ --dataset-name $dataset_name \ @@ -105,7 +105,7 @@ main() { # launch round robin proxy bash ./round_robin_proxy.sh & - python3 benchmark_serving.py \ + python3 ../benchmark_serving.py \ --backend vllm \ --model $model \ --dataset-name $dataset_name \ From f8cb6fcb91cfffb78d6eaf3668cd549abceebf8c Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Sat, 6 Jul 2024 16:45:27 -0700 Subject: [PATCH 006/303] disable log request from vllm server, and mute curl --- benchmarks/disagg_benchmarks/disagg_benchmark.sh | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/benchmarks/disagg_benchmarks/disagg_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_benchmark.sh index ae5034cc3b614..d1d5cc721c3d5 100644 --- a/benchmarks/disagg_benchmarks/disagg_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_benchmark.sh @@ -35,7 +35,7 @@ wait_for_server() { # return 1 if vllm server crashes local port=$1 timeout 1200 bash -c " - until curl localhost:${port}/v1/completions; do + until curl -s localhost:${port}/v1/completions > /dev/null; do sleep 1 done" && return 0 || return 1 } @@ -66,6 +66,8 @@ main() { --model $model \ --port 8000 \ -tp 8 \ + --disable-log-stats \ + --disable-log-requests \ --enable-chunked-prefill & wait_for_server 8000 @@ -91,6 +93,8 @@ main() { --model $model \ --port 8100 \ -tp 4 \ + --disable-log-stats \ + --disable-log-requests \ --enable-chunked-prefill & CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ @@ -98,6 +102,8 @@ main() { --model $model \ --port 8200 \ -tp 4 \ + --disable-log-stats \ + --disable-log-requests \ --enable-chunked-prefill & wait_for_server 8100 From d4b23c079e6a745a4c46a27ab4d9e098d49cca84 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Sat, 6 Jul 2024 16:57:47 -0700 Subject: [PATCH 007/303] add disaggregated prefilling benchmark --- .../disagg_benchmarks/disagg_benchmark.sh | 56 ++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/benchmarks/disagg_benchmarks/disagg_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_benchmark.sh index d1d5cc721c3d5..8f229ef26f632 100644 --- a/benchmarks/disagg_benchmarks/disagg_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_benchmark.sh @@ -85,6 +85,7 @@ main() { --result-dir $results_folder \ --result-filename chunked_prefill_tp8.json \ --request-rate $qps + kill_gpu_processes # chunked prefill with tp=4 @@ -125,13 +126,66 @@ main() { --result-dir $results_folder \ --result-filename chunked_prefill_tp8.json \ --request-rate $qps - kill_gpu_processes pkill -f round_robin_proxy.sh # disaggregated prefill + # prefill with tp=4 + python3 -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8000 \ + -tp 4 \ + --disable-log-stats \ + --disable-log-requests & + wait_for_server 8000 + + # set output-len to 1 so that it only do prefilling + python3 ../benchmark_serving.py \ + --backend vllm \ + --model $model \ + --dataset-name $dataset_name \ + --dataset-path $dataset_path \ + --sonnet-input-len $input_len \ + --sonnet-output-len 1 \ + --sonnet-prefix-len $prefix_len \ + --num-prompts $num_prompts \ + --port 8000 \ + --save-result \ + --result-dir $results_folder \ + --result-filename disagg_prefill_tp4.json \ + --request-rate $qps + kill_gpu_processes + + # decode with tp=4, enable APC + python3 -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8000 \ + -tp 4 \ + --enable-prefix-caching \ + --disable-log-stats \ + --disable-log-requests & + wait_for_server 8000 + + # skip prefilling + # by enabling APC and force the input tokens be the same + python3 ../benchmark_serving.py \ + --backend vllm \ + --model $model \ + --dataset-name $dataset_name \ + --dataset-path $dataset_path \ + --sonnet-input-len $input_len \ + --sonnet-output-len $output_len \ + --sonnet-prefix-len $input_len \ + --num-prompts $num_prompts \ + --port 8000 \ + --save-result \ + --result-dir $results_folder \ + --result-filename disagg_decode_tp4.json \ + --request-rate $qps + kill_gpu_processes + } From a9426631083ee96df1dee55f58dfd0e5261112f6 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Sat, 6 Jul 2024 18:34:03 -0700 Subject: [PATCH 008/303] do not launch 2 vllm instance --- .../disagg_benchmarks/disagg_benchmark.sh | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/benchmarks/disagg_benchmarks/disagg_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_benchmark.sh index 8f229ef26f632..f2150fab3b634 100644 --- a/benchmarks/disagg_benchmarks/disagg_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_benchmark.sh @@ -98,19 +98,19 @@ main() { --disable-log-requests \ --enable-chunked-prefill & - CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ - -m vllm.entrypoints.openai.api_server \ - --model $model \ - --port 8200 \ - -tp 4 \ - --disable-log-stats \ - --disable-log-requests \ - --enable-chunked-prefill & + # CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ + # -m vllm.entrypoints.openai.api_server \ + # --model $model \ + # --port 8200 \ + # -tp 4 \ + # --disable-log-stats \ + # --disable-log-requests \ + # --enable-chunked-prefill & wait_for_server 8100 - wait_for_server 8200 - # launch round robin proxy - bash ./round_robin_proxy.sh & + # wait_for_server 8200 + # # launch round robin proxy + # bash ./round_robin_proxy.sh & python3 ../benchmark_serving.py \ --backend vllm \ @@ -121,13 +121,13 @@ main() { --sonnet-output-len $output_len \ --sonnet-prefix-len $prefix_len \ --num-prompts $num_prompts \ - --port 8000 \ + --port 8100 \ --save-result \ --result-dir $results_folder \ --result-filename chunked_prefill_tp8.json \ - --request-rate $qps + --request-rate $((qps / 2)) kill_gpu_processes - pkill -f round_robin_proxy.sh + # pkill -f round_robin_proxy.sh # disaggregated prefill From 540d36260c6b8322fa444a2b42a7ebb49e456ecf Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Sat, 6 Jul 2024 18:38:55 -0700 Subject: [PATCH 009/303] reduce # of prompt to half --- benchmarks/disagg_benchmarks/disagg_benchmark.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/disagg_benchmarks/disagg_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_benchmark.sh index f2150fab3b634..cbc372c5f3d3e 100644 --- a/benchmarks/disagg_benchmarks/disagg_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_benchmark.sh @@ -120,7 +120,7 @@ main() { --sonnet-input-len $input_len \ --sonnet-output-len $output_len \ --sonnet-prefix-len $prefix_len \ - --num-prompts $num_prompts \ + --num-prompts $((num_prompts / 2)) \ --port 8100 \ --save-result \ --result-dir $results_folder \ From 4b0a7ff77f8059b98f7387c75c66e0ca151ac416 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Sat, 6 Jul 2024 18:48:58 -0700 Subject: [PATCH 010/303] reduce input len by 1 --- benchmarks/disagg_benchmarks/disagg_benchmark.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/disagg_benchmarks/disagg_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_benchmark.sh index cbc372c5f3d3e..9a53d96470b86 100644 --- a/benchmarks/disagg_benchmarks/disagg_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_benchmark.sh @@ -177,7 +177,7 @@ main() { --dataset-path $dataset_path \ --sonnet-input-len $input_len \ --sonnet-output-len $output_len \ - --sonnet-prefix-len $input_len \ + --sonnet-prefix-len $((input_len - 1)) \ --num-prompts $num_prompts \ --port 8000 \ --save-result \ From 298965614e2ffd0264a2e80afed281894f4fa297 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Sat, 6 Jul 2024 19:39:09 -0700 Subject: [PATCH 011/303] adjust filename --- benchmarks/disagg_benchmarks/disagg_benchmark.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/disagg_benchmarks/disagg_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_benchmark.sh index 9a53d96470b86..fe172b075e0cb 100644 --- a/benchmarks/disagg_benchmarks/disagg_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_benchmark.sh @@ -57,7 +57,7 @@ main() { num_prompts=500 qps=8 prefix_len=64 - input_len=2048 + input_len=8192 output_len=128 @@ -124,7 +124,7 @@ main() { --port 8100 \ --save-result \ --result-dir $results_folder \ - --result-filename chunked_prefill_tp8.json \ + --result-filename chunked_prefill_tp4.json \ --request-rate $((qps / 2)) kill_gpu_processes # pkill -f round_robin_proxy.sh From 69f729c0384818a2e1cac4b858c01414a7fc7978 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Sat, 6 Jul 2024 19:47:06 -0700 Subject: [PATCH 012/303] create 4x sonnet --- benchmarks/disagg_benchmarks/disagg_benchmark.sh | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/benchmarks/disagg_benchmarks/disagg_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_benchmark.sh index fe172b075e0cb..725d050aac537 100644 --- a/benchmarks/disagg_benchmarks/disagg_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_benchmark.sh @@ -8,7 +8,7 @@ # Resource: 8x H100 # Approaches: # 1. Chunked prefill: 1 vllm instance with tp=8 -# 2. Chunked prefill: 2 vllm instance with tp=4 +# 2. Chunked prefill: 2 vllm instance with tp=4, equivalent to 1 tp=4 instance with QPS 4 # 3. Disaggregated prefill: 1 prefilling instance and 1 decoding instance # Prefilling instance: max_output_token=1 # Decoding instance: force the input tokens be the same across requests to bypass prefilling @@ -49,11 +49,21 @@ main() { cd "$(dirname "$0")" + cd .. + # create sonnet-4x.txt + echo "" > sonnet_4x.txt + for _ in {1..4} + do + cat sonnet.txt >> sonnet_4x.txt + done + cd disagg_benchmarks + + mkdir -p results results_folder="./results" model="neuralmagic/Meta-Llama-3-70B-Instruct-FP8" dataset_name="sonnet" - dataset_path="../sonnet.txt" + dataset_path="../sonnet_4x.txt" num_prompts=500 qps=8 prefix_len=64 From 43e1e5e1f1876524a0e65112b04807114638e7ab Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Sat, 6 Jul 2024 21:34:33 -0700 Subject: [PATCH 013/303] adjust setup --- benchmarks/disagg_benchmarks/disagg_benchmark.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmarks/disagg_benchmarks/disagg_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_benchmark.sh index 725d050aac537..11c9e25ca120d 100644 --- a/benchmarks/disagg_benchmarks/disagg_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_benchmark.sh @@ -65,10 +65,10 @@ main() { dataset_name="sonnet" dataset_path="../sonnet_4x.txt" num_prompts=500 - qps=8 + qps=4 prefix_len=64 - input_len=8192 - output_len=128 + input_len=2048 + output_len=11 # chunked prefill with tp=8 From 29a7b88c4e708aebb9df903eeb2e0eb0bc260ccd Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Sat, 6 Jul 2024 22:46:37 -0700 Subject: [PATCH 014/303] add benchmark --- benchmarks/disagg_benchmarks/disagg_benchmark.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/disagg_benchmarks/disagg_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_benchmark.sh index 11c9e25ca120d..6c68011acd954 100644 --- a/benchmarks/disagg_benchmarks/disagg_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_benchmark.sh @@ -4,7 +4,7 @@ # Model: neuralmagic/Meta-Llama-3-70B-Instruct-FP8-KV -# Query: 2048 input tokens, 128 output tokens, QPS 8, 1000 requests +# Query: 2048 input tokens, 11 output tokens, QPS 4, 500 requests # Resource: 8x H100 # Approaches: # 1. Chunked prefill: 1 vllm instance with tp=8 From 4d31316f4e804181cede778823643ecaced22091 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 8 Jul 2024 16:12:02 -0700 Subject: [PATCH 015/303] allow prefix input len == sonnet input len --- benchmarks/benchmark_serving.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 42867fc40edd2..49dfe780812fd 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -123,9 +123,9 @@ def sample_sonnet_requests( prefix_len: int, tokenizer: PreTrainedTokenizerBase, ) -> List[Tuple[str, str, int, int]]: - assert ( - input_len > prefix_len - ), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'." + assert input_len >= prefix_len, ( + "'args.sonnet-input-len' must be greater than or equal to " + "'args.prefix-input-len'.") # Load the dataset. with open(dataset_path) as f: From 4e336fcdf68cba994a181e7b3c16188c0dfb2efe Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 8 Jul 2024 16:42:19 -0700 Subject: [PATCH 016/303] add parameter sweeping --- .../analyze_benchmark_result.py | 47 +++++++ .../disagg_benchmarks/disagg_benchmark.sh | 118 ++++++++---------- 2 files changed, 97 insertions(+), 68 deletions(-) create mode 100644 benchmarks/disagg_benchmarks/analyze_benchmark_result.py diff --git a/benchmarks/disagg_benchmarks/analyze_benchmark_result.py b/benchmarks/disagg_benchmarks/analyze_benchmark_result.py new file mode 100644 index 0000000000000..0f7ea8f2b6541 --- /dev/null +++ b/benchmarks/disagg_benchmarks/analyze_benchmark_result.py @@ -0,0 +1,47 @@ + +import argparse +import json +import yaml +import os +from pathlib import Path + +def load(path): + + with open(str(path), 'r') as f: + return json.loads(f.read()) + +def main(args): + + results = Path(args.results_folder) + + chunk = load(results / "chunked_prefill_tp4.json") + prefill = load(results / "disagg_prefill_tp4.json") + decode = load(results / "disagg_decode_tp4.json") + + ttft_ratio = chunk["mean_ttft_ms"] / prefill["mean_ttft_ms"] + itl_ratio = chunk["mean_itl_ms"] / decode["mean_itl_ms"] + prefill_decode_ratio = prefill["mean_ttft_ms"] / (decode["mean_itl_ms"] * args.output_len) + + with open(results / args.output_file, 'a') as f: + f.write(yaml.dump([{ + 'qps': args.qps, + 'output_len': args.output_len, + 'prefill_decode_ratio': prefill_decode_ratio, + 'ttft_ratio': ttft_ratio, + 'itl_ratio': itl_ratio, + "chunk_ttft": chunk["mean_ttft_ms"], + "chunk_itl": chunk["mean_itl_ms"], + "disagg_ttft": prefill["mean_ttft_ms"], + "disagg_itl": decode["mean_itl_ms"] + }])) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Analyze benchmark results") + parser.add_argument("--results-folder", required=True, help="Path to the results folder") + parser.add_argument("--output-len", type=int, required=True, help="Target output length") + parser.add_argument("--qps", type=int, required=True, help="Target QPS") + parser.add_argument("--output-file", type=str, default="chunk_vs_disagg.yaml") + + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/benchmarks/disagg_benchmarks/disagg_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_benchmark.sh index 6c68011acd954..7736ab2439b86 100644 --- a/benchmarks/disagg_benchmarks/disagg_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_benchmark.sh @@ -41,86 +41,31 @@ wait_for_server() { } -main() { - - (which wget && which curl) || (apt-get update && apt-get install -y wget curl) - (which jq) || (apt-get -y install jq) - (which socat) || (apt-get -y install socat) +benchmark() { - cd "$(dirname "$0")" + # compare chunked prefill with disaggregated prefill - cd .. - # create sonnet-4x.txt - echo "" > sonnet_4x.txt - for _ in {1..4} - do - cat sonnet.txt >> sonnet_4x.txt - done - cd disagg_benchmarks - - - mkdir -p results results_folder="./results" model="neuralmagic/Meta-Llama-3-70B-Instruct-FP8" dataset_name="sonnet" dataset_path="../sonnet_4x.txt" num_prompts=500 - qps=4 + qps=$1 prefix_len=64 input_len=2048 - output_len=11 - - - # chunked prefill with tp=8 - python3 -m vllm.entrypoints.openai.api_server \ - --model $model \ - --port 8000 \ - -tp 8 \ - --disable-log-stats \ - --disable-log-requests \ - --enable-chunked-prefill & - wait_for_server 8000 - - python3 ../benchmark_serving.py \ - --backend vllm \ - --model $model \ - --dataset-name $dataset_name \ - --dataset-path $dataset_path \ - --sonnet-input-len $input_len \ - --sonnet-output-len $output_len \ - --sonnet-prefix-len $prefix_len \ - --num-prompts $num_prompts \ - --port 8000 \ - --save-result \ - --result-dir $results_folder \ - --result-filename chunked_prefill_tp8.json \ - --request-rate $qps - kill_gpu_processes + output_len=$2 # chunked prefill with tp=4 CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ -m vllm.entrypoints.openai.api_server \ --model $model \ - --port 8100 \ + --port 8000 \ -tp 4 \ --disable-log-stats \ --disable-log-requests \ --enable-chunked-prefill & - - # CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ - # -m vllm.entrypoints.openai.api_server \ - # --model $model \ - # --port 8200 \ - # -tp 4 \ - # --disable-log-stats \ - # --disable-log-requests \ - # --enable-chunked-prefill & - - wait_for_server 8100 - # wait_for_server 8200 - # # launch round robin proxy - # bash ./round_robin_proxy.sh & + wait_for_server 8000 python3 ../benchmark_serving.py \ --backend vllm \ @@ -131,17 +76,15 @@ main() { --sonnet-output-len $output_len \ --sonnet-prefix-len $prefix_len \ --num-prompts $((num_prompts / 2)) \ - --port 8100 \ + --port 8000 \ --save-result \ --result-dir $results_folder \ --result-filename chunked_prefill_tp4.json \ --request-rate $((qps / 2)) kill_gpu_processes - # pkill -f round_robin_proxy.sh # disaggregated prefill - # prefill with tp=4 python3 -m vllm.entrypoints.openai.api_server \ --model $model \ @@ -150,7 +93,6 @@ main() { --disable-log-stats \ --disable-log-requests & wait_for_server 8000 - # set output-len to 1 so that it only do prefilling python3 ../benchmark_serving.py \ --backend vllm \ @@ -177,7 +119,6 @@ main() { --disable-log-stats \ --disable-log-requests & wait_for_server 8000 - # skip prefilling # by enabling APC and force the input tokens be the same python3 ../benchmark_serving.py \ @@ -187,7 +128,7 @@ main() { --dataset-path $dataset_path \ --sonnet-input-len $input_len \ --sonnet-output-len $output_len \ - --sonnet-prefix-len $((input_len - 1)) \ + --sonnet-prefix-len $input_len \ --num-prompts $num_prompts \ --port 8000 \ --save-result \ @@ -196,7 +137,48 @@ main() { --request-rate $qps kill_gpu_processes + python3 analyze_results.py \ + --results-folder $results_folder \ + --output-len $output_len \ + --qps $qps + +} + + +main() { + + (which wget && which curl) || (apt-get update && apt-get install -y wget curl) + (which jq) || (apt-get -y install jq) + (which socat) || (apt-get -y install socat) + + cd "$(dirname "$0")" + + cd .. + # create sonnet-4x.txt + echo "" > sonnet_4x.txt + for _ in {1..4} + do + cat sonnet.txt >> sonnet_4x.txt + done + cd disagg_benchmarks + + rm -rf results + mkdir results + + default_qps=4 + default_output_len=12 + + for target_qps in 1 2 4 8 16 + do + benchmark $target_qps $default_output_len + done + + for output_len in 5 10 20 40 80 + do + benchmark $default_qps $output_len + done + } -main "$@" \ No newline at end of file +main "$@" From 2770c61dd5f96622be6c07f42032b58265837c10 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 8 Jul 2024 16:47:10 -0700 Subject: [PATCH 017/303] aadjust firmat --- benchmarks/disagg_benchmarks/analyze_benchmark_result.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/benchmarks/disagg_benchmarks/analyze_benchmark_result.py b/benchmarks/disagg_benchmarks/analyze_benchmark_result.py index 0f7ea8f2b6541..4b675c675d25f 100644 --- a/benchmarks/disagg_benchmarks/analyze_benchmark_result.py +++ b/benchmarks/disagg_benchmarks/analyze_benchmark_result.py @@ -44,4 +44,5 @@ def main(args): parser.add_argument("--output-file", type=str, default="chunk_vs_disagg.yaml") args = parser.parse_args() - main(args) \ No newline at end of file + main(args) + \ No newline at end of file From 80061d2f3276aa4490ea416c3d5d2a98d18f4457 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 8 Jul 2024 16:47:48 -0700 Subject: [PATCH 018/303] rename script --- benchmarks/disagg_benchmarks/disagg_benchmark.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/disagg_benchmarks/disagg_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_benchmark.sh index 7736ab2439b86..bf7ef9dfa3ef0 100644 --- a/benchmarks/disagg_benchmarks/disagg_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_benchmark.sh @@ -137,7 +137,7 @@ benchmark() { --request-rate $qps kill_gpu_processes - python3 analyze_results.py \ + python3 analyze_benchmark_results.py \ --results-folder $results_folder \ --output-len $output_len \ --qps $qps From 8c0a9dc0183b9110695e093c3d78069d05daaa72 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 8 Jul 2024 16:48:02 -0700 Subject: [PATCH 019/303] align naming --- .../{analyze_benchmark_result.py => analyze_benchmark_results.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename benchmarks/disagg_benchmarks/{analyze_benchmark_result.py => analyze_benchmark_results.py} (100%) diff --git a/benchmarks/disagg_benchmarks/analyze_benchmark_result.py b/benchmarks/disagg_benchmarks/analyze_benchmark_results.py similarity index 100% rename from benchmarks/disagg_benchmarks/analyze_benchmark_result.py rename to benchmarks/disagg_benchmarks/analyze_benchmark_results.py From 7d84965280d185f2eb8d4381e1936f146612121c Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 8 Jul 2024 17:01:52 -0700 Subject: [PATCH 020/303] adjust qps --- benchmarks/disagg_benchmarks/disagg_benchmark.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmarks/disagg_benchmarks/disagg_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_benchmark.sh index bf7ef9dfa3ef0..c3e652b456859 100644 --- a/benchmarks/disagg_benchmarks/disagg_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_benchmark.sh @@ -168,14 +168,14 @@ main() { default_qps=4 default_output_len=12 - for target_qps in 1 2 4 8 16 + for target_qps in 2 4 8 16 do benchmark $target_qps $default_output_len done - for output_len in 5 10 20 40 80 + for target_output_len in 5 10 20 40 80 do - benchmark $default_qps $output_len + benchmark $default_qps $target_output_len done } From 5ac5249d5347d4cf0d384b40cfde1a55f068ce1a Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 8 Jul 2024 20:06:15 -0700 Subject: [PATCH 021/303] adjust swap range --- .../disagg_benchmarks/disagg_benchmark.sh | 10 +-- .../results/chunk_vs_disagg.yaml | 81 +++++++++++++++++++ .../visualize_benchmark_results.py | 73 +++++++++++++++++ 3 files changed, 159 insertions(+), 5 deletions(-) create mode 100644 benchmarks/disagg_benchmarks/results/chunk_vs_disagg.yaml create mode 100644 benchmarks/disagg_benchmarks/visualize_benchmark_results.py diff --git a/benchmarks/disagg_benchmarks/disagg_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_benchmark.sh index c3e652b456859..c8a7cba02a706 100644 --- a/benchmarks/disagg_benchmarks/disagg_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_benchmark.sh @@ -166,17 +166,17 @@ main() { mkdir results default_qps=4 - default_output_len=12 + default_output_len=150 for target_qps in 2 4 8 16 do benchmark $target_qps $default_output_len done - for target_output_len in 5 10 20 40 80 - do - benchmark $default_qps $target_output_len - done + # for target_output_len in 5 10 20 40 80 + # do + # benchmark $default_qps $target_output_len + # done } diff --git a/benchmarks/disagg_benchmarks/results/chunk_vs_disagg.yaml b/benchmarks/disagg_benchmarks/results/chunk_vs_disagg.yaml new file mode 100644 index 0000000000000..cbbb8de8826df --- /dev/null +++ b/benchmarks/disagg_benchmarks/results/chunk_vs_disagg.yaml @@ -0,0 +1,81 @@ +- chunk_itl: 35.312214966863394 + chunk_ttft: 197.25125090777874 + disagg_itl: 21.324921273315947 + disagg_ttft: 152.8801853582263 + itl_ratio: 1.6559130284363517 + output_len: 12 + prefill_decode_ratio: 0.597423797407428 + qps: 2 + ttft_ratio: 1.290234247463678 +- chunk_itl: 41.504599403589964 + chunk_ttft: 229.26480463147163 + disagg_itl: 22.99003972671926 + disagg_ttft: 199.59523537009954 + itl_ratio: 1.8053296078193763 + output_len: 12 + prefill_decode_ratio: 0.723484451464895 + qps: 4 + ttft_ratio: 1.1486486849566186 +- chunk_itl: 63.580538438012205 + chunk_ttft: 379.72798639535904 + disagg_itl: 29.123107485473156 + disagg_ttft: 508.8736157938838 + itl_ratio: 2.183164638928957 + output_len: 12 + prefill_decode_ratio: 1.4560992390885088 + qps: 8 + ttft_ratio: 0.7462127620881912 +- chunk_itl: 438.7920691122612 + chunk_ttft: 4792.676218897104 + disagg_itl: 38.97295152582228 + disagg_ttft: 10359.5165893808 + itl_ratio: 11.258887303455348 + output_len: 12 + prefill_decode_ratio: 22.151082104804797 + qps: 16 + ttft_ratio: 0.4626351217787439 +- chunk_itl: 65.75006234049798 + chunk_ttft: 219.36342687904835 + disagg_itl: 28.58696384578943 + disagg_ttft: 199.52613697946072 + itl_ratio: 2.300001591466045 + output_len: 5 + prefill_decode_ratio: 1.395923946703762 + qps: 4 + ttft_ratio: 1.099422011571495 +- chunk_itl: 45.51790158599616 + chunk_ttft: 231.06786338984966 + disagg_itl: 24.55629511550069 + disagg_ttft: 200.25028175115585 + itl_ratio: 1.8536143734998467 + output_len: 10 + prefill_decode_ratio: 0.815474324645788 + qps: 4 + ttft_ratio: 1.1538953222397448 +- chunk_itl: 32.62334335371852 + chunk_ttft: 224.4068773239851 + disagg_itl: 22.356921216845514 + disagg_ttft: 199.03663477301598 + itl_ratio: 1.4592055425385428 + output_len: 20 + prefill_decode_ratio: 0.44513426701849645 + qps: 4 + ttft_ratio: 1.1274651904153308 +- chunk_itl: 28.700303505733608 + chunk_ttft: 238.22125577926636 + disagg_itl: 21.38163771480322 + disagg_ttft: 200.1644251793623 + itl_ratio: 1.3422874285192585 + output_len: 40 + prefill_decode_ratio: 0.23403776157050613 + qps: 4 + ttft_ratio: 1.1901278439752834 +- chunk_itl: 25.861735691688956 + chunk_ttft: 237.75536592304707 + disagg_itl: 22.10882957596332 + disagg_ttft: 200.66210460662842 + itl_ratio: 1.1697469376581466 + output_len: 80 + prefill_decode_ratio: 0.11345133847835387 + qps: 4 + ttft_ratio: 1.184854342025043 \ No newline at end of file diff --git a/benchmarks/disagg_benchmarks/visualize_benchmark_results.py b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py new file mode 100644 index 0000000000000..1d5c3536736d3 --- /dev/null +++ b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py @@ -0,0 +1,73 @@ + +import matplotlib.pyplot as plt +import yaml +import pandas as pd +from tabulate import tabulate + + +def stringify(x): + return [str(i) for i in x] + + +if __name__ == "__main__": + + with open("results/chunk_vs_disagg.yaml", "r") as f: + data = yaml.load(f, Loader=yaml.FullLoader) + df = pd.DataFrame.from_dict(data) + + print_df = df.copy() + print_df.drop(columns=[ + "ttft_ratio", + "itl_ratio", + "prefill_decode_ratio", + ], inplace=True) + print_df.to_csv('results/chunk_vs_disagg.csv', index=False) + + df["chunk_e2e"] = df["chunk_ttft"] + df["chunk_itl"] * df["output_len"] + df["disagg_e2e"] = df["disagg_ttft"] + df["disagg_itl"] * df["output_len"] + df["e2e_ratio"] = df["chunk_e2e"] / df["disagg_e2e"] + + plt.rcParams['font.size'] = 20 + + + # qps vs performance + qps_df = df[df["output_len"] == 12].copy() + qps_df.drop(columns=[ + "chunk_itl", + "chunk_ttft", + "disagg_itl", + "disagg_ttft", + "output_len", + "prefill_decode_ratio", + ], inplace=True) + fig, ax = plt.subplots(figsize=(10, 7)) + qps_df.plot( + ax=ax, + kind="bar", + x="qps", + y=["ttft_ratio", "itl_ratio", "e2e_ratio"], + ylabel="$T_{chunked}~/~T_{disagg}$", + rot=0, + ) + ax.hlines(1, -1, 5, color='black') + fig.savefig('results/qps.png') + plt.close(fig) + + + # prefill decode ratio vs performance + tokens_df = df[df["output_len"] != 12] + fig, ax = plt.subplots(figsize=(10, 7)) + tokens_df.plot( + ax=ax, + kind="bar", + x="output_len", + xlabel="# of output tokens", + y=["ttft_ratio", "itl_ratio", "e2e_ratio", "prefill_decode_ratio"], + ylabel="$T_{chunked}~/~T_{disagg}$", + rot=0, + ) + ax.hlines(1, -1, 5, color='black') + fig.savefig('results/tokens.png') + plt.close(fig) + + \ No newline at end of file From 8f259853a4bca740988e0eaff7619c2f1863b4a3 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 8 Jul 2024 20:06:41 -0700 Subject: [PATCH 022/303] remove results --- .../results/chunk_vs_disagg.yaml | 81 ------------------- 1 file changed, 81 deletions(-) delete mode 100644 benchmarks/disagg_benchmarks/results/chunk_vs_disagg.yaml diff --git a/benchmarks/disagg_benchmarks/results/chunk_vs_disagg.yaml b/benchmarks/disagg_benchmarks/results/chunk_vs_disagg.yaml deleted file mode 100644 index cbbb8de8826df..0000000000000 --- a/benchmarks/disagg_benchmarks/results/chunk_vs_disagg.yaml +++ /dev/null @@ -1,81 +0,0 @@ -- chunk_itl: 35.312214966863394 - chunk_ttft: 197.25125090777874 - disagg_itl: 21.324921273315947 - disagg_ttft: 152.8801853582263 - itl_ratio: 1.6559130284363517 - output_len: 12 - prefill_decode_ratio: 0.597423797407428 - qps: 2 - ttft_ratio: 1.290234247463678 -- chunk_itl: 41.504599403589964 - chunk_ttft: 229.26480463147163 - disagg_itl: 22.99003972671926 - disagg_ttft: 199.59523537009954 - itl_ratio: 1.8053296078193763 - output_len: 12 - prefill_decode_ratio: 0.723484451464895 - qps: 4 - ttft_ratio: 1.1486486849566186 -- chunk_itl: 63.580538438012205 - chunk_ttft: 379.72798639535904 - disagg_itl: 29.123107485473156 - disagg_ttft: 508.8736157938838 - itl_ratio: 2.183164638928957 - output_len: 12 - prefill_decode_ratio: 1.4560992390885088 - qps: 8 - ttft_ratio: 0.7462127620881912 -- chunk_itl: 438.7920691122612 - chunk_ttft: 4792.676218897104 - disagg_itl: 38.97295152582228 - disagg_ttft: 10359.5165893808 - itl_ratio: 11.258887303455348 - output_len: 12 - prefill_decode_ratio: 22.151082104804797 - qps: 16 - ttft_ratio: 0.4626351217787439 -- chunk_itl: 65.75006234049798 - chunk_ttft: 219.36342687904835 - disagg_itl: 28.58696384578943 - disagg_ttft: 199.52613697946072 - itl_ratio: 2.300001591466045 - output_len: 5 - prefill_decode_ratio: 1.395923946703762 - qps: 4 - ttft_ratio: 1.099422011571495 -- chunk_itl: 45.51790158599616 - chunk_ttft: 231.06786338984966 - disagg_itl: 24.55629511550069 - disagg_ttft: 200.25028175115585 - itl_ratio: 1.8536143734998467 - output_len: 10 - prefill_decode_ratio: 0.815474324645788 - qps: 4 - ttft_ratio: 1.1538953222397448 -- chunk_itl: 32.62334335371852 - chunk_ttft: 224.4068773239851 - disagg_itl: 22.356921216845514 - disagg_ttft: 199.03663477301598 - itl_ratio: 1.4592055425385428 - output_len: 20 - prefill_decode_ratio: 0.44513426701849645 - qps: 4 - ttft_ratio: 1.1274651904153308 -- chunk_itl: 28.700303505733608 - chunk_ttft: 238.22125577926636 - disagg_itl: 21.38163771480322 - disagg_ttft: 200.1644251793623 - itl_ratio: 1.3422874285192585 - output_len: 40 - prefill_decode_ratio: 0.23403776157050613 - qps: 4 - ttft_ratio: 1.1901278439752834 -- chunk_itl: 25.861735691688956 - chunk_ttft: 237.75536592304707 - disagg_itl: 22.10882957596332 - disagg_ttft: 200.66210460662842 - itl_ratio: 1.1697469376581466 - output_len: 80 - prefill_decode_ratio: 0.11345133847835387 - qps: 4 - ttft_ratio: 1.184854342025043 \ No newline at end of file From 2363fa09f2b0ce4f18b8bd935d4f547dd01623bc Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 9 Jul 2024 16:26:44 -0700 Subject: [PATCH 023/303] adjust benchmark results so that there are 150 output tokens by default. Much more realistic --- benchmarks/disagg_benchmarks/visualize_benchmark_results.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/disagg_benchmarks/visualize_benchmark_results.py b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py index 1d5c3536736d3..8686fb2abf9b9 100644 --- a/benchmarks/disagg_benchmarks/visualize_benchmark_results.py +++ b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py @@ -31,7 +31,7 @@ def stringify(x): # qps vs performance - qps_df = df[df["output_len"] == 12].copy() + qps_df = df[df["output_len"] == 150].copy() qps_df.drop(columns=[ "chunk_itl", "chunk_ttft", From 3db988c713691052f0517d877f76d100a9c015f6 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 16 Jul 2024 23:49:59 -0700 Subject: [PATCH 024/303] add example usage for disaggregated prefill --- examples/disaggregated_prefill_example.sh | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 examples/disaggregated_prefill_example.sh diff --git a/examples/disaggregated_prefill_example.sh b/examples/disaggregated_prefill_example.sh new file mode 100644 index 0000000000000..7adcc6ec35b63 --- /dev/null +++ b/examples/disaggregated_prefill_example.sh @@ -0,0 +1,22 @@ + + +# prefilling instance +VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model neuralmagic/Meta-Llama-3-70B-Instruct-FP8 \ + --port 8100 \ + -tp 4 \ + --disable-log-stats \ + --disable-log-requests \ + --enable-chunked-prefill & + +sleep 2 + +VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model neuralmagic/Meta-Llama-3-70B-Instruct-FP8 \ + --port 8200 \ + -tp 4 \ + --disable-log-stats \ + --disable-log-requests \ + --enable-chunked-prefill & \ No newline at end of file From 00e46de261589b53ec9969d986b4dc8f2b726a63 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 16 Jul 2024 23:50:26 -0700 Subject: [PATCH 025/303] add environment variable for disaggregated prefill --- vllm/config.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 1ea2888796808..9f9fa938465c6 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -662,7 +662,15 @@ def __init__( self.ray_workers_use_nsight = ray_workers_use_nsight self.placement_group = placement_group - self.world_size = pipeline_parallel_size * self.tensor_parallel_size + if envs.VLLM_DISAGG_PREFILL_ROLE is not None: + # Disaggregated prefilling is enabled + # There will be 2 copies of vLLM + # One for prefilling and one for decoding + self.disagg_prefill_size = 2 + else: + self.disagg_prefill_size = 1 + + self.world_size = pipeline_parallel_size * tensor_parallel_size * self.disagg_prefill_size if worker_use_ray: if self.distributed_executor_backend is None: self.distributed_executor_backend = "ray" From de434d977e3181da74e44efe2be9925e99481e4c Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 16 Jul 2024 23:50:59 -0700 Subject: [PATCH 026/303] add a new distributed group for disaggregated prefill NCCL communication --- vllm/distributed/parallel_state.py | 55 ++++++++++++++++++++++++++++-- vllm/envs.py | 5 +++ 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 66ffe6e8a9fa9..eb3d32aeebb46 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -754,6 +754,14 @@ def get_pp_group() -> GroupCoordinator: "pipeline model parallel group is not initialized") return _PP + +_DISAGG: Optional[GroupCoordinator] = None + +def get_disagg_group() -> GroupCoordinator: + assert _DISAGG is not None, ( + "disaggregated prefilling parallel group is not initialized") + return _DISAGG + # kept for backward compatibility get_pipeline_model_parallel_group = get_pp_group @@ -827,6 +835,28 @@ def init_distributed_environment( else: assert _WORLD.world_size == torch.distributed.get_world_size(), ( "world group already initialized with a different world size") + + +def extend_distributed_group_with_offset( + groups: List[List[int]], + offset: int, +) -> List[List[int]]: + """ + Extend original distributed group. + The extended part will be the original distributed group plus an offset. + + Arguments: + groups: original distributed group + offset: the offset we want to apply to the duplicated group. + Typically world_size // 2 + """ + + new_groups = [] + for group in groups: + new_groups.append([rank for rank in group]) + new_groups.append([rank + offset for rank in group]) + + return new_groups def initialize_model_parallel( @@ -862,15 +892,24 @@ def initialize_model_parallel( backend = backend or torch.distributed.get_backend( get_world_group().device_group) + + if envs.VLLM_DISAGG_PREFILL_ROLE is not None: + # Disaggregated prefilling is enabled + # There will be 2 copies of vLLM + # One for prefilling and one for decoding + disagg_prefill_size = 2 + else: + disagg_prefill_size = 1 + if (world_size != - tensor_model_parallel_size * pipeline_model_parallel_size): + tensor_model_parallel_size * pipeline_model_parallel_size * disagg_prefill_size): raise RuntimeError( f"world_size ({world_size}) is not equal to " f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " f"pipeline_model_parallel_size ({pipeline_model_parallel_size})") # Build the tensor model-parallel groups. - num_tensor_model_parallel_groups: int = (world_size // + num_tensor_model_parallel_groups: int = (world_size // disagg_prefill_size // tensor_model_parallel_size) global _TP assert _TP is None, ("tensor model parallel group is already initialized") @@ -880,6 +919,12 @@ def initialize_model_parallel( range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)) group_ranks.append(ranks) + # extend the distributed group if disaggregated prefilling is enabled + if disagg_prefill_size > 1: + group_ranks = extend_distributed_group_with_offset( + group_ranks, + world_size // disagg_prefill_size + ) _TP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend) @@ -893,6 +938,12 @@ def initialize_model_parallel( for i in range(num_pipeline_model_parallel_groups): ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) group_ranks.append(ranks) + # extend the distributed group if disaggregated prefilling is enabled + if disagg_prefill_size > 1: + group_ranks = extend_distributed_group_with_offset( + group_ranks, + world_size // disagg_prefill_size + ) # pipeline parallel does not need custom allreduce _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, diff --git a/vllm/envs.py b/vllm/envs.py index c624510c7ea1a..0c8d9ae2b8d4b 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -251,6 +251,11 @@ lambda: os.getenv("VLLM_XLA_CACHE_PATH", "~/.vllm/xla_cache/"), "VLLM_FUSED_MOE_CHUNK_SIZE": lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "65536")), + + # Specify the role of current vllm instance + # Value can be "prefill", "decode" or None. + "VLLM_DISAGG_PREFILL_ROLE": + lambda: os.getenv("VLLM_DISAGG_PREFILL_ROLE", None), } # end-env-vars-definition From f157f6b42d6348a98a4685bc2b360ea35241b19e Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Wed, 17 Jul 2024 17:37:25 -0700 Subject: [PATCH 027/303] only inflate the world size inside parallel_state.py --- examples/disaggregated_prefill_example.sh | 3 ++- vllm/config.py | 10 +--------- vllm/distributed/parallel_state.py | 12 +++++++++++- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/examples/disaggregated_prefill_example.sh b/examples/disaggregated_prefill_example.sh index 7adcc6ec35b63..01b30433697c3 100644 --- a/examples/disaggregated_prefill_example.sh +++ b/examples/disaggregated_prefill_example.sh @@ -1,4 +1,5 @@ +export VLLM_LOGGING_LEVEL=DEBUG # prefilling instance VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ @@ -10,7 +11,7 @@ VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ --disable-log-requests \ --enable-chunked-prefill & -sleep 2 +sleep 10 VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ -m vllm.entrypoints.openai.api_server \ diff --git a/vllm/config.py b/vllm/config.py index 9f9fa938465c6..5586132c8be6c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -662,15 +662,7 @@ def __init__( self.ray_workers_use_nsight = ray_workers_use_nsight self.placement_group = placement_group - if envs.VLLM_DISAGG_PREFILL_ROLE is not None: - # Disaggregated prefilling is enabled - # There will be 2 copies of vLLM - # One for prefilling and one for decoding - self.disagg_prefill_size = 2 - else: - self.disagg_prefill_size = 1 - - self.world_size = pipeline_parallel_size * tensor_parallel_size * self.disagg_prefill_size + self.world_size = pipeline_parallel_size * tensor_parallel_size if worker_use_ray: if self.distributed_executor_backend is None: self.distributed_executor_backend = "ray" diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index eb3d32aeebb46..428ffb01dc21b 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -808,6 +808,16 @@ def init_distributed_environment( "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", world_size, rank, local_rank, distributed_init_method, backend) + if envs.VLLM_DISAGG_PREFILL_ROLE is not None: + # Disaggregated prefilling is enabled + # There will be 2 copies of vLLM + # One for prefilling and one for decoding + world_size = world_size * 2 + logger.debug( + "Disaggregated prefill enabled, " + "increase world size to %d", world_size) + else: + disagg_prefill_size = 1 if not torch.distributed.is_initialized(): assert distributed_init_method is not None, ( "distributed_init_method must be provided when initializing " @@ -929,7 +939,7 @@ def initialize_model_parallel( get_world_group().local_rank, backend) # Build the pipeline model-parallel groups. - num_pipeline_model_parallel_groups: int = (world_size // + num_pipeline_model_parallel_groups: int = (world_size // disagg_prefill_size // pipeline_model_parallel_size) global _PP assert _PP is None, ( From de82c3cbfb716649eb4817aad5f9a9e5b13f8ce2 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Wed, 17 Jul 2024 17:49:13 -0700 Subject: [PATCH 028/303] add more log information --- examples/disaggregated_prefill_example.sh | 4 +++- vllm/distributed/parallel_state.py | 5 +++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/examples/disaggregated_prefill_example.sh b/examples/disaggregated_prefill_example.sh index 01b30433697c3..ee8706227f9e1 100644 --- a/examples/disaggregated_prefill_example.sh +++ b/examples/disaggregated_prefill_example.sh @@ -20,4 +20,6 @@ VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ -tp 4 \ --disable-log-stats \ --disable-log-requests \ - --enable-chunked-prefill & \ No newline at end of file + --enable-chunked-prefill & + + diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 428ffb01dc21b..0bb5a66f37f2e 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -861,10 +861,15 @@ def extend_distributed_group_with_offset( Typically world_size // 2 """ + logger.debug("Extend the distributed groups with offset %d", offset) + logger.debug("Before extension:\n%s", str(groups)) + new_groups = [] for group in groups: new_groups.append([rank for rank in group]) new_groups.append([rank + offset for rank in group]) + + logger.debug("After extension:\n%s", str(new_groups)) return new_groups From 69ce0e0d7c1ec7432b64b3fa49a30a8693af957e Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Wed, 17 Jul 2024 17:55:40 -0700 Subject: [PATCH 029/303] specify vllm port --- examples/disaggregated_prefill_example.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/disaggregated_prefill_example.sh b/examples/disaggregated_prefill_example.sh index ee8706227f9e1..ebfaa6e43000a 100644 --- a/examples/disaggregated_prefill_example.sh +++ b/examples/disaggregated_prefill_example.sh @@ -1,5 +1,6 @@ export VLLM_LOGGING_LEVEL=DEBUG +export VLLM_PORT=12345 # prefilling instance VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ From e3dc2e9ca8df12682558006a2b432838af147441 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Wed, 17 Jul 2024 20:47:43 -0700 Subject: [PATCH 030/303] avoid switching to unused ports in disaggregated prefilling --- vllm/distributed/parallel_state.py | 4 ++++ vllm/utils.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 0bb5a66f37f2e..96bfc27610f97 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -9,6 +9,10 @@ - call `init_distributed_environment` to initialize the distributed environment. - call `initialize_model_parallel` or `ensure_model_parallel_initialized` to initialize the model parallel groups. + - In disaggregated prefilling, we will modify: + - World size: 2 * tp * pp + - Rank: [0, tp * pp) for prefilling, [tp * pp, 2 * tp * pp) for decoding + - Local rank: unchanged - any code dealing with the distributed stuff diff --git a/vllm/utils.py b/vllm/utils.py index a3d15d7979228..505f2a895ef93 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -364,6 +364,10 @@ def get_distributed_init_method(ip: str, port: int) -> str: def get_open_port() -> int: port = envs.VLLM_PORT if port is not None: + if envs.VLLM_DISAGG_PREFILL_ROLE is not None: + # The prefill and decode instance shares the same port + # Skip the binding check as the port may be binded by prefill + return port while True: try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: From 18fe19c4744f88ec80f73bc159ab46f53360fd84 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Wed, 17 Jul 2024 22:21:15 -0700 Subject: [PATCH 031/303] adjust parallel state to include _DISAGG distributed group --- vllm/distributed/parallel_state.py | 103 ++++++++++++++++++++--------- 1 file changed, 71 insertions(+), 32 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 8ce8a332bdf3e..aabe660cf1da3 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -8,11 +8,8 @@ - call `init_distributed_environment` to initialize the distributed environment. - call `initialize_model_parallel` or `ensure_model_parallel_initialized` to - initialize the model parallel groups. - - In disaggregated prefilling, we will modify: - - World size: 2 * tp * pp - - Rank: [0, tp * pp) for prefilling, [tp * pp, 2 * tp * pp) for decoding - - Local rank: unchanged + initialize the model parallel groups and disaggregated prefilling parallel + groups. - any code dealing with the distributed stuff @@ -703,8 +700,8 @@ def destroy(self): self.ca_comm = None if self.mq_broadcaster is not None: self.mq_broadcaster = None - - + + _WORLD: Optional[GroupCoordinator] = None @@ -855,7 +852,7 @@ def init_distributed_environment( "world group already initialized with a different world size") -def extend_distributed_group_with_offset( +def offset_distributed_groups( groups: List[List[int]], offset: int, ) -> List[List[int]]: @@ -869,15 +866,14 @@ def extend_distributed_group_with_offset( Typically world_size // 2 """ - logger.debug("Extend the distributed groups with offset %d", offset) - logger.debug("Before extension:\n%s", str(groups)) + logger.debug("Offset distributed groups with offset %d", offset) + logger.debug("Before offset:\n%s", str(groups)) new_groups = [] for group in groups: - new_groups.append([rank for rank in group]) new_groups.append([rank + offset for rank in group]) - logger.debug("After extension:\n%s", str(new_groups)) + logger.debug("After offset:\n%s", str(new_groups)) return new_groups @@ -908,31 +904,51 @@ def initialize_model_parallel( are on the same DGX box. For example if we are using 2 DGX-1 boxes with a total of 16 GPUs, rank 0 to 7 belong to the first box and ranks 8 to 15 belong to the second box. + + Disaggregated prefilling will also initialize using this function. + Why: disaggregated prefilling is similar to pipeline parallel + except that disaggregated prefilling does not partition model + Methodology: + - Only change variables in this file + - Any variable outside this file should be unchanged + Modifications: + - World size in vLLM variables (like in `ParallelConfig`): unchanged + - World size in `torch.distributed`: doubled (2 * tp * pp) + - Rank: + - [0, tp * pp) for prefilling + - [tp * pp, 2 * tp * pp) for decoding + - Parallel groups + - Unchanged for prefilling + - Offseted by tp * pp for decoding + - Add a new parallel group `_DISAGG` for disaggregated prefilling + - [0, tp * pp], [1, tp * pp + 1], .. + - Local rank: unchanged + - Thanks to PP implementation, distributed operations only rely on + local rank. This guarantees the communications inside the + prefilling instance and decoding instance are unchanged. """ # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() world_size: int = torch.distributed.get_world_size() backend = backend or torch.distributed.get_backend( - get_world_group().device_group) - - + get_world_group().device_group) if envs.VLLM_DISAGG_PREFILL_ROLE is not None: - # Disaggregated prefilling is enabled - # There will be 2 copies of vLLM - # One for prefilling and one for decoding - disagg_prefill_size = 2 - else: - disagg_prefill_size = 1 + # Keep the semantics of world_size the same (`tp * pp`) + logger.debug("Disaggregated prefilling enabled") + world_size = world_size // 2 + logger.debug("Shrink the world size from %d to %d", + world_size * 2, + world_size) if (world_size != - tensor_model_parallel_size * pipeline_model_parallel_size * disagg_prefill_size): + tensor_model_parallel_size * pipeline_model_parallel_size): raise RuntimeError( f"world_size ({world_size}) is not equal to " f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " f"pipeline_model_parallel_size ({pipeline_model_parallel_size})") # Build the tensor model-parallel groups. - num_tensor_model_parallel_groups: int = (world_size // disagg_prefill_size // + num_tensor_model_parallel_groups: int = (world_size // tensor_model_parallel_size) global _TP assert _TP is None, ("tensor model parallel group is already initialized") @@ -942,11 +958,12 @@ def initialize_model_parallel( range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)) group_ranks.append(ranks) - # extend the distributed group if disaggregated prefilling is enabled - if disagg_prefill_size > 1: - group_ranks = extend_distributed_group_with_offset( + if envs.VLLM_DISAGG_PREFILL_ROLE == "decoding": + logger.debug("Current instance is decoding instance") + logger.debug("Offset the distributed group ranks by %d", world_size) + group_ranks = offset_distributed_groups( group_ranks, - world_size // disagg_prefill_size + world_size ) # message queue broadcaster is only used in tensor model parallel group @@ -956,7 +973,7 @@ def initialize_model_parallel( use_message_queue_broadcaster=True) # Build the pipeline model-parallel groups. - num_pipeline_model_parallel_groups: int = (world_size // disagg_prefill_size // + num_pipeline_model_parallel_groups: int = (world_size // pipeline_model_parallel_size) global _PP assert _PP is None, ( @@ -965,17 +982,34 @@ def initialize_model_parallel( for i in range(num_pipeline_model_parallel_groups): ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) group_ranks.append(ranks) - # extend the distributed group if disaggregated prefilling is enabled - if disagg_prefill_size > 1: - group_ranks = extend_distributed_group_with_offset( + if envs.VLLM_DISAGG_PREFILL_ROLE == "decoding": + logger.debug("Current instance is decoding instance") + logger.debug("Offset the distributed group ranks by %d", world_size) + group_ranks = offset_distributed_groups( group_ranks, - world_size // disagg_prefill_size + world_size ) # pipeline parallel does not need custom allreduce _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False) + + if envs.VLLM_DISAGG_PREFILL_ROLE is not None: + assert envs.VLLM_DISAGG_PREFILL_ROLE in ["prefilling", "decoding"], ( + "VLLM_DISAGG_PREFILL_ROLE should be either prefilling or decoding") + logger.debug("Disaggregated prefilling enabled, create distributed group") + group_ranks = [] + for i in range(world_size): + # prefilling local rank: i + # decoding global rank: i + world_size + group_ranks.append([i, i + world_size]) + logger.debug("Distributed group is %s", str(group_ranks)) + _DISAGG = init_model_parallel_group( + group_ranks, + int(envs.VLLM_DISAGG_PREFILL_ROLE == "decoding"), + backend, + use_custom_allreduce=False) def ensure_model_parallel_initialized( @@ -1061,6 +1095,11 @@ def destroy_model_parallel(): _PP.destroy() _PP = None + global _DISAGG + if _DISAGG: + _DISAGG.destroy() + _DISAGG = None + def destroy_distributed_environment(): global _WORLD From 94cadb85855458c96535e774d7261585fe975e73 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Wed, 17 Jul 2024 22:37:55 -0700 Subject: [PATCH 032/303] offset global rank for decoding instances --- vllm/distributed/parallel_state.py | 32 +++++++++++++++++++----------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index aabe660cf1da3..9c1247e066236 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -813,26 +813,34 @@ def init_distributed_environment( "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", world_size, rank, local_rank, distributed_init_method, backend) - if envs.VLLM_DISAGG_PREFILL_ROLE is not None: - # Disaggregated prefilling is enabled - # There will be 2 copies of vLLM - # One for prefilling and one for decoding - world_size = world_size * 2 - logger.debug( - "Disaggregated prefill enabled, " - "increase world size to %d", world_size) - else: - disagg_prefill_size = 1 if not torch.distributed.is_initialized(): assert distributed_init_method is not None, ( "distributed_init_method must be provided when initializing " "distributed environment") # this backend is used for WORLD + if envs.VLLM_DISAGG_PREFILL_ROLE is not None: + # Disaggregated prefilling is enabled + # world_size in vLLM is tp * pp + # for prefill, the ranks are [0, world_size) + # for decode, the ranks are [world_size, 2 * world_size) + maybe_disagg_world_size = world_size * 2 + logger.debug( + "Disaggregated prefill enabled, handle torch-related changes on world size and ranks. This change is only inside `vllm/distributed/parallel_state.py`) and the other files are unchanged.") + assert envs.VLLM_DISAGG_PREFILL_ROLE in ["prefilling", "decoding"], ( + "VLLM_DISAGG_PREFILL_ROLE should be either prefilling or decoding") + if envs.VLLM_DISAGG_PREFILL_ROLE == "prefilling": + maybe_disagg_rank = rank + else: + # offset global rank by tp * pp (which is world_size) + maybe_disagg_rank = rank + world_size + else: + maybe_disagg_world_size = world_size + maybe_disagg_rank = rank torch.distributed.init_process_group( backend=backend, init_method=distributed_init_method, - world_size=world_size, - rank=rank) + world_size=maybe_disagg_world_size, + rank=maybe_disagg_rank) # set the local rank # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816 From ded5d92564e68d4442bdd698708c24d91b3c461f Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Wed, 17 Jul 2024 22:42:04 -0700 Subject: [PATCH 033/303] adjust naming: use prefill and decode instead of prefilling and decoding --- vllm/distributed/parallel_state.py | 54 +++++++++++++++--------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 9c1247e066236..01ccbbb989178 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -8,7 +8,7 @@ - call `init_distributed_environment` to initialize the distributed environment. - call `initialize_model_parallel` or `ensure_model_parallel_initialized` to - initialize the model parallel groups and disaggregated prefilling parallel + initialize the model parallel groups and disaggregated prefill parallel groups. - any code dealing with the distributed stuff @@ -764,7 +764,7 @@ def get_pp_group() -> GroupCoordinator: def get_disagg_group() -> GroupCoordinator: assert _DISAGG is not None, ( - "disaggregated prefilling parallel group is not initialized") + "disaggregated prefill parallel group is not initialized") return _DISAGG @@ -819,16 +819,16 @@ def init_distributed_environment( "distributed environment") # this backend is used for WORLD if envs.VLLM_DISAGG_PREFILL_ROLE is not None: - # Disaggregated prefilling is enabled + # Disaggregated prefill is enabled # world_size in vLLM is tp * pp # for prefill, the ranks are [0, world_size) # for decode, the ranks are [world_size, 2 * world_size) maybe_disagg_world_size = world_size * 2 logger.debug( "Disaggregated prefill enabled, handle torch-related changes on world size and ranks. This change is only inside `vllm/distributed/parallel_state.py`) and the other files are unchanged.") - assert envs.VLLM_DISAGG_PREFILL_ROLE in ["prefilling", "decoding"], ( - "VLLM_DISAGG_PREFILL_ROLE should be either prefilling or decoding") - if envs.VLLM_DISAGG_PREFILL_ROLE == "prefilling": + assert envs.VLLM_DISAGG_PREFILL_ROLE in ["prefill", "decode"], ( + "VLLM_DISAGG_PREFILL_ROLE should be either prefill or decode") + if envs.VLLM_DISAGG_PREFILL_ROLE == "prefill": maybe_disagg_rank = rank else: # offset global rank by tp * pp (which is world_size) @@ -913,9 +913,9 @@ def initialize_model_parallel( with a total of 16 GPUs, rank 0 to 7 belong to the first box and ranks 8 to 15 belong to the second box. - Disaggregated prefilling will also initialize using this function. - Why: disaggregated prefilling is similar to pipeline parallel - except that disaggregated prefilling does not partition model + Disaggregated prefill will also initialize using this function. + Why: disaggregated prefill is similar to pipeline parallel + except that disaggregated prefill does not partition model Methodology: - Only change variables in this file - Any variable outside this file should be unchanged @@ -923,17 +923,17 @@ def initialize_model_parallel( - World size in vLLM variables (like in `ParallelConfig`): unchanged - World size in `torch.distributed`: doubled (2 * tp * pp) - Rank: - - [0, tp * pp) for prefilling - - [tp * pp, 2 * tp * pp) for decoding + - [0, tp * pp) for prefill + - [tp * pp, 2 * tp * pp) for decode - Parallel groups - - Unchanged for prefilling - - Offseted by tp * pp for decoding - - Add a new parallel group `_DISAGG` for disaggregated prefilling + - Unchanged for prefill + - Offseted by tp * pp for decode + - Add a new parallel group `_DISAGG` for disaggregated prefill - [0, tp * pp], [1, tp * pp + 1], .. - Local rank: unchanged - Thanks to PP implementation, distributed operations only rely on local rank. This guarantees the communications inside the - prefilling instance and decoding instance are unchanged. + prefill instance and decode instance are unchanged. """ # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() @@ -942,7 +942,7 @@ def initialize_model_parallel( get_world_group().device_group) if envs.VLLM_DISAGG_PREFILL_ROLE is not None: # Keep the semantics of world_size the same (`tp * pp`) - logger.debug("Disaggregated prefilling enabled") + logger.debug("Disaggregated prefill enabled") world_size = world_size // 2 logger.debug("Shrink the world size from %d to %d", world_size * 2, @@ -966,8 +966,8 @@ def initialize_model_parallel( range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)) group_ranks.append(ranks) - if envs.VLLM_DISAGG_PREFILL_ROLE == "decoding": - logger.debug("Current instance is decoding instance") + if envs.VLLM_DISAGG_PREFILL_ROLE == "decode": + logger.debug("Current instance is decode instance") logger.debug("Offset the distributed group ranks by %d", world_size) group_ranks = offset_distributed_groups( group_ranks, @@ -990,8 +990,8 @@ def initialize_model_parallel( for i in range(num_pipeline_model_parallel_groups): ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) group_ranks.append(ranks) - if envs.VLLM_DISAGG_PREFILL_ROLE == "decoding": - logger.debug("Current instance is decoding instance") + if envs.VLLM_DISAGG_PREFILL_ROLE == "decode": + logger.debug("Current instance is decode instance") logger.debug("Offset the distributed group ranks by %d", world_size) group_ranks = offset_distributed_groups( group_ranks, @@ -1004,18 +1004,18 @@ def initialize_model_parallel( use_custom_allreduce=False) if envs.VLLM_DISAGG_PREFILL_ROLE is not None: - assert envs.VLLM_DISAGG_PREFILL_ROLE in ["prefilling", "decoding"], ( - "VLLM_DISAGG_PREFILL_ROLE should be either prefilling or decoding") - logger.debug("Disaggregated prefilling enabled, create distributed group") + assert envs.VLLM_DISAGG_PREFILL_ROLE in ["prefill", "decode"], ( + "VLLM_DISAGG_PREFILL_ROLE should be either prefill or decode") + logger.debug("Disaggregated prefill enabled, create distributed group") group_ranks = [] for i in range(world_size): - # prefilling local rank: i - # decoding global rank: i + world_size + # prefill local rank: i + # decode global rank: i + world_size group_ranks.append([i, i + world_size]) logger.debug("Distributed group is %s", str(group_ranks)) _DISAGG = init_model_parallel_group( group_ranks, - int(envs.VLLM_DISAGG_PREFILL_ROLE == "decoding"), + int(envs.VLLM_DISAGG_PREFILL_ROLE == "decode"), backend, use_custom_allreduce=False) @@ -1060,7 +1060,7 @@ def model_parallel_is_initialized(): def patch_tensor_parallel_group(tp_group: GroupCoordinator): """Patch the tp group temporarily until this function ends. - This method is for draft workers of speculative decoding to run draft model + This method is for draft workers of speculative decode to run draft model with different tp degree from that of target model workers. Args: From 709ae054ba098bdcfd5ab24ffd1639c5ba82a35a Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Wed, 17 Jul 2024 22:42:59 -0700 Subject: [PATCH 034/303] adjust the example: let the decode process in foreground for debugging --- examples/disaggregated_prefill_example.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/disaggregated_prefill_example.sh b/examples/disaggregated_prefill_example.sh index ebfaa6e43000a..6835a378ff565 100644 --- a/examples/disaggregated_prefill_example.sh +++ b/examples/disaggregated_prefill_example.sh @@ -21,6 +21,6 @@ VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ -tp 4 \ --disable-log-stats \ --disable-log-requests \ - --enable-chunked-prefill & + --enable-chunked-prefill From 2ab44d4aea31b45f57be26b2fac1724a17b73c62 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Wed, 17 Jul 2024 22:52:06 -0700 Subject: [PATCH 035/303] adjust logger format --- vllm/distributed/parallel_state.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 01ccbbb989178..ccf0bb465d751 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -22,6 +22,7 @@ """ import contextlib import pickle +import logging from collections import namedtuple from contextlib import contextmanager, nullcontext from dataclasses import dataclass @@ -793,6 +794,13 @@ def graph_capture(): logger = init_logger(__name__) +if envs.VLLM_DISAGG_PREFILL_ROLE is not None: + # disaggregated prefill enabled + # indicating if the current instance is prefill or decode + class CustomAdapter(logging.LoggerAdapter): + def process(self, msg, kwargs): + return f"[{envs.VLLM_DISAGG_PREFILL_ROLE}] {msg}", kwargs + logger = CustomAdapter(logger) _ENABLE_CUSTOM_ALL_REDUCE = True @@ -942,11 +950,8 @@ def initialize_model_parallel( get_world_group().device_group) if envs.VLLM_DISAGG_PREFILL_ROLE is not None: # Keep the semantics of world_size the same (`tp * pp`) - logger.debug("Disaggregated prefill enabled") + logger.debug("Disaggregated prefill enabled, the world size obtained from torch.distributed (2 * tp * pp) should be decreased to align with vLLM world size (tp * pp)") world_size = world_size // 2 - logger.debug("Shrink the world size from %d to %d", - world_size * 2, - world_size) if (world_size != tensor_model_parallel_size * pipeline_model_parallel_size): From 22138814ffad83a1e2a5ddf12bbe2ec28126f285 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Wed, 17 Jul 2024 23:03:21 -0700 Subject: [PATCH 036/303] test if the P2P cache stucks when no disaggregated prefilling --- examples/disaggregated_prefill_example.sh | 33 ++++++++++++++--------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/examples/disaggregated_prefill_example.sh b/examples/disaggregated_prefill_example.sh index 6835a378ff565..7b3247dcc4a60 100644 --- a/examples/disaggregated_prefill_example.sh +++ b/examples/disaggregated_prefill_example.sh @@ -2,25 +2,34 @@ export VLLM_LOGGING_LEVEL=DEBUG export VLLM_PORT=12345 -# prefilling instance -VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ +CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ -m vllm.entrypoints.openai.api_server \ --model neuralmagic/Meta-Llama-3-70B-Instruct-FP8 \ --port 8100 \ -tp 4 \ --disable-log-stats \ --disable-log-requests \ - --enable-chunked-prefill & + --enable-chunked-prefill -sleep 10 +# # prefilling instance +# VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ +# -m vllm.entrypoints.openai.api_server \ +# --model neuralmagic/Meta-Llama-3-70B-Instruct-FP8 \ +# --port 8100 \ +# -tp 4 \ +# --disable-log-stats \ +# --disable-log-requests \ +# --enable-chunked-prefill & -VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ - -m vllm.entrypoints.openai.api_server \ - --model neuralmagic/Meta-Llama-3-70B-Instruct-FP8 \ - --port 8200 \ - -tp 4 \ - --disable-log-stats \ - --disable-log-requests \ - --enable-chunked-prefill +# sleep 10 + +# VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ +# -m vllm.entrypoints.openai.api_server \ +# --model neuralmagic/Meta-Llama-3-70B-Instruct-FP8 \ +# --port 8200 \ +# -tp 4 \ +# --disable-log-stats \ +# --disable-log-requests \ +# --enable-chunked-prefill From 544f5cb243a86baeba6c5c84ddb2727cacbd535a Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Wed, 17 Jul 2024 23:07:05 -0700 Subject: [PATCH 037/303] let decode instance sleep, to avoid generating P2P cache simultaneously --- vllm/distributed/parallel_state.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index ccf0bb465d751..3f770c520ee0e 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -20,6 +20,7 @@ parallelism, you can skip the model parallel initialization and destruction steps. """ +import time import contextlib import pickle import logging @@ -849,6 +850,8 @@ def init_distributed_environment( init_method=distributed_init_method, world_size=maybe_disagg_world_size, rank=maybe_disagg_rank) + if envs.VLLM_DISAGG_PREFILL_ROLE == "decode": + time.sleep(60) # set the local rank # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816 @@ -883,13 +886,10 @@ def offset_distributed_groups( """ logger.debug("Offset distributed groups with offset %d", offset) - logger.debug("Before offset:\n%s", str(groups)) new_groups = [] for group in groups: new_groups.append([rank + offset for rank in group]) - - logger.debug("After offset:\n%s", str(new_groups)) return new_groups From 04d319a08ae484934813ee525c8d73230f49d99c Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Wed, 17 Jul 2024 23:09:34 -0700 Subject: [PATCH 038/303] continue disaggregated prefill debugging --- examples/disaggregated_prefill_example.sh | 34 ++++++++--------------- 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/examples/disaggregated_prefill_example.sh b/examples/disaggregated_prefill_example.sh index 7b3247dcc4a60..360ad4c69c00a 100644 --- a/examples/disaggregated_prefill_example.sh +++ b/examples/disaggregated_prefill_example.sh @@ -2,34 +2,24 @@ export VLLM_LOGGING_LEVEL=DEBUG export VLLM_PORT=12345 -CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ +# prefilling instance +VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ -m vllm.entrypoints.openai.api_server \ --model neuralmagic/Meta-Llama-3-70B-Instruct-FP8 \ --port 8100 \ -tp 4 \ --disable-log-stats \ --disable-log-requests \ - --enable-chunked-prefill + --enable-chunked-prefill & -# # prefilling instance -# VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ -# -m vllm.entrypoints.openai.api_server \ -# --model neuralmagic/Meta-Llama-3-70B-Instruct-FP8 \ -# --port 8100 \ -# -tp 4 \ -# --disable-log-stats \ -# --disable-log-requests \ -# --enable-chunked-prefill & - -# sleep 10 - -# VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ -# -m vllm.entrypoints.openai.api_server \ -# --model neuralmagic/Meta-Llama-3-70B-Instruct-FP8 \ -# --port 8200 \ -# -tp 4 \ -# --disable-log-stats \ -# --disable-log-requests \ -# --enable-chunked-prefill +# decoding instance +VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model neuralmagic/Meta-Llama-3-70B-Instruct-FP8 \ + --port 8200 \ + -tp 4 \ + --disable-log-stats \ + --disable-log-requests \ + --enable-chunked-prefill From 2e0f02cea3e86309e364a31ac7293dbdfc0a85cf Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Wed, 17 Jul 2024 23:17:46 -0700 Subject: [PATCH 039/303] offset world group for decoding instance --- vllm/distributed/parallel_state.py | 54 ++++++++++++++++-------------- 1 file changed, 29 insertions(+), 25 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 3f770c520ee0e..60d999ab75aab 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -809,6 +809,29 @@ def process(self, msg, kwargs): def set_custom_all_reduce(enable: bool): global _ENABLE_CUSTOM_ALL_REDUCE _ENABLE_CUSTOM_ALL_REDUCE = enable + + +def offset_distributed_groups( + groups: List[List[int]], + offset: int, +) -> List[List[int]]: + """ + Extend original distributed group. + The extended part will be the original distributed group plus an offset. + + Arguments: + groups: original distributed group + offset: the offset we want to apply to the duplicated group. + Typically world_size // 2 + """ + + logger.debug("Offset distributed groups with offset %d", offset) + + new_groups = [] + for group in groups: + new_groups.append([rank + offset for rank in group]) + + return new_groups def init_distributed_environment( @@ -850,8 +873,6 @@ def init_distributed_environment( init_method=distributed_init_method, world_size=maybe_disagg_world_size, rank=maybe_disagg_rank) - if envs.VLLM_DISAGG_PREFILL_ROLE == "decode": - time.sleep(60) # set the local rank # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816 @@ -864,34 +885,17 @@ def init_distributed_environment( local_rank = rank global _WORLD if _WORLD is None: - ranks = list(range(torch.distributed.get_world_size())) + ranks = list(range(world_size)) + # offset the distributed group + if all( + [envs.VLLM_DISAGG_PREFILL_ROLE is not None], + [envs.VLLM_DISAGG_PREFILL_ROLE == "decode"]): + ranks = offset_distributed_groups(ranks, world_size) _WORLD = init_world_group(ranks, local_rank, backend) else: assert _WORLD.world_size == torch.distributed.get_world_size(), ( "world group already initialized with a different world size") - -def offset_distributed_groups( - groups: List[List[int]], - offset: int, -) -> List[List[int]]: - """ - Extend original distributed group. - The extended part will be the original distributed group plus an offset. - - Arguments: - groups: original distributed group - offset: the offset we want to apply to the duplicated group. - Typically world_size // 2 - """ - - logger.debug("Offset distributed groups with offset %d", offset) - - new_groups = [] - for group in groups: - new_groups.append([rank + offset for rank in group]) - - return new_groups def initialize_model_parallel( From fd5f1153c0b4a39b626026b32917fef1522f260c Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Wed, 17 Jul 2024 23:19:31 -0700 Subject: [PATCH 040/303] a syntax fix --- vllm/distributed/parallel_state.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 60d999ab75aab..66bf43ec563e4 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -887,9 +887,9 @@ def init_distributed_environment( if _WORLD is None: ranks = list(range(world_size)) # offset the distributed group - if all( - [envs.VLLM_DISAGG_PREFILL_ROLE is not None], - [envs.VLLM_DISAGG_PREFILL_ROLE == "decode"]): + if all([ + envs.VLLM_DISAGG_PREFILL_ROLE is not None, + envs.VLLM_DISAGG_PREFILL_ROLE == "decode"]): ranks = offset_distributed_groups(ranks, world_size) _WORLD = init_world_group(ranks, local_rank, backend) else: From 8d90e6a4e902ab8868f3901aa6cff74fd68422e9 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Wed, 17 Jul 2024 23:22:04 -0700 Subject: [PATCH 041/303] bug fix --- vllm/distributed/parallel_state.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 66bf43ec563e4..86d3ef5135d3d 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -890,7 +890,8 @@ def init_distributed_environment( if all([ envs.VLLM_DISAGG_PREFILL_ROLE is not None, envs.VLLM_DISAGG_PREFILL_ROLE == "decode"]): - ranks = offset_distributed_groups(ranks, world_size) + ranks = list(range(world_size, 2 * world_size)) + _WORLD = init_world_group(ranks, local_rank, backend) else: assert _WORLD.world_size == torch.distributed.get_world_size(), ( From a9474a72b52649ffa96b2e3a64eee10bca8c54b4 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Thu, 18 Jul 2024 00:00:27 -0700 Subject: [PATCH 042/303] specify the source of get_open_port --- vllm/distributed/device_communicators/shm_broadcast.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 151b08c1b996c..213e734f58c1d 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -182,11 +182,11 @@ def __init__( max_chunks) self.local_socket = context.socket(PUB) - local_subscribe_port = get_open_port() + local_subscribe_port = get_open_port(is_for_dist_init = False) self.local_socket.bind(f"tcp://*:{local_subscribe_port}") self.local_sync_socket = context.socket(REP) - local_sync_port = get_open_port() + local_sync_port = get_open_port(is_for_dist_init = False) self.local_sync_socket.bind(f"tcp://*:{local_sync_port}") self.current_idx = 0 @@ -202,11 +202,11 @@ def __init__( # for remote readers, we will: # create a publish-subscribe socket to communicate large data self.remote_socket = context.socket(PUB) - remote_subscribe_port = get_open_port() + remote_subscribe_port = get_open_port(is_for_dist_init = False) self.remote_socket.bind(f"tcp://*:{remote_subscribe_port}") self.remote_sync_socket = context.socket(REP) - remote_sync_port = get_open_port() + remote_sync_port = get_open_port(is_for_dist_init = False) self.remote_sync_socket.bind(f"tcp://*:{remote_sync_port}") else: remote_subscribe_port = None From 701b0878e69245cb0a1231efac4dc43a90a654e5 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Thu, 18 Jul 2024 00:01:53 -0700 Subject: [PATCH 043/303] document why specifying the source of get_open_port --- vllm/utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 75c3c4de02c0c..f4c3809360f86 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -361,11 +361,12 @@ def get_distributed_init_method(ip: str, port: int) -> str: return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}" -def get_open_port() -> int: +def get_open_port(is_for_dist_init: bool = True) -> int: port = envs.VLLM_PORT if port is not None: - if envs.VLLM_DISAGG_PREFILL_ROLE is not None: - # The prefill and decode instance shares the same port + if envs.VLLM_DISAGG_PREFILL_ROLE is not None and is_for_dist_init: + # When initializing distributed environment for disagg prefill + # The prefill and decode instance may share the same port # Skip the binding check as the port may be binded by prefill return port while True: From fa5d71fa5afa0d3c87ddbc07caa8d97e6d91c1ba Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Thu, 18 Jul 2024 19:02:50 -0700 Subject: [PATCH 044/303] add VLLM_TRACE_FUNCTION to track the call stack --- examples/disaggregated_prefill_example.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/disaggregated_prefill_example.sh b/examples/disaggregated_prefill_example.sh index 360ad4c69c00a..bbcb11e863cdb 100644 --- a/examples/disaggregated_prefill_example.sh +++ b/examples/disaggregated_prefill_example.sh @@ -1,6 +1,7 @@ export VLLM_LOGGING_LEVEL=DEBUG export VLLM_PORT=12345 +export VLLM_TRACE_FUNCTION=1 # prefilling instance VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ From e2faedee295196ae8ce6fe2e950fbd250dfc332a Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Thu, 18 Jul 2024 19:06:27 -0700 Subject: [PATCH 045/303] fix customadapter bug --- vllm/distributed/parallel_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 86d3ef5135d3d..52cadcea90ec1 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -801,7 +801,7 @@ def graph_capture(): class CustomAdapter(logging.LoggerAdapter): def process(self, msg, kwargs): return f"[{envs.VLLM_DISAGG_PREFILL_ROLE}] {msg}", kwargs - logger = CustomAdapter(logger) + logger = CustomAdapter(logger, extra=None) _ENABLE_CUSTOM_ALL_REDUCE = True From 76b6c5e2277fcd3251c989cdce7d24f2e53788b3 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Thu, 18 Jul 2024 19:58:52 -0700 Subject: [PATCH 046/303] add parallel state logs for debugging --- vllm/distributed/parallel_state.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 52cadcea90ec1..4a4ef7c50ca25 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -140,6 +140,7 @@ def __init__( ): self.rank = torch.distributed.get_rank() + logger.debug("My rank is %d", self.rank) self.local_rank = local_rank self.device_group = None self.cpu_group = None @@ -147,9 +148,11 @@ def __init__( for ranks in group_ranks: device_group = torch.distributed.new_group( ranks, backend=torch_distributed_backend) + logger.debug("device group initialized") # a group with `gloo` backend, to allow direct coordination between # processes through the CPU. cpu_group = torch.distributed.new_group(ranks, backend="gloo") + logger.debug("cpu group initialized") if self.rank in ranks: self.ranks = ranks self.world_size = len(ranks) @@ -180,6 +183,7 @@ def __init__( group=self.cpu_group, device=self.device, ) + logger.debug("Pynccl initialized") else: self.pynccl_comm = None @@ -954,7 +958,6 @@ def initialize_model_parallel( backend = backend or torch.distributed.get_backend( get_world_group().device_group) if envs.VLLM_DISAGG_PREFILL_ROLE is not None: - # Keep the semantics of world_size the same (`tp * pp`) logger.debug("Disaggregated prefill enabled, the world size obtained from torch.distributed (2 * tp * pp) should be decreased to align with vLLM world size (tp * pp)") world_size = world_size // 2 @@ -978,7 +981,7 @@ def initialize_model_parallel( group_ranks.append(ranks) if envs.VLLM_DISAGG_PREFILL_ROLE == "decode": logger.debug("Current instance is decode instance") - logger.debug("Offset the distributed group ranks by %d", world_size) + logger.debug("Offset the _TP ranks by %d", world_size) group_ranks = offset_distributed_groups( group_ranks, world_size @@ -1002,7 +1005,7 @@ def initialize_model_parallel( group_ranks.append(ranks) if envs.VLLM_DISAGG_PREFILL_ROLE == "decode": logger.debug("Current instance is decode instance") - logger.debug("Offset the distributed group ranks by %d", world_size) + logger.debug("Offset the _PP ranks by %d", world_size) group_ranks = offset_distributed_groups( group_ranks, world_size @@ -1016,7 +1019,7 @@ def initialize_model_parallel( if envs.VLLM_DISAGG_PREFILL_ROLE is not None: assert envs.VLLM_DISAGG_PREFILL_ROLE in ["prefill", "decode"], ( "VLLM_DISAGG_PREFILL_ROLE should be either prefill or decode") - logger.debug("Disaggregated prefill enabled, create distributed group") + logger.debug("Disaggregated prefill enabled, create _DISAGG group") group_ranks = [] for i in range(world_size): # prefill local rank: i From cb6d6a5a11d8a63e82ff663c0aa216c76b81bc2e Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Thu, 18 Jul 2024 20:08:58 -0700 Subject: [PATCH 047/303] add sleep when initializing parallel state --- vllm/distributed/parallel_state.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 4a4ef7c50ca25..6db042bade2b7 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -894,6 +894,9 @@ def init_distributed_environment( if all([ envs.VLLM_DISAGG_PREFILL_ROLE is not None, envs.VLLM_DISAGG_PREFILL_ROLE == "decode"]): + # sleep 10 seconds to avoid potential collisions + # when initializing distributed environment + time.sleep(10) ranks = list(range(world_size, 2 * world_size)) _WORLD = init_world_group(ranks, local_rank, backend) From fe8fb473aa73ed056ab8a31ab6a6c517ece1289c Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Thu, 18 Jul 2024 20:26:17 -0700 Subject: [PATCH 048/303] only log when rank%4==0 --- vllm/distributed/parallel_state.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 6db042bade2b7..18b63639ce087 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -140,7 +140,6 @@ def __init__( ): self.rank = torch.distributed.get_rank() - logger.debug("My rank is %d", self.rank) self.local_rank = local_rank self.device_group = None self.cpu_group = None @@ -799,13 +798,16 @@ def graph_capture(): logger = init_logger(__name__) -if envs.VLLM_DISAGG_PREFILL_ROLE is not None: + +original_logger = logger +def logger(*args, **kwargs): # disaggregated prefill enabled # indicating if the current instance is prefill or decode - class CustomAdapter(logging.LoggerAdapter): - def process(self, msg, kwargs): - return f"[{envs.VLLM_DISAGG_PREFILL_ROLE}] {msg}", kwargs - logger = CustomAdapter(logger, extra=None) + if torch.distributed.is_initialized(): + if torch.distributed.get_rank() % 4 == 0: + original_logger(*args, **kwargs) + else: + original_logger(*args, **kwargs) _ENABLE_CUSTOM_ALL_REDUCE = True From cc89bfbb6d5728b7eb004ecb9d994162526247be Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Thu, 18 Jul 2024 20:33:41 -0700 Subject: [PATCH 049/303] only log when rank%4==0 --- vllm/distributed/parallel_state.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 18b63639ce087..56bdef65a882b 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -798,16 +798,16 @@ def graph_capture(): logger = init_logger(__name__) - -original_logger = logger -def logger(*args, **kwargs): - # disaggregated prefill enabled - # indicating if the current instance is prefill or decode - if torch.distributed.is_initialized(): - if torch.distributed.get_rank() % 4 == 0: - original_logger(*args, **kwargs) - else: - original_logger(*args, **kwargs) +class ConditionalLoggingHandler(logging.Handler): + def emit(self, record): + dist = torch.distributed + try: + if not dist.is_initialized() or (dist.is_initialized() and dist.get_rank() % 4 == 0): + msg = self.format(record) + print(msg) # You can replace this with any other logging mechanism you prefer + except Exception: + pass +logger.addHandler(handler) _ENABLE_CUSTOM_ALL_REDUCE = True From 531bdf3ca93e7005e621bac2f8127104359e2af0 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Thu, 18 Jul 2024 20:34:14 -0700 Subject: [PATCH 050/303] bug fix --- vllm/distributed/parallel_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 56bdef65a882b..260e7dcfd423f 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -807,7 +807,7 @@ def emit(self, record): print(msg) # You can replace this with any other logging mechanism you prefer except Exception: pass -logger.addHandler(handler) +logger.addHandler(ConditionalLoggingHandler()) _ENABLE_CUSTOM_ALL_REDUCE = True From 1804656ca72d77a16f81411feb3571932b3b16e9 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Thu, 18 Jul 2024 20:36:11 -0700 Subject: [PATCH 051/303] also only log when rank=4 in custom all reduce --- .../device_communicators/custom_all_reduce.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index a4f30808d32e1..b7d5af5a8a0a5 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -1,5 +1,6 @@ from contextlib import contextmanager from typing import Any, List, Optional, Union +import logging import torch import torch.distributed as dist @@ -22,6 +23,17 @@ logger = init_logger(__name__) +class ConditionalLoggingHandler(logging.Handler): + def emit(self, record): + dist = torch.distributed + try: + if not dist.is_initialized() or (dist.is_initialized() and dist.get_rank() % 4 == 0): + msg = self.format(record) + print(msg) # You can replace this with any other logging mechanism you prefer + except Exception: + pass +logger.addHandler(ConditionalLoggingHandler()) + def _can_p2p(rank: int, world_size: int) -> bool: for i in range(world_size): From 81c8640066b05c824af6db21ea8d1407f02e5a07 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Thu, 18 Jul 2024 21:44:33 -0700 Subject: [PATCH 052/303] add debuging statement around broadcast --- vllm/distributed/device_communicators/pynccl.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 7319566545678..36f1e04aec79f 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -41,6 +41,7 @@ def __init__( self.rank = dist.get_rank(group) self.world_size = dist.get_world_size(group) + # if world_size == 1, no need to create communicator if self.world_size == 1: self.available = False @@ -70,8 +71,11 @@ def __init__( self.unique_id = ncclUniqueId() tensor = torch.ByteTensor(list(self.unique_id.internal)) ranks = dist.get_process_group_ranks(group) + logger.debug("Group: %s, group rank: %s, world size: %s, src: %s", str(group), str(self.rank), str(self.world_size), ranks[0]) + # arg `src` in `broadcast` is the global rank dist.broadcast(tensor, src=ranks[0], group=group) + logger.debug("dist broadcast succeeded") byte_list = tensor.tolist() for i, byte in enumerate(byte_list): self.unique_id.internal[i] = byte From 5ba142c287fc8910441d8c9ce84f49f06585727c Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Thu, 18 Jul 2024 22:24:34 -0700 Subject: [PATCH 053/303] debug init_world_group --- vllm/distributed/parallel_state.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 260e7dcfd423f..de5e68822e122 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -145,6 +145,7 @@ def __init__( self.cpu_group = None for ranks in group_ranks: + logger.debug("initializing device group") device_group = torch.distributed.new_group( ranks, backend=torch_distributed_backend) logger.debug("device group initialized") @@ -889,6 +890,11 @@ def init_distributed_environment( local_rank = envs.LOCAL_RANK else: local_rank = rank + + if all([ + envs.VLLM_DISAGG_PREFILL_ROLE is not None, + envs.VLLM_DISAGG_PREFILL_ROLE == "prefill"]): + time.sleep(1000) global _WORLD if _WORLD is None: ranks = list(range(world_size)) @@ -898,7 +904,6 @@ def init_distributed_environment( envs.VLLM_DISAGG_PREFILL_ROLE == "decode"]): # sleep 10 seconds to avoid potential collisions # when initializing distributed environment - time.sleep(10) ranks = list(range(world_size, 2 * world_size)) _WORLD = init_world_group(ranks, local_rank, backend) From cc939cfeed19b95485b48199be61d941c7e5eaf1 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Thu, 18 Jul 2024 22:29:23 -0700 Subject: [PATCH 054/303] put the log inside a text file --- examples/disaggregated_prefill_example.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/disaggregated_prefill_example.sh b/examples/disaggregated_prefill_example.sh index bbcb11e863cdb..c06e90cb44326 100644 --- a/examples/disaggregated_prefill_example.sh +++ b/examples/disaggregated_prefill_example.sh @@ -11,7 +11,7 @@ VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ -tp 4 \ --disable-log-stats \ --disable-log-requests \ - --enable-chunked-prefill & + --enable-chunked-prefill > >(tee -a prefill.txt) 2> >(tee -a prefill.txt >&2) & # decoding instance VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ @@ -21,6 +21,6 @@ VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ -tp 4 \ --disable-log-stats \ --disable-log-requests \ - --enable-chunked-prefill + --enable-chunked-prefill > >(tee -a decode.txt) 2> >(tee -a decode.txt >&2) & From 8ac9266a711ca9e0ceff8abcc171b9896232ee70 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Fri, 19 Jul 2024 01:08:49 -0700 Subject: [PATCH 055/303] init DISAGG first --- vllm/distributed/parallel_state.py | 34 ++++++++++++++++-------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index de5e68822e122..58862574d493f 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -996,6 +996,24 @@ def initialize_model_parallel( group_ranks, world_size ) + + if envs.VLLM_DISAGG_PREFILL_ROLE is not None: + assert envs.VLLM_DISAGG_PREFILL_ROLE in ["prefill", "decode"], ( + "VLLM_DISAGG_PREFILL_ROLE should be either prefill or decode") + logger.debug("Disaggregated prefill enabled, create _DISAGG group") + group_ranks = [] + for i in range(world_size): + # prefill local rank: i + # decode global rank: i + world_size + group_ranks.append([i, i + world_size]) + logger.debug("Distributed group is %s", str(group_ranks)) + _DISAGG = init_model_parallel_group( + group_ranks, + int(envs.VLLM_DISAGG_PREFILL_ROLE == "decode"), + backend, + use_custom_allreduce=False) + + time.sleep(1000) # message queue broadcaster is only used in tensor model parallel group _TP = init_model_parallel_group(group_ranks, @@ -1025,22 +1043,6 @@ def initialize_model_parallel( get_world_group().local_rank, backend, use_custom_allreduce=False) - - if envs.VLLM_DISAGG_PREFILL_ROLE is not None: - assert envs.VLLM_DISAGG_PREFILL_ROLE in ["prefill", "decode"], ( - "VLLM_DISAGG_PREFILL_ROLE should be either prefill or decode") - logger.debug("Disaggregated prefill enabled, create _DISAGG group") - group_ranks = [] - for i in range(world_size): - # prefill local rank: i - # decode global rank: i + world_size - group_ranks.append([i, i + world_size]) - logger.debug("Distributed group is %s", str(group_ranks)) - _DISAGG = init_model_parallel_group( - group_ranks, - int(envs.VLLM_DISAGG_PREFILL_ROLE == "decode"), - backend, - use_custom_allreduce=False) def ensure_model_parallel_initialized( From 58849fa61b559d9785a50696d03f810aaa0788c4 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Fri, 19 Jul 2024 01:11:31 -0700 Subject: [PATCH 056/303] init DISAGG before global --- vllm/distributed/parallel_state.py | 36 ++++++++++++++++-------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 58862574d493f..a83f286c3b46e 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -962,6 +962,25 @@ def initialize_model_parallel( local rank. This guarantees the communications inside the prefill instance and decode instance are unchanged. """ + + if envs.VLLM_DISAGG_PREFILL_ROLE is not None: + assert envs.VLLM_DISAGG_PREFILL_ROLE in ["prefill", "decode"], ( + "VLLM_DISAGG_PREFILL_ROLE should be either prefill or decode") + logger.debug("Disaggregated prefill enabled, create _DISAGG group") + group_ranks = [] + for i in range(world_size): + # prefill local rank: i + # decode global rank: i + world_size + group_ranks.append([i, i + world_size]) + logger.debug("Distributed group is %s", str(group_ranks)) + _DISAGG = init_model_parallel_group( + group_ranks, + int(envs.VLLM_DISAGG_PREFILL_ROLE == "decode"), + backend, + use_custom_allreduce=False) + + time.sleep(1000) + # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() world_size: int = torch.distributed.get_world_size() @@ -997,23 +1016,6 @@ def initialize_model_parallel( world_size ) - if envs.VLLM_DISAGG_PREFILL_ROLE is not None: - assert envs.VLLM_DISAGG_PREFILL_ROLE in ["prefill", "decode"], ( - "VLLM_DISAGG_PREFILL_ROLE should be either prefill or decode") - logger.debug("Disaggregated prefill enabled, create _DISAGG group") - group_ranks = [] - for i in range(world_size): - # prefill local rank: i - # decode global rank: i + world_size - group_ranks.append([i, i + world_size]) - logger.debug("Distributed group is %s", str(group_ranks)) - _DISAGG = init_model_parallel_group( - group_ranks, - int(envs.VLLM_DISAGG_PREFILL_ROLE == "decode"), - backend, - use_custom_allreduce=False) - - time.sleep(1000) # message queue broadcaster is only used in tensor model parallel group _TP = init_model_parallel_group(group_ranks, From 08797e263e79b74434c6114bf39edeaae5799647 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Fri, 19 Jul 2024 01:14:05 -0700 Subject: [PATCH 057/303] put it behind world_size --- vllm/distributed/parallel_state.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index a83f286c3b46e..a9e4008e18b08 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -963,6 +963,18 @@ def initialize_model_parallel( prefill instance and decode instance are unchanged. """ + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + backend = backend or torch.distributed.get_backend( + get_world_group().device_group) + if envs.VLLM_DISAGG_PREFILL_ROLE is not None: + logger.debug("Disaggregated prefill enabled, the world size obtained from torch.distributed (2 * tp * pp) should be decreased to align with vLLM world size (tp * pp)") + world_size = world_size // 2 + + + + if envs.VLLM_DISAGG_PREFILL_ROLE is not None: assert envs.VLLM_DISAGG_PREFILL_ROLE in ["prefill", "decode"], ( "VLLM_DISAGG_PREFILL_ROLE should be either prefill or decode") @@ -981,14 +993,9 @@ def initialize_model_parallel( time.sleep(1000) - # Get world size and rank. Ensure some consistencies. - assert torch.distributed.is_initialized() - world_size: int = torch.distributed.get_world_size() - backend = backend or torch.distributed.get_backend( - get_world_group().device_group) - if envs.VLLM_DISAGG_PREFILL_ROLE is not None: - logger.debug("Disaggregated prefill enabled, the world size obtained from torch.distributed (2 * tp * pp) should be decreased to align with vLLM world size (tp * pp)") - world_size = world_size // 2 + + + if (world_size != tensor_model_parallel_size * pipeline_model_parallel_size): From 4ff4cd69f61e69a89f52d5f355a7dd333f4d962e Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Fri, 19 Jul 2024 01:18:16 -0700 Subject: [PATCH 058/303] add more debug information in pynccl --- vllm/distributed/parallel_state.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index a9e4008e18b08..a21a64ece37be 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -175,10 +175,11 @@ def __init__( from vllm.distributed.device_communicators.custom_all_reduce import ( CustomAllreduce) from vllm.distributed.device_communicators.pynccl import ( - PyNcclCommunicator) + PyNcclCommunicator)) self.pynccl_comm: Optional[PyNcclCommunicator] if use_pynccl and self.world_size > 1: + logger.debug("Before pynccl") self.pynccl_comm = PyNcclCommunicator( group=self.cpu_group, device=self.device, From b09e4e601cd39a5d56083d59529ab29ade210af2 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Fri, 19 Jul 2024 01:18:41 -0700 Subject: [PATCH 059/303] typo fix --- vllm/distributed/parallel_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index a21a64ece37be..eba1e4af9bc13 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -175,7 +175,7 @@ def __init__( from vllm.distributed.device_communicators.custom_all_reduce import ( CustomAllreduce) from vllm.distributed.device_communicators.pynccl import ( - PyNcclCommunicator)) + PyNcclCommunicator) self.pynccl_comm: Optional[PyNcclCommunicator] if use_pynccl and self.world_size > 1: From 583de9753dc09f45af7d656e1a928c6715c5c8d3 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Fri, 19 Jul 2024 01:20:39 -0700 Subject: [PATCH 060/303] more debug --- vllm/distributed/parallel_state.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index eba1e4af9bc13..bc9890b0aff58 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -177,6 +177,8 @@ def __init__( from vllm.distributed.device_communicators.pynccl import ( PyNcclCommunicator) + logger.debug("Oh plz") + self.pynccl_comm: Optional[PyNcclCommunicator] if use_pynccl and self.world_size > 1: logger.debug("Before pynccl") From 74bcffffd922da8a1d1d9c1b645aabf7f80dcfcf Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Fri, 19 Jul 2024 01:23:08 -0700 Subject: [PATCH 061/303] more debug info --- vllm/distributed/parallel_state.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index bc9890b0aff58..e5fcde36f33f9 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -145,14 +145,14 @@ def __init__( self.cpu_group = None for ranks in group_ranks: - logger.debug("initializing device group") + logger.debug("initializing device group, rank %d", self.rank) device_group = torch.distributed.new_group( ranks, backend=torch_distributed_backend) - logger.debug("device group initialized") + logger.debug("device group initialized, rank %d", self.rank) # a group with `gloo` backend, to allow direct coordination between # processes through the CPU. cpu_group = torch.distributed.new_group(ranks, backend="gloo") - logger.debug("cpu group initialized") + logger.debug("cpu group initialized, rank %d", self.rank) if self.rank in ranks: self.ranks = ranks self.world_size = len(ranks) @@ -162,11 +162,15 @@ def __init__( assert self.cpu_group is not None assert self.device_group is not None + + logger.debug("Here 166 , rank %d", self.rank) if torch.cuda.is_available(): self.device = torch.device(f"cuda:{local_rank}") else: self.device = torch.device("cpu") + + logger.debug("Here 173 , rank %d", self.rank) self.use_pynccl = use_pynccl self.use_custom_allreduce = use_custom_allreduce From 21758259da88fd94c77155acaaf9c24cbefa67e1 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Fri, 19 Jul 2024 01:24:38 -0700 Subject: [PATCH 062/303] put every output --- examples/disaggregated_prefill_example.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/disaggregated_prefill_example.sh b/examples/disaggregated_prefill_example.sh index c06e90cb44326..5c9b652f9e618 100644 --- a/examples/disaggregated_prefill_example.sh +++ b/examples/disaggregated_prefill_example.sh @@ -11,7 +11,7 @@ VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ -tp 4 \ --disable-log-stats \ --disable-log-requests \ - --enable-chunked-prefill > >(tee -a prefill.txt) 2> >(tee -a prefill.txt >&2) & + --enable-chunked-prefill & # decoding instance VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ @@ -21,6 +21,6 @@ VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ -tp 4 \ --disable-log-stats \ --disable-log-requests \ - --enable-chunked-prefill > >(tee -a decode.txt) 2> >(tee -a decode.txt >&2) & + --enable-chunked-prefill & From 3e0777009f9742e41aad42660ba10d14dfd097df Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Fri, 19 Jul 2024 01:25:42 -0700 Subject: [PATCH 063/303] remove unnecessary sleep --- vllm/distributed/parallel_state.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index e5fcde36f33f9..01dfd37a653d0 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -898,10 +898,6 @@ def init_distributed_environment( else: local_rank = rank - if all([ - envs.VLLM_DISAGG_PREFILL_ROLE is not None, - envs.VLLM_DISAGG_PREFILL_ROLE == "prefill"]): - time.sleep(1000) global _WORLD if _WORLD is None: ranks = list(range(world_size)) @@ -909,7 +905,6 @@ def init_distributed_environment( if all([ envs.VLLM_DISAGG_PREFILL_ROLE is not None, envs.VLLM_DISAGG_PREFILL_ROLE == "decode"]): - # sleep 10 seconds to avoid potential collisions # when initializing distributed environment ranks = list(range(world_size, 2 * world_size)) From a22e5cdfa59704de07fcf99fd558a2539eb7ca68 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Fri, 19 Jul 2024 01:37:20 -0700 Subject: [PATCH 064/303] add sucess statement --- vllm/distributed/parallel_state.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 01dfd37a653d0..f06091d04ab68 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -993,6 +993,7 @@ def initialize_model_parallel( backend, use_custom_allreduce=False) + logger.debug("Success") time.sleep(1000) From 2c0c27dc6ef45b3ad172f75c4bbfe2efc9140e69 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Fri, 19 Jul 2024 01:45:10 -0700 Subject: [PATCH 065/303] add debug statement --- vllm/distributed/parallel_state.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index f06091d04ab68..637318f36f2c5 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -145,14 +145,17 @@ def __init__( self.cpu_group = None for ranks in group_ranks: - logger.debug("initializing device group, rank %d", self.rank) + if self.rank in ranks: + logger.debug("initializing device group, rank %d", self.rank) device_group = torch.distributed.new_group( ranks, backend=torch_distributed_backend) - logger.debug("device group initialized, rank %d", self.rank) + if self.rank in ranks: + logger.debug("device group initialized, rank %d", self.rank) # a group with `gloo` backend, to allow direct coordination between # processes through the CPU. cpu_group = torch.distributed.new_group(ranks, backend="gloo") - logger.debug("cpu group initialized, rank %d", self.rank) + if self.rank in ranks: + logger.debug("cpu group initialized, rank %d", self.rank) if self.rank in ranks: self.ranks = ranks self.world_size = len(ranks) From a783787ae4cf931b7d235c1776e1cf6e75816329 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Fri, 19 Jul 2024 10:56:08 -0700 Subject: [PATCH 066/303] log rank in success message --- vllm/distributed/parallel_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 637318f36f2c5..be3021b10d21d 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -996,7 +996,7 @@ def initialize_model_parallel( backend, use_custom_allreduce=False) - logger.debug("Success") + logger.debug("Success, rank %d", torch.distributed.get_rank()) time.sleep(1000) From 79f0b06500e40d6c79d2c6b24e9e72658364dd3d Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Fri, 19 Jul 2024 10:58:26 -0700 Subject: [PATCH 067/303] sleep based on rank to avoid message overlapping --- vllm/distributed/parallel_state.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index be3021b10d21d..a874ea8ca605f 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -976,6 +976,8 @@ def initialize_model_parallel( if envs.VLLM_DISAGG_PREFILL_ROLE is not None: logger.debug("Disaggregated prefill enabled, the world size obtained from torch.distributed (2 * tp * pp) should be decreased to align with vLLM world size (tp * pp)") world_size = world_size // 2 + + time.sleep(torch.distributed.get_rank()) From b17f20f34dbe59f00bab812c32186b1521746616 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Fri, 19 Jul 2024 16:07:41 -0700 Subject: [PATCH 068/303] increase torch debug level --- examples/disaggregated_prefill_example.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/disaggregated_prefill_example.sh b/examples/disaggregated_prefill_example.sh index 5c9b652f9e618..cfc183675917f 100644 --- a/examples/disaggregated_prefill_example.sh +++ b/examples/disaggregated_prefill_example.sh @@ -2,6 +2,7 @@ export VLLM_LOGGING_LEVEL=DEBUG export VLLM_PORT=12345 export VLLM_TRACE_FUNCTION=1 +export TORCH_DISTRIBUTED_DEBUG=DETAIL # prefilling instance VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ From 025f20941aa8b59453781cc86126a5b39cdd2044 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Fri, 19 Jul 2024 16:11:05 -0700 Subject: [PATCH 069/303] sleep --- vllm/distributed/parallel_state.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index a874ea8ca605f..534d8bbcedce8 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -900,6 +900,8 @@ def init_distributed_environment( local_rank = envs.LOCAL_RANK else: local_rank = rank + + global _WORLD if _WORLD is None: @@ -915,6 +917,11 @@ def init_distributed_environment( else: assert _WORLD.world_size == torch.distributed.get_world_size(), ( "world group already initialized with a different world size") + + + time.sleep(torch.distributed.get_rank()) + logger.debug("Success initialized _WORLD for rank %d", torch.distributed.get_rank()) + time.sleep(100) From 32292f1272c06d5fec7c2393694056b1cdd93c60 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Fri, 19 Jul 2024 16:17:21 -0700 Subject: [PATCH 070/303] set gloo debugging level to trace --- examples/disaggregated_prefill_example.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/disaggregated_prefill_example.sh b/examples/disaggregated_prefill_example.sh index cfc183675917f..653bb72c4dce3 100644 --- a/examples/disaggregated_prefill_example.sh +++ b/examples/disaggregated_prefill_example.sh @@ -3,6 +3,7 @@ export VLLM_LOGGING_LEVEL=DEBUG export VLLM_PORT=12345 export VLLM_TRACE_FUNCTION=1 export TORCH_DISTRIBUTED_DEBUG=DETAIL +export GLOO_LOGGING_LEVEL=TRACE # prefilling instance VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ From 389fb24c5b1922a8a8ca0ef9dd646eacb5f4f947 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Fri, 19 Jul 2024 16:20:40 -0700 Subject: [PATCH 071/303] reduce debugging commands --- vllm/distributed/parallel_state.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 534d8bbcedce8..35844996b6174 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -145,12 +145,12 @@ def __init__( self.cpu_group = None for ranks in group_ranks: - if self.rank in ranks: - logger.debug("initializing device group, rank %d", self.rank) device_group = torch.distributed.new_group( ranks, backend=torch_distributed_backend) if self.rank in ranks: - logger.debug("device group initialized, rank %d", self.rank) + import time + time.sleep(self.rank) + logger.debug("initializing cpu group, rank %d", self.rank) # a group with `gloo` backend, to allow direct coordination between # processes through the CPU. cpu_group = torch.distributed.new_group(ranks, backend="gloo") @@ -166,14 +166,12 @@ def __init__( assert self.cpu_group is not None assert self.device_group is not None - logger.debug("Here 166 , rank %d", self.rank) if torch.cuda.is_available(): self.device = torch.device(f"cuda:{local_rank}") else: self.device = torch.device("cpu") - logger.debug("Here 173 , rank %d", self.rank) self.use_pynccl = use_pynccl self.use_custom_allreduce = use_custom_allreduce @@ -184,16 +182,13 @@ def __init__( from vllm.distributed.device_communicators.pynccl import ( PyNcclCommunicator) - logger.debug("Oh plz") self.pynccl_comm: Optional[PyNcclCommunicator] if use_pynccl and self.world_size > 1: - logger.debug("Before pynccl") self.pynccl_comm = PyNcclCommunicator( group=self.cpu_group, device=self.device, ) - logger.debug("Pynccl initialized") else: self.pynccl_comm = None @@ -919,7 +914,6 @@ def init_distributed_environment( "world group already initialized with a different world size") - time.sleep(torch.distributed.get_rank()) logger.debug("Success initialized _WORLD for rank %d", torch.distributed.get_rank()) time.sleep(100) From 1b38b298b26afdf8e3a3fcec1804ab9f91d4ef0e Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Fri, 19 Jul 2024 16:24:22 -0700 Subject: [PATCH 072/303] avoid initializing NCCL first --- vllm/distributed/parallel_state.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 35844996b6174..92ad807392b77 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -145,8 +145,9 @@ def __init__( self.cpu_group = None for ranks in group_ranks: - device_group = torch.distributed.new_group( - ranks, backend=torch_distributed_backend) + # device_group = torch.distributed.new_group( + # ranks, backend=torch_distributed_backend) + device_group = 233 if self.rank in ranks: import time time.sleep(self.rank) From bb8c08a42c3bbd4939a8349fde78d620fa6a1e18 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Fri, 19 Jul 2024 16:31:49 -0700 Subject: [PATCH 073/303] check --- vllm/distributed/parallel_state.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 92ad807392b77..b6a2a13acc6d2 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -897,6 +897,10 @@ def init_distributed_environment( else: local_rank = rank + + logger.debug("My rank is %d", torch.distributed.get_rank()) + time.sleep(20) + global _WORLD From 25a7cf332a2bf505deb078e44ffbcf85e1436add Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Fri, 19 Jul 2024 16:44:09 -0700 Subject: [PATCH 074/303] locate the hanging line --- vllm/distributed/parallel_state.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index b6a2a13acc6d2..5680a7d6cf58c 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -899,7 +899,11 @@ def init_distributed_environment( logger.debug("My rank is %d", torch.distributed.get_rank()) - time.sleep(20) + + cpu_group = torch.distributed.new_group(list(range(8)), backend="gloo") + + logger.debug("CPU group initialized") + time.sleep(1000) @@ -919,8 +923,6 @@ def init_distributed_environment( "world group already initialized with a different world size") - logger.debug("Success initialized _WORLD for rank %d", torch.distributed.get_rank()) - time.sleep(100) From 999bd729d83920a0fa391193b94123dc85f75483 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Fri, 19 Jul 2024 16:45:40 -0700 Subject: [PATCH 075/303] add rank to CPU group --- vllm/distributed/parallel_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 5680a7d6cf58c..cfcf5b6c5c49c 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -902,7 +902,7 @@ def init_distributed_environment( cpu_group = torch.distributed.new_group(list(range(8)), backend="gloo") - logger.debug("CPU group initialized") + logger.debug("CPU group initialized, rank %d", torch.distributed.get_rank()) time.sleep(1000) From 3428ea667a8b41ba3b9f168e408af753f8f4e82a Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Fri, 19 Jul 2024 16:49:59 -0700 Subject: [PATCH 076/303] narrow case --- vllm/distributed/parallel_state.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index cfcf5b6c5c49c..f44381deb37c6 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -147,16 +147,16 @@ def __init__( for ranks in group_ranks: # device_group = torch.distributed.new_group( # ranks, backend=torch_distributed_backend) - device_group = 233 if self.rank in ranks: import time time.sleep(self.rank) - logger.debug("initializing cpu group, rank %d", self.rank) + logger.debug("initializing cpu group, rank %d, group %s", self.rank, ranks) # a group with `gloo` backend, to allow direct coordination between # processes through the CPU. cpu_group = torch.distributed.new_group(ranks, backend="gloo") if self.rank in ranks: - logger.debug("cpu group initialized, rank %d", self.rank) + logger.debug("cpu group initialized, rank %d, group %s", self.rank, ranks) + time.sleep(1000) if self.rank in ranks: self.ranks = ranks self.world_size = len(ranks) @@ -898,12 +898,6 @@ def init_distributed_environment( local_rank = rank - logger.debug("My rank is %d", torch.distributed.get_rank()) - - cpu_group = torch.distributed.new_group(list(range(8)), backend="gloo") - - logger.debug("CPU group initialized, rank %d", torch.distributed.get_rank()) - time.sleep(1000) From 91e3ed2dbdfc89b262d950d0c522426f150b8469 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Fri, 19 Jul 2024 17:42:39 -0700 Subject: [PATCH 077/303] bug fix: need to align the distributed groups between prefill and decode instances --- examples/disaggregated_prefill_example.sh | 6 +- vllm/distributed/parallel_state.py | 173 +++++++++------------- 2 files changed, 73 insertions(+), 106 deletions(-) diff --git a/examples/disaggregated_prefill_example.sh b/examples/disaggregated_prefill_example.sh index 653bb72c4dce3..c0009ae0b17d9 100644 --- a/examples/disaggregated_prefill_example.sh +++ b/examples/disaggregated_prefill_example.sh @@ -1,9 +1,9 @@ export VLLM_LOGGING_LEVEL=DEBUG export VLLM_PORT=12345 -export VLLM_TRACE_FUNCTION=1 -export TORCH_DISTRIBUTED_DEBUG=DETAIL -export GLOO_LOGGING_LEVEL=TRACE +# export VLLM_TRACE_FUNCTION=1 +# export TORCH_DISTRIBUTED_DEBUG=DETAIL +# export GLOO_LOGGING_LEVEL=TRACE # prefilling instance VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index f44381deb37c6..7be0fefc3d330 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -145,18 +145,11 @@ def __init__( self.cpu_group = None for ranks in group_ranks: - # device_group = torch.distributed.new_group( - # ranks, backend=torch_distributed_backend) - if self.rank in ranks: - import time - time.sleep(self.rank) - logger.debug("initializing cpu group, rank %d, group %s", self.rank, ranks) + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend) # a group with `gloo` backend, to allow direct coordination between # processes through the CPU. cpu_group = torch.distributed.new_group(ranks, backend="gloo") - if self.rank in ranks: - logger.debug("cpu group initialized, rank %d, group %s", self.rank, ranks) - time.sleep(1000) if self.rank in ranks: self.ranks = ranks self.world_size = len(ranks) @@ -722,10 +715,10 @@ def get_world_group() -> GroupCoordinator: return _WORLD -def init_world_group(ranks: List[int], local_rank: int, +def init_world_group(ranks: List[List[int]], local_rank: int, backend: str) -> GroupCoordinator: return GroupCoordinator( - group_ranks=[ranks], + group_ranks=ranks, local_rank=local_rank, torch_distributed_backend=backend, use_pynccl=False, @@ -824,27 +817,33 @@ def set_custom_all_reduce(enable: bool): _ENABLE_CUSTOM_ALL_REDUCE = enable -def offset_distributed_groups( +def include_decoding_groups_if_disagg_enabled( groups: List[List[int]], - offset: int, + world_size: int, ) -> List[List[int]]: """ - Extend original distributed group. - The extended part will be the original distributed group plus an offset. + Include the distributed group for decode + Only for disaggregated prefill + Example: + Original group: [ [0,1], [2,3] ], world_size = 4 + Extended: [ [0,1], [2,3], [4,5], [6,7] ] Arguments: groups: original distributed group - offset: the offset we want to apply to the duplicated group. - Typically world_size // 2 + world_size: the vLLM world size, which is half of torch.distributed.get_world_size() """ - - logger.debug("Offset distributed groups with offset %d", offset) - - new_groups = [] - for group in groups: - new_groups.append([rank + offset for rank in group]) - - return new_groups + + if envs.VLLM_DISAGG_PREFILL_ROLE is not None: + assert envs.VLLM_DISAGG_PREFILL_ROLE in ["prefill", "decode"], ( + "VLLM_DISAGG_PREFILL_ROLE should be either prefill or decode") + new_groups = [] + for group in groups: + new_groups.append([rank for rank in group]) + for group in groups: + new_groups.append([rank + world_size for rank in group]) + return new_groups + else: + return groups def init_distributed_environment( @@ -863,24 +862,21 @@ def init_distributed_environment( "distributed_init_method must be provided when initializing " "distributed environment") # this backend is used for WORLD + maybe_disagg_world_size = world_size + maybe_disagg_rank = rank if envs.VLLM_DISAGG_PREFILL_ROLE is not None: - # Disaggregated prefill is enabled - # world_size in vLLM is tp * pp - # for prefill, the ranks are [0, world_size) - # for decode, the ranks are [world_size, 2 * world_size) maybe_disagg_world_size = world_size * 2 logger.debug( - "Disaggregated prefill enabled, handle torch-related changes on world size and ranks. This change is only inside `vllm/distributed/parallel_state.py`) and the other files are unchanged.") + "Disaggregated prefill enabled.") assert envs.VLLM_DISAGG_PREFILL_ROLE in ["prefill", "decode"], ( "VLLM_DISAGG_PREFILL_ROLE should be either prefill or decode") if envs.VLLM_DISAGG_PREFILL_ROLE == "prefill": + # for prefill, the ranks are [0, world_size) maybe_disagg_rank = rank else: # offset global rank by tp * pp (which is world_size) maybe_disagg_rank = rank + world_size - else: - maybe_disagg_world_size = world_size - maybe_disagg_rank = rank + torch.distributed.init_process_group( backend=backend, init_method=distributed_init_method, @@ -896,22 +892,18 @@ def init_distributed_environment( local_rank = envs.LOCAL_RANK else: local_rank = rank - - - global _WORLD if _WORLD is None: - ranks = list(range(world_size)) + ranks = [[i for i in range(world_size)]] # offset the distributed group - if all([ - envs.VLLM_DISAGG_PREFILL_ROLE is not None, - envs.VLLM_DISAGG_PREFILL_ROLE == "decode"]): - # when initializing distributed environment - ranks = list(range(world_size, 2 * world_size)) + if envs.VLLM_DISAGG_PREFILL_ROLE is not None: + ranks = include_decoding_groups_if_disagg_enabled(ranks, world_size) _WORLD = init_world_group(ranks, local_rank, backend) + logger.debug("_WORLD initialized for rank %d", torch.distributed.get_rank()) + time.sleep(5) else: assert _WORLD.world_size == torch.distributed.get_world_size(), ( "world group already initialized with a different world size") @@ -947,27 +939,21 @@ def initialize_model_parallel( with a total of 16 GPUs, rank 0 to 7 belong to the first box and ranks 8 to 15 belong to the second box. - Disaggregated prefill will also initialize using this function. - Why: disaggregated prefill is similar to pipeline parallel - except that disaggregated prefill does not partition model - Methodology: - - Only change variables in this file - - Any variable outside this file should be unchanged - Modifications: - - World size in vLLM variables (like in `ParallelConfig`): unchanged - - World size in `torch.distributed`: doubled (2 * tp * pp) - - Rank: + + Disaggregated prefill will also initialize its process group using this function. + Changes: + - vLLM world size: unchanged (tp * pp) + - torch.distributed.get_world_size(): + - 2 * tp * pp + - Why: torch.distributed package sees 2 vLLM instances (prefill and decode) + - Global rank: - [0, tp * pp) for prefill - [tp * pp, 2 * tp * pp) for decode - Parallel groups - - Unchanged for prefill - - Offseted by tp * pp for decode + - Extend _WORLD, _TP and _PP using `include_decoding_groups_if_disagg_enabled` - Add a new parallel group `_DISAGG` for disaggregated prefill - - [0, tp * pp], [1, tp * pp + 1], .. + - [ [0, tp * pp], [1, tp * pp + 1], .. ] - Local rank: unchanged - - Thanks to PP implementation, distributed operations only rely on - local rank. This guarantees the communications inside the - prefill instance and decode instance are unchanged. """ # Get world size and rank. Ensure some consistencies. @@ -976,37 +962,11 @@ def initialize_model_parallel( backend = backend or torch.distributed.get_backend( get_world_group().device_group) if envs.VLLM_DISAGG_PREFILL_ROLE is not None: - logger.debug("Disaggregated prefill enabled, the world size obtained from torch.distributed (2 * tp * pp) should be decreased to align with vLLM world size (tp * pp)") + # Disaggregated prefill enabled + # The world_size for this vLLM instance is tp * pp, but torch.distributed contains 2 vLLM instances, its world size is 2 * tp * pp + # Adjust the world_size to match. world_size = world_size // 2 - time.sleep(torch.distributed.get_rank()) - - - - - if envs.VLLM_DISAGG_PREFILL_ROLE is not None: - assert envs.VLLM_DISAGG_PREFILL_ROLE in ["prefill", "decode"], ( - "VLLM_DISAGG_PREFILL_ROLE should be either prefill or decode") - logger.debug("Disaggregated prefill enabled, create _DISAGG group") - group_ranks = [] - for i in range(world_size): - # prefill local rank: i - # decode global rank: i + world_size - group_ranks.append([i, i + world_size]) - logger.debug("Distributed group is %s", str(group_ranks)) - _DISAGG = init_model_parallel_group( - group_ranks, - int(envs.VLLM_DISAGG_PREFILL_ROLE == "decode"), - backend, - use_custom_allreduce=False) - - logger.debug("Success, rank %d", torch.distributed.get_rank()) - time.sleep(1000) - - - - - if (world_size != tensor_model_parallel_size * pipeline_model_parallel_size): raise RuntimeError( @@ -1025,20 +985,14 @@ def initialize_model_parallel( range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)) group_ranks.append(ranks) - if envs.VLLM_DISAGG_PREFILL_ROLE == "decode": - logger.debug("Current instance is decode instance") - logger.debug("Offset the _TP ranks by %d", world_size) - group_ranks = offset_distributed_groups( - group_ranks, - world_size - ) - - + group_ranks = include_decoding_groups_if_disagg_enabled(group_ranks, world_size) # message queue broadcaster is only used in tensor model parallel group _TP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_message_queue_broadcaster=True) + logger.debug("_TP initialized for rank %d", torch.distributed.get_rank()) + time.sleep(5) # Build the pipeline model-parallel groups. num_pipeline_model_parallel_groups: int = (world_size // @@ -1050,18 +1004,31 @@ def initialize_model_parallel( for i in range(num_pipeline_model_parallel_groups): ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) group_ranks.append(ranks) - if envs.VLLM_DISAGG_PREFILL_ROLE == "decode": - logger.debug("Current instance is decode instance") - logger.debug("Offset the _PP ranks by %d", world_size) - group_ranks = offset_distributed_groups( - group_ranks, - world_size - ) + group_ranks = include_decoding_groups_if_disagg_enabled(group_ranks, world_size) # pipeline parallel does not need custom allreduce _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False) + logger.debug("_PP initialized for rank %d", torch.distributed.get_rank()) + time.sleep(5) + + if envs.VLLM_DISAGG_PREFILL_ROLE is not None: + global _DISAGG + logger.debug("Disaggregated prefill enabled, create _DISAGG group") + group_ranks = [] + for i in range(world_size): + # prefill local rank: i + # decode global rank: i + world_size + group_ranks.append([i, i + world_size]) + logger.debug("Distributed group is %s", str(group_ranks)) + _DISAGG = init_model_parallel_group( + group_ranks, + int(envs.VLLM_DISAGG_PREFILL_ROLE == "decode"), + backend, + use_custom_allreduce=False) + logger.debug("_DISAGG initialized for rank %d", torch.distributed.get_rank()) + time.sleep(5) def ensure_model_parallel_initialized( From 3dd2275cea617f85925ac854a0674a2f2d83d4e8 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 22 Jul 2024 18:56:01 -0700 Subject: [PATCH 078/303] add disaggregated prefilling for flashinfer --- vllm/attention/backends/flashinfer.py | 55 +++++++++++++++++++++++---- 1 file changed, 48 insertions(+), 7 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index daff76051a956..0a76962bdd91a 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -23,11 +23,18 @@ from vllm.sequence import SequenceGroupMetadata from vllm.utils import get_kv_cache_torch_dtype, make_tensor_with_pad +# This group is used for KV cache transfer in disaggregated prefilling +from vllm.distributed import get_disagg_group + +# To identify if the VLLM_DISAGG_PREFILL_ROLE is set or no +import vllm.envs as envs + if TYPE_CHECKING: from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPUBuilder) + class FlashInferBackend(AttentionBackend): @staticmethod @@ -479,6 +486,20 @@ def forward( if attn_metadata.num_decode_tokens > 0: assert attn_metadata.num_prefill_tokens == 0, ( "Chunked prefill is not supported with flashinfer yet.") + + prefill_meta = attn_metadata.prefill_metadata + + if all([ + kv_cache is not None, # we are not in profile run + prefill_meta is not None, # during prefill stage + envs.VLLM_DISAGG_PREFILL_ROLE is not None, # disagg prefill enabled + ]): + if envs.VLLM_DISAGG_PREFILL_ROLE == "prefill": + get_disagg_group().send(key) + get_disagg_group().send(value) + else: + key = get_disagg_group().recv(key.shape, key.dtype) + value = get_disagg_group().recv(value.shape, value.dtype) if kv_cache is not None: # Use the same reshape and cache kernel as flash attention. @@ -493,7 +514,7 @@ def forward( query = query.contiguous( ) # Flashinfer requires query to be contiguous - if prefill_meta := attn_metadata.prefill_metadata: + if prefill_meta is not None: # We will use flash attention for prefill # when kv_cache is not provided. # This happens when vllm runs the profiling to @@ -515,11 +536,26 @@ def forward( else: assert prefill_meta is not None assert prefill_meta.prefill_wrapper is not None - output = prefill_meta.prefill_wrapper.forward( - query, - kv_cache, - logits_soft_cap=attn_metadata.logits_soft_cap, - causal=True) + + if not all([ + envs.VLLM_DISAGG_PREFILL_ROLE is not None, + envs.VLLM_DISAGG_PREFILL_ROLE == "decode", + ]): # Only skip prefill for disagg decode instance + output = prefill_meta.prefill_wrapper.forward( + query, + kv_cache, + logits_soft_cap=attn_metadata.logits_soft_cap, + causal=True) + output = output.view(num_tokens, hidden_size) + + if envs.VLLM_DISAGG_PREFILL_ROLE is not None: + # communication for disaggregated prefill. + if envs.VLLM_DISAGG_PREFILL_ROLE == "prefill": + get_disagg_group().send(output) + else: + # Kuntai: This assume that output has the same dtype as key + # Is this assumption true? + output = get_disagg_group().recv([num_tokens, hidden_size], key.dtype) else: assert attn_metadata.decode_metadata is not None assert attn_metadata.decode_metadata.decode_wrapper is not None @@ -528,4 +564,9 @@ def forward( kv_cache, sm_scale=self.scale, logits_soft_cap=attn_metadata.logits_soft_cap) - return output.view(num_tokens, hidden_size) + output = output.view(num_tokens, hidden_size) + + + + + return output From 2b13f3ca1803444262542d49315bd01a78092c8b Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 22 Jul 2024 18:56:38 -0700 Subject: [PATCH 079/303] adjust comments --- vllm/distributed/parallel_state.py | 12 +++++------- vllm/envs.py | 2 +- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 7be0fefc3d330..5188b71d95b9d 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -764,7 +764,11 @@ def get_pp_group() -> GroupCoordinator: "pipeline model parallel group is not initialized") return _PP - + +# kept for backward compatibility +get_pipeline_model_parallel_group = get_pp_group + + _DISAGG: Optional[GroupCoordinator] = None def get_disagg_group() -> GroupCoordinator: @@ -773,9 +777,6 @@ def get_disagg_group() -> GroupCoordinator: return _DISAGG -# kept for backward compatibility -get_pipeline_model_parallel_group = get_pp_group - @contextmanager def graph_capture(): @@ -992,7 +993,6 @@ def initialize_model_parallel( backend, use_message_queue_broadcaster=True) logger.debug("_TP initialized for rank %d", torch.distributed.get_rank()) - time.sleep(5) # Build the pipeline model-parallel groups. num_pipeline_model_parallel_groups: int = (world_size // @@ -1011,7 +1011,6 @@ def initialize_model_parallel( backend, use_custom_allreduce=False) logger.debug("_PP initialized for rank %d", torch.distributed.get_rank()) - time.sleep(5) if envs.VLLM_DISAGG_PREFILL_ROLE is not None: global _DISAGG @@ -1028,7 +1027,6 @@ def initialize_model_parallel( backend, use_custom_allreduce=False) logger.debug("_DISAGG initialized for rank %d", torch.distributed.get_rank()) - time.sleep(5) def ensure_model_parallel_initialized( diff --git a/vllm/envs.py b/vllm/envs.py index 0952b299a4a67..d409bba2292fa 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -297,7 +297,7 @@ def get_default_config_root(): lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "65536")), # Specify the role of current vllm instance - # Value can be "prefill", "decode" or None. + # Value can be "prefill", "decode". "VLLM_DISAGG_PREFILL_ROLE": lambda: os.getenv("VLLM_DISAGG_PREFILL_ROLE", None), From 8c3f209f35b30df74dc5373c68ffe7be7dd1dc43 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 22 Jul 2024 19:30:57 -0700 Subject: [PATCH 080/303] add logging for send and recv --- vllm/distributed/parallel_state.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 5188b71d95b9d..7889b4829648e 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -666,6 +666,8 @@ def barrier(self): def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: """Sends a tensor to the destination rank in a non-blocking way""" """NOTE: `dst` is the local rank of the destination rank.""" + + logger.debug("Sending tensor: ", tensor.shape, tensor.dtype, dst) if dst is None: dst = (self.rank_in_group + 1) % self.world_size @@ -681,6 +683,8 @@ def recv(self, src: Optional[int] = None) -> torch.Tensor: """Receives a tensor from the src rank.""" """NOTE: `src` is the local rank of the destination rank.""" + + logger.debug("Recving tensor: ", size, dtype, src) if src is None: src = (self.rank_in_group - 1) % self.world_size From c6a5e5759c1b843db2a4223b626c345b37d5ef9f Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 22 Jul 2024 19:46:23 -0700 Subject: [PATCH 081/303] turn off chunked prefill to use flashinfer kernel --- examples/disaggregated_prefill_example.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/disaggregated_prefill_example.sh b/examples/disaggregated_prefill_example.sh index c0009ae0b17d9..47ed5434a0ce1 100644 --- a/examples/disaggregated_prefill_example.sh +++ b/examples/disaggregated_prefill_example.sh @@ -13,7 +13,7 @@ VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ -tp 4 \ --disable-log-stats \ --disable-log-requests \ - --enable-chunked-prefill & + --enable-prefix-caching & # decoding instance VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ @@ -23,6 +23,6 @@ VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ -tp 4 \ --disable-log-stats \ --disable-log-requests \ - --enable-chunked-prefill & + --enable-prefix-caching & From b3c47f3c5db69636db77ccf01617e1d0c0c55513 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 22 Jul 2024 19:49:50 -0700 Subject: [PATCH 082/303] confirm which backend is being used --- vllm/attention/backends/flash_attn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index b8a64205b362b..ddf99802fd999 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -473,6 +473,7 @@ def forward( # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. + print("I am in flash attn") ops.reshape_and_cache_flash( key, value, From f05540c65d46b176c84b1dd645c989921042ea01 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 22 Jul 2024 19:54:33 -0700 Subject: [PATCH 083/303] remove debugging from parallel_state, its too much... --- vllm/distributed/parallel_state.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 7889b4829648e..5188b71d95b9d 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -666,8 +666,6 @@ def barrier(self): def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: """Sends a tensor to the destination rank in a non-blocking way""" """NOTE: `dst` is the local rank of the destination rank.""" - - logger.debug("Sending tensor: ", tensor.shape, tensor.dtype, dst) if dst is None: dst = (self.rank_in_group + 1) % self.world_size @@ -683,8 +681,6 @@ def recv(self, src: Optional[int] = None) -> torch.Tensor: """Receives a tensor from the src rank.""" """NOTE: `src` is the local rank of the destination rank.""" - - logger.debug("Recving tensor: ", size, dtype, src) if src is None: src = (self.rank_in_group - 1) % self.world_size From eb96fe739cf16a41dae1ff0966d9af2488590a8d Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 22 Jul 2024 19:58:00 -0700 Subject: [PATCH 084/303] add disagg prefill for flash attn backend --- vllm/attention/backends/flash_attn.py | 68 +++++++++++++++++++++------ 1 file changed, 54 insertions(+), 14 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index ddf99802fd999..69ae57364fdb3 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -16,10 +16,18 @@ from vllm.sequence import SequenceGroupMetadata from vllm.utils import make_tensor_with_pad +# This group is used for KV cache transfer in disaggregated prefilling +from vllm.distributed import get_disagg_group + +# To identify if the VLLM_DISAGG_PREFILL_ROLE is set or no +import vllm.envs as envs +from vllm.logger import init_logger + if TYPE_CHECKING: from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPUBuilder) +logger = init_logger(__name__) class FlashAttentionBackend(AttentionBackend): @@ -466,6 +474,21 @@ def forward( key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) + if all([ + kv_cache is not None, # we are not in profile run + prefill_meta is not None, # during prefill stage + envs.VLLM_DISAGG_PREFILL_ROLE is not None, # disagg prefill enabled + ]): + if envs.VLLM_DISAGG_PREFILL_ROLE == "prefill": + logger.debug("Sending key & value, ", key.shape, key.dtype, value.shape, value.dtype) + get_disagg_group().send(key) + get_disagg_group().send(value) + else: + logger.debug("Recving key & value, ", key.shape, key.dtype, value.shape, value.dtype) + key = get_disagg_group().recv(key.shape, key.dtype) + value = get_disagg_group().recv(value.shape, value.dtype) + + if kv_cache is not None: key_cache = kv_cache[0] value_cache = kv_cache[1] @@ -473,7 +496,6 @@ def forward( # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. - print("I am in flash attn") ops.reshape_and_cache_flash( key, value, @@ -525,19 +547,37 @@ def forward( # prefix-enabled attention assert prefill_meta.seq_lens is not None max_seq_len = max(prefill_meta.seq_lens) - output[:num_prefill_tokens] = flash_attn_varlen_func( - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=prefill_meta.query_start_loc, - max_seqlen_q=prefill_meta.max_query_len, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_k=max_seq_len, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - block_table=prefill_meta.block_tables, - ) + + if not all([ + envs.VLLM_DISAGG_PREFILL_ROLE is not None, + envs.VLLM_DISAGG_PREFILL_ROLE == "decode", + ]): # Only skip prefill for disagg decode instance + output[:num_prefill_tokens] = flash_attn_varlen_func( + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=prefill_meta.query_start_loc, + max_seqlen_q=prefill_meta.max_query_len, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_k=max_seq_len, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + block_table=prefill_meta.block_tables, + ) + output.view(num_tokens, hidden_size) + + if envs.VLLM_DISAGG_PREFILL_ROLE is not None: + # communication for disaggregated prefill. + if envs.VLLM_DISAGG_PREFILL_ROLE == "prefill": + logger.info("Sending output, " , output.shape, output.dtype) + get_disagg_group().send(output) + else: + logger.info("Recv output, " , output.shape, output.dtype) + # Kuntai: This assume that output has the same dtype as key + # Is this assumption true? + output = get_disagg_group().recv([num_tokens, hidden_size], key.dtype) + if decode_meta := attn_metadata.decode_metadata: # Decoding run. From 09d5588b06285a43912d8ca5372a67a04c30ec4e Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 22 Jul 2024 20:00:12 -0700 Subject: [PATCH 085/303] edit flash attn to assign prefill_meta first --- vllm/attention/backends/flash_attn.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 69ae57364fdb3..00846efac521d 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -474,6 +474,8 @@ def forward( key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) + prefill_meta = attn_metadata.prefill_metadata + if all([ kv_cache is not None, # we are not in profile run prefill_meta is not None, # during prefill stage @@ -521,7 +523,7 @@ def forward( assert query.shape[0] == num_prefill_tokens assert decode_query.shape[0] == num_decode_tokens - if prefill_meta := attn_metadata.prefill_metadata: + if prefill_meta is not None: # Prompt run. if (kv_cache is None or prefill_meta.block_tables is None or prefill_meta.block_tables.numel() == 0): From 43077e7d3dad51139c694211f9ef8f7110c0a96c Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 22 Jul 2024 20:04:02 -0700 Subject: [PATCH 086/303] use print instead of attn --- vllm/attention/backends/flash_attn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 00846efac521d..9ca3909fa4e85 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -482,11 +482,11 @@ def forward( envs.VLLM_DISAGG_PREFILL_ROLE is not None, # disagg prefill enabled ]): if envs.VLLM_DISAGG_PREFILL_ROLE == "prefill": - logger.debug("Sending key & value, ", key.shape, key.dtype, value.shape, value.dtype) + print("Sending key & value, ", key.shape, key.dtype, value.shape, value.dtype) get_disagg_group().send(key) get_disagg_group().send(value) else: - logger.debug("Recving key & value, ", key.shape, key.dtype, value.shape, value.dtype) + print("Recving key & value, ", key.shape, key.dtype, value.shape, value.dtype) key = get_disagg_group().recv(key.shape, key.dtype) value = get_disagg_group().recv(value.shape, value.dtype) @@ -572,10 +572,10 @@ def forward( if envs.VLLM_DISAGG_PREFILL_ROLE is not None: # communication for disaggregated prefill. if envs.VLLM_DISAGG_PREFILL_ROLE == "prefill": - logger.info("Sending output, " , output.shape, output.dtype) + print("Sending output, " , output.shape, output.dtype) get_disagg_group().send(output) else: - logger.info("Recv output, " , output.shape, output.dtype) + print("Recv output, " , output.shape, output.dtype) # Kuntai: This assume that output has the same dtype as key # Is this assumption true? output = get_disagg_group().recv([num_tokens, hidden_size], key.dtype) From f7167379656fa6bf5b7f1483b2cfc63a2fd6a300 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 22 Jul 2024 20:05:12 -0700 Subject: [PATCH 087/303] make data contiguous --- vllm/attention/backends/flash_attn.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 9ca3909fa4e85..3892cec9aa199 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -482,6 +482,8 @@ def forward( envs.VLLM_DISAGG_PREFILL_ROLE is not None, # disagg prefill enabled ]): if envs.VLLM_DISAGG_PREFILL_ROLE == "prefill": + key = key.contiguous() + value = value.contiguous() print("Sending key & value, ", key.shape, key.dtype, value.shape, value.dtype) get_disagg_group().send(key) get_disagg_group().send(value) @@ -567,7 +569,7 @@ def forward( alibi_slopes=self.alibi_slopes, block_table=prefill_meta.block_tables, ) - output.view(num_tokens, hidden_size) + output = output.view(num_tokens, hidden_size).contiguous() if envs.VLLM_DISAGG_PREFILL_ROLE is not None: # communication for disaggregated prefill. From 0d072519caad47881f9e3d04da336af2fd629ed3 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 22 Jul 2024 20:09:59 -0700 Subject: [PATCH 088/303] add more debug message --- vllm/attention/backends/flash_attn.py | 22 ++-------------------- 1 file changed, 2 insertions(+), 20 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 3892cec9aa199..c30d323de7f8c 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -474,25 +474,6 @@ def forward( key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) - prefill_meta = attn_metadata.prefill_metadata - - if all([ - kv_cache is not None, # we are not in profile run - prefill_meta is not None, # during prefill stage - envs.VLLM_DISAGG_PREFILL_ROLE is not None, # disagg prefill enabled - ]): - if envs.VLLM_DISAGG_PREFILL_ROLE == "prefill": - key = key.contiguous() - value = value.contiguous() - print("Sending key & value, ", key.shape, key.dtype, value.shape, value.dtype) - get_disagg_group().send(key) - get_disagg_group().send(value) - else: - print("Recving key & value, ", key.shape, key.dtype, value.shape, value.dtype) - key = get_disagg_group().recv(key.shape, key.dtype) - value = get_disagg_group().recv(value.shape, value.dtype) - - if kv_cache is not None: key_cache = kv_cache[0] value_cache = kv_cache[1] @@ -525,7 +506,7 @@ def forward( assert query.shape[0] == num_prefill_tokens assert decode_query.shape[0] == num_decode_tokens - if prefill_meta is not None: + if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. if (kv_cache is None or prefill_meta.block_tables is None or prefill_meta.block_tables.numel() == 0): @@ -556,6 +537,7 @@ def forward( envs.VLLM_DISAGG_PREFILL_ROLE is not None, envs.VLLM_DISAGG_PREFILL_ROLE == "decode", ]): # Only skip prefill for disagg decode instance + logger.debug("Do prefill") output[:num_prefill_tokens] = flash_attn_varlen_func( q=query, k=key_cache, From 2177737e44f0d855bc8c9d0d4e808463ad28122e Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 22 Jul 2024 20:44:16 -0700 Subject: [PATCH 089/303] turn on logging --- examples/disaggregated_prefill_example.sh | 4 ---- 1 file changed, 4 deletions(-) diff --git a/examples/disaggregated_prefill_example.sh b/examples/disaggregated_prefill_example.sh index 47ed5434a0ce1..601420fabd75e 100644 --- a/examples/disaggregated_prefill_example.sh +++ b/examples/disaggregated_prefill_example.sh @@ -11,8 +11,6 @@ VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ --model neuralmagic/Meta-Llama-3-70B-Instruct-FP8 \ --port 8100 \ -tp 4 \ - --disable-log-stats \ - --disable-log-requests \ --enable-prefix-caching & # decoding instance @@ -21,8 +19,6 @@ VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ --model neuralmagic/Meta-Llama-3-70B-Instruct-FP8 \ --port 8200 \ -tp 4 \ - --disable-log-stats \ - --disable-log-requests \ --enable-prefix-caching & From a293bd08948bf7e2c028e01a598c8233527b4bee Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 22 Jul 2024 20:49:17 -0700 Subject: [PATCH 090/303] more debug prints in flash_attn --- vllm/attention/backends/flash_attn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index c30d323de7f8c..91b080c254a58 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -513,6 +513,7 @@ def forward( # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. + print("profile run ", end="") out = flash_attn_varlen_func( q=query, k=key, @@ -532,6 +533,7 @@ def forward( # prefix-enabled attention assert prefill_meta.seq_lens is not None max_seq_len = max(prefill_meta.seq_lens) + print("non promt_run") if not all([ envs.VLLM_DISAGG_PREFILL_ROLE is not None, From cc7f646a511771f505f4e52ae99765c7d27d6f63 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 22 Jul 2024 21:01:08 -0700 Subject: [PATCH 091/303] remove enforce eager --- examples/disaggregated_prefill_example.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/disaggregated_prefill_example.sh b/examples/disaggregated_prefill_example.sh index 601420fabd75e..1df1aba0e4a67 100644 --- a/examples/disaggregated_prefill_example.sh +++ b/examples/disaggregated_prefill_example.sh @@ -12,6 +12,7 @@ VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ --port 8100 \ -tp 4 \ --enable-prefix-caching & + # decoding instance VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ From 68f3d16511afc575038c3cc0c9936ca8209ec92e Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 22 Jul 2024 21:01:25 -0700 Subject: [PATCH 092/303] adjust printing order in flash attn --- vllm/attention/backends/flash_attn.py | 69 +++++++++++++-------------- 1 file changed, 34 insertions(+), 35 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 91b080c254a58..1594bbb76d04d 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -505,15 +505,19 @@ def forward( assert query.shape[0] == num_prefill_tokens assert decode_query.shape[0] == num_decode_tokens - - if prefill_meta := attn_metadata.prefill_metadata: + + prefill_meta = attn_metadata.prefill_metadata + if (prefill_meta is not None) and ( + (envs.VLLM_DISAGG_PREFILL_ROLE is None) + or + (envs.VLLM_DISAGG_PREFILL_ROLE == "prefill") + ): # during prefilling, and this instance is not disagg decode instance # Prompt run. if (kv_cache is None or prefill_meta.block_tables is None or prefill_meta.block_tables.numel() == 0): # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. - print("profile run ", end="") out = flash_attn_varlen_func( q=query, k=key, @@ -533,38 +537,33 @@ def forward( # prefix-enabled attention assert prefill_meta.seq_lens is not None max_seq_len = max(prefill_meta.seq_lens) - print("non promt_run") - - if not all([ - envs.VLLM_DISAGG_PREFILL_ROLE is not None, - envs.VLLM_DISAGG_PREFILL_ROLE == "decode", - ]): # Only skip prefill for disagg decode instance - logger.debug("Do prefill") - output[:num_prefill_tokens] = flash_attn_varlen_func( - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=prefill_meta.query_start_loc, - max_seqlen_q=prefill_meta.max_query_len, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_k=max_seq_len, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - block_table=prefill_meta.block_tables, - ) - output = output.view(num_tokens, hidden_size).contiguous() - - if envs.VLLM_DISAGG_PREFILL_ROLE is not None: - # communication for disaggregated prefill. - if envs.VLLM_DISAGG_PREFILL_ROLE == "prefill": - print("Sending output, " , output.shape, output.dtype) - get_disagg_group().send(output) - else: - print("Recv output, " , output.shape, output.dtype) - # Kuntai: This assume that output has the same dtype as key - # Is this assumption true? - output = get_disagg_group().recv([num_tokens, hidden_size], key.dtype) + + output[:num_prefill_tokens] = flash_attn_varlen_func( + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=prefill_meta.query_start_loc, + max_seqlen_q=prefill_meta.max_query_len, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_k=max_seq_len, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + block_table=prefill_meta.block_tables, + ) + + if (prefill_meta is not None) and \ + (envs.VLLM_DISAGG_PREFILL_ROLE is not None): + # communication for disaggregated prefill. + if envs.VLLM_DISAGG_PREFILL_ROLE == "prefill": + output = output.view(num_tokens, hidden_size).contiguous() + print("Sending output, " , output.shape, output.dtype) + get_disagg_group().send(output) + else: + print("Recv output, " , output.shape, output.dtype) + # Kuntai: This assume that output has the same dtype as key + # Is this assumption true? + output = get_disagg_group().recv([num_tokens, hidden_size], key.dtype) if decode_meta := attn_metadata.decode_metadata: From 21a61b9e027f074ab3d707e0212a5d08bd417162 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 22 Jul 2024 21:05:01 -0700 Subject: [PATCH 093/303] avoid sending & receiving output tensor during profile run --- vllm/attention/backends/flash_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 1594bbb76d04d..3df3f4ac402ce 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -552,7 +552,7 @@ def forward( block_table=prefill_meta.block_tables, ) - if (prefill_meta is not None) and \ + if (prefill_meta is not None) and (kv_cache is not None) and \ (envs.VLLM_DISAGG_PREFILL_ROLE is not None): # communication for disaggregated prefill. if envs.VLLM_DISAGG_PREFILL_ROLE == "prefill": From 691cad78164f73e127de7dec0c515e513289fd44 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 22 Jul 2024 21:17:14 -0700 Subject: [PATCH 094/303] also log the device --- vllm/attention/backends/flash_attn.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 3df3f4ac402ce..ffb8f954e28cd 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -557,14 +557,15 @@ def forward( # communication for disaggregated prefill. if envs.VLLM_DISAGG_PREFILL_ROLE == "prefill": output = output.view(num_tokens, hidden_size).contiguous() - print("Sending output, " , output.shape, output.dtype) + print("Sending output, " , output.shape, output.dtype, output.device) get_disagg_group().send(output) else: - print("Recv output, " , output.shape, output.dtype) + print("Recv output, " , output.shape, output.dtype, output.device) # Kuntai: This assume that output has the same dtype as key # Is this assumption true? output = get_disagg_group().recv([num_tokens, hidden_size], key.dtype) - + import time + time.sleep(10) if decode_meta := attn_metadata.decode_metadata: # Decoding run. From c057f1935949067d800342764b942291532bd9ac Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 22 Jul 2024 21:51:28 -0700 Subject: [PATCH 095/303] adjust implementation --- vllm/attention/backends/flash_attn.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index ffb8f954e28cd..6928e05530c8b 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -556,14 +556,13 @@ def forward( (envs.VLLM_DISAGG_PREFILL_ROLE is not None): # communication for disaggregated prefill. if envs.VLLM_DISAGG_PREFILL_ROLE == "prefill": - output = output.view(num_tokens, hidden_size).contiguous() - print("Sending output, " , output.shape, output.dtype, output.device) - get_disagg_group().send(output) + out = output[:num_prefill_tokens].contiguous() + print("Sending out, " , out.shape, out.dtype, out.device) + get_disagg_group().send(out) else: - print("Recv output, " , output.shape, output.dtype, output.device) - # Kuntai: This assume that output has the same dtype as key - # Is this assumption true? - output = get_disagg_group().recv([num_tokens, hidden_size], key.dtype) + + print("Recv out, " , output[:num_prefill_tokens].shape, output.dtype, output.device) + output[:num_prefill_tokens] = get_disagg_group().recv(output[:num_prefill_tokens].shape, output.dtype) import time time.sleep(10) From 82b73bbe0e8fc3693492449d28f9d9ca5270f7fc Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 22 Jul 2024 21:54:53 -0700 Subject: [PATCH 096/303] finish adjustment --- vllm/attention/backends/flash_attn.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 6928e05530c8b..0563e33b6c3c5 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -563,9 +563,7 @@ def forward( print("Recv out, " , output[:num_prefill_tokens].shape, output.dtype, output.device) output[:num_prefill_tokens] = get_disagg_group().recv(output[:num_prefill_tokens].shape, output.dtype) - import time - time.sleep(10) - + if decode_meta := attn_metadata.decode_metadata: # Decoding run. output[num_prefill_tokens:] = flash_attn_with_kvcache( From 6db1d48d3cc5cfe33db5d69035db504ab17fe442 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 22 Jul 2024 23:53:59 -0700 Subject: [PATCH 097/303] fall back to original flashinfer --- vllm/attention/backends/flashinfer.py | 123 +++++++++----------------- 1 file changed, 41 insertions(+), 82 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 0a76962bdd91a..8271efe330c91 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -20,19 +20,11 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) -from vllm.sequence import SequenceGroupMetadata +from vllm.attention.ops.paged_attn import PagedAttention from vllm.utils import get_kv_cache_torch_dtype, make_tensor_with_pad -# This group is used for KV cache transfer in disaggregated prefilling -from vllm.distributed import get_disagg_group - -# To identify if the VLLM_DISAGG_PREFILL_ROLE is set or no -import vllm.envs as envs - if TYPE_CHECKING: - from vllm.worker.model_runner import (GPUModelRunnerBase, - ModelInputForGPUBuilder) - + from vllm.worker.model_runner import ModelInputForGPUBuilder class FlashInferBackend(AttentionBackend): @@ -68,14 +60,14 @@ def swap_blocks( dst_kv_cache: torch.Tensor, src_to_dst: torch.Tensor, ) -> None: - raise NotImplementedError + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: - raise NotImplementedError + PagedAttention.copy_blocks(kv_caches, src_to_dists) @staticmethod def get_supported_head_sizes() -> List[int]: @@ -222,6 +214,9 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.num_prefill_tokens = 0 self.num_decode_tokens = 0 + self.input_builder = input_builder + self.runner = input_builder.runner + self.sliding_window = input_builder.sliding_window self.block_size = input_builder.block_size self.use_v2_block_manager = ( @@ -244,26 +239,24 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): # paged_kv_last_page_len is the length of the last page of each request self.paged_kv_last_page_len: List[int] = [] - def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata, - token_lens: List[int], seq_lens: List[int], - curr_seq_lens: List[int], query_lens: List[int], - context_lens: List[int], - curr_sliding_window_blocks: List[int], - prefix_cache_hit: bool, chunked_prefill_enabled: bool): + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool): """Add a sequence group to the metadata. Specifically update/append 1. context length. 2. block table. 3. slot mapping. """ - is_prompt = seq_group_metadata.is_prompt - block_tables = seq_group_metadata.block_tables - computed_block_nums = seq_group_metadata.computed_block_nums + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + computed_block_nums = inter_data.computed_block_nums for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, curr_sliding_window_block) in zip( - seq_group_metadata.seq_data.keys(), token_lens, seq_lens, - curr_seq_lens, query_lens, context_lens, - curr_sliding_window_blocks): + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): self.context_lens.append(context_len) if is_prompt: self.num_prefills += 1 @@ -281,7 +274,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata, # only allowing multiple of block_size chunk size. # NOTE: This only works for oooooooxxx style attention. block_table = [] - if prefix_cache_hit: + if inter_data.prefix_cache_hit: block_table = computed_block_nums elif ((chunked_prefill_enabled or not is_prompt) and block_tables is not None): @@ -296,8 +289,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata, self.use_v2_block_manager) compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, seq_len, context_len, start_idx, - self.block_size, - seq_group_metadata.block_tables) + self.block_size, inter_data.block_tables) # It is not necessary to add paged_kv_indices, paged_kv_indptr, # and paged_kv_last_page_len for profile run because we will @@ -323,9 +315,13 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata, last_page_len = self.block_size self.paged_kv_last_page_len.append(last_page_len) - def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens, + def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int): - device = runner.device + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled) + + device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 max_query_len = max(query_lens) @@ -339,7 +335,7 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens, # The shape of graph_block_tables is # [max batch size, max context len // block size]. - input_block_tables = runner.graph_block_tables[:batch_size] + input_block_tables = self.runner.graph_block_tables[:batch_size] for i, block_table in enumerate(self.block_tables): if block_table: input_block_tables[i, :len(block_table)] = block_table @@ -350,11 +346,8 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens, cuda_graph_pad_size) self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size) else: - max_block_table_len = max( - len(block_table) for block_table in self.block_tables) block_tables = make_tensor_with_pad( self.block_tables, - max_len=max_block_table_len, pad=0, dtype=torch.int, device=device, @@ -386,7 +379,7 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens, dtype=torch.long, device=device) - logits_soft_cap = getattr(runner.model_config.hf_config, + logits_soft_cap = getattr(self.runner.model_config.hf_config, "attn_logit_softcapping", None) if len(self.paged_kv_indptr) > 0: @@ -403,8 +396,8 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens, paged_kv_indptr_tensor = None paged_kv_last_page_len_tensor = None - kv_cache_dtype = get_kv_cache_torch_dtype(runner.kv_cache_dtype, - runner.model_config.dtype) + kv_cache_dtype = get_kv_cache_torch_dtype( + self.runner.kv_cache_dtype, self.runner.model_config.dtype) return FlashInferMetadata( num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, @@ -415,11 +408,11 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens, paged_kv_indptr=paged_kv_indptr_tensor, paged_kv_indices=paged_kv_indices_tensor, paged_kv_last_page_len=paged_kv_last_page_len_tensor, - num_qo_heads=runner.model_config.get_num_attention_heads( - runner.parallel_config), - num_kv_heads=runner.model_config.get_num_kv_heads( - runner.parallel_config), - head_dim=runner.model_config.get_head_size(), + num_qo_heads=self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config), + num_kv_heads=self.runner.model_config.get_num_kv_heads( + self.runner.parallel_config), + head_dim=self.runner.model_config.get_head_size(), page_size=self.block_size, seq_start_loc=seq_start_loc, query_start_loc=query_start_loc, @@ -486,20 +479,6 @@ def forward( if attn_metadata.num_decode_tokens > 0: assert attn_metadata.num_prefill_tokens == 0, ( "Chunked prefill is not supported with flashinfer yet.") - - prefill_meta = attn_metadata.prefill_metadata - - if all([ - kv_cache is not None, # we are not in profile run - prefill_meta is not None, # during prefill stage - envs.VLLM_DISAGG_PREFILL_ROLE is not None, # disagg prefill enabled - ]): - if envs.VLLM_DISAGG_PREFILL_ROLE == "prefill": - get_disagg_group().send(key) - get_disagg_group().send(value) - else: - key = get_disagg_group().recv(key.shape, key.dtype) - value = get_disagg_group().recv(value.shape, value.dtype) if kv_cache is not None: # Use the same reshape and cache kernel as flash attention. @@ -514,7 +493,7 @@ def forward( query = query.contiguous( ) # Flashinfer requires query to be contiguous - if prefill_meta is not None: + if prefill_meta := attn_metadata.prefill_metadata: # We will use flash attention for prefill # when kv_cache is not provided. # This happens when vllm runs the profiling to @@ -536,26 +515,11 @@ def forward( else: assert prefill_meta is not None assert prefill_meta.prefill_wrapper is not None - - if not all([ - envs.VLLM_DISAGG_PREFILL_ROLE is not None, - envs.VLLM_DISAGG_PREFILL_ROLE == "decode", - ]): # Only skip prefill for disagg decode instance - output = prefill_meta.prefill_wrapper.forward( - query, - kv_cache, - logits_soft_cap=attn_metadata.logits_soft_cap, - causal=True) - output = output.view(num_tokens, hidden_size) - - if envs.VLLM_DISAGG_PREFILL_ROLE is not None: - # communication for disaggregated prefill. - if envs.VLLM_DISAGG_PREFILL_ROLE == "prefill": - get_disagg_group().send(output) - else: - # Kuntai: This assume that output has the same dtype as key - # Is this assumption true? - output = get_disagg_group().recv([num_tokens, hidden_size], key.dtype) + output = prefill_meta.prefill_wrapper.forward( + query, + kv_cache, + logits_soft_cap=attn_metadata.logits_soft_cap, + causal=True) else: assert attn_metadata.decode_metadata is not None assert attn_metadata.decode_metadata.decode_wrapper is not None @@ -564,9 +528,4 @@ def forward( kv_cache, sm_scale=self.scale, logits_soft_cap=attn_metadata.logits_soft_cap) - output = output.view(num_tokens, hidden_size) - - - - - return output + return output.view(num_tokens, hidden_size) \ No newline at end of file From dbaade746704b674fccc41bf76596a5a3234e528 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 22 Jul 2024 23:56:18 -0700 Subject: [PATCH 098/303] add space --- vllm/attention/backends/flashinfer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 8271efe330c91..f7e467e121e7b 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -528,4 +528,5 @@ def forward( kv_cache, sm_scale=self.scale, logits_soft_cap=attn_metadata.logits_soft_cap) - return output.view(num_tokens, hidden_size) \ No newline at end of file + return output.view(num_tokens, hidden_size) + \ No newline at end of file From f572db8590fd3f16a2f8e62f2a4ba108ad04f161 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 22 Jul 2024 23:57:34 -0700 Subject: [PATCH 099/303] clean config.py --- vllm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index a2ca527173ca0..c87974d0df16d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -661,7 +661,7 @@ def __init__( self.ray_workers_use_nsight = ray_workers_use_nsight self.placement_group = placement_group - self.world_size = pipeline_parallel_size * tensor_parallel_size + self.world_size = pipeline_parallel_size * self.tensor_parallel_size if worker_use_ray: if self.distributed_executor_backend is None: self.distributed_executor_backend = "ray" From 9ebf3ad8d8f4397cff3bc7c061a170531c65ccf5 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 22 Jul 2024 23:58:56 -0700 Subject: [PATCH 100/303] keep flashattn implementation --- vllm/attention/backends/flash_attn.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 0563e33b6c3c5..db16a6561270f 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -554,14 +554,17 @@ def forward( if (prefill_meta is not None) and (kv_cache is not None) and \ (envs.VLLM_DISAGG_PREFILL_ROLE is not None): - # communication for disaggregated prefill. + # transfer the output if + # 1). during prefilling + # 2). disaggregated prefill enabled + # 3). not in the profile run (kv_cache is not None) + # no need to transfer kv cache, as it is already the input of this function if envs.VLLM_DISAGG_PREFILL_ROLE == "prefill": out = output[:num_prefill_tokens].contiguous() - print("Sending out, " , out.shape, out.dtype, out.device) + logger.debug("Send output, " , out.shape, out.dtype, out.device) get_disagg_group().send(out) else: - - print("Recv out, " , output[:num_prefill_tokens].shape, output.dtype, output.device) + logger.debug("Recv output, " , output[:num_prefill_tokens].shape, output.dtype, output.device) output[:num_prefill_tokens] = get_disagg_group().recv(output[:num_prefill_tokens].shape, output.dtype) if decode_meta := attn_metadata.decode_metadata: From 67b1c2eac7234e6d608a5a3063129811092803a1 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 22 Jul 2024 23:59:21 -0700 Subject: [PATCH 101/303] commit changes that will be merged --- vllm/core/block/prefix_caching_block.py | 5 +++++ vllm/model_executor/models/llama.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index f272e23ee6088..2664657e3c26b 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -11,6 +11,10 @@ from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor from vllm.utils import cdiv +import vllm.envs as envs +from vllm.distributed import get_disagg_group + + PrefixHash = int # By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME @@ -163,6 +167,7 @@ def allocate_immutable_block(self, # No cached block => Allocate a new block block = self.allocate_mutable_block(prev_block) block.append_token_ids(token_ids) + return block def allocate_immutable_blocks( diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 4c434e54cf743..5cc27dbbc34e2 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -137,7 +137,7 @@ def __init__( quant_config=quant_config, ) - self.rotary_emb = get_rope( + self.rotry_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, max_position=max_position_embeddings, From 3abca470ba30ea48655034b2e568e6682c2ea428 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 23 Jul 2024 00:02:28 -0700 Subject: [PATCH 102/303] revert custom allreduce changes --- .../device_communicators/custom_all_reduce.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index b7d5af5a8a0a5..a4f30808d32e1 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -1,6 +1,5 @@ from contextlib import contextmanager from typing import Any, List, Optional, Union -import logging import torch import torch.distributed as dist @@ -23,17 +22,6 @@ logger = init_logger(__name__) -class ConditionalLoggingHandler(logging.Handler): - def emit(self, record): - dist = torch.distributed - try: - if not dist.is_initialized() or (dist.is_initialized() and dist.get_rank() % 4 == 0): - msg = self.format(record) - print(msg) # You can replace this with any other logging mechanism you prefer - except Exception: - pass -logger.addHandler(ConditionalLoggingHandler()) - def _can_p2p(rank: int, world_size: int) -> bool: for i in range(world_size): From 0ce251b90d3421538ee73973fa49f4a46bf1427f Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 23 Jul 2024 00:03:17 -0700 Subject: [PATCH 103/303] remove debug logs from the file --- vllm/distributed/device_communicators/pynccl.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 36f1e04aec79f..f159d43a47c9a 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -71,11 +71,8 @@ def __init__( self.unique_id = ncclUniqueId() tensor = torch.ByteTensor(list(self.unique_id.internal)) ranks = dist.get_process_group_ranks(group) - logger.debug("Group: %s, group rank: %s, world size: %s, src: %s", str(group), str(self.rank), str(self.world_size), ranks[0]) - # arg `src` in `broadcast` is the global rank dist.broadcast(tensor, src=ranks[0], group=group) - logger.debug("dist broadcast succeeded") byte_list = tensor.tolist() for i, byte in enumerate(byte_list): self.unique_id.internal[i] = byte From 1f3ac2bcf96299e3414794156d80c329440227f4 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 23 Jul 2024 00:04:29 -0700 Subject: [PATCH 104/303] revert changes to prefix_caching_block --- unnecessary --- vllm/core/block/prefix_caching_block.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index a7eacadd03b06..d102ad4045591 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -11,10 +11,6 @@ from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor from vllm.utils import cdiv -import vllm.envs as envs -from vllm.distributed import get_disagg_group - - PrefixHash = int # By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME @@ -167,7 +163,6 @@ def allocate_immutable_block(self, # No cached block => Allocate a new block block = self.allocate_mutable_block(prev_block) block.append_token_ids(token_ids) - return block def allocate_immutable_blocks( From c93bf33810c7e6d45b7b7c206e3b8d7421f8ce9e Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 23 Jul 2024 00:05:14 -0700 Subject: [PATCH 105/303] revert changes --- vllm/distributed/device_communicators/pynccl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index f159d43a47c9a..7319566545678 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -41,7 +41,6 @@ def __init__( self.rank = dist.get_rank(group) self.world_size = dist.get_world_size(group) - # if world_size == 1, no need to create communicator if self.world_size == 1: self.available = False From 8dcaf43df0ee8da38e70dd2359cb7da58ddedcbd Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 23 Jul 2024 00:06:42 -0700 Subject: [PATCH 106/303] fix typos --- vllm/model_executor/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 88b6fc62e963f..2052c443a8885 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -148,7 +148,7 @@ def __init__( prefix=f"{prefix}.o_proj", ) - self.rotry_emb = get_rope( + self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, max_position=max_position_embeddings, From 4d83813d7934e8edc88f656ef079e9be74a0e560 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 23 Jul 2024 00:15:08 -0700 Subject: [PATCH 107/303] add example usage to disaggregated prefill --- examples/disaggregated_prefill_example.sh | 46 +++++++++++++++++++++-- 1 file changed, 42 insertions(+), 4 deletions(-) diff --git a/examples/disaggregated_prefill_example.sh b/examples/disaggregated_prefill_example.sh index 1df1aba0e4a67..571bc7072a8ef 100644 --- a/examples/disaggregated_prefill_example.sh +++ b/examples/disaggregated_prefill_example.sh @@ -1,9 +1,19 @@ +#!/bin/bash +# This file demonstrates the example usage of disaggregated prefilling +# We will launch 2 vllm instances (1 for prefill and 1 for decode), +# and then transfer the KV cache between them. export VLLM_LOGGING_LEVEL=DEBUG export VLLM_PORT=12345 -# export VLLM_TRACE_FUNCTION=1 -# export TORCH_DISTRIBUTED_DEBUG=DETAIL -# export GLOO_LOGGING_LEVEL=TRACE + +# a function that waits vLLM server to start +wait_for_server() { + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} # prefilling instance VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ @@ -12,7 +22,6 @@ VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ --port 8100 \ -tp 4 \ --enable-prefix-caching & - # decoding instance VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ @@ -23,3 +32,32 @@ VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ --enable-prefix-caching & +wait_for_server 8100 +wait_for_server 8200 + +# sending an example request +# in disaggregated prefilling, there are two steps of sending a request: +# 1. send the request to prefill instance, with max_tokens set to 1 +# 2. send the request again to decode instance, no modification + +# send to prefill instance +curl http://localhost:8100/v1/completions \ +-H "Content-Type: application/json" \ +-d '{ +"model": "neuralmagic/Meta-Llama-3-70B-Instruct-FP8", +"prompt": "San Francisco is a", +"max_tokens": 1, +"temperature": 0 +}' & + +# send to decode instance +curl http://localhost:8200/v1/completions \ +-H "Content-Type: application/json" \ +-d '{ +"model": "neuralmagic/Meta-Llama-3-70B-Instruct-FP8", +"prompt": "San Francisco is a", +"max_tokens": 5, +"temperature": 0 +}' + + From 11c3ace8525e83e2a6adcc0c44f6784873720350 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 23 Jul 2024 00:17:54 -0700 Subject: [PATCH 108/303] can only use print instead of log.debug... --- vllm/attention/backends/flash_attn.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 992bd7c2366bb..49b8d38c02baa 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -20,13 +20,10 @@ # To identify if the VLLM_DISAGG_PREFILL_ROLE is set or no import vllm.envs as envs -from vllm.logger import init_logger if TYPE_CHECKING: from vllm.worker.model_runner import ModelInputForGPUBuilder -logger = init_logger(__name__) - class FlashAttentionBackend(AttentionBackend): @staticmethod @@ -559,10 +556,10 @@ def forward( # no need to transfer kv cache, as it is already the input of this function if envs.VLLM_DISAGG_PREFILL_ROLE == "prefill": out = output[:num_prefill_tokens].contiguous() - logger.debug("Send output, " , out.shape, out.dtype, out.device) + print("Send output, " , out.shape, out.dtype, out.device) get_disagg_group().send(out) else: - logger.debug("Recv output, " , output[:num_prefill_tokens].shape, output.dtype, output.device) + print("Recv output, " , output[:num_prefill_tokens].shape, output.dtype, output.device) output[:num_prefill_tokens] = get_disagg_group().recv(output[:num_prefill_tokens].shape, output.dtype) if decode_meta := attn_metadata.decode_metadata: From 0bd0cc9defae39d657303a030f07a64ebe6a69ca Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 23 Jul 2024 00:18:51 -0700 Subject: [PATCH 109/303] kill vllm instance after run --- examples/disaggregated_prefill_example.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/disaggregated_prefill_example.sh b/examples/disaggregated_prefill_example.sh index 571bc7072a8ef..fabba139ae262 100644 --- a/examples/disaggregated_prefill_example.sh +++ b/examples/disaggregated_prefill_example.sh @@ -60,4 +60,5 @@ curl http://localhost:8200/v1/completions \ "temperature": 0 }' - +# gracefully kill all vllm instances +ps -e | grep pt_main_thread | awk '{print $1}' | xargs kill -9 From 39973bb7cbc121f381a2709f6517e817a5266cd0 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 23 Jul 2024 21:27:53 -0700 Subject: [PATCH 110/303] add proxy server for disaggregated prefilling --- .../disagg_prefill_example.sh} | 0 .../disagg_prefill/disagg_proxy_server.py | 62 +++++++++++++++++++ 2 files changed, 62 insertions(+) rename examples/{disaggregated_prefill_example.sh => disagg_prefill/disagg_prefill_example.sh} (100%) create mode 100644 examples/disagg_prefill/disagg_proxy_server.py diff --git a/examples/disaggregated_prefill_example.sh b/examples/disagg_prefill/disagg_prefill_example.sh similarity index 100% rename from examples/disaggregated_prefill_example.sh rename to examples/disagg_prefill/disagg_prefill_example.sh diff --git a/examples/disagg_prefill/disagg_proxy_server.py b/examples/disagg_prefill/disagg_proxy_server.py new file mode 100644 index 0000000000000..eb45778d4ad21 --- /dev/null +++ b/examples/disagg_prefill/disagg_proxy_server.py @@ -0,0 +1,62 @@ +import http.server +import socketserver +import requests +import json +import argparse + +class ProxyHTTPRequestHandler(http.server.BaseHTTPRequestHandler): + def __init__(self, *args, **kwargs): + self.prefill_port = kwargs.pop('prefill_port', 8100) + self.decode_port = kwargs.pop('decode_port', 8200) + super().__init__(*args, **kwargs) + + def do_POST(self): + # Read the content length to get the data size + content_length = int(self.headers['Content-Length']) + post_data = self.rfile.read(content_length) + + # Parse the JSON payload + data = json.loads(post_data) + + # Change the max_tokens to 1 for the request to prefill_port + data_prefill = data.copy() + data_prefill["max_tokens"] = 1 + post_data_prefill = json.dumps(data_prefill) + + # Forward the request to prefill_port with modified max_tokens + response_prefill = requests.post(f"http://localhost:{self.prefill_port}/v1/completions", + headers={"Content-Type": "application/json"}, + data=post_data_prefill) + + # Check if the response from prefill_port is successful + if response_prefill.status_code == 200: + # Forward the original request to decode_port + response_decode = requests.post(f"http://localhost:{self.decode_port}/v1/completions", + headers={"Content-Type": "application/json"}, + data=post_data) + + # Send the response back to the client + self.send_response(response_decode.status_code) + self.send_header('Content-type', 'application/json') + self.end_headers() + self.wfile.write(response_decode.content) + else: + # Send an error response back to the client + self.send_response(response_prefill.status_code) + self.send_header('Content-type', 'application/json') + self.end_headers() + self.wfile.write(response_prefill.content) + +def run_server(port_8000, prefill_port, decode_port): + handler = lambda *args, **kwargs: ProxyHTTPRequestHandler(*args, prefill_port=prefill_port, decode_port=decode_port, **kwargs) + with socketserver.TCPServer(("", port_8000), handler) as httpd: + print(f"Serving at port {port_8000}") + httpd.serve_forever() + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Proxy server") + parser.add_argument('prefill_port', type=int, help='Port to forward the first request to (with max_tokens=1)') + parser.add_argument('decode_port', type=int, help='Port to forward the second request to') + args = parser.parse_args() + + run_server(8000, args.prefill_port, args.decode_port) From 13a6d12be2def55976a3a78b514d916bcf222e7f Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 23 Jul 2024 21:49:56 -0700 Subject: [PATCH 111/303] update disagg proxy server --- .../disagg_prefill/disagg_prefill_example.sh | 2 - .../disagg_prefill/disagg_proxy_server.py | 102 +++++++++--------- 2 files changed, 48 insertions(+), 56 deletions(-) diff --git a/examples/disagg_prefill/disagg_prefill_example.sh b/examples/disagg_prefill/disagg_prefill_example.sh index fabba139ae262..a505d8b3ca85d 100644 --- a/examples/disagg_prefill/disagg_prefill_example.sh +++ b/examples/disagg_prefill/disagg_prefill_example.sh @@ -60,5 +60,3 @@ curl http://localhost:8200/v1/completions \ "temperature": 0 }' -# gracefully kill all vllm instances -ps -e | grep pt_main_thread | awk '{print $1}' | xargs kill -9 diff --git a/examples/disagg_prefill/disagg_proxy_server.py b/examples/disagg_prefill/disagg_proxy_server.py index eb45778d4ad21..6ec2737367c8c 100644 --- a/examples/disagg_prefill/disagg_proxy_server.py +++ b/examples/disagg_prefill/disagg_proxy_server.py @@ -1,62 +1,56 @@ -import http.server -import socketserver -import requests -import json import argparse +import aiohttp +import asyncio +from aiohttp import web +import json -class ProxyHTTPRequestHandler(http.server.BaseHTTPRequestHandler): - def __init__(self, *args, **kwargs): - self.prefill_port = kwargs.pop('prefill_port', 8100) - self.decode_port = kwargs.pop('decode_port', 8200) - super().__init__(*args, **kwargs) +async def handle_post(request): + prefill_port = request.app['prefill_port'] + decode_port = request.app['decode_port'] + + # Read and parse the request payload + try: + payload = await request.json() + except Exception as e: + return web.json_response({'error': str(e)}, status=400) + + # Modify max_tokens for prefill request + payload_prefill = payload.copy() + payload_prefill["max_tokens"] = 1 - def do_POST(self): - # Read the content length to get the data size - content_length = int(self.headers['Content-Length']) - post_data = self.rfile.read(content_length) - - # Parse the JSON payload - data = json.loads(post_data) - - # Change the max_tokens to 1 for the request to prefill_port - data_prefill = data.copy() - data_prefill["max_tokens"] = 1 - post_data_prefill = json.dumps(data_prefill) - - # Forward the request to prefill_port with modified max_tokens - response_prefill = requests.post(f"http://localhost:{self.prefill_port}/v1/completions", - headers={"Content-Type": "application/json"}, - data=post_data_prefill) - - # Check if the response from prefill_port is successful - if response_prefill.status_code == 200: - # Forward the original request to decode_port - response_decode = requests.post(f"http://localhost:{self.decode_port}/v1/completions", - headers={"Content-Type": "application/json"}, - data=post_data) - - # Send the response back to the client - self.send_response(response_decode.status_code) - self.send_header('Content-type', 'application/json') - self.end_headers() - self.wfile.write(response_decode.content) - else: - # Send an error response back to the client - self.send_response(response_prefill.status_code) - self.send_header('Content-type', 'application/json') - self.end_headers() - self.wfile.write(response_prefill.content) + async with aiohttp.ClientSession() as session: + # Forward request to prefill port + async with session.post(f"http://localhost:{prefill_port}/v1/completions", json=payload_prefill) as response_prefill: + if response_prefill.status != 200: + return web.json_response(await response_prefill.json(), status=response_prefill.status) -def run_server(port_8000, prefill_port, decode_port): - handler = lambda *args, **kwargs: ProxyHTTPRequestHandler(*args, prefill_port=prefill_port, decode_port=decode_port, **kwargs) - with socketserver.TCPServer(("", port_8000), handler) as httpd: - print(f"Serving at port {port_8000}") - httpd.serve_forever() + # Forward original request to decode port + async with session.post(f"http://localhost:{decode_port}/v1/completions", json=payload) as response_decode: + if 'stream' in payload and payload['stream']: + # If streaming, set up a streaming response + response = web.StreamResponse(status=response_decode.status, reason=response_decode.reason, headers=response_decode.headers) + await response.prepare(request) + + async for data, _ in response_decode.content.iter_chunks(): + await response.write(data) + await response.write_eof() + return response + else: + # Return non-streaming response as JSON + return web.json_response(await response_decode.json(), status=response_decode.status) -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Proxy server") +async def init_app(prefill_port, decode_port): + app = web.Application() + app['prefill_port'] = prefill_port + app['decode_port'] = decode_port + app.router.add_post('/v1/completions', handle_post) + return app + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Async Proxy server") parser.add_argument('prefill_port', type=int, help='Port to forward the first request to (with max_tokens=1)') parser.add_argument('decode_port', type=int, help='Port to forward the second request to') args = parser.parse_args() - - run_server(8000, args.prefill_port, args.decode_port) + + app = asyncio.run(init_app(args.prefill_port, args.decode_port)) + web.run_app(app, port=8000) From 81cad25d3dfca617d2be4ce299af1948b01fbc9d Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 23 Jul 2024 22:00:37 -0700 Subject: [PATCH 112/303] add debug message for proxy server --- .../disagg_prefill/disagg_proxy_server.py | 37 ++++++++++++++++--- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/examples/disagg_prefill/disagg_proxy_server.py b/examples/disagg_prefill/disagg_proxy_server.py index 6ec2737367c8c..6655ff6379b22 100644 --- a/examples/disagg_prefill/disagg_proxy_server.py +++ b/examples/disagg_prefill/disagg_proxy_server.py @@ -3,29 +3,48 @@ import asyncio from aiohttp import web import json +import logging + +# Configure logging +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) async def handle_post(request): prefill_port = request.app['prefill_port'] decode_port = request.app['decode_port'] + logger.debug(f"Received request to {request.path} with method {request.method}") + # Read and parse the request payload try: payload = await request.json() + logger.debug(f"Request payload: {json.dumps(payload, indent=2)}") except Exception as e: + logger.error(f"Error parsing request payload: {str(e)}") return web.json_response({'error': str(e)}, status=400) # Modify max_tokens for prefill request payload_prefill = payload.copy() payload_prefill["max_tokens"] = 1 + logger.debug(f"Modified prefill payload: {json.dumps(payload_prefill, indent=2)}") async with aiohttp.ClientSession() as session: # Forward request to prefill port async with session.post(f"http://localhost:{prefill_port}/v1/completions", json=payload_prefill) as response_prefill: + try: + response_prefill_data = await response_prefill.json() + logger.debug(f"Prefill response data: {json.dumps(response_prefill_data, indent=2)}") + except aiohttp.ContentTypeError: + response_prefill_data = await response_prefill.text() + logger.debug(f"Prefill response text: {response_prefill_data}") + if response_prefill.status != 200: - return web.json_response(await response_prefill.json(), status=response_prefill.status) + logger.error(f"Prefill request failed with status {response_prefill.status}") + return web.json_response(response_prefill_data, status=response_prefill.status) # Forward original request to decode port async with session.post(f"http://localhost:{decode_port}/v1/completions", json=payload) as response_decode: + logger.debug(f"Forwarding request to decode port {decode_port}") if 'stream' in payload and payload['stream']: # If streaming, set up a streaming response response = web.StreamResponse(status=response_decode.status, reason=response_decode.reason, headers=response_decode.headers) @@ -33,11 +52,19 @@ async def handle_post(request): async for data, _ in response_decode.content.iter_chunks(): await response.write(data) + logger.debug(f"Streaming chunk: {data}") await response.write_eof() + logger.debug("Finished streaming response") return response else: - # Return non-streaming response as JSON - return web.json_response(await response_decode.json(), status=response_decode.status) + # Handle non-streaming response + try: + response_decode_data = await response_decode.json() + logger.debug(f"Decode response data: {json.dumps(response_decode_data, indent=2)}") + except aiohttp.ContentTypeError: + response_decode_data = await response_decode.text() + logger.debug(f"Decode response text: {response_decode_data}") + return web.json_response(response_decode_data, status=response_decode.status) async def init_app(prefill_port, decode_port): app = web.Application() @@ -49,8 +76,8 @@ async def init_app(prefill_port, decode_port): if __name__ == '__main__': parser = argparse.ArgumentParser(description="Async Proxy server") parser.add_argument('prefill_port', type=int, help='Port to forward the first request to (with max_tokens=1)') - parser.add_argument('decode_port', type=int, help='Port to forward the second request to') + parser.add_argument 'decode_port', type=int, help='Port to forward the second request to') args = parser.parse_args() app = asyncio.run(init_app(args.prefill_port, args.decode_port)) - web.run_app(app, port=8000) + web.run_app(app, port=8000) \ No newline at end of file From 198931befb8ff7fe6611189a5d22000a7c794114 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 23 Jul 2024 22:01:35 -0700 Subject: [PATCH 113/303] fix bug --- examples/disagg_prefill/disagg_proxy_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/disagg_prefill/disagg_proxy_server.py b/examples/disagg_prefill/disagg_proxy_server.py index 6655ff6379b22..f87a081941fc8 100644 --- a/examples/disagg_prefill/disagg_proxy_server.py +++ b/examples/disagg_prefill/disagg_proxy_server.py @@ -76,7 +76,7 @@ async def init_app(prefill_port, decode_port): if __name__ == '__main__': parser = argparse.ArgumentParser(description="Async Proxy server") parser.add_argument('prefill_port', type=int, help='Port to forward the first request to (with max_tokens=1)') - parser.add_argument 'decode_port', type=int, help='Port to forward the second request to') + parser.add_argument('decode_port', type=int, help='Port to forward the second request to') args = parser.parse_args() app = asyncio.run(init_app(args.prefill_port, args.decode_port)) From 7412767d5574640ec4ef406fe25c05844b54d045 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 23 Jul 2024 22:12:48 -0700 Subject: [PATCH 114/303] increase nccl buff size --- examples/disagg_prefill/disagg_prefill_example.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/disagg_prefill/disagg_prefill_example.sh b/examples/disagg_prefill/disagg_prefill_example.sh index a505d8b3ca85d..cf1646b6b3076 100644 --- a/examples/disagg_prefill/disagg_prefill_example.sh +++ b/examples/disagg_prefill/disagg_prefill_example.sh @@ -5,6 +5,7 @@ export VLLM_LOGGING_LEVEL=DEBUG export VLLM_PORT=12345 +export NCCL_BUFFSIZE=2147483648 # a function that waits vLLM server to start wait_for_server() { From bd6f41b5c9dee9762fbe0fbbc8842224d8cf2099 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 23 Jul 2024 22:15:25 -0700 Subject: [PATCH 115/303] increase nccl buffer size --- examples/disagg_prefill/disagg_prefill_example.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/disagg_prefill/disagg_prefill_example.sh b/examples/disagg_prefill/disagg_prefill_example.sh index cf1646b6b3076..88e3920ac6e73 100644 --- a/examples/disagg_prefill/disagg_prefill_example.sh +++ b/examples/disagg_prefill/disagg_prefill_example.sh @@ -5,7 +5,7 @@ export VLLM_LOGGING_LEVEL=DEBUG export VLLM_PORT=12345 -export NCCL_BUFFSIZE=2147483648 +export NCCL_BUFFSIZE=1073741824 # a function that waits vLLM server to start wait_for_server() { From 20f9de1155792b6b6e87331701e8682e8350f3b7 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 23 Jul 2024 22:19:30 -0700 Subject: [PATCH 116/303] add debug flag --- examples/disagg_prefill/disagg_prefill_example.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/disagg_prefill/disagg_prefill_example.sh b/examples/disagg_prefill/disagg_prefill_example.sh index 88e3920ac6e73..3c1e33ca161de 100644 --- a/examples/disagg_prefill/disagg_prefill_example.sh +++ b/examples/disagg_prefill/disagg_prefill_example.sh @@ -6,6 +6,7 @@ export VLLM_LOGGING_LEVEL=DEBUG export VLLM_PORT=12345 export NCCL_BUFFSIZE=1073741824 +export NCCL_DEBUG=INFO # a function that waits vLLM server to start wait_for_server() { From 11850d57a739c94b1e01a3b93baff1b2e27180a1 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 23 Jul 2024 22:22:21 -0700 Subject: [PATCH 117/303] reduce gpu memory usage --- examples/disagg_prefill/disagg_prefill_example.sh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/disagg_prefill/disagg_prefill_example.sh b/examples/disagg_prefill/disagg_prefill_example.sh index 3c1e33ca161de..b6521e22336f1 100644 --- a/examples/disagg_prefill/disagg_prefill_example.sh +++ b/examples/disagg_prefill/disagg_prefill_example.sh @@ -23,7 +23,8 @@ VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ --model neuralmagic/Meta-Llama-3-70B-Instruct-FP8 \ --port 8100 \ -tp 4 \ - --enable-prefix-caching & + --enable-prefix-caching \ + --gpu-memory-utilization 0.8 & # decoding instance VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ @@ -31,7 +32,7 @@ VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ --model neuralmagic/Meta-Llama-3-70B-Instruct-FP8 \ --port 8200 \ -tp 4 \ - --enable-prefix-caching & + --enable-prefix-caching 0.8 & wait_for_server 8100 From d6ad9bdf53834306a0cf9694df93fd2e0bf7eb71 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 23 Jul 2024 22:24:18 -0700 Subject: [PATCH 118/303] fix syntax bug --- examples/disagg_prefill/disagg_prefill_example.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/disagg_prefill/disagg_prefill_example.sh b/examples/disagg_prefill/disagg_prefill_example.sh index b6521e22336f1..28160fd67ff15 100644 --- a/examples/disagg_prefill/disagg_prefill_example.sh +++ b/examples/disagg_prefill/disagg_prefill_example.sh @@ -32,7 +32,8 @@ VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ --model neuralmagic/Meta-Llama-3-70B-Instruct-FP8 \ --port 8200 \ -tp 4 \ - --enable-prefix-caching 0.8 & + --enable-prefix-caching \ + --gpu-memory-utilization 0.8 & wait_for_server 8100 From 57dd656fde6d775b8510441a5d84684fe997ea83 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 23 Jul 2024 22:34:30 -0700 Subject: [PATCH 119/303] temporarily lift up nccl buffer size for send and recv --- examples/disagg_prefill/disagg_prefill_example.sh | 1 - vllm/distributed/parallel_state.py | 3 +++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/disagg_prefill/disagg_prefill_example.sh b/examples/disagg_prefill/disagg_prefill_example.sh index 28160fd67ff15..d9ea28070053b 100644 --- a/examples/disagg_prefill/disagg_prefill_example.sh +++ b/examples/disagg_prefill/disagg_prefill_example.sh @@ -5,7 +5,6 @@ export VLLM_LOGGING_LEVEL=DEBUG export VLLM_PORT=12345 -export NCCL_BUFFSIZE=1073741824 export NCCL_DEBUG=INFO # a function that waits vLLM server to start diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 5188b71d95b9d..1eab00db49d4c 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -835,6 +835,9 @@ def include_decoding_groups_if_disagg_enabled( """ if envs.VLLM_DISAGG_PREFILL_ROLE is not None: + # dirty fix: temporarily lift up NCCL buffer size to 1GB + import os + os.environ["NCCL_BUFFSIZE"] = "1073741824" assert envs.VLLM_DISAGG_PREFILL_ROLE in ["prefill", "decode"], ( "VLLM_DISAGG_PREFILL_ROLE should be either prefill or decode") new_groups = [] From 9379fbbe80268b4dd7b20cf83c6d343d880469bb Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 23 Jul 2024 22:37:27 -0700 Subject: [PATCH 120/303] reduce nccl buffer size and see if bug fixed --- vllm/distributed/parallel_state.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 1eab00db49d4c..15ffa8817b6a4 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -836,8 +836,8 @@ def include_decoding_groups_if_disagg_enabled( if envs.VLLM_DISAGG_PREFILL_ROLE is not None: # dirty fix: temporarily lift up NCCL buffer size to 1GB - import os - os.environ["NCCL_BUFFSIZE"] = "1073741824" + # import os + # os.environ["NCCL_BUFFSIZE"] = "1073741824" assert envs.VLLM_DISAGG_PREFILL_ROLE in ["prefill", "decode"], ( "VLLM_DISAGG_PREFILL_ROLE should be either prefill or decode") new_groups = [] From c23d8419e5fbc531106577f669e1fdaf709188b0 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 23 Jul 2024 22:38:23 -0700 Subject: [PATCH 121/303] fix --- vllm/distributed/parallel_state.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 15ffa8817b6a4..99a7cc932baa5 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -835,9 +835,6 @@ def include_decoding_groups_if_disagg_enabled( """ if envs.VLLM_DISAGG_PREFILL_ROLE is not None: - # dirty fix: temporarily lift up NCCL buffer size to 1GB - # import os - # os.environ["NCCL_BUFFSIZE"] = "1073741824" assert envs.VLLM_DISAGG_PREFILL_ROLE in ["prefill", "decode"], ( "VLLM_DISAGG_PREFILL_ROLE should be either prefill or decode") new_groups = [] @@ -1016,6 +1013,9 @@ def initialize_model_parallel( logger.debug("_PP initialized for rank %d", torch.distributed.get_rank()) if envs.VLLM_DISAGG_PREFILL_ROLE is not None: + # dirty fix: temporarily lift up NCCL buffer size to 1GB + import os + os.environ["NCCL_BUFFSIZE"] = "1073741824" global _DISAGG logger.debug("Disaggregated prefill enabled, create _DISAGG group") group_ranks = [] From 7fc62b4162a2055933e7f1e609e80af9b0775e50 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 23 Jul 2024 22:47:17 -0700 Subject: [PATCH 122/303] add debug info -- see which layer the prefill instance got stuck --- vllm/attention/backends/flash_attn.py | 6 ++++-- vllm/model_executor/models/llama.py | 2 ++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 49b8d38c02baa..2d363af2590a3 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -556,10 +556,12 @@ def forward( # no need to transfer kv cache, as it is already the input of this function if envs.VLLM_DISAGG_PREFILL_ROLE == "prefill": out = output[:num_prefill_tokens].contiguous() - print("Send output, " , out.shape, out.dtype, out.device) + if torch.distributed.get_rank() % 4 == 0: + print("Send output, " , out.shape, out.dtype, out.device) get_disagg_group().send(out) else: - print("Recv output, " , output[:num_prefill_tokens].shape, output.dtype, output.device) + if torch.distributed.get_rank() % 4 == 0: + print("Recv output, " , output[:num_prefill_tokens].shape, output.dtype, output.device) output[:num_prefill_tokens] = get_disagg_group().recv(output[:num_prefill_tokens].shape, output.dtype) if decode_meta := attn_metadata.decode_metadata: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 2052c443a8885..9e5fc0223554e 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -318,6 +318,8 @@ def forward( residual = intermediate_tensors["residual"] for i in range(self.start_layer, self.end_layer): + if torch.distributed.get_rank() % 4 == 0: + print(f"Layer {i}") layer = self.layers[i] hidden_states, residual = layer( positions, From e54236690ad86b2c1a23ebb86010ec322f6b1038 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 23 Jul 2024 22:48:48 -0700 Subject: [PATCH 123/303] remove nccl debug -- it is too loud --- examples/disagg_prefill/disagg_prefill_example.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/disagg_prefill/disagg_prefill_example.sh b/examples/disagg_prefill/disagg_prefill_example.sh index d9ea28070053b..7b58fe6f64ce5 100644 --- a/examples/disagg_prefill/disagg_prefill_example.sh +++ b/examples/disagg_prefill/disagg_prefill_example.sh @@ -5,7 +5,7 @@ export VLLM_LOGGING_LEVEL=DEBUG export VLLM_PORT=12345 -export NCCL_DEBUG=INFO +# export NCCL_DEBUG=INFO # a function that waits vLLM server to start wait_for_server() { From e9f7dc2004d636535184f512d5e29d562548c619 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 23 Jul 2024 23:51:23 -0700 Subject: [PATCH 124/303] change buffer size only for disagg communicator --- examples/disagg_prefill/disagg_prefill_example.sh | 2 +- vllm/distributed/parallel_state.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/disagg_prefill/disagg_prefill_example.sh b/examples/disagg_prefill/disagg_prefill_example.sh index 7b58fe6f64ce5..d9ea28070053b 100644 --- a/examples/disagg_prefill/disagg_prefill_example.sh +++ b/examples/disagg_prefill/disagg_prefill_example.sh @@ -5,7 +5,7 @@ export VLLM_LOGGING_LEVEL=DEBUG export VLLM_PORT=12345 -# export NCCL_DEBUG=INFO +export NCCL_DEBUG=INFO # a function that waits vLLM server to start wait_for_server() { diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 99a7cc932baa5..a58be5ed14e1c 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1016,6 +1016,8 @@ def initialize_model_parallel( # dirty fix: temporarily lift up NCCL buffer size to 1GB import os os.environ["NCCL_BUFFSIZE"] = "1073741824" + import time + time.sleep(20) global _DISAGG logger.debug("Disaggregated prefill enabled, create _DISAGG group") group_ranks = [] From 18ded4ca14633f7b8b248fc4b7896003472a3068 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 23 Jul 2024 23:55:37 -0700 Subject: [PATCH 125/303] disable nccl debug --- examples/disagg_prefill/disagg_prefill_example.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/disagg_prefill/disagg_prefill_example.sh b/examples/disagg_prefill/disagg_prefill_example.sh index d9ea28070053b..7b58fe6f64ce5 100644 --- a/examples/disagg_prefill/disagg_prefill_example.sh +++ b/examples/disagg_prefill/disagg_prefill_example.sh @@ -5,7 +5,7 @@ export VLLM_LOGGING_LEVEL=DEBUG export VLLM_PORT=12345 -export NCCL_DEBUG=INFO +# export NCCL_DEBUG=INFO # a function that waits vLLM server to start wait_for_server() { From e814f82d7a44e967f3dbf57189622c82151c192a Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Wed, 24 Jul 2024 00:13:24 -0700 Subject: [PATCH 126/303] use isend and irecv --- vllm/distributed/parallel_state.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index a58be5ed14e1c..06b4cce3e813c 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -675,6 +675,16 @@ def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: else: torch.distributed.send(tensor, self.ranks[dst], self.device_group) + def isend(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the local rank of the destination rank.""" + """NOTE: this function leverage pytorch's isend, to bypass PyNccl buffer limit""" + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + + torch.distributed.isend(tensor, self.ranks[dst], self.device_group) + + def recv(self, size: torch.Size, dtype: torch.dtype, @@ -692,6 +702,19 @@ def recv(self, torch.distributed.recv(tensor, self.ranks[src], self.device_group) return tensor + def irecv_wait(self, + size: torch.Size, + dtype: torch.dtype, + src: Optional[int] = None) -> torch.Tensor: + """Receives a tensor from the src rank asynchronously.""" + """NOTE: `src` is the local rank of the destination rank.""" + if src is None: + src = (self.rank_in_group - 1) % self.world_size + + tensor = torch.empty(size, dtype=dtype, device=self.device) + torch.distributed.irecv(tensor, self.ranks[src], self.device_group).wait() + return tensor + def destroy(self): if self.device_group is not None: torch.distributed.destroy_process_group(self.device_group) From a3399b336d1e348c894b97175caa4444893543b0 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Wed, 24 Jul 2024 00:22:19 -0700 Subject: [PATCH 127/303] try to increase the buffer size --- examples/disagg_prefill/disagg_prefill_example.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/disagg_prefill/disagg_prefill_example.sh b/examples/disagg_prefill/disagg_prefill_example.sh index 7b58fe6f64ce5..8a1f7a48e7e37 100644 --- a/examples/disagg_prefill/disagg_prefill_example.sh +++ b/examples/disagg_prefill/disagg_prefill_example.sh @@ -6,6 +6,7 @@ export VLLM_LOGGING_LEVEL=DEBUG export VLLM_PORT=12345 # export NCCL_DEBUG=INFO +export NCCL_BUFFSIZE=536870912 # a function that waits vLLM server to start wait_for_server() { From e4e60d91969db33ceb31a632c96549cd169e60d7 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Wed, 31 Jul 2024 00:15:23 +0000 Subject: [PATCH 128/303] bug fix, now disaggregated prefill should work as expected --- .../disagg_prefill/disagg_prefill_example.sh | 4 +- vllm/attention/backends/flash_attn.py | 42 +++------ vllm/distributed/parallel_state.py | 79 +++++++++++++---- vllm/model_executor/models/llama.py | 6 +- vllm/worker/model_runner.py | 87 ++++++++++++++++--- 5 files changed, 153 insertions(+), 65 deletions(-) diff --git a/examples/disagg_prefill/disagg_prefill_example.sh b/examples/disagg_prefill/disagg_prefill_example.sh index 8a1f7a48e7e37..0ebd3e7e97dad 100644 --- a/examples/disagg_prefill/disagg_prefill_example.sh +++ b/examples/disagg_prefill/disagg_prefill_example.sh @@ -6,7 +6,7 @@ export VLLM_LOGGING_LEVEL=DEBUG export VLLM_PORT=12345 # export NCCL_DEBUG=INFO -export NCCL_BUFFSIZE=536870912 +# export NCCL_BUFFSIZE=536870912 # a function that waits vLLM server to start wait_for_server() { @@ -52,7 +52,7 @@ curl http://localhost:8100/v1/completions \ "prompt": "San Francisco is a", "max_tokens": 1, "temperature": 0 -}' & +}' # send to decode instance curl http://localhost:8200/v1/completions \ diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index ebe5bf348489e..330addd3b52f2 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -15,15 +15,13 @@ is_block_tables_empty) from vllm.utils import make_tensor_with_pad -# This group is used for KV cache transfer in disaggregated prefilling from vllm.distributed import get_disagg_group - -# To identify if the VLLM_DISAGG_PREFILL_ROLE is set or no import vllm.envs as envs if TYPE_CHECKING: from vllm.worker.model_runner import ModelInputForGPUBuilder + class FlashAttentionBackend(AttentionBackend): @staticmethod @@ -495,6 +493,15 @@ def forward( v_scale, ) + # send out the KV cache when current vllm is prefill instance + # the corresponding receive code is in vllm/worker/model_runner.py + if all([ + envs.VLLM_DISAGG_PREFILL_ROLE == "prefill", + attn_metadata.prefill_metadata is not None]): + + get_disagg_group().push(key) + get_disagg_group().push(value) + num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens assert key.shape[0] == num_prefill_tokens + num_decode_tokens @@ -510,13 +517,8 @@ def forward( assert query.shape[0] == num_prefill_tokens assert decode_query.shape[0] == num_decode_tokens - - prefill_meta = attn_metadata.prefill_metadata - if (prefill_meta is not None) and ( - (envs.VLLM_DISAGG_PREFILL_ROLE is None) - or - (envs.VLLM_DISAGG_PREFILL_ROLE == "prefill") - ): # during prefilling, and this instance is not disagg decode instance + + if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. if (kv_cache is None or prefill_meta.block_tables is None or prefill_meta.block_tables.numel() == 0): @@ -542,7 +544,6 @@ def forward( # prefix-enabled attention assert prefill_meta.seq_lens is not None max_seq_len = max(prefill_meta.seq_lens) - output[:num_prefill_tokens] = flash_attn_varlen_func( q=query, k=key_cache, @@ -557,23 +558,6 @@ def forward( block_table=prefill_meta.block_tables, ) - if (prefill_meta is not None) and (kv_cache is not None) and \ - (envs.VLLM_DISAGG_PREFILL_ROLE is not None): - # transfer the output if - # 1). during prefilling - # 2). disaggregated prefill enabled - # 3). not in the profile run (kv_cache is not None) - # no need to transfer kv cache, as it is already the input of this function - if envs.VLLM_DISAGG_PREFILL_ROLE == "prefill": - out = output[:num_prefill_tokens].contiguous() - if torch.distributed.get_rank() % 4 == 0: - print("Send output, " , out.shape, out.dtype, out.device) - get_disagg_group().send(out) - else: - if torch.distributed.get_rank() % 4 == 0: - print("Recv output, " , output[:num_prefill_tokens].shape, output.dtype, output.device) - output[:num_prefill_tokens] = get_disagg_group().recv(output[:num_prefill_tokens].shape, output.dtype) - if decode_meta := attn_metadata.decode_metadata: # Decoding run. output[num_prefill_tokens:] = flash_attn_with_kvcache( @@ -588,4 +572,4 @@ def forward( ).squeeze(1) # Reshape the output tensor. - return output.view(num_tokens, hidden_size) + return output.view(num_tokens, hidden_size) \ No newline at end of file diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 23d72fb7ea921..83cdb381490be 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -30,6 +30,8 @@ from multiprocessing import shared_memory from typing import Any, Dict, List, Optional, Tuple, Union from unittest.mock import patch +from concurrent.futures import ThreadPoolExecutor +import queue import torch import torch.distributed @@ -160,8 +162,8 @@ def __init__( assert self.cpu_group is not None assert self.device_group is not None - + if torch.cuda.is_available(): self.device = torch.device(f"cuda:{local_rank}") else: @@ -211,6 +213,12 @@ def __init__( self.mq_broadcaster = MessageQueue.create_from_process_group( self.cpu_group, 1 << 22, 6) + + # use a threadpool to buffer send request in disaggregated prefill + self.send_buffer = None + # use a list to cache send items. + self.send_queue = queue.Queue() + @property def first_rank(self): """Return the global rank of the first process in the group""" @@ -696,6 +704,7 @@ def barrier(self): def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: """Sends a tensor to the destination rank in a non-blocking way""" """NOTE: `dst` is the local rank of the destination rank.""" + if dst is None: dst = (self.rank_in_group + 1) % self.world_size @@ -705,16 +714,6 @@ def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: else: torch.distributed.send(tensor, self.ranks[dst], self.device_group) - def isend(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: - """Sends a tensor to the destination rank in a non-blocking way""" - """NOTE: `dst` is the local rank of the destination rank.""" - """NOTE: this function leverage pytorch's isend, to bypass PyNccl buffer limit""" - if dst is None: - dst = (self.rank_in_group + 1) % self.world_size - - torch.distributed.isend(tensor, self.ranks[dst], self.device_group) - - def recv(self, size: torch.Size, dtype: torch.dtype, @@ -731,18 +730,60 @@ def recv(self, else: torch.distributed.recv(tensor, self.ranks[src], self.device_group) return tensor + + + def push(self, + tensor: torch.Tensor, + dst: Optional[int] = None, + enable_verification: bool = False) -> None: + """Push the KV cache send request into the send buffer""" + """NOTE: `dst` is the local rank of the destination rank.""" - def irecv_wait(self, + if self.send_buffer is None: + self.send_buffer = ThreadPoolExecutor(max_workers=1) + + if enable_verification: + # Send tensor, together with metadatas + # We will use this metadata to perform some sanity check + # But this transfer is VERY slow. + # So this is a good option for debugging but not for produciton + self.send_buffer.submit( + self.send_tensor_dict, + # tensor needs to be cloned, if not the mean doesn't match + {"tensor": tensor.clone(), "mean": tensor.mean()}, + dst + ) + else: + # only send tensor, use NCCL if available + # very fast but error-prone + self.send_buffer.submit( + self.send, + # tensor needs to be cloned, if not the mean doesn't match + tensor.clone(), + dst + ) + + + def fetch(self, size: torch.Size, dtype: torch.dtype, - src: Optional[int] = None) -> torch.Tensor: - """Receives a tensor from the src rank asynchronously.""" + src: Optional[int] = None, + enable_verification: bool = False) -> torch.Tensor: + """Receives a tensor from the src rank (blocking).""" + """This API should be used together with `push`""" """NOTE: `src` is the local rank of the destination rank.""" - if src is None: - src = (self.rank_in_group - 1) % self.world_size - tensor = torch.empty(size, dtype=dtype, device=self.device) - torch.distributed.irecv(tensor, self.ranks[src], self.device_group).wait() + if enable_verification: + # receive tensor and perform verifications + result = self.recv_tensor_dict(src) + tensor = result["tensor"] + mean = result["mean"] + assert tensor.shape == size + assert tensor.dtype == dtype + assert tensor.mean() == mean + else: + tensor = self.recv(size, dtype, src) + return tensor def destroy(self): @@ -1083,7 +1124,7 @@ def initialize_model_parallel( logger.debug("Distributed group is %s", str(group_ranks)) _DISAGG = init_model_parallel_group( group_ranks, - int(envs.VLLM_DISAGG_PREFILL_ROLE == "decode"), + get_world_group().local_rank, backend, use_custom_allreduce=False) logger.debug("_DISAGG initialized for rank %d", torch.distributed.get_rank()) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index ab884110b71cc..2052c443a8885 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -318,8 +318,6 @@ def forward( residual = intermediate_tensors["residual"] for i in range(self.start_layer, self.end_layer): - if torch.distributed.get_rank() % 4 == 0: - print(f"Layer {i}") layer = self.layers[i] hidden_states, residual = layer( positions, @@ -420,11 +418,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - input_embeds: Optional[torch.Tensor] = None ) -> Union[torch.Tensor, IntermediateTensors]: model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - input_embeds) + attn_metadata, intermediate_tensors) return model_output def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 86d26b4a84c36..8493c0e6fc7db 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -27,7 +27,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) -from vllm.distributed import get_pp_group +from vllm.distributed import get_pp_group, get_disagg_group from vllm.distributed.parallel_state import graph_capture from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger @@ -59,6 +59,9 @@ _init_attn_metadata_from_tensor_dict, _init_sampling_metadata_from_tensor_dict) +import vllm.envs as envs +from vllm import _custom_ops as ops + if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -1351,19 +1354,82 @@ def execute_model( "finished_requests_ids": model_input.finished_requests_ids, "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, } if self.has_seqlen_agnostic else {} - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **multi_modal_kwargs, - **seqlen_agnostic_kwargs) + + + # call `model_executable` + # and handle KV cache transfer for disaggregated prefilling + if any([ + prefill_meta is None, + envs.VLLM_DISAGG_PREFILL_ROLE != "decode", + kv_caches is None, + kv_caches[0] is None]): + + # model forwarding + # during forwarding the KV cache will be sent in prefill instance + # see vllm/attention/backends/flash_attn.py for sending impl + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **multi_modal_kwargs, + **seqlen_agnostic_kwargs) + + + if all([ + prefill_meta is not None, + envs.VLLM_DISAGG_PREFILL_ROLE == "prefill", + kv_caches is not None, + kv_caches[0] is not None,]): + # send hidden state if disaggregated prefilling enabled + + get_disagg_group().push(hidden_or_intermediate_states) + + else: + # receive KV cache from disaggregated prefill instance + for i in range(model_executable.model.start_layer, + model_executable.model.end_layer): + + # get kv cache + kv_cache = kv_caches[i - model_executable.model.start_layer] + # get corresponding layer + layer = model_executable.model.layers[i] + + # get kv cache shape (after sliced by tp) + _, _, num_head, head_size = kv_cache[0].shape + num_tokens = model_input.input_tokens.shape[0] + key = get_disagg_group().fetch( + torch.Size([num_tokens, num_head, head_size]), + kv_cache[0].dtype + ) + value = get_disagg_group().fetch( + torch.Size([num_tokens, num_head, head_size]), + kv_cache[0].dtype + ) + + key_cache, value_cache = kv_cache[0], kv_cache[1] + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + model_input.attn_metadata.slot_mapping.flatten(), + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + layer.self_attn.attn._v_scale, + ) + + hidden_or_intermediate_states = get_disagg_group().fetch( + torch.Size([num_tokens, model_executable.config.hidden_size]), + kv_cache[0].dtype + ) + # Compute the logits in the last pipeline stage. if not get_pp_group().is_last_rank: return hidden_or_intermediate_states - + logits = self.model.compute_logits(hidden_or_intermediate_states, model_input.sampling_metadata) @@ -1376,6 +1442,7 @@ def execute_model( sampling_metadata=model_input.sampling_metadata, ) + if self.return_hidden_states: # we only need to pass hidden states of most recent token assert model_input.sampling_metadata is not None From 87fbfae154dfc21dd9b751ab01bc6f5a6c6c61d6 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Wed, 31 Jul 2024 06:35:46 +0000 Subject: [PATCH 129/303] add proxy server --- .../disagg_benchmarks/disagg_benchmark.sh | 180 +++++++++++------- .../disagg_prefill_proxy_server.py | 49 +++++ .../disagg_prefill/disagg_prefill_example.sh | 4 +- .../disagg_prefill/disagg_proxy_server.py | 83 -------- 4 files changed, 163 insertions(+), 153 deletions(-) create mode 100644 benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py delete mode 100644 examples/disagg_prefill/disagg_proxy_server.py diff --git a/benchmarks/disagg_benchmarks/disagg_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_benchmark.sh index c8a7cba02a706..58a99de7e4a92 100644 --- a/benchmarks/disagg_benchmarks/disagg_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_benchmark.sh @@ -49,78 +49,121 @@ benchmark() { model="neuralmagic/Meta-Llama-3-70B-Instruct-FP8" dataset_name="sonnet" dataset_path="../sonnet_4x.txt" - num_prompts=500 + num_prompts=50 qps=$1 prefix_len=64 input_len=2048 output_len=$2 + # # chunked prefill with tp=4 + # CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ + # -m vllm.entrypoints.openai.api_server \ + # --model $model \ + # --port 8000 \ + # -tp 4 \ + # --disable-log-stats \ + # --disable-log-requests \ + # --enable-chunked-prefill & + # wait_for_server 8000 + + # python3 ../benchmark_serving.py \ + # --backend vllm \ + # --model $model \ + # --dataset-name $dataset_name \ + # --dataset-path $dataset_path \ + # --sonnet-input-len $input_len \ + # --sonnet-output-len $output_len \ + # --sonnet-prefix-len $prefix_len \ + # --num-prompts $((num_prompts / 2)) \ + # --port 8000 \ + # --save-result \ + # --result-dir $results_folder \ + # --result-filename chunked_prefill_tp4.json \ + # --request-rate $((qps / 2)) + # kill_gpu_processes + + + # # disaggregated prefill + # # prefill with tp=4 + # python3 -m vllm.entrypoints.openai.api_server \ + # --model $model \ + # --port 8000 \ + # -tp 4 \ + # --disable-log-stats \ + # --disable-log-requests & + # wait_for_server 8000 + # # set output-len to 1 so that it only do prefilling + # python3 ../benchmark_serving.py \ + # --backend vllm \ + # --model $model \ + # --dataset-name $dataset_name \ + # --dataset-path $dataset_path \ + # --sonnet-input-len $input_len \ + # --sonnet-output-len 1 \ + # --sonnet-prefix-len $prefix_len \ + # --num-prompts $num_prompts \ + # --port 8000 \ + # --save-result \ + # --result-dir $results_folder \ + # --result-filename disagg_prefill_tp4.json \ + # --request-rate $qps + # kill_gpu_processes + + # # decode with tp=4, enable APC + # python3 -m vllm.entrypoints.openai.api_server \ + # --model $model \ + # --port 8000 \ + # -tp 4 \ + # --enable-prefix-caching \ + # --disable-log-stats \ + # --disable-log-requests & + # wait_for_server 8000 + # # skip prefilling + # # by enabling APC and force the input tokens be the same + # python3 ../benchmark_serving.py \ + # --backend vllm \ + # --model $model \ + # --dataset-name $dataset_name \ + # --dataset-path $dataset_path \ + # --sonnet-input-len $input_len \ + # --sonnet-output-len $output_len \ + # --sonnet-prefix-len $input_len \ + # --num-prompts $num_prompts \ + # --port 8000 \ + # --save-result \ + # --result-dir $results_folder \ + # --result-filename disagg_decode_tp4.json \ + # --request-rate $qps + # kill_gpu_processes + + + # chunked prefill with tp=4 - CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ + export VLLM_PORT=12345 + VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8100 \ + -tp 4 \ + --disable-log-stats \ + --disable-log-requests \ + --enable-chunked-prefill & + VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ -m vllm.entrypoints.openai.api_server \ --model $model \ - --port 8000 \ + --port 8200 \ -tp 4 \ --disable-log-stats \ --disable-log-requests \ --enable-chunked-prefill & - wait_for_server 8000 - - python3 ../benchmark_serving.py \ - --backend vllm \ - --model $model \ - --dataset-name $dataset_name \ - --dataset-path $dataset_path \ - --sonnet-input-len $input_len \ - --sonnet-output-len $output_len \ - --sonnet-prefix-len $prefix_len \ - --num-prompts $((num_prompts / 2)) \ - --port 8000 \ - --save-result \ - --result-dir $results_folder \ - --result-filename chunked_prefill_tp4.json \ - --request-rate $((qps / 2)) - kill_gpu_processes - + wait_for_server 8100 + wait_for_server 8200 - # disaggregated prefill - # prefill with tp=4 - python3 -m vllm.entrypoints.openai.api_server \ - --model $model \ - --port 8000 \ - -tp 4 \ - --disable-log-stats \ - --disable-log-requests & - wait_for_server 8000 - # set output-len to 1 so that it only do prefilling - python3 ../benchmark_serving.py \ - --backend vllm \ - --model $model \ - --dataset-name $dataset_name \ - --dataset-path $dataset_path \ - --sonnet-input-len $input_len \ - --sonnet-output-len 1 \ - --sonnet-prefix-len $prefix_len \ - --num-prompts $num_prompts \ - --port 8000 \ - --save-result \ - --result-dir $results_folder \ - --result-filename disagg_prefill_tp4.json \ - --request-rate $qps - kill_gpu_processes + # launch a proxy server that listen from port 8000 + python3 disagg_prefill_proxy_server.py & + sleep 5 - # decode with tp=4, enable APC - python3 -m vllm.entrypoints.openai.api_server \ - --model $model \ - --port 8000 \ - -tp 4 \ - --enable-prefix-caching \ - --disable-log-stats \ - --disable-log-requests & - wait_for_server 8000 - # skip prefilling - # by enabling APC and force the input tokens be the same python3 ../benchmark_serving.py \ --backend vllm \ --model $model \ @@ -128,19 +171,19 @@ benchmark() { --dataset-path $dataset_path \ --sonnet-input-len $input_len \ --sonnet-output-len $output_len \ - --sonnet-prefix-len $input_len \ + --sonnet-prefix-len $prefix_len \ --num-prompts $num_prompts \ --port 8000 \ --save-result \ --result-dir $results_folder \ - --result-filename disagg_decode_tp4.json \ + --result-filename disagg_prefill_2xtp4.json \ --request-rate $qps kill_gpu_processes - python3 analyze_benchmark_results.py \ - --results-folder $results_folder \ - --output-len $output_len \ - --qps $qps + # python3 analyze_benchmark_results.py \ + # --results-folder $results_folder \ + # --output-len $output_len \ + # --qps $qps } @@ -151,6 +194,8 @@ main() { (which jq) || (apt-get -y install jq) (which socat) || (apt-get -y install socat) + pip install quart httpx + cd "$(dirname "$0")" cd .. @@ -168,10 +213,11 @@ main() { default_qps=4 default_output_len=150 - for target_qps in 2 4 8 16 - do - benchmark $target_qps $default_output_len - done + # for target_qps in 2 4 8 16 + # do + # benchmark $target_qps $default_output_len + # done + benchmark 1 150 # for target_output_len in 5 10 20 40 80 # do diff --git a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py new file mode 100644 index 0000000000000..ea21e94c5f64c --- /dev/null +++ b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py @@ -0,0 +1,49 @@ +from quart import Quart, request, jsonify, Response +import httpx + +app = Quart(__name__) + +async def forward_request(url, data): + async with httpx.AsyncClient() as client: + async with client.stream('POST', url, json=data) as response: + if response.status_code == 200: + # Check if the response is streaming + if 'transfer-encoding' in response.headers and response.headers['transfer-encoding'] == 'chunked': + # Stream the response + async def stream_response(): + async for chunk in response.aiter_bytes(): + yield chunk + return Response(stream_response(), status=200, content_type=response.headers.get('content-type')) + else: + # Return the full response + response_data = await response.aread() + return Response(response_data, status=200, content_type=response.headers.get('content-type')) + else: + error_data = await response.aread() + return jsonify({'error': error_data.decode()}), response.status_code + +@app.route('/v1/completions', methods=['POST']) +async def handle_request(): + # Get the original request data + original_request_data = await request.get_json() + + # Modify the max_tokens to 1 for the request to port 8100 + modified_request_data_8100 = original_request_data.copy() + modified_request_data_8100['max_tokens'] = 1 + + # Forward the request to port 8100 + response_8100 = await forward_request('http://localhost:8100/v1/completions', modified_request_data_8100) + + if response_8100.status_code == 200: + # If the request to port 8100 is successful, forward the original request to port 8200 + response_8200 = await forward_request('http://localhost:8200/v1/completions', original_request_data) + + if response_8200.status_code == 200: + return response_8200 + else: + return jsonify({'error': 'Failed to get response from port 8200'}), response_8200.status_code + else: + return jsonify({'error': 'Failed to get response from port 8100'}), response_8100.status_code + +if __name__ == '__main__': + app.run(port=8000) diff --git a/examples/disagg_prefill/disagg_prefill_example.sh b/examples/disagg_prefill/disagg_prefill_example.sh index 0ebd3e7e97dad..576a3ef4975f2 100644 --- a/examples/disagg_prefill/disagg_prefill_example.sh +++ b/examples/disagg_prefill/disagg_prefill_example.sh @@ -5,8 +5,6 @@ export VLLM_LOGGING_LEVEL=DEBUG export VLLM_PORT=12345 -# export NCCL_DEBUG=INFO -# export NCCL_BUFFSIZE=536870912 # a function that waits vLLM server to start wait_for_server() { @@ -35,7 +33,7 @@ VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ --enable-prefix-caching \ --gpu-memory-utilization 0.8 & - +# wait until prefill and decode instances are ready wait_for_server 8100 wait_for_server 8200 diff --git a/examples/disagg_prefill/disagg_proxy_server.py b/examples/disagg_prefill/disagg_proxy_server.py deleted file mode 100644 index f87a081941fc8..0000000000000 --- a/examples/disagg_prefill/disagg_proxy_server.py +++ /dev/null @@ -1,83 +0,0 @@ -import argparse -import aiohttp -import asyncio -from aiohttp import web -import json -import logging - -# Configure logging -logging.basicConfig(level=logging.DEBUG) -logger = logging.getLogger(__name__) - -async def handle_post(request): - prefill_port = request.app['prefill_port'] - decode_port = request.app['decode_port'] - - logger.debug(f"Received request to {request.path} with method {request.method}") - - # Read and parse the request payload - try: - payload = await request.json() - logger.debug(f"Request payload: {json.dumps(payload, indent=2)}") - except Exception as e: - logger.error(f"Error parsing request payload: {str(e)}") - return web.json_response({'error': str(e)}, status=400) - - # Modify max_tokens for prefill request - payload_prefill = payload.copy() - payload_prefill["max_tokens"] = 1 - logger.debug(f"Modified prefill payload: {json.dumps(payload_prefill, indent=2)}") - - async with aiohttp.ClientSession() as session: - # Forward request to prefill port - async with session.post(f"http://localhost:{prefill_port}/v1/completions", json=payload_prefill) as response_prefill: - try: - response_prefill_data = await response_prefill.json() - logger.debug(f"Prefill response data: {json.dumps(response_prefill_data, indent=2)}") - except aiohttp.ContentTypeError: - response_prefill_data = await response_prefill.text() - logger.debug(f"Prefill response text: {response_prefill_data}") - - if response_prefill.status != 200: - logger.error(f"Prefill request failed with status {response_prefill.status}") - return web.json_response(response_prefill_data, status=response_prefill.status) - - # Forward original request to decode port - async with session.post(f"http://localhost:{decode_port}/v1/completions", json=payload) as response_decode: - logger.debug(f"Forwarding request to decode port {decode_port}") - if 'stream' in payload and payload['stream']: - # If streaming, set up a streaming response - response = web.StreamResponse(status=response_decode.status, reason=response_decode.reason, headers=response_decode.headers) - await response.prepare(request) - - async for data, _ in response_decode.content.iter_chunks(): - await response.write(data) - logger.debug(f"Streaming chunk: {data}") - await response.write_eof() - logger.debug("Finished streaming response") - return response - else: - # Handle non-streaming response - try: - response_decode_data = await response_decode.json() - logger.debug(f"Decode response data: {json.dumps(response_decode_data, indent=2)}") - except aiohttp.ContentTypeError: - response_decode_data = await response_decode.text() - logger.debug(f"Decode response text: {response_decode_data}") - return web.json_response(response_decode_data, status=response_decode.status) - -async def init_app(prefill_port, decode_port): - app = web.Application() - app['prefill_port'] = prefill_port - app['decode_port'] = decode_port - app.router.add_post('/v1/completions', handle_post) - return app - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description="Async Proxy server") - parser.add_argument('prefill_port', type=int, help='Port to forward the first request to (with max_tokens=1)') - parser.add_argument('decode_port', type=int, help='Port to forward the second request to') - args = parser.parse_args() - - app = asyncio.run(init_app(args.prefill_port, args.decode_port)) - web.run_app(app, port=8000) \ No newline at end of file From fa664c0dae92f0b1e4b88fea2ce96035488f3825 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Thu, 1 Aug 2024 10:07:04 +0000 Subject: [PATCH 130/303] startr slow -- using pp=1 and tp=1 --- .../disagg_prefill/disagg_prefill_example.sh | 74 ++++++++++++------- 1 file changed, 47 insertions(+), 27 deletions(-) diff --git a/examples/disagg_prefill/disagg_prefill_example.sh b/examples/disagg_prefill/disagg_prefill_example.sh index 576a3ef4975f2..8cfe528ffb58c 100644 --- a/examples/disagg_prefill/disagg_prefill_example.sh +++ b/examples/disagg_prefill/disagg_prefill_example.sh @@ -4,7 +4,9 @@ # and then transfer the KV cache between them. export VLLM_LOGGING_LEVEL=DEBUG -export VLLM_PORT=12345 +export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') +# export NCCL_DEBUG=INFO +export NCCL_BUFFSIZE=67108864 # a function that waits vLLM server to start wait_for_server() { @@ -16,22 +18,24 @@ wait_for_server() { } # prefilling instance -VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ +VLLM_LOGGING_LEVEL=DEBUG VLLM_HOST_IP=$(hostname -I | awk '{print $1}') VLLM_PORT=2345 VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0 python3 \ -m vllm.entrypoints.openai.api_server \ - --model neuralmagic/Meta-Llama-3-70B-Instruct-FP8 \ + --model meta-llama/Meta-Llama-3.1-8B-Instruct \ --port 8100 \ - -tp 4 \ + -tp 1 \ --enable-prefix-caching \ - --gpu-memory-utilization 0.8 & + --gpu-memory-utilization 0.8 \ + --max-model-len 10000 & # decoding instance -VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ +VLLM_LOGGING_LEVEL=DEBUG VLLM_HOST_IP=$(hostname -I | awk '{print $1}') VLLM_PORT=2345 VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=1 python3 \ -m vllm.entrypoints.openai.api_server \ - --model neuralmagic/Meta-Llama-3-70B-Instruct-FP8 \ + --model meta-llama/Meta-Llama-3.1-8B-Instruct \ --port 8200 \ - -tp 4 \ + -tp 1 \ --enable-prefix-caching \ - --gpu-memory-utilization 0.8 & + --gpu-memory-utilization 0.8 \ + --max-model-len 10000 & # wait until prefill and decode instances are ready wait_for_server 8100 @@ -42,23 +46,39 @@ wait_for_server 8200 # 1. send the request to prefill instance, with max_tokens set to 1 # 2. send the request again to decode instance, no modification -# send to prefill instance -curl http://localhost:8100/v1/completions \ --H "Content-Type: application/json" \ --d '{ -"model": "neuralmagic/Meta-Llama-3-70B-Instruct-FP8", -"prompt": "San Francisco is a", -"max_tokens": 1, -"temperature": 0 -}' -# send to decode instance -curl http://localhost:8200/v1/completions \ --H "Content-Type: application/json" \ --d '{ -"model": "neuralmagic/Meta-Llama-3-70B-Instruct-FP8", -"prompt": "San Francisco is a", -"max_tokens": 5, -"temperature": 0 -}' +for i in {0..0} +do + # send to prefill instance + curl -m 5 http://localhost:8100/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "prompt": "'$i' San Francisco is a", + "max_tokens": 1, + "temperature": 0 + }' + curl -m 5 http://localhost:8100/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "prompt": "'$i' San Francisco is a", + "max_tokens": 1, + "temperature": 0 + }' + + # # send to decode instance + # curl -m 60 http://localhost:8200/v1/completions \ + # -H "Content-Type: application/json" \ + # -d '{ + # "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + # "prompt": "'$i' San Francisco is a", + # "max_tokens": 5, + # "temperature": 0 + # }' + +done + +# kill command: +# ps -e | grep pt_main_thread | awk '{print $1}' | xargs kill -9 \ No newline at end of file From 6bf7583f02c01d1d195b87117cb02fbe6f922a4a Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Thu, 1 Aug 2024 10:07:25 +0000 Subject: [PATCH 131/303] adjust the API --- vllm/distributed/parallel_state.py | 109 ++++++++++++++++++----------- 1 file changed, 69 insertions(+), 40 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 83cdb381490be..e8c7bbe59e05b 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -24,7 +24,7 @@ import contextlib import pickle import logging -from collections import namedtuple +from collections import namedtuple, defaultdict from contextlib import contextmanager, nullcontext from dataclasses import dataclass from multiprocessing import shared_memory @@ -215,9 +215,8 @@ def __init__( # use a threadpool to buffer send request in disaggregated prefill - self.send_buffer = None - # use a list to cache send items. - self.send_queue = queue.Queue() + self.input_hash_to_kv_sending_requests = {} + self.kv_sending_thread = ThreadPoolExecutor(max_workers=1) @property def first_rank(self): @@ -705,8 +704,11 @@ def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: """Sends a tensor to the destination rank in a non-blocking way""" """NOTE: `dst` is the local rank of the destination rank.""" + if dst is None: dst = (self.rank_in_group + 1) % self.world_size + + print('Sending %.3f MB to %d' % (tensor.element_size() * tensor.numel() / 1024 / 1024, self.ranks[dst]), end=' ', flush=True) pynccl_comm = self.pynccl_comm if pynccl_comm is not None and not pynccl_comm.disabled: @@ -714,6 +716,8 @@ def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: else: torch.distributed.send(tensor, self.ranks[dst], self.device_group) + print(' End sending ', end=' ', flush=True) + def recv(self, size: torch.Size, dtype: torch.dtype, @@ -722,6 +726,8 @@ def recv(self, """NOTE: `src` is the local rank of the destination rank.""" if src is None: src = (self.rank_in_group - 1) % self.world_size + + print('Start receiving from %d',self.ranks[src]) tensor = torch.empty(size, dtype=dtype, device=self.device) pynccl_comm = self.pynccl_comm @@ -729,46 +735,74 @@ def recv(self, pynccl_comm.recv(tensor, src) else: torch.distributed.recv(tensor, self.ranks[src], self.device_group) + + print('End receiving') return tensor - def push(self, - tensor: torch.Tensor, - dst: Optional[int] = None, - enable_verification: bool = False) -> None: + def kv_cache_send(self, + input_hash: int, + tensor: torch.Tensor, + dst: Optional[int] = None, + enable_verification: bool = False) -> None: """Push the KV cache send request into the send buffer""" """NOTE: `dst` is the local rank of the destination rank.""" - if self.send_buffer is None: - self.send_buffer = ThreadPoolExecutor(max_workers=1) + print('Pushing %.3f MB' % (tensor.element_size() * tensor.numel() / 1024 / 1024), end=' ', flush=True) + if enable_verification: # Send tensor, together with metadatas # We will use this metadata to perform some sanity check # But this transfer is VERY slow. # So this is a good option for debugging but not for produciton - self.send_buffer.submit( + self.input_hash_to_kv_sending_requests[input_hash].append([ self.send_tensor_dict, # tensor needs to be cloned, if not the mean doesn't match {"tensor": tensor.clone(), "mean": tensor.mean()}, dst - ) + ]) else: # only send tensor, use NCCL if available # very fast but error-prone - self.send_buffer.submit( + self.input_hash_to_kv_sending_requests[input_hash].append([ self.send, - # tensor needs to be cloned, if not the mean doesn't match + # tensor needs to be cloned, if not the tensor may be freed tensor.clone(), dst - ) + ]) + + + def sending_kv_from_input_hash(self): + + # receive the input hash that the decode instance requires + input_hash_tensor = self.recv(torch.Size([1]), torch.long) + input_hash = input_hash_tensor.item() + + # execute corresponding send jobs in request queue + for request in input_hash_to_kv_sending_requests[input_hash]: + request[0](*request[1:]) + # free GPU memory occupied by sending + del input_hash_to_kv_sending_requests[input_hash] + + + def kv_cache_send_ready(self): + + self.kv_sending_thread.submit([self.sending_kv_from_input_hash]) + + + def kv_cache_recv_start(self, input_hash: int): + + input_hash_tensor = torch.tensor([input_hash]).long().to(self.device) + # notify the kv cache sender with the input hash id + torch.distributed.send(input_hash_tensor) - def fetch(self, - size: torch.Size, - dtype: torch.dtype, - src: Optional[int] = None, - enable_verification: bool = False) -> torch.Tensor: + def kv_cache_recv(self, + size: torch.Size, + dtype: torch.dtype, + src: Optional[int] = None, + enable_verification: bool = False) -> torch.Tensor: """Receives a tensor from the src rank (blocking).""" """This API should be used together with `push`""" """NOTE: `src` is the local rank of the destination rank.""" @@ -895,16 +929,6 @@ def graph_capture(): logger = init_logger(__name__) -class ConditionalLoggingHandler(logging.Handler): - def emit(self, record): - dist = torch.distributed - try: - if not dist.is_initialized() or (dist.is_initialized() and dist.get_rank() % 4 == 0): - msg = self.format(record) - print(msg) # You can replace this with any other logging mechanism you prefer - except Exception: - pass -logger.addHandler(ConditionalLoggingHandler()) _ENABLE_CUSTOM_ALL_REDUCE = True @@ -973,12 +997,15 @@ def init_distributed_environment( else: # offset global rank by tp * pp (which is world_size) maybe_disagg_rank = rank + world_size + + logger.debug(f"Before: world size {maybe_disagg_world_size}, rank {maybe_disagg_rank}") torch.distributed.init_process_group( backend=backend, init_method=distributed_init_method, world_size=maybe_disagg_world_size, rank=maybe_disagg_rank) + logger.debug("torch.distributed initialized") # set the local rank # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816 @@ -1109,11 +1136,6 @@ def initialize_model_parallel( logger.debug("_PP initialized for rank %d", torch.distributed.get_rank()) if envs.VLLM_DISAGG_PREFILL_ROLE is not None: - # dirty fix: temporarily lift up NCCL buffer size to 1GB - import os - os.environ["NCCL_BUFFSIZE"] = "1073741824" - import time - time.sleep(20) global _DISAGG logger.debug("Disaggregated prefill enabled, create _DISAGG group") group_ranks = [] @@ -1122,11 +1144,18 @@ def initialize_model_parallel( # decode global rank: i + world_size group_ranks.append([i, i + world_size]) logger.debug("Distributed group is %s", str(group_ranks)) - _DISAGG = init_model_parallel_group( - group_ranks, - get_world_group().local_rank, - backend, - use_custom_allreduce=False) + _DISAGG = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + use_custom_allreduce=False) + # follow by a warmup, to warmup nccl + # necessary, as NCCL may not be warmed up when tp and pp are both 1. + temp_tensor = torch.tensor([1.]).to(_DISAGG.device) + if envs.VLLM_DISAGG_PREFILL_ROLE == "prefill": + _DISAGG.send(temp_tensor) + else: + recv_tensor = _DISAGG.recv(temp_tensor.shape, temp_tensor.dtype) + assert torch.allclose(temp_tensor, recv_tensor) logger.debug("_DISAGG initialized for rank %d", torch.distributed.get_rank()) From 6aad5cc4991aa5b25eed4d480f7fd99297f7e7d8 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Fri, 2 Aug 2024 06:05:15 +0000 Subject: [PATCH 132/303] support batch size >1 --- .../disagg_benchmarks/disagg_benchmark.sh | 66 ++++--- .../disagg_prefill_proxy_server.py | 80 ++++---- vllm/distributed/parallel_state.py | 159 +++++++++------- vllm/worker/model_runner.py | 171 ++++++++++++++---- 4 files changed, 313 insertions(+), 163 deletions(-) diff --git a/benchmarks/disagg_benchmarks/disagg_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_benchmark.sh index 58a99de7e4a92..96e3cf35d49a4 100644 --- a/benchmarks/disagg_benchmarks/disagg_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_benchmark.sh @@ -43,16 +43,20 @@ wait_for_server() { benchmark() { + export VLLM_LOGGING_LEVEL=DEBUG + export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') + export VLLM_PORT=12345 + # compare chunked prefill with disaggregated prefill results_folder="./results" - model="neuralmagic/Meta-Llama-3-70B-Instruct-FP8" + model="meta-llama/Meta-Llama-3.1-70B-Instruct" dataset_name="sonnet" dataset_path="../sonnet_4x.txt" - num_prompts=50 + num_prompts=100 qps=$1 - prefix_len=64 - input_len=2048 + prefix_len=50 + input_len="100" output_len=$2 @@ -138,31 +142,47 @@ benchmark() { # kill_gpu_processes - - # chunked prefill with tp=4 - export VLLM_PORT=12345 - VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ - -m vllm.entrypoints.openai.api_server \ - --model $model \ - --port 8100 \ - -tp 4 \ - --disable-log-stats \ - --disable-log-requests \ - --enable-chunked-prefill & - VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ +# large model +VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ -m vllm.entrypoints.openai.api_server \ --model $model \ - --port 8200 \ + --port 8100 \ -tp 4 \ - --disable-log-stats \ - --disable-log-requests \ - --enable-chunked-prefill & + --max-model-len 10000 \ + --gpu-memory-utilization 0.8 & +VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8200 \ + -tp 4 \ + --max-model-len 10000 \ + --gpu-memory-utilization 0.8 & + +# # Small Model +# # prefilling instance +# VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0 python3 \ +# -m vllm.entrypoints.openai.api_server \ +# --model $model \ +# --port 8100 \ +# -tp 1 \ +# --gpu-memory-utilization 0.8 \ +# --max-model-len 10000 & + +# # decoding instance +# VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=1 python3 \ +# -m vllm.entrypoints.openai.api_server \ +# --model $model \ +# --port 8200 \ +# -tp 1 \ +# --gpu-memory-utilization 0.8 \ +# --max-model-len 10000 & + wait_for_server 8100 wait_for_server 8200 # launch a proxy server that listen from port 8000 python3 disagg_prefill_proxy_server.py & - sleep 5 + sleep 1 python3 ../benchmark_serving.py \ --backend vllm \ @@ -210,14 +230,14 @@ main() { rm -rf results mkdir results - default_qps=4 + default_qps=10 default_output_len=150 # for target_qps in 2 4 8 16 # do # benchmark $target_qps $default_output_len # done - benchmark 1 150 + benchmark $default_qps $default_output_len # for target_output_len in 5 10 20 40 80 # do diff --git a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py index ea21e94c5f64c..9028d9be86ec5 100644 --- a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py +++ b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py @@ -1,49 +1,51 @@ -from quart import Quart, request, jsonify, Response +from quart import Quart, request, Response, jsonify, make_response +import aiohttp +import sys import httpx +import traceback +import os + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) app = Quart(__name__) async def forward_request(url, data): - async with httpx.AsyncClient() as client: - async with client.stream('POST', url, json=data) as response: - if response.status_code == 200: - # Check if the response is streaming - if 'transfer-encoding' in response.headers and response.headers['transfer-encoding'] == 'chunked': - # Stream the response - async def stream_response(): - async for chunk in response.aiter_bytes(): - yield chunk - return Response(stream_response(), status=200, content_type=response.headers.get('content-type')) - else: - # Return the full response - response_data = await response.aread() - return Response(response_data, status=200, content_type=response.headers.get('content-type')) - else: - error_data = await response.aread() - return jsonify({'error': error_data.decode()}), response.status_code - + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" + } + async with session.post(url=url, json=data, + headers=headers) as response: + if response.status == 200: + async for chunk_bytes in response.content: + yield chunk_bytes + @app.route('/v1/completions', methods=['POST']) async def handle_request(): - # Get the original request data - original_request_data = await request.get_json() - - # Modify the max_tokens to 1 for the request to port 8100 - modified_request_data_8100 = original_request_data.copy() - modified_request_data_8100['max_tokens'] = 1 - - # Forward the request to port 8100 - response_8100 = await forward_request('http://localhost:8100/v1/completions', modified_request_data_8100) - - if response_8100.status_code == 200: - # If the request to port 8100 is successful, forward the original request to port 8200 - response_8200 = await forward_request('http://localhost:8200/v1/completions', original_request_data) - - if response_8200.status_code == 200: - return response_8200 - else: - return jsonify({'error': 'Failed to get response from port 8200'}), response_8200.status_code - else: - return jsonify({'error': 'Failed to get response from port 8100'}), response_8100.status_code + + try: + original_request_data = await request.get_json() + + prefill_request = original_request_data.copy() + prefill_request['max_tokens'] = 1 + + # finish prefill + async for data in forward_request('http://localhost:8100/v1/completions', prefill_request): + continue + + print(f"Request {prefill_request} prefill done. proceeding to decode.") + + # return decode + generator = forward_request('http://localhost:8200/v1/completions', original_request_data) + response = await make_response(generator) + response.timeout = None + + return response + + except Exception as e: + exc_info = sys.exc_info() + print(e) + print("".join(traceback.format_exception(*exc_info))) if __name__ == '__main__': app.run(port=8000) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index e8c7bbe59e05b..b4f668bd537ec 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -215,8 +215,8 @@ def __init__( # use a threadpool to buffer send request in disaggregated prefill - self.input_hash_to_kv_sending_requests = {} - self.kv_sending_thread = ThreadPoolExecutor(max_workers=1) + self.input_hash_to_kv_sending_requests = defaultdict(list) + self.kv_sending_thread = None @property def first_rank(self): @@ -707,8 +707,6 @@ def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: if dst is None: dst = (self.rank_in_group + 1) % self.world_size - - print('Sending %.3f MB to %d' % (tensor.element_size() * tensor.numel() / 1024 / 1024, self.ranks[dst]), end=' ', flush=True) pynccl_comm = self.pynccl_comm if pynccl_comm is not None and not pynccl_comm.disabled: @@ -716,7 +714,6 @@ def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: else: torch.distributed.send(tensor, self.ranks[dst], self.device_group) - print(' End sending ', end=' ', flush=True) def recv(self, size: torch.Size, @@ -727,7 +724,6 @@ def recv(self, if src is None: src = (self.rank_in_group - 1) % self.world_size - print('Start receiving from %d',self.ranks[src]) tensor = torch.empty(size, dtype=dtype, device=self.device) pynccl_comm = self.pynccl_comm @@ -736,90 +732,129 @@ def recv(self, else: torch.distributed.recv(tensor, self.ranks[src], self.device_group) - print('End receiving') return tensor + + + def debug_send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """Will send several metadata. Useful for debugging.""" + """NOTE: `dst` is the local rank of the destination rank.""" + + + self.send_tensor_dict( + { + "tensor": tensor, + "mean": tensor.float().mean(), + "shape": tensor.shape + }, + dst + ) + + def debug_recv(self, + size: torch.Size, + dtype: torch.dtype, + src: Optional[int] = None) -> torch.Tensor: + """Receives a tensor from the src rank.""" + """NOTE: `src` is the local rank of the destination rank.""" + + result = self.recv_tensor_dict(src) + tensor = result["tensor"] + assert torch.allclose(result["mean"], tensor.float().mean()) + assert result["shape"] == tensor.shape + assert result["shape"] == size, f"The shape sent by sender is {result['shape']} but trying to receive {size}" + return tensor + + + def kv_cache_send(self, input_hash: int, tensor: torch.Tensor, dst: Optional[int] = None, - enable_verification: bool = False) -> None: + enable_verification: bool = True) -> None: """Push the KV cache send request into the send buffer""" """NOTE: `dst` is the local rank of the destination rank.""" - print('Pushing %.3f MB' % (tensor.element_size() * tensor.numel() / 1024 / 1024), end=' ', flush=True) - - if enable_verification: - # Send tensor, together with metadatas - # We will use this metadata to perform some sanity check - # But this transfer is VERY slow. - # So this is a good option for debugging but not for produciton - self.input_hash_to_kv_sending_requests[input_hash].append([ - self.send_tensor_dict, - # tensor needs to be cloned, if not the mean doesn't match - {"tensor": tensor.clone(), "mean": tensor.mean()}, - dst - ]) + send_func = self.debug_send else: - # only send tensor, use NCCL if available - # very fast but error-prone - self.input_hash_to_kv_sending_requests[input_hash].append([ - self.send, - # tensor needs to be cloned, if not the tensor may be freed - tensor.clone(), - dst - ]) - - - def sending_kv_from_input_hash(self): - - # receive the input hash that the decode instance requires - input_hash_tensor = self.recv(torch.Size([1]), torch.long) - input_hash = input_hash_tensor.item() - - # execute corresponding send jobs in request queue - for request in input_hash_to_kv_sending_requests[input_hash]: - request[0](*request[1:]) - # free GPU memory occupied by sending - del input_hash_to_kv_sending_requests[input_hash] - + send_func = self.send - def kv_cache_send_ready(self): - - self.kv_sending_thread.submit([self.sending_kv_from_input_hash]) + self.input_hash_to_kv_sending_requests[input_hash].append([ + send_func, + # tensor needs to be cloned, if not the tensor may be freed + tensor.clone(), + dst + ]) - - def kv_cache_recv_start(self, input_hash: int): - - input_hash_tensor = torch.tensor([input_hash]).long().to(self.device) - # notify the kv cache sender with the input hash id - torch.distributed.send(input_hash_tensor) - def kv_cache_recv(self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None, - enable_verification: bool = False) -> torch.Tensor: + enable_verification: bool = True) -> torch.Tensor: """Receives a tensor from the src rank (blocking).""" """This API should be used together with `push`""" """NOTE: `src` is the local rank of the destination rank.""" if enable_verification: - # receive tensor and perform verifications - result = self.recv_tensor_dict(src) - tensor = result["tensor"] - mean = result["mean"] - assert tensor.shape == size - assert tensor.dtype == dtype - assert tensor.mean() == mean + recv_func = self.debug_recv else: - tensor = self.recv(size, dtype, src) + recv_func = self.recv + + tensor = recv_func(size, dtype, src) return tensor + + def recv_input_hash_and_send_kv(self): + + try: + + # receive the input hash that the decode instance requires + logger.debug('Waiting for input hash ...') + # FIXME(Kuntai): debug_recv guarantees correctness but hurts perf + input_hash_tensor = self.debug_recv(torch.Size([1]), torch.long) + input_hash = input_hash_tensor.item() + logger.debug('Receiving input hash %d', input_hash) + assert input_hash in self.input_hash_to_kv_sending_requests, \ + f"The KV cache of {input_hash} does not exist." + logger.debug('Input hash %d exists, start sending', input_hash) + + # execute corresponding kv cache sending jobs in request queue + for idx, request in enumerate( + self.input_hash_to_kv_sending_requests[input_hash]): + request[0](*request[1:]) + logger.debug('Finish input hash %d, free memory...' % input_hash) + # free GPU memory occupied by sending + del self.input_hash_to_kv_sending_requests[input_hash] + + except Exception: + import sys + import traceback + + + def kv_cache_send_finish(self): + + if self.kv_sending_thread is None: + self.kv_sending_thread = ThreadPoolExecutor(max_workers=1) + + job = self.kv_sending_thread.submit(self.recv_input_hash_and_send_kv) + logger.debug(f'Submit job {job} into kv cache sending thread') + + + def kv_cache_recv_start(self, input_hash: int): + + logger.debug('Requesting KV cache transfer for input hash %d', input_hash) + + input_hash_tensor = torch.tensor([input_hash]).long().to(self.device) + # notify the kv cache sender with the input hash id + # FIXME(Kuntai): debug_send guarantees correctness but hurts perf. + self.debug_send(input_hash_tensor) + + + def destroy(self): if self.device_group is not None: torch.distributed.destroy_process_group(self.device_group) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 8493c0e6fc7db..349ec6514817d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1363,7 +1363,7 @@ def execute_model( envs.VLLM_DISAGG_PREFILL_ROLE != "decode", kv_caches is None, kv_caches[0] is None]): - + # model forwarding # during forwarding the KV cache will be sent in prefill instance # see vllm/attention/backends/flash_attn.py for sending impl @@ -1375,7 +1375,7 @@ def execute_model( intermediate_tensors=intermediate_tensors, **multi_modal_kwargs, **seqlen_agnostic_kwargs) - + if all([ prefill_meta is not None, @@ -1383,47 +1383,140 @@ def execute_model( kv_caches is not None, kv_caches[0] is not None,]): # send hidden state if disaggregated prefilling enabled - - get_disagg_group().push(hidden_or_intermediate_states) - else: - # receive KV cache from disaggregated prefill instance - for i in range(model_executable.model.start_layer, - model_executable.model.end_layer): + _input_tokens_list = model_input.input_tokens.tolist() + seq_lens = model_input.seq_lens + query_lens = model_input.query_lens + slot_mapping = model_input.attn_metadata.slot_mapping.flatten() + + # failed = False + # reason = "" + + # if sum(query_lens) != sum(seq_lens): + # logger.error("Query len sum is %d but seq len sum is %d", sum(query_lens), sum(seq_lens)) + # failed=True + # if sum(query_lens) != len(_input_tokens_list): + # logger.error("Input tokens len is %d, doesn't match with query lens sum %d", + # sum(query_lens), + # len(_input_tokens_list)) + # failed=True + # if slot_mapping.shape[0] != len(_input_tokens_list): + # logger.error("Slot mapping shape is %s, mismatch with input shape %s", + # slot_mapping.shape, + # len(_input_tokens_list)) + # failed=True + # if failed: + # import subprocess + # subprocess.run("ps -e | grep pt_main_thread | awk '{print $1}' | xargs kill -9", shell=True) - # get kv cache - kv_cache = kv_caches[i - model_executable.model.start_layer] - # get corresponding layer - layer = model_executable.model.layers[i] - # get kv cache shape (after sliced by tp) - _, _, num_head, head_size = kv_cache[0].shape - num_tokens = model_input.input_tokens.shape[0] - key = get_disagg_group().fetch( - torch.Size([num_tokens, num_head, head_size]), - kv_cache[0].dtype - ) - value = get_disagg_group().fetch( - torch.Size([num_tokens, num_head, head_size]), - kv_cache[0].dtype - ) + # query_lens contains new KV caches that are added to vLLM. + # so we will send them to decode instance + # FIXME(Kuntai): This assume that all requests are prefill. + if query_lens is not None: + for idx, qlen in enumerate(query_lens): + + + start_pos = sum(query_lens[:idx]) + end_pos = start_pos + qlen + input_hash = hash(tuple(_input_tokens_list[start_pos:end_pos])) + + for i in range(model_executable.model.start_layer, + model_executable.model.end_layer): + kv_cache = kv_caches[i - model_executable.model.start_layer] + + _, _, num_heads, head_size = kv_cache[0].shape + + key_cache = kv_cache[0].reshape(-1, num_heads, head_size) + value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + + current_slot_mapping = slot_mapping[start_pos:end_pos] + + get_disagg_group().kv_cache_send( + input_hash, + key_cache[current_slot_mapping]) + get_disagg_group().kv_cache_send( + input_hash, + value_cache[current_slot_mapping]) + + + get_disagg_group().kv_cache_send( + input_hash, + hidden_or_intermediate_states[start_pos:end_pos]) + get_disagg_group().kv_cache_send_finish() - key_cache, value_cache = kv_cache[0], kv_cache[1] - ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - model_input.attn_metadata.slot_mapping.flatten(), - layer.self_attn.attn.kv_cache_dtype, - layer.self_attn.attn._k_scale, - layer.self_attn.attn._v_scale, - ) - - hidden_or_intermediate_states = get_disagg_group().fetch( - torch.Size([num_tokens, model_executable.config.hidden_size]), - kv_cache[0].dtype - ) + else: + + # This is disagg decode instance, during prefill state + # Need to receive KV from the prefill instance + # FIXME(Kuntai): This impl assumes that all requests are prefill. + + _input_tokens_list = model_input.input_tokens.tolist() + query_lens = model_input.query_lens + seq_lens = model_input.seq_lens + slot_mapping = model_input.attn_metadata.slot_mapping.flatten() + + hidden_or_intermediate_states_for_one_req = [] + + # enumerate different requests + logger.debug("My query lens is %s, seq len is %s, rank is %s", + str(query_lens), + str(seq_lens), + torch.distributed.get_rank()) + if query_lens is not None: + for idx, qlen in enumerate(query_lens): + + start_pos = sum(query_lens[:idx]) + end_pos = start_pos + qlen + input_hash = hash(tuple(_input_tokens_list[start_pos:end_pos])) + num_tokens = qlen + + # notify the prefill instance to start sending KVs associated with input_hash + get_disagg_group().kv_cache_recv_start(input_hash) + + # receive KV cache from disaggregated prefill instance + for i in range(model_executable.model.start_layer, + model_executable.model.end_layer): + + # get kv cache + kv_cache = kv_caches[i - model_executable.model.start_layer] + # get corresponding layer + layer = model_executable.model.layers[i] + + # get kv cache shape (after sliced by tp) + _, _, num_heads, head_size = kv_cache[0].shape + key = get_disagg_group().kv_cache_recv( + torch.Size([num_tokens, num_heads, head_size]), + kv_cache[0].dtype + ) + value = get_disagg_group().kv_cache_recv( + torch.Size([num_tokens, num_heads, head_size]), + kv_cache[0].dtype + ) + + key_cache, value_cache = kv_cache[0], kv_cache[1] + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + layer.self_attn.attn._v_scale, + ) + + + hidden_or_intermediate_states_for_one_req.append( + get_disagg_group().kv_cache_recv( + torch.Size([num_tokens, model_executable.config.hidden_size]), + kv_cache[0].dtype + ) + ) + + # concatenate hidden states from different requests + hidden_or_intermediate_states = torch.cat(hidden_or_intermediate_states_for_one_req, dim=0) + # Compute the logits in the last pipeline stage. From e9342867c66ec03382c7be27b3d4b51f1bfcda9e Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Fri, 2 Aug 2024 06:22:31 +0000 Subject: [PATCH 133/303] update model runner --- vllm/worker/model_runner.py | 152 +++++++++++++++++++----------------- 1 file changed, 79 insertions(+), 73 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 349ec6514817d..ed552d6104ece 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -27,7 +27,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) -from vllm.distributed import get_pp_group, get_disagg_group +from vllm.distributed import get_tp_group, get_pp_group, get_disagg_group from vllm.distributed.parallel_state import graph_capture from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger @@ -1387,6 +1387,8 @@ def execute_model( _input_tokens_list = model_input.input_tokens.tolist() seq_lens = model_input.seq_lens query_lens = model_input.query_lens + seq_lens = get_tp_group().broadcast_object(seq_lens) + query_lens = get_tp_group().broadcast_object(query_lens) slot_mapping = model_input.attn_metadata.slot_mapping.flatten() # failed = False @@ -1413,37 +1415,38 @@ def execute_model( # query_lens contains new KV caches that are added to vLLM. # so we will send them to decode instance # FIXME(Kuntai): This assume that all requests are prefill. - if query_lens is not None: - for idx, qlen in enumerate(query_lens): + for idx, qlen in enumerate(query_lens): - start_pos = sum(query_lens[:idx]) - end_pos = start_pos + qlen - input_hash = hash(tuple(_input_tokens_list[start_pos:end_pos])) + start_pos = sum(query_lens[:idx]) + end_pos = start_pos + qlen + input_hash = hash(tuple(_input_tokens_list[start_pos:end_pos])) + + for i in range(model_executable.model.start_layer, + model_executable.model.end_layer): + kv_cache = kv_caches[i - model_executable.model.start_layer] - for i in range(model_executable.model.start_layer, - model_executable.model.end_layer): - kv_cache = kv_caches[i - model_executable.model.start_layer] - - _, _, num_heads, head_size = kv_cache[0].shape - - key_cache = kv_cache[0].reshape(-1, num_heads, head_size) - value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + _, _, num_heads, head_size = kv_cache[0].shape + + key_cache = kv_cache[0].reshape(-1, num_heads, head_size) + value_cache = kv_cache[1].reshape(-1, num_heads, head_size) - current_slot_mapping = slot_mapping[start_pos:end_pos] + current_slot_mapping = slot_mapping[start_pos:end_pos] - get_disagg_group().kv_cache_send( - input_hash, - key_cache[current_slot_mapping]) - get_disagg_group().kv_cache_send( - input_hash, - value_cache[current_slot_mapping]) + get_disagg_group().kv_cache_send( + input_hash, + key_cache[current_slot_mapping]) + get_disagg_group().kv_cache_send( + input_hash, + value_cache[current_slot_mapping]) - get_disagg_group().kv_cache_send( - input_hash, - hidden_or_intermediate_states[start_pos:end_pos]) - get_disagg_group().kv_cache_send_finish() + get_disagg_group().kv_cache_send( + input_hash, + hidden_or_intermediate_states[start_pos:end_pos]) + get_disagg_group().kv_cache_send_finish() + + logger.error("\033[92mKV send DONE for rank %d\033[0m", torch.distributed.get_rank()) else: @@ -1452,8 +1455,10 @@ def execute_model( # FIXME(Kuntai): This impl assumes that all requests are prefill. _input_tokens_list = model_input.input_tokens.tolist() - query_lens = model_input.query_lens seq_lens = model_input.seq_lens + query_lens = model_input.query_lens + seq_lens = get_tp_group().broadcast_object(seq_lens) + query_lens = get_tp_group().broadcast_object(query_lens) slot_mapping = model_input.attn_metadata.slot_mapping.flatten() hidden_or_intermediate_states_for_one_req = [] @@ -1463,59 +1468,60 @@ def execute_model( str(query_lens), str(seq_lens), torch.distributed.get_rank()) - if query_lens is not None: - for idx, qlen in enumerate(query_lens): + for idx, qlen in enumerate(query_lens): - start_pos = sum(query_lens[:idx]) - end_pos = start_pos + qlen - input_hash = hash(tuple(_input_tokens_list[start_pos:end_pos])) - num_tokens = qlen - - # notify the prefill instance to start sending KVs associated with input_hash - get_disagg_group().kv_cache_recv_start(input_hash) + start_pos = sum(query_lens[:idx]) + end_pos = start_pos + qlen + input_hash = hash(tuple(_input_tokens_list[start_pos:end_pos])) + num_tokens = qlen + + # notify the prefill instance to start sending KVs associated with input_hash + get_disagg_group().kv_cache_recv_start(input_hash) - # receive KV cache from disaggregated prefill instance - for i in range(model_executable.model.start_layer, - model_executable.model.end_layer): - - # get kv cache - kv_cache = kv_caches[i - model_executable.model.start_layer] - # get corresponding layer - layer = model_executable.model.layers[i] - - # get kv cache shape (after sliced by tp) - _, _, num_heads, head_size = kv_cache[0].shape - key = get_disagg_group().kv_cache_recv( - torch.Size([num_tokens, num_heads, head_size]), - kv_cache[0].dtype - ) - value = get_disagg_group().kv_cache_recv( - torch.Size([num_tokens, num_heads, head_size]), - kv_cache[0].dtype - ) - - key_cache, value_cache = kv_cache[0], kv_cache[1] - ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - slot_mapping[start_pos:end_pos], - layer.self_attn.attn.kv_cache_dtype, - layer.self_attn.attn._k_scale, - layer.self_attn.attn._v_scale, - ) + # receive KV cache from disaggregated prefill instance + for i in range(model_executable.model.start_layer, + model_executable.model.end_layer): + + # get kv cache + kv_cache = kv_caches[i - model_executable.model.start_layer] + # get corresponding layer + layer = model_executable.model.layers[i] + + # get kv cache shape (after sliced by tp) + _, _, num_heads, head_size = kv_cache[0].shape + key = get_disagg_group().kv_cache_recv( + torch.Size([num_tokens, num_heads, head_size]), + kv_cache[0].dtype + ) + value = get_disagg_group().kv_cache_recv( + torch.Size([num_tokens, num_heads, head_size]), + kv_cache[0].dtype + ) + + key_cache, value_cache = kv_cache[0], kv_cache[1] + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + layer.self_attn.attn._v_scale, + ) - hidden_or_intermediate_states_for_one_req.append( - get_disagg_group().kv_cache_recv( - torch.Size([num_tokens, model_executable.config.hidden_size]), - kv_cache[0].dtype - ) + hidden_or_intermediate_states_for_one_req.append( + get_disagg_group().kv_cache_recv( + torch.Size([num_tokens, model_executable.config.hidden_size]), + kv_cache[0].dtype ) + ) + + # concatenate hidden states from different requests + hidden_or_intermediate_states = torch.cat(hidden_or_intermediate_states_for_one_req, dim=0) - # concatenate hidden states from different requests - hidden_or_intermediate_states = torch.cat(hidden_or_intermediate_states_for_one_req, dim=0) + logger.error("\033[92mKV receive DONE for rank %d\033[0m", torch.distributed.get_rank()) From b68435ac66dafca1e2f36d75a7ffaf8891f43d86 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 4 Aug 2024 03:56:38 +0000 Subject: [PATCH 134/303] move group coordinator to a separate file, move disagg implementations to a separate file --- tests/distributed/test_parallel_state.py | 2 +- vllm/distributed/distributed_kv.py | 169 ++++++ vllm/distributed/group_coordinator.py | 714 +++++++++++++++++++++++ 3 files changed, 884 insertions(+), 1 deletion(-) create mode 100644 vllm/distributed/distributed_kv.py create mode 100644 vllm/distributed/group_coordinator.py diff --git a/tests/distributed/test_parallel_state.py b/tests/distributed/test_parallel_state.py index 3adcf6b61046d..cbb01239521a7 100644 --- a/tests/distributed/test_parallel_state.py +++ b/tests/distributed/test_parallel_state.py @@ -3,7 +3,7 @@ import pytest import torch -from vllm.distributed.parallel_state import (_split_tensor_dict, +from vllm.distributed.group_coordinator import (_split_tensor_dict, _update_nested_dict) diff --git a/vllm/distributed/distributed_kv.py b/vllm/distributed/distributed_kv.py new file mode 100644 index 0000000000000..1ed6c03872e64 --- /dev/null +++ b/vllm/distributed/distributed_kv.py @@ -0,0 +1,169 @@ +"""vLLM distributed KV cache transfer API. +These APIs are used in `vllm/worker/model_runner.py`. +""" +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch.distributed import Backend, ProcessGroup + +import vllm.envs as envs +from vllm.distributed.group_coordinator import GroupCoordinator + + +assert envs.VLLM_DISAGG_PREFILL_ROLE in [None, "prefill", "decode"], \ + "VLLM_DISAGG_PREFILL_ROLE can only be prefill or decode." + +IS_DISTRIBUTED_KV_INSTANCE = (envs.VLLM_DISAGG_PREFILL_ROLE is not None) +IS_KV_PREFILL_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE == "prefill") +IS_KV_DECODE_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE == "decode") + + +class DistributedKVCoordinator(GroupCoordinator): + """ + A class designated for distributed KV transfer + + Target use cases: + 1. Disaggregated prefill + 2. Remote KV cache storage + """ + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + use_pynccl: bool = True, + use_custom_allreduce: bool = False, + use_tpu_communicator: bool = True, + use_message_queue_broadcaster: bool = False, + use_cpu_verfication: bool = True, + ): + + super().__init__( + group_ranks, + local_rank, + torch_distributed_backend, + use_pynccl, + use_custom_allreduce, + use_tpu_communicator, + use_message_queue_broadcaster, + ) + + # if turned on, will use CPU-based communication to perform a series of sanity check. + # but it adds ~5ms delay, so please turn it off in performance-demanding usecases (e.g. disaggregated prefill) + self.use_cpu_verfication = use_cpu_verfication + + # use a threadpool to buffer send request in disaggregated prefill + self.input_hash_to_kv_sending_requests = defaultdict(list) + self.kv_sending_thread = None + + def debug_send(self, + tensor: torch.Tensor, + dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """Will send several metadata. Useful for debugging.""" + """NOTE: `dst` is the local rank of the destination rank.""" + + self.send_tensor_dict( + { + "tensor": tensor, + "mean": tensor.float().mean(), + "shape": tensor.shape + }, dst) + + def debug_recv(self, + size: torch.Size, + dtype: torch.dtype, + src: Optional[int] = None) -> torch.Tensor: + """Receives a tensor from the src rank.""" + """NOTE: `src` is the local rank of the destination rank.""" + + result = self.recv_tensor_dict(src) + tensor = result["tensor"] + assert torch.allclose(result["mean"], tensor.float().mean()) + assert result["shape"] == tensor.shape + assert result[ + "shape"] == size, f"The shape sent by sender is {result['shape']} but trying to receive {size}" + return tensor + + def kv_cache_send(self, + input_hash: int, + tensor: torch.Tensor, + dst: Optional[int] = None, + enable_verification: bool = True) -> None: + """Push the KV cache send request into the send buffer""" + """NOTE: `dst` is the local rank of the destination rank.""" + + if enable_verification: + send_func = self.debug_send + else: + send_func = self.send + + self.input_hash_to_kv_sending_requests[input_hash].append([ + send_func, + # tensor needs to be cloned, if not the tensor may be freed + tensor.clone(), + dst + ]) + + def kv_cache_recv(self, + size: torch.Size, + dtype: torch.dtype, + src: Optional[int] = None, + enable_verification: bool = True) -> torch.Tensor: + """Receives a tensor from the src rank (blocking).""" + """This API should be used together with `push`""" + """NOTE: `src` is the local rank of the destination rank.""" + + if enable_verification: + recv_func = self.debug_recv + else: + recv_func = self.recv + + tensor = recv_func(size, dtype, src) + + return tensor + + def recv_input_hash_and_send_kv(self): + + try: + + # receive the input hash that the decode instance requires + logger.debug('Waiting for input hash ...') + # FIXME(Kuntai): debug_recv guarantees correctness but hurts perf + input_hash_tensor = self.debug_recv(torch.Size([1]), torch.long) + input_hash = input_hash_tensor.item() + logger.debug('Receiving input hash %d', input_hash) + assert input_hash in self.input_hash_to_kv_sending_requests, \ + f"The KV cache of {input_hash} does not exist." + logger.debug('Input hash %d exists, start sending', input_hash) + + # execute corresponding kv cache sending jobs in request queue + for idx, request in enumerate( + self.input_hash_to_kv_sending_requests[input_hash]): + request[0](*request[1:]) + logger.debug('Finish input hash %d, free memory...' % input_hash) + # free GPU memory occupied by sending + del self.input_hash_to_kv_sending_requests[input_hash] + + except Exception: + import sys + import traceback + + def kv_cache_send_finish(self): + + if self.kv_sending_thread is None: + self.kv_sending_thread = ThreadPoolExecutor(max_workers=1) + + job = self.kv_sending_thread.submit(self.recv_input_hash_and_send_kv) + logger.debug(f'Submit job {job} into kv cache sending thread') + + def kv_cache_recv_start(self, input_hash: int): + + logger.debug('Requesting KV cache transfer for input hash %d', + input_hash) + + input_hash_tensor = torch.tensor([input_hash]).long().to(self.device) + # notify the kv cache sender with the input hash id + # FIXME(Kuntai): debug_send guarantees correctness but hurts perf. + self.debug_send(input_hash_tensor) diff --git a/vllm/distributed/group_coordinator.py b/vllm/distributed/group_coordinator.py new file mode 100644 index 0000000000000..1202ed652a0bb --- /dev/null +++ b/vllm/distributed/group_coordinator.py @@ -0,0 +1,714 @@ +"""vLLM PyTorch ProcessGroup wrapper for a group of processes. + PyTorch ProcessGroup is bound to one specific communication backend, + e.g. NCCL, Gloo, MPI, etc. + GroupCoordinator takes charge of all the communication operations among + the processes in the group. It can route the communication to + a specific implementation (e.g. switch allreduce implementation + based on the tensor size and cuda graph mode). +""" + +from dataclasses import dataclass +from contextlib import contextmanager, nullcontext +from collections import namedtuple +from typing import Any, Dict, List, Optional, Tuple, Union +import pickle + +import torch +from torch.distributed import Backend, ProcessGroup + + +@dataclass +class GraphCaptureContext: + stream: torch.cuda.Stream + + +TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) + + +def _split_tensor_dict( + tensor_dict: Dict[str, Union[torch.Tensor, Any]], + prefix: str = "") -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: + """Split the tensor dictionary into two parts: + 1. A list of (key, value) pairs. If the value is a tensor, it is replaced + by its metadata. + 2. A list of tensors. + + If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its + metadata will be "key1%key2". + """ + metadata_list: List[Tuple[str, Any]] = [] + tensor_list = [] + for key, value in tensor_dict.items(): + assert "%" not in key, ( + "Avoid having '%' in key " + "as it is used as a separator for nested entries.") + if isinstance(value, torch.Tensor): + # Note: we cannot use `value.device` here, + # because it contains not only the device type but also the device + # index (e.g. "cuda:0"). We only need the device type. + # receiving side will set the device index. + device = value.device.type + metadata_list.append( + (prefix + key, TensorMetadata(device, value.dtype, + value.size()))) + tensor_list.append(value) + elif isinstance(value, dict): + if len(value) == 0: + metadata_list.append((prefix + key, value)) + inner_metadata_list, inner_tensor_list = _split_tensor_dict( + value, prefix + key + "%") + metadata_list.extend(inner_metadata_list) + tensor_list.extend(inner_tensor_list) + else: + metadata_list.append((prefix + key, value)) + return metadata_list, tensor_list + + +def _update_nested_dict(nested_dict, flattened_key, value): + key_splits = flattened_key.split("%") + cur_dict = nested_dict + for k in key_splits[:-1]: + if k not in cur_dict: + cur_dict[k] = {} + cur_dict = cur_dict[k] + cur_dict[key_splits[-1]] = value + + +class GroupCoordinator: + """ + PyTorch ProcessGroup wrapper for a group of processes. + PyTorch ProcessGroup is bound to one specific communication backend, + e.g. NCCL, Gloo, MPI, etc. + GroupCoordinator takes charge of all the communication operations among + the processes in the group. It can route the communication to + a specific implementation (e.g. switch allreduce implementation + based on the tensor size and cuda graph mode). + """ + + # available attributes: + rank: int # global rank + ranks: List[int] # global ranks in the group + world_size: int # size of the group + # difference between `local_rank` and `rank_in_group`: + # if we have a group of size 4 across two nodes: + # Process | Node | Rank | Local Rank | Rank in Group + # 0 | 0 | 0 | 0 | 0 + # 1 | 0 | 1 | 1 | 1 + # 2 | 1 | 2 | 0 | 2 + # 3 | 1 | 3 | 1 | 3 + local_rank: int # local rank used to assign devices + rank_in_group: int # rank inside the group + cpu_group: ProcessGroup # group for CPU communication + device_group: ProcessGroup # group for device communication + use_pynccl: bool # a hint of whether to use PyNccl + use_custom_allreduce: bool # a hint of whether to use CustomAllreduce + # communicators are only created for world size > 1 + pynccl_comm: Optional[Any] # PyNccl communicator + ca_comm: Optional[Any] # Custom allreduce communicator + mq_broadcaster: Optional[Any] # shared memory broadcaster + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + use_pynccl: bool, + use_custom_allreduce: bool, + use_tpu_communicator: bool, + use_message_queue_broadcaster: bool = False, + ): + + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + self.cpu_group = None + + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + self.cpu_group = cpu_group + + assert self.cpu_group is not None + assert self.device_group is not None + + if torch.cuda.is_available(): + self.device = torch.device(f"cuda:{local_rank}") + else: + self.device = torch.device("cpu") + + self.use_pynccl = use_pynccl + self.use_custom_allreduce = use_custom_allreduce + self.use_tpu_communicator = use_tpu_communicator + + # lazy import to avoid documentation build error + from vllm.distributed.device_communicators.custom_all_reduce import ( + CustomAllreduce) + from vllm.distributed.device_communicators.pynccl import ( + PyNcclCommunicator) + + self.pynccl_comm: Optional[PyNcclCommunicator] + if use_pynccl and self.world_size > 1: + self.pynccl_comm = PyNcclCommunicator( + group=self.cpu_group, + device=self.device, + ) + else: + self.pynccl_comm = None + + self.ca_comm: Optional[CustomAllreduce] + if use_custom_allreduce and self.world_size > 1: + # Initialize a custom fast all-reduce implementation. + self.ca_comm = CustomAllreduce( + group=self.cpu_group, + device=self.device, + ) + else: + self.ca_comm = None + + from vllm.distributed.device_communicators.tpu_communicator import ( + TpuCommunicator) + self.tpu_communicator: Optional[TpuCommunicator] + if use_tpu_communicator and self.world_size > 1: + self.tpu_communicator = TpuCommunicator(group=self.cpu_group) + + from vllm.distributed.device_communicators.shm_broadcast import ( + MessageQueue) + self.mq_broadcaster: Optional[MessageQueue] = None + if use_message_queue_broadcaster and self.world_size > 1: + self.mq_broadcaster = MessageQueue.create_from_process_group( + self.cpu_group, 1 << 22, 6) + + @property + def first_rank(self): + """Return the global rank of the first process in the group""" + return self.ranks[0] + + @property + def last_rank(self): + """Return the global rank of the last process in the group""" + return self.ranks[-1] + + @property + def is_first_rank(self): + """Return whether the caller is the first process in the group""" + return self.rank == self.first_rank + + @property + def is_last_rank(self): + """Return whether the caller is the last process in the group""" + return self.rank == self.last_rank + + @property + def next_rank(self): + """Return the global rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group + 1) % world_size] + + @property + def prev_rank(self): + """Return the global rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group - 1) % world_size] + + @contextmanager + def graph_capture( + self, graph_capture_context: Optional[GraphCaptureContext] = None): + if graph_capture_context is None: + stream = torch.cuda.Stream() + graph_capture_context = GraphCaptureContext(stream) + else: + stream = graph_capture_context.stream + + ca_comm = self.ca_comm + maybe_ca_context = nullcontext( + ) if ca_comm is None else ca_comm.capture() + + # ensure all initialization operations complete before attempting to + # capture the graph on another stream + curr_stream = torch.cuda.current_stream() + if curr_stream != stream: + stream.wait_stream(curr_stream) + + with torch.cuda.stream(stream), maybe_ca_context: + # In graph mode, we have to be very careful about the collective + # operations. The current status is: + # allreduce \ Mode | Eager | Graph | + # -------------------------------------------- + # custom allreduce | enabled | enabled | + # PyNccl | disabled| enabled | + # torch.distributed | enabled | disabled| + # + # Note that custom allreduce will have a runtime check, if the + # tensor size is too large, it will fallback to the next + # available option. + # In summary: When using CUDA graph, we use + # either custom all-reduce kernel or pynccl. When not using + # CUDA graph, we use either custom all-reduce kernel or + # PyTorch NCCL. We always prioritize using custom all-reduce + # kernel but fall back to PyTorch or pynccl if it is + # disabled or not supported. + pynccl_comm = self.pynccl_comm + maybe_pynccl_context: Any + if not pynccl_comm: + maybe_pynccl_context = nullcontext() + else: + maybe_pynccl_context = pynccl_comm.change_state( + enable=True, stream=torch.cuda.current_stream()) + with maybe_pynccl_context: + yield graph_capture_context + + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + """ + NOTE: This operation will be applied in-place or out-of-place. + Always assume this function modifies its input, but use the return + value as the output. + """ + ca_comm = self.ca_comm + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + + # For TPUs, use TPU communicator. + tpu_comm = self.tpu_communicator + if tpu_comm is not None and not tpu_comm.disabled: + return tpu_comm.all_reduce(input_) + + if ca_comm is not None: + out = ca_comm.custom_all_reduce(input_) + if out is not None: + return out + pynccl_comm = self.pynccl_comm + if (pynccl_comm is not None and not pynccl_comm.disabled): + pynccl_comm.all_reduce(input_) + elif input_.is_cpu: + import intel_extension_for_pytorch as ipex + ipex.distributed.all_reduce(input_, group=self.device_group) + else: + torch.distributed.all_reduce(input_, group=self.device_group) + return input_ + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + + # For TPUs, use TPU communicator. + tpu_comm = self.tpu_communicator + if tpu_comm is not None and not tpu_comm.disabled: + return tpu_comm.all_gather(input_, dim) + + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + input_size = input_.size() + # Allocate output tensor. + output_tensor = torch.empty((world_size, ) + input_size, + dtype=input_.dtype, + device=input_.device) + # All-gather. + torch.distributed.all_gather_into_tensor(output_tensor, + input_, + group=self.device_group) + # Reshape + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape(input_size[:dim] + + (world_size * + input_size[dim], ) + + input_size[dim + 1:]) + return output_tensor + + def gather(self, + input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> torch.Tensor: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + # Gather. + torch.distributed.gather(input_, + gather_list, + dst=self.ranks[dst], + group=self.device_group) + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def broadcast(self, input_: torch.Tensor, src: int = 0): + """Broadcast the input tensor. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + # Broadcast. + torch.distributed.broadcast(input_, + src=self.ranks[src], + group=self.device_group) + return input_ + + def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): + """Broadcast the input object. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj + if self.mq_broadcaster is not None: + assert src == 0, "Message queue broadcaster only supports src=0" + return self.mq_broadcaster.broadcast_object(obj) + if self.rank_in_group == src: + torch.distributed.broadcast_object_list([obj], + src=self.ranks[src], + group=self.cpu_group) + return obj + else: + recv = [None] + torch.distributed.broadcast_object_list(recv, + src=self.ranks[src], + group=self.cpu_group) + return recv[0] + + def broadcast_object_list(self, + obj_list: List[Any], + src: int = 0, + group: Optional[ProcessGroup] = None): + """Broadcast the input object list. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj_list + # Broadcast. + torch.distributed.broadcast_object_list(obj_list, + src=self.ranks[src], + group=self.device_group) + return obj_list + + def send_object(self, obj: Any, dst: int) -> None: + """Send the input object list to the destination rank.""" + """NOTE: `dst` is the local rank of the destination rank.""" + + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + assert dst != self.rank_in_group, ( + "Invalid destination rank. Destination rank is the same " + "as the current rank.") + + # Serialize object to tensor and get the size as well + object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) + + size_tensor = torch.tensor([object_tensor.numel()], + dtype=torch.long, + device="cpu") + + # Send object size + + torch.distributed.send(size_tensor, + dst=self.ranks[dst], + group=self.cpu_group) + + # Send object + torch.distributed.send(object_tensor, + dst=self.ranks[dst], + group=self.cpu_group) + + return None + + def recv_object(self, src: int) -> Any: + """Receive the input object list from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + + assert src < self.world_size, f"Invalid src rank ({src})" + + assert src != self.rank_in_group, ( + "Invalid source rank. Source rank is the same as the current rank." + ) + + size_tensor = torch.empty(1, dtype=torch.long, device="cpu") + + # Receive object size + rank_size = torch.distributed.recv(size_tensor, + src=self.ranks[src], + group=self.cpu_group) + + # Tensor to receive serialized objects into. + object_tensor = torch.empty( # type: ignore[call-overload] + size_tensor.item(), # type: ignore[arg-type] + dtype=torch.uint8, + device="cpu") + + rank_object = torch.distributed.recv(object_tensor, + src=self.ranks[src], + group=self.cpu_group) + + assert rank_object == rank_size, ( + "Received object sender rank does not match the size sender rank.") + + obj = pickle.loads(object_tensor.numpy().tobytes()) + + return obj + + def broadcast_tensor_dict( + self, + tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, + src: int = 0, + group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Broadcast the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if (not torch.distributed.is_initialized() or self.world_size == 1): + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + assert src < self.world_size, f"Invalid src rank ({src})" + + rank_in_group = self.rank_in_group + if rank_in_group == src: + metadata_list: List[Tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, + dict), (f"Expecting a dictionary, got {type(tensor_dict)}") + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.broadcast_object(metadata_list, src=src) + async_handles = [] + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast(tensor, + src=self.ranks[src], + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast(tensor, + src=self.ranks[src], + group=group, + async_op=True) + async_handles.append(handle) + for async_handle in async_handles: + async_handle.wait() + + else: + metadata_list = self.broadcast_object(None, src=src) + tensor_dict = {} + async_handles = [] + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, + dtype=value.dtype, + device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + _update_nested_dict(tensor_dict, key, tensor) + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, + src=self.ranks[src], + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, + src=self.ranks[src], + group=group, + async_op=True) + async_handles.append(handle) + _update_nested_dict(tensor_dict, key, tensor) + else: + _update_nested_dict(tensor_dict, key, value) + for async_handle in async_handles: + async_handle.wait() + return tensor_dict + + def send_tensor_dict( + self, + tensor_dict: Dict[str, Union[torch.Tensor, Any]], + dst: Optional[int] = None + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Send the input tensor dictionary. + NOTE: `dst` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + metadata_list: List[Tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, + dict), f"Expecting a dictionary, got {type(tensor_dict)}" + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `send_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.send_object(metadata_list, dst=dst) + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip sending empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.send(tensor, + dst=self.ranks[dst], + group=metadata_group) + else: + # use group for GPU tensors + torch.distributed.send(tensor, + dst=self.ranks[dst], + group=group) + return None + + def recv_tensor_dict( + self, + src: Optional[int] = None + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Recv the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return None + + group = self.device_group + metadata_group = self.cpu_group + + if src is None: + src = (self.rank_in_group - 1) % self.world_size + assert src < self.world_size, f"Invalid src rank ({src})" + + recv_metadata_list = self.recv_object(src=src) + tensor_dict: Dict[str, Any] = {} + for key, value in recv_metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, + dtype=value.dtype, + device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + _update_nested_dict(tensor_dict, key, tensor) + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.recv(tensor, + src=self.ranks[src], + group=metadata_group) + else: + # use group for GPU tensors + torch.distributed.recv(tensor, + src=self.ranks[src], + group=group) + _update_nested_dict(tensor_dict, key, tensor) + else: + _update_nested_dict(tensor_dict, key, value) + return tensor_dict + + def barrier(self): + """Barrier synchronization among the group. + NOTE: don't use `device_group` here! `barrier` in NCCL is + terrible because it is internally a broadcast operation with + secretly created GPU tensors. It is easy to mess up the current + device. Use the CPU group instead. + """ + torch.distributed.barrier(group=self.cpu_group) + + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the local rank of the destination rank.""" + + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.send(tensor, dst) + else: + torch.distributed.send(tensor, self.ranks[dst], self.device_group) + + def recv(self, + size: torch.Size, + dtype: torch.dtype, + src: Optional[int] = None) -> torch.Tensor: + """Receives a tensor from the src rank.""" + """NOTE: `src` is the local rank of the destination rank.""" + if src is None: + src = (self.rank_in_group - 1) % self.world_size + + tensor = torch.empty(size, dtype=dtype, device=self.device) + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.recv(tensor, src) + else: + torch.distributed.recv(tensor, self.ranks[src], self.device_group) + + return tensor + + def destroy(self): + if self.device_group is not None: + torch.distributed.destroy_process_group(self.device_group) + self.device_group = None + if self.cpu_group is not None: + torch.distributed.destroy_process_group(self.cpu_group) + self.cpu_group = None + if self.pynccl_comm is not None: + self.pynccl_comm = None + if self.ca_comm is not None: + self.ca_comm = None + if self.mq_broadcaster is not None: + self.mq_broadcaster = None From e54f7a3b9413d6ec72b9c2a153987e7f4402f4bd Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 4 Aug 2024 04:06:54 +0000 Subject: [PATCH 135/303] no need to send during attention --- vllm/attention/backends/flash_attn.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 330addd3b52f2..9aaac2886510a 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -493,15 +493,6 @@ def forward( v_scale, ) - # send out the KV cache when current vllm is prefill instance - # the corresponding receive code is in vllm/worker/model_runner.py - if all([ - envs.VLLM_DISAGG_PREFILL_ROLE == "prefill", - attn_metadata.prefill_metadata is not None]): - - get_disagg_group().push(key) - get_disagg_group().push(value) - num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens assert key.shape[0] == num_prefill_tokens + num_decode_tokens @@ -572,4 +563,4 @@ def forward( ).squeeze(1) # Reshape the output tensor. - return output.view(num_tokens, hidden_size) \ No newline at end of file + return output.view(num_tokens, hidden_size) From 23c99496a4f9ee0236454161c89d742a0476bfcb Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 4 Aug 2024 07:44:00 +0000 Subject: [PATCH 136/303] debug tp --- vllm/distributed/group_coordinator.py | 11 +- vllm/distributed/parallel_state.py | 914 ++------------------------ vllm/worker/model_runner.py | 175 +---- 3 files changed, 78 insertions(+), 1022 deletions(-) diff --git a/vllm/distributed/group_coordinator.py b/vllm/distributed/group_coordinator.py index 1202ed652a0bb..411e59ce24404 100644 --- a/vllm/distributed/group_coordinator.py +++ b/vllm/distributed/group_coordinator.py @@ -637,9 +637,14 @@ def recv_tensor_dict( tensor_dict: Dict[str, Any] = {} for key, value in recv_metadata_list: if isinstance(value, TensorMetadata): + target_device = value.device + if 'cuda' in value.device: + # receiving a cuda tensor + # need to allocate buffer on LOCAL cuda device + target_device = self.device tensor = torch.empty(value.size, dtype=value.dtype, - device=value.device) + device=target_device) if tensor.numel() == 0: # Skip broadcasting empty tensors. _update_nested_dict(tensor_dict, key, tensor) @@ -652,8 +657,8 @@ def recv_tensor_dict( else: # use group for GPU tensors torch.distributed.recv(tensor, - src=self.ranks[src], - group=group) + src=self.ranks[src], + group=group) _update_nested_dict(tensor_dict, key, tensor) else: _update_nested_dict(tensor_dict, key, value) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index b4f668bd537ec..d2f8d7ea2a21e 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -24,13 +24,12 @@ import contextlib import pickle import logging -from collections import namedtuple, defaultdict +from collections import namedtuple from contextlib import contextmanager, nullcontext from dataclasses import dataclass from multiprocessing import shared_memory from typing import Any, Dict, List, Optional, Tuple, Union from unittest.mock import patch -from concurrent.futures import ThreadPoolExecutor import queue import torch @@ -39,837 +38,9 @@ import vllm.envs as envs from vllm.logger import init_logger +from vllm.distributed.group_coordinator import GroupCoordinator +import vllm.distributed.distributed_kv as dist_kv - -@dataclass -class GraphCaptureContext: - stream: torch.cuda.Stream - - -TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) - - -def _split_tensor_dict( - tensor_dict: Dict[str, Union[torch.Tensor, Any]], - prefix: str = "") -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: - """Split the tensor dictionary into two parts: - 1. A list of (key, value) pairs. If the value is a tensor, it is replaced - by its metadata. - 2. A list of tensors. - - If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its - metadata will be "key1%key2". - """ - metadata_list: List[Tuple[str, Any]] = [] - tensor_list = [] - for key, value in tensor_dict.items(): - assert "%" not in key, ( - "Avoid having '%' in key " - "as it is used as a separator for nested entries.") - if isinstance(value, torch.Tensor): - # Note: we cannot use `value.device` here, - # because it contains not only the device type but also the device - # index (e.g. "cuda:0"). We only need the device type. - # receiving side will set the device index. - device = value.device.type - metadata_list.append( - (prefix + key, TensorMetadata(device, value.dtype, - value.size()))) - tensor_list.append(value) - elif isinstance(value, dict): - if len(value) == 0: - metadata_list.append((prefix + key, value)) - inner_metadata_list, inner_tensor_list = _split_tensor_dict( - value, prefix + key + "%") - metadata_list.extend(inner_metadata_list) - tensor_list.extend(inner_tensor_list) - else: - metadata_list.append((prefix + key, value)) - return metadata_list, tensor_list - - -def _update_nested_dict(nested_dict, flattened_key, value): - key_splits = flattened_key.split("%") - cur_dict = nested_dict - for k in key_splits[:-1]: - if k not in cur_dict: - cur_dict[k] = {} - cur_dict = cur_dict[k] - cur_dict[key_splits[-1]] = value - - -class GroupCoordinator: - """ - PyTorch ProcessGroup wrapper for a group of processes. - PyTorch ProcessGroup is bound to one specific communication backend, - e.g. NCCL, Gloo, MPI, etc. - GroupCoordinator takes charge of all the communication operations among - the processes in the group. It can route the communication to - a specific implementation (e.g. switch allreduce implementation - based on the tensor size and cuda graph mode). - """ - - # available attributes: - rank: int # global rank - ranks: List[int] # global ranks in the group - world_size: int # size of the group - # difference between `local_rank` and `rank_in_group`: - # if we have a group of size 4 across two nodes: - # Process | Node | Rank | Local Rank | Rank in Group - # 0 | 0 | 0 | 0 | 0 - # 1 | 0 | 1 | 1 | 1 - # 2 | 1 | 2 | 0 | 2 - # 3 | 1 | 3 | 1 | 3 - local_rank: int # local rank used to assign devices - rank_in_group: int # rank inside the group - cpu_group: ProcessGroup # group for CPU communication - device_group: ProcessGroup # group for device communication - use_pynccl: bool # a hint of whether to use PyNccl - use_custom_allreduce: bool # a hint of whether to use CustomAllreduce - # communicators are only created for world size > 1 - pynccl_comm: Optional[Any] # PyNccl communicator - ca_comm: Optional[Any] # Custom allreduce communicator - mq_broadcaster: Optional[Any] # shared memory broadcaster - - def __init__( - self, - group_ranks: List[List[int]], - local_rank: int, - torch_distributed_backend: Union[str, Backend], - use_pynccl: bool, - use_custom_allreduce: bool, - use_tpu_communicator: bool, - use_message_queue_broadcaster: bool = False, - ): - - self.rank = torch.distributed.get_rank() - self.local_rank = local_rank - self.device_group = None - self.cpu_group = None - - for ranks in group_ranks: - device_group = torch.distributed.new_group( - ranks, backend=torch_distributed_backend) - # a group with `gloo` backend, to allow direct coordination between - # processes through the CPU. - cpu_group = torch.distributed.new_group(ranks, backend="gloo") - if self.rank in ranks: - self.ranks = ranks - self.world_size = len(ranks) - self.rank_in_group = ranks.index(self.rank) - self.device_group = device_group - self.cpu_group = cpu_group - - assert self.cpu_group is not None - assert self.device_group is not None - - - if torch.cuda.is_available(): - self.device = torch.device(f"cuda:{local_rank}") - else: - self.device = torch.device("cpu") - - - self.use_pynccl = use_pynccl - self.use_custom_allreduce = use_custom_allreduce - self.use_tpu_communicator = use_tpu_communicator - - # lazy import to avoid documentation build error - from vllm.distributed.device_communicators.custom_all_reduce import ( - CustomAllreduce) - from vllm.distributed.device_communicators.pynccl import ( - PyNcclCommunicator) - - - self.pynccl_comm: Optional[PyNcclCommunicator] - if use_pynccl and self.world_size > 1: - self.pynccl_comm = PyNcclCommunicator( - group=self.cpu_group, - device=self.device, - ) - else: - self.pynccl_comm = None - - self.ca_comm: Optional[CustomAllreduce] - if use_custom_allreduce and self.world_size > 1: - # Initialize a custom fast all-reduce implementation. - self.ca_comm = CustomAllreduce( - group=self.cpu_group, - device=self.device, - ) - else: - self.ca_comm = None - - from vllm.distributed.device_communicators.tpu_communicator import ( - TpuCommunicator) - self.tpu_communicator: Optional[TpuCommunicator] - if use_tpu_communicator and self.world_size > 1: - self.tpu_communicator = TpuCommunicator(group=self.cpu_group) - - from vllm.distributed.device_communicators.shm_broadcast import ( - MessageQueue) - self.mq_broadcaster: Optional[MessageQueue] = None - if use_message_queue_broadcaster and self.world_size > 1: - self.mq_broadcaster = MessageQueue.create_from_process_group( - self.cpu_group, 1 << 22, 6) - - - # use a threadpool to buffer send request in disaggregated prefill - self.input_hash_to_kv_sending_requests = defaultdict(list) - self.kv_sending_thread = None - - @property - def first_rank(self): - """Return the global rank of the first process in the group""" - return self.ranks[0] - - @property - def last_rank(self): - """Return the global rank of the last process in the group""" - return self.ranks[-1] - - @property - def is_first_rank(self): - """Return whether the caller is the first process in the group""" - return self.rank == self.first_rank - - @property - def is_last_rank(self): - """Return whether the caller is the last process in the group""" - return self.rank == self.last_rank - - @property - def next_rank(self): - """Return the global rank of the process that follows the caller""" - rank_in_group = self.rank_in_group - world_size = self.world_size - return self.ranks[(rank_in_group + 1) % world_size] - - @property - def prev_rank(self): - """Return the global rank of the process that precedes the caller""" - rank_in_group = self.rank_in_group - world_size = self.world_size - return self.ranks[(rank_in_group - 1) % world_size] - - @contextmanager - def graph_capture( - self, graph_capture_context: Optional[GraphCaptureContext] = None): - if graph_capture_context is None: - stream = torch.cuda.Stream() - graph_capture_context = GraphCaptureContext(stream) - else: - stream = graph_capture_context.stream - - ca_comm = self.ca_comm - maybe_ca_context = nullcontext( - ) if ca_comm is None else ca_comm.capture() - - # ensure all initialization operations complete before attempting to - # capture the graph on another stream - curr_stream = torch.cuda.current_stream() - if curr_stream != stream: - stream.wait_stream(curr_stream) - - with torch.cuda.stream(stream), maybe_ca_context: - # In graph mode, we have to be very careful about the collective - # operations. The current status is: - # allreduce \ Mode | Eager | Graph | - # -------------------------------------------- - # custom allreduce | enabled | enabled | - # PyNccl | disabled| enabled | - # torch.distributed | enabled | disabled| - # - # Note that custom allreduce will have a runtime check, if the - # tensor size is too large, it will fallback to the next - # available option. - # In summary: When using CUDA graph, we use - # either custom all-reduce kernel or pynccl. When not using - # CUDA graph, we use either custom all-reduce kernel or - # PyTorch NCCL. We always prioritize using custom all-reduce - # kernel but fall back to PyTorch or pynccl if it is - # disabled or not supported. - pynccl_comm = self.pynccl_comm - maybe_pynccl_context: Any - if not pynccl_comm: - maybe_pynccl_context = nullcontext() - else: - maybe_pynccl_context = pynccl_comm.change_state( - enable=True, stream=torch.cuda.current_stream()) - with maybe_pynccl_context: - yield graph_capture_context - - def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: - """ - NOTE: This operation will be applied in-place or out-of-place. - Always assume this function modifies its input, but use the return - value as the output. - """ - ca_comm = self.ca_comm - - # Bypass the function if we are using only 1 GPU. - if self.world_size == 1: - return input_ - - # For TPUs, use TPU communicator. - tpu_comm = self.tpu_communicator - if tpu_comm is not None and not tpu_comm.disabled: - return tpu_comm.all_reduce(input_) - - if ca_comm is not None: - out = ca_comm.custom_all_reduce(input_) - if out is not None: - return out - pynccl_comm = self.pynccl_comm - if (pynccl_comm is not None and not pynccl_comm.disabled): - pynccl_comm.all_reduce(input_) - elif input_.is_cpu: - import intel_extension_for_pytorch as ipex - ipex.distributed.all_reduce(input_, group=self.device_group) - else: - torch.distributed.all_reduce(input_, group=self.device_group) - return input_ - - def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: - world_size = self.world_size - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - assert -input_.dim() <= dim < input_.dim(), ( - f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") - - # For TPUs, use TPU communicator. - tpu_comm = self.tpu_communicator - if tpu_comm is not None and not tpu_comm.disabled: - return tpu_comm.all_gather(input_, dim) - - if dim < 0: - # Convert negative dim to positive. - dim += input_.dim() - input_size = input_.size() - # Allocate output tensor. - output_tensor = torch.empty((world_size, ) + input_size, - dtype=input_.dtype, - device=input_.device) - # All-gather. - torch.distributed.all_gather_into_tensor(output_tensor, - input_, - group=self.device_group) - # Reshape - output_tensor = output_tensor.movedim(0, dim) - output_tensor = output_tensor.reshape(input_size[:dim] + - (world_size * - input_size[dim], ) + - input_size[dim + 1:]) - return output_tensor - - def gather(self, - input_: torch.Tensor, - dst: int = 0, - dim: int = -1) -> torch.Tensor: - """ - NOTE: We assume that the input tensor is on the same device across - all the ranks. - NOTE: `dst` is the local rank of the destination rank. - """ - world_size = self.world_size - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - assert -input_.dim() <= dim < input_.dim(), ( - f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") - if dim < 0: - # Convert negative dim to positive. - dim += input_.dim() - # Allocate output tensor. - if self.rank_in_group == dst: - gather_list = [torch.empty_like(input_) for _ in range(world_size)] - else: - gather_list = None - # Gather. - torch.distributed.gather(input_, - gather_list, - dst=self.ranks[dst], - group=self.device_group) - if self.rank_in_group == dst: - output_tensor = torch.cat(gather_list, dim=dim) - else: - output_tensor = None - return output_tensor - - def broadcast(self, input_: torch.Tensor, src: int = 0): - """Broadcast the input tensor. - NOTE: `src` is the local rank of the source rank. - """ - assert src < self.world_size, f"Invalid src rank ({src})" - - # Bypass the function if we are using only 1 GPU. - if self.world_size == 1: - return input_ - # Broadcast. - torch.distributed.broadcast(input_, - src=self.ranks[src], - group=self.device_group) - return input_ - - def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): - """Broadcast the input object. - NOTE: `src` is the local rank of the source rank. - """ - assert src < self.world_size, f"Invalid src rank ({src})" - - # Bypass the function if we are using only 1 GPU. - if self.world_size == 1: - return obj - if self.mq_broadcaster is not None: - assert src == 0, "Message queue broadcaster only supports src=0" - return self.mq_broadcaster.broadcast_object(obj) - if self.rank_in_group == src: - torch.distributed.broadcast_object_list([obj], - src=self.ranks[src], - group=self.cpu_group) - return obj - else: - recv = [None] - torch.distributed.broadcast_object_list(recv, - src=self.ranks[src], - group=self.cpu_group) - return recv[0] - - def broadcast_object_list(self, - obj_list: List[Any], - src: int = 0, - group: Optional[ProcessGroup] = None): - """Broadcast the input object list. - NOTE: `src` is the local rank of the source rank. - """ - assert src < self.world_size, f"Invalid src rank ({src})" - - # Bypass the function if we are using only 1 GPU. - if self.world_size == 1: - return obj_list - # Broadcast. - torch.distributed.broadcast_object_list(obj_list, - src=self.ranks[src], - group=self.device_group) - return obj_list - - def send_object(self, obj: Any, dst: int) -> None: - """Send the input object list to the destination rank.""" - """NOTE: `dst` is the local rank of the destination rank.""" - - assert dst < self.world_size, f"Invalid dst rank ({dst})" - - assert dst != self.rank_in_group, ( - "Invalid destination rank. Destination rank is the same " - "as the current rank.") - - # Serialize object to tensor and get the size as well - object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) - - size_tensor = torch.tensor([object_tensor.numel()], - dtype=torch.long, - device="cpu") - - # Send object size - - torch.distributed.send(size_tensor, - dst=self.ranks[dst], - group=self.cpu_group) - - # Send object - torch.distributed.send(object_tensor, - dst=self.ranks[dst], - group=self.cpu_group) - - return None - - def recv_object(self, src: int) -> Any: - """Receive the input object list from the source rank.""" - """NOTE: `src` is the local rank of the source rank.""" - - assert src < self.world_size, f"Invalid src rank ({src})" - - assert src != self.rank_in_group, ( - "Invalid source rank. Source rank is the same as the current rank." - ) - - size_tensor = torch.empty(1, dtype=torch.long, device="cpu") - - # Receive object size - rank_size = torch.distributed.recv(size_tensor, - src=self.ranks[src], - group=self.cpu_group) - - # Tensor to receive serialized objects into. - object_tensor = torch.empty( # type: ignore[call-overload] - size_tensor.item(), # type: ignore[arg-type] - dtype=torch.uint8, - device="cpu") - - rank_object = torch.distributed.recv(object_tensor, - src=self.ranks[src], - group=self.cpu_group) - - assert rank_object == rank_size, ( - "Received object sender rank does not match the size sender rank.") - - obj = pickle.loads(object_tensor.numpy().tobytes()) - - return obj - - def broadcast_tensor_dict( - self, - tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, - src: int = 0, - group: Optional[ProcessGroup] = None, - metadata_group: Optional[ProcessGroup] = None - ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: - """Broadcast the input tensor dictionary. - NOTE: `src` is the local rank of the source rank. - """ - # Bypass the function if we are using only 1 GPU. - if (not torch.distributed.is_initialized() or self.world_size == 1): - return tensor_dict - - group = self.device_group - metadata_group = self.cpu_group - assert src < self.world_size, f"Invalid src rank ({src})" - - rank_in_group = self.rank_in_group - if rank_in_group == src: - metadata_list: List[Tuple[Any, Any]] = [] - assert isinstance( - tensor_dict, - dict), (f"Expecting a dictionary, got {type(tensor_dict)}") - metadata_list, tensor_list = _split_tensor_dict(tensor_dict) - # `metadata_list` lives in CPU memory. - # `broadcast_object_list` has serialization & deserialization, - # all happening on CPU. Therefore, we can use the CPU group. - self.broadcast_object(metadata_list, src=src) - async_handles = [] - for tensor in tensor_list: - if tensor.numel() == 0: - # Skip broadcasting empty tensors. - continue - if tensor.is_cpu: - # use metadata_group for CPU tensors - handle = torch.distributed.broadcast(tensor, - src=self.ranks[src], - group=metadata_group, - async_op=True) - else: - # use group for GPU tensors - handle = torch.distributed.broadcast(tensor, - src=self.ranks[src], - group=group, - async_op=True) - async_handles.append(handle) - for async_handle in async_handles: - async_handle.wait() - - else: - metadata_list = self.broadcast_object(None, src=src) - tensor_dict = {} - async_handles = [] - for key, value in metadata_list: - if isinstance(value, TensorMetadata): - tensor = torch.empty(value.size, - dtype=value.dtype, - device=value.device) - if tensor.numel() == 0: - # Skip broadcasting empty tensors. - _update_nested_dict(tensor_dict, key, tensor) - continue - if tensor.is_cpu: - # use metadata_group for CPU tensors - handle = torch.distributed.broadcast( - tensor, - src=self.ranks[src], - group=metadata_group, - async_op=True) - else: - # use group for GPU tensors - handle = torch.distributed.broadcast( - tensor, - src=self.ranks[src], - group=group, - async_op=True) - async_handles.append(handle) - _update_nested_dict(tensor_dict, key, tensor) - else: - _update_nested_dict(tensor_dict, key, value) - for async_handle in async_handles: - async_handle.wait() - return tensor_dict - - def send_tensor_dict( - self, - tensor_dict: Dict[str, Union[torch.Tensor, Any]], - dst: Optional[int] = None - ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: - """Send the input tensor dictionary. - NOTE: `dst` is the local rank of the source rank. - """ - # Bypass the function if we are using only 1 GPU. - if not torch.distributed.is_initialized() or self.world_size == 1: - return tensor_dict - - group = self.device_group - metadata_group = self.cpu_group - - if dst is None: - dst = (self.rank_in_group + 1) % self.world_size - assert dst < self.world_size, f"Invalid dst rank ({dst})" - - metadata_list: List[Tuple[Any, Any]] = [] - assert isinstance( - tensor_dict, - dict), f"Expecting a dictionary, got {type(tensor_dict)}" - metadata_list, tensor_list = _split_tensor_dict(tensor_dict) - # `metadata_list` lives in CPU memory. - # `send_object_list` has serialization & deserialization, - # all happening on CPU. Therefore, we can use the CPU group. - self.send_object(metadata_list, dst=dst) - for tensor in tensor_list: - if tensor.numel() == 0: - # Skip sending empty tensors. - continue - if tensor.is_cpu: - # use metadata_group for CPU tensors - torch.distributed.send(tensor, - dst=self.ranks[dst], - group=metadata_group) - else: - # use group for GPU tensors - torch.distributed.send(tensor, - dst=self.ranks[dst], - group=group) - return None - - def recv_tensor_dict( - self, - src: Optional[int] = None - ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: - """Recv the input tensor dictionary. - NOTE: `src` is the local rank of the source rank. - """ - # Bypass the function if we are using only 1 GPU. - if not torch.distributed.is_initialized() or self.world_size == 1: - return None - - group = self.device_group - metadata_group = self.cpu_group - - if src is None: - src = (self.rank_in_group - 1) % self.world_size - assert src < self.world_size, f"Invalid src rank ({src})" - - recv_metadata_list = self.recv_object(src=src) - tensor_dict: Dict[str, Any] = {} - for key, value in recv_metadata_list: - if isinstance(value, TensorMetadata): - tensor = torch.empty(value.size, - dtype=value.dtype, - device=value.device) - if tensor.numel() == 0: - # Skip broadcasting empty tensors. - _update_nested_dict(tensor_dict, key, tensor) - continue - if tensor.is_cpu: - # use metadata_group for CPU tensors - torch.distributed.recv(tensor, - src=self.ranks[src], - group=metadata_group) - else: - # use group for GPU tensors - torch.distributed.recv(tensor, - src=self.ranks[src], - group=group) - _update_nested_dict(tensor_dict, key, tensor) - else: - _update_nested_dict(tensor_dict, key, value) - return tensor_dict - - def barrier(self): - """Barrier synchronization among the group. - NOTE: don't use `device_group` here! `barrier` in NCCL is - terrible because it is internally a broadcast operation with - secretly created GPU tensors. It is easy to mess up the current - device. Use the CPU group instead. - """ - torch.distributed.barrier(group=self.cpu_group) - - def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: - """Sends a tensor to the destination rank in a non-blocking way""" - """NOTE: `dst` is the local rank of the destination rank.""" - - - if dst is None: - dst = (self.rank_in_group + 1) % self.world_size - - pynccl_comm = self.pynccl_comm - if pynccl_comm is not None and not pynccl_comm.disabled: - pynccl_comm.send(tensor, dst) - else: - torch.distributed.send(tensor, self.ranks[dst], self.device_group) - - - def recv(self, - size: torch.Size, - dtype: torch.dtype, - src: Optional[int] = None) -> torch.Tensor: - """Receives a tensor from the src rank.""" - """NOTE: `src` is the local rank of the destination rank.""" - if src is None: - src = (self.rank_in_group - 1) % self.world_size - - - tensor = torch.empty(size, dtype=dtype, device=self.device) - pynccl_comm = self.pynccl_comm - if pynccl_comm is not None and not pynccl_comm.disabled: - pynccl_comm.recv(tensor, src) - else: - torch.distributed.recv(tensor, self.ranks[src], self.device_group) - - return tensor - - - def debug_send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: - """Sends a tensor to the destination rank in a non-blocking way""" - """Will send several metadata. Useful for debugging.""" - """NOTE: `dst` is the local rank of the destination rank.""" - - - self.send_tensor_dict( - { - "tensor": tensor, - "mean": tensor.float().mean(), - "shape": tensor.shape - }, - dst - ) - - def debug_recv(self, - size: torch.Size, - dtype: torch.dtype, - src: Optional[int] = None) -> torch.Tensor: - """Receives a tensor from the src rank.""" - """NOTE: `src` is the local rank of the destination rank.""" - - result = self.recv_tensor_dict(src) - tensor = result["tensor"] - assert torch.allclose(result["mean"], tensor.float().mean()) - assert result["shape"] == tensor.shape - assert result["shape"] == size, f"The shape sent by sender is {result['shape']} but trying to receive {size}" - return tensor - - - - - - def kv_cache_send(self, - input_hash: int, - tensor: torch.Tensor, - dst: Optional[int] = None, - enable_verification: bool = True) -> None: - """Push the KV cache send request into the send buffer""" - """NOTE: `dst` is the local rank of the destination rank.""" - - if enable_verification: - send_func = self.debug_send - else: - send_func = self.send - - self.input_hash_to_kv_sending_requests[input_hash].append([ - send_func, - # tensor needs to be cloned, if not the tensor may be freed - tensor.clone(), - dst - ]) - - - def kv_cache_recv(self, - size: torch.Size, - dtype: torch.dtype, - src: Optional[int] = None, - enable_verification: bool = True) -> torch.Tensor: - """Receives a tensor from the src rank (blocking).""" - """This API should be used together with `push`""" - """NOTE: `src` is the local rank of the destination rank.""" - - if enable_verification: - recv_func = self.debug_recv - else: - recv_func = self.recv - - tensor = recv_func(size, dtype, src) - - return tensor - - - def recv_input_hash_and_send_kv(self): - - try: - - # receive the input hash that the decode instance requires - logger.debug('Waiting for input hash ...') - # FIXME(Kuntai): debug_recv guarantees correctness but hurts perf - input_hash_tensor = self.debug_recv(torch.Size([1]), torch.long) - input_hash = input_hash_tensor.item() - logger.debug('Receiving input hash %d', input_hash) - assert input_hash in self.input_hash_to_kv_sending_requests, \ - f"The KV cache of {input_hash} does not exist." - logger.debug('Input hash %d exists, start sending', input_hash) - - # execute corresponding kv cache sending jobs in request queue - for idx, request in enumerate( - self.input_hash_to_kv_sending_requests[input_hash]): - request[0](*request[1:]) - logger.debug('Finish input hash %d, free memory...' % input_hash) - # free GPU memory occupied by sending - del self.input_hash_to_kv_sending_requests[input_hash] - - except Exception: - import sys - import traceback - - - def kv_cache_send_finish(self): - - if self.kv_sending_thread is None: - self.kv_sending_thread = ThreadPoolExecutor(max_workers=1) - - job = self.kv_sending_thread.submit(self.recv_input_hash_and_send_kv) - logger.debug(f'Submit job {job} into kv cache sending thread') - - - def kv_cache_recv_start(self, input_hash: int): - - logger.debug('Requesting KV cache transfer for input hash %d', input_hash) - - input_hash_tensor = torch.tensor([input_hash]).long().to(self.device) - # notify the kv cache sender with the input hash id - # FIXME(Kuntai): debug_send guarantees correctness but hurts perf. - self.debug_send(input_hash_tensor) - - - - def destroy(self): - if self.device_group is not None: - torch.distributed.destroy_process_group(self.device_group) - self.device_group = None - if self.cpu_group is not None: - torch.distributed.destroy_process_group(self.cpu_group) - self.cpu_group = None - if self.pynccl_comm is not None: - self.pynccl_comm = None - if self.ca_comm is not None: - self.ca_comm = None - if self.mq_broadcaster is not None: - self.mq_broadcaster = None - - _WORLD: Optional[GroupCoordinator] = None @@ -933,16 +104,15 @@ def get_pp_group() -> GroupCoordinator: # kept for backward compatibility get_pipeline_model_parallel_group = get_pp_group +_DISAGG: Optional[dist_kv.DistributedKVCoordinator] = None -_DISAGG: Optional[GroupCoordinator] = None -def get_disagg_group() -> GroupCoordinator: +def get_disagg_group() -> dist_kv.DistributedKVCoordinator: assert _DISAGG is not None, ( "disaggregated prefill parallel group is not initialized") return _DISAGG - @contextmanager def graph_capture(): """ @@ -971,7 +141,7 @@ def graph_capture(): def set_custom_all_reduce(enable: bool): global _ENABLE_CUSTOM_ALL_REDUCE _ENABLE_CUSTOM_ALL_REDUCE = enable - + def include_decoding_groups_if_disagg_enabled( groups: List[List[int]], @@ -989,9 +159,7 @@ def include_decoding_groups_if_disagg_enabled( world_size: the vLLM world size, which is half of torch.distributed.get_world_size() """ - if envs.VLLM_DISAGG_PREFILL_ROLE is not None: - assert envs.VLLM_DISAGG_PREFILL_ROLE in ["prefill", "decode"], ( - "VLLM_DISAGG_PREFILL_ROLE should be either prefill or decode") + if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: new_groups = [] for group in groups: new_groups.append([rank for rank in group]) @@ -1020,21 +188,21 @@ def init_distributed_environment( # this backend is used for WORLD maybe_disagg_world_size = world_size maybe_disagg_rank = rank - if envs.VLLM_DISAGG_PREFILL_ROLE is not None: + if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: maybe_disagg_world_size = world_size * 2 - logger.debug( - "Disaggregated prefill enabled.") - assert envs.VLLM_DISAGG_PREFILL_ROLE in ["prefill", "decode"], ( - "VLLM_DISAGG_PREFILL_ROLE should be either prefill or decode") - if envs.VLLM_DISAGG_PREFILL_ROLE == "prefill": + logger.debug("Disaggregated prefill enabled.") + if dist_kv.IS_KV_PREFILL_INSTANCE: # for prefill, the ranks are [0, world_size) maybe_disagg_rank = rank else: + # this is decode instance. # offset global rank by tp * pp (which is world_size) maybe_disagg_rank = rank + world_size - logger.debug(f"Before: world size {maybe_disagg_world_size}, rank {maybe_disagg_rank}") - + logger.debug( + f"Before: world size {maybe_disagg_world_size}, rank {maybe_disagg_rank}" + ) + torch.distributed.init_process_group( backend=backend, init_method=distributed_init_method, @@ -1051,25 +219,23 @@ def init_distributed_environment( local_rank = envs.LOCAL_RANK else: local_rank = rank - - + global _WORLD if _WORLD is None: ranks = [[i for i in range(world_size)]] # offset the distributed group - if envs.VLLM_DISAGG_PREFILL_ROLE is not None: - ranks = include_decoding_groups_if_disagg_enabled(ranks, world_size) - + if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: + ranks = include_decoding_groups_if_disagg_enabled( + ranks, world_size) + _WORLD = init_world_group(ranks, local_rank, backend) - logger.debug("_WORLD initialized for rank %d", torch.distributed.get_rank()) + logger.debug("_WORLD initialized for rank %d", + torch.distributed.get_rank()) time.sleep(5) else: assert _WORLD.world_size == torch.distributed.get_world_size(), ( "world group already initialized with a different world size") - - - def initialize_model_parallel( tensor_model_parallel_size: int = 1, @@ -1114,20 +280,20 @@ def initialize_model_parallel( - [ [0, tp * pp], [1, tp * pp + 1], .. ] - Local rank: unchanged """ - + # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() world_size: int = torch.distributed.get_world_size() backend = backend or torch.distributed.get_backend( - get_world_group().device_group) - if envs.VLLM_DISAGG_PREFILL_ROLE is not None: + get_world_group().device_group) + if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: # Disaggregated prefill enabled # The world_size for this vLLM instance is tp * pp, but torch.distributed contains 2 vLLM instances, its world size is 2 * tp * pp # Adjust the world_size to match. world_size = world_size // 2 - if (world_size != - tensor_model_parallel_size * pipeline_model_parallel_size): + if (world_size + != tensor_model_parallel_size * pipeline_model_parallel_size): raise RuntimeError( f"world_size ({world_size}) is not equal to " f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " @@ -1144,7 +310,8 @@ def initialize_model_parallel( range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)) group_ranks.append(ranks) - group_ranks = include_decoding_groups_if_disagg_enabled(group_ranks, world_size) + group_ranks = include_decoding_groups_if_disagg_enabled( + group_ranks, world_size) # message queue broadcaster is only used in tensor model parallel group _TP = init_model_parallel_group(group_ranks, get_world_group().local_rank, @@ -1153,7 +320,7 @@ def initialize_model_parallel( logger.debug("_TP initialized for rank %d", torch.distributed.get_rank()) # Build the pipeline model-parallel groups. - num_pipeline_model_parallel_groups: int = (world_size // + num_pipeline_model_parallel_groups: int = (world_size // pipeline_model_parallel_size) global _PP assert _PP is None, ( @@ -1162,15 +329,16 @@ def initialize_model_parallel( for i in range(num_pipeline_model_parallel_groups): ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) group_ranks.append(ranks) - group_ranks = include_decoding_groups_if_disagg_enabled(group_ranks, world_size) + group_ranks = include_decoding_groups_if_disagg_enabled( + group_ranks, world_size) # pipeline parallel does not need custom allreduce _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False) logger.debug("_PP initialized for rank %d", torch.distributed.get_rank()) - - if envs.VLLM_DISAGG_PREFILL_ROLE is not None: + + if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: global _DISAGG logger.debug("Disaggregated prefill enabled, create _DISAGG group") group_ranks = [] @@ -1179,19 +347,21 @@ def initialize_model_parallel( # decode global rank: i + world_size group_ranks.append([i, i + world_size]) logger.debug("Distributed group is %s", str(group_ranks)) - _DISAGG = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - use_custom_allreduce=False) + _DISAGG = dist_kv.DistributedKVCoordinator( + group_ranks=group_ranks, + local_rank=get_world_group().local_rank, + torch_distributed_backend=backend, + ) # follow by a warmup, to warmup nccl # necessary, as NCCL may not be warmed up when tp and pp are both 1. temp_tensor = torch.tensor([1.]).to(_DISAGG.device) - if envs.VLLM_DISAGG_PREFILL_ROLE == "prefill": + if dist_kv.IS_KV_PREFILL_INSTANCE: _DISAGG.send(temp_tensor) else: recv_tensor = _DISAGG.recv(temp_tensor.shape, temp_tensor.dtype) assert torch.allclose(temp_tensor, recv_tensor) - logger.debug("_DISAGG initialized for rank %d", torch.distributed.get_rank()) + logger.debug("_DISAGG initialized for rank %d", + torch.distributed.get_rank()) def ensure_model_parallel_initialized( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index ed552d6104ece..3e097438449ca 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -12,6 +12,8 @@ import torch.distributed import torch.nn as nn +import vllm.distributed.distributed_kv as dist_kv + try: from flashinfer import BatchDecodeWithPagedKVCacheWrapper from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper @@ -1355,14 +1357,17 @@ def execute_model( "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, } if self.has_seqlen_agnostic else {} - - # call `model_executable` - # and handle KV cache transfer for disaggregated prefilling + # check if the current run is profiling + is_profile_run = (kv_caches is None) or (kv_caches[0] is None) + # check if the current run is prefill + is_prefill_run = prefill_meta is not None + + # check if we can skip prefilling + # We can only skip during prefill phase in disaggregated decode instance if any([ - prefill_meta is None, - envs.VLLM_DISAGG_PREFILL_ROLE != "decode", - kv_caches is None, - kv_caches[0] is None]): + not is_prefill_run, + not dist_kv.IS_KV_DECODE_INSTANCE, + is_profile_run]): # model forwarding # during forwarding the KV cache will be sent in prefill instance @@ -1378,151 +1383,27 @@ def execute_model( if all([ - prefill_meta is not None, - envs.VLLM_DISAGG_PREFILL_ROLE == "prefill", - kv_caches is not None, - kv_caches[0] is not None,]): - # send hidden state if disaggregated prefilling enabled - - _input_tokens_list = model_input.input_tokens.tolist() - seq_lens = model_input.seq_lens - query_lens = model_input.query_lens - seq_lens = get_tp_group().broadcast_object(seq_lens) - query_lens = get_tp_group().broadcast_object(query_lens) - slot_mapping = model_input.attn_metadata.slot_mapping.flatten() - - # failed = False - # reason = "" - - # if sum(query_lens) != sum(seq_lens): - # logger.error("Query len sum is %d but seq len sum is %d", sum(query_lens), sum(seq_lens)) - # failed=True - # if sum(query_lens) != len(_input_tokens_list): - # logger.error("Input tokens len is %d, doesn't match with query lens sum %d", - # sum(query_lens), - # len(_input_tokens_list)) - # failed=True - # if slot_mapping.shape[0] != len(_input_tokens_list): - # logger.error("Slot mapping shape is %s, mismatch with input shape %s", - # slot_mapping.shape, - # len(_input_tokens_list)) - # failed=True - # if failed: - # import subprocess - # subprocess.run("ps -e | grep pt_main_thread | awk '{print $1}' | xargs kill -9", shell=True) - + is_prefill_run, + dist_kv.IS_KV_PREFILL_INSTANCE, + not is_profile_run]): - # query_lens contains new KV caches that are added to vLLM. - # so we will send them to decode instance - # FIXME(Kuntai): This assume that all requests are prefill. - for idx, qlen in enumerate(query_lens): - - - start_pos = sum(query_lens[:idx]) - end_pos = start_pos + qlen - input_hash = hash(tuple(_input_tokens_list[start_pos:end_pos])) - - for i in range(model_executable.model.start_layer, - model_executable.model.end_layer): - kv_cache = kv_caches[i - model_executable.model.start_layer] - - _, _, num_heads, head_size = kv_cache[0].shape - - key_cache = kv_cache[0].reshape(-1, num_heads, head_size) - value_cache = kv_cache[1].reshape(-1, num_heads, head_size) - - current_slot_mapping = slot_mapping[start_pos:end_pos] - - get_disagg_group().kv_cache_send( - input_hash, - key_cache[current_slot_mapping]) - get_disagg_group().kv_cache_send( - input_hash, - value_cache[current_slot_mapping]) - - - get_disagg_group().kv_cache_send( - input_hash, - hidden_or_intermediate_states[start_pos:end_pos]) - get_disagg_group().kv_cache_send_finish() - - logger.error("\033[92mKV send DONE for rank %d\033[0m", torch.distributed.get_rank()) + # transfer KV cache and hidden state + dist_kv.buffer_kv_caches_send_and_listen_for_input_hash( + model_executable, + model_input, + kv_caches, + hidden_or_intermediate_states, + ) else: - # This is disagg decode instance, during prefill state - # Need to receive KV from the prefill instance - # FIXME(Kuntai): This impl assumes that all requests are prefill. - - _input_tokens_list = model_input.input_tokens.tolist() - seq_lens = model_input.seq_lens - query_lens = model_input.query_lens - seq_lens = get_tp_group().broadcast_object(seq_lens) - query_lens = get_tp_group().broadcast_object(query_lens) - slot_mapping = model_input.attn_metadata.slot_mapping.flatten() - - hidden_or_intermediate_states_for_one_req = [] - - # enumerate different requests - logger.debug("My query lens is %s, seq len is %s, rank is %s", - str(query_lens), - str(seq_lens), - torch.distributed.get_rank()) - for idx, qlen in enumerate(query_lens): - - start_pos = sum(query_lens[:idx]) - end_pos = start_pos + qlen - input_hash = hash(tuple(_input_tokens_list[start_pos:end_pos])) - num_tokens = qlen - - # notify the prefill instance to start sending KVs associated with input_hash - get_disagg_group().kv_cache_recv_start(input_hash) - - # receive KV cache from disaggregated prefill instance - for i in range(model_executable.model.start_layer, - model_executable.model.end_layer): - - # get kv cache - kv_cache = kv_caches[i - model_executable.model.start_layer] - # get corresponding layer - layer = model_executable.model.layers[i] - - # get kv cache shape (after sliced by tp) - _, _, num_heads, head_size = kv_cache[0].shape - key = get_disagg_group().kv_cache_recv( - torch.Size([num_tokens, num_heads, head_size]), - kv_cache[0].dtype - ) - value = get_disagg_group().kv_cache_recv( - torch.Size([num_tokens, num_heads, head_size]), - kv_cache[0].dtype - ) - - key_cache, value_cache = kv_cache[0], kv_cache[1] - ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - slot_mapping[start_pos:end_pos], - layer.self_attn.attn.kv_cache_dtype, - layer.self_attn.attn._k_scale, - layer.self_attn.attn._v_scale, - ) - - - hidden_or_intermediate_states_for_one_req.append( - get_disagg_group().kv_cache_recv( - torch.Size([num_tokens, model_executable.config.hidden_size]), - kv_cache[0].dtype - ) + # skip prefill, receive KV cache and hidden state + hidden_or_intermediate_states = \ + dist_kv.send_input_hash_and_do_kv_caches_recv( + model_executable, + model_input, + kv_caches, ) - - # concatenate hidden states from different requests - hidden_or_intermediate_states = torch.cat(hidden_or_intermediate_states_for_one_req, dim=0) - - logger.error("\033[92mKV receive DONE for rank %d\033[0m", torch.distributed.get_rank()) - # Compute the logits in the last pipeline stage. From 06a526a08084e4c85b8dab40f1bd38c43daee8e9 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Mon, 5 Aug 2024 05:03:54 +0000 Subject: [PATCH 137/303] Fix several bugs: tensor device placement, misc performance optimizations, handle the case where 2 request has identical input, refactor the code --- .../disagg_benchmarks/disagg_benchmark.sh | 38 ++- .../device_communicators/shm_broadcast.py | 4 +- vllm/distributed/distributed_kv.py | 221 ++++++++++-------- vllm/distributed/group_coordinator.py | 12 +- vllm/envs.py | 2 +- vllm/executor/gpu_executor.py | 3 +- vllm/executor/multiproc_gpu_executor.py | 3 +- vllm/executor/ray_gpu_executor.py | 6 +- vllm/utils.py | 12 +- 9 files changed, 174 insertions(+), 127 deletions(-) diff --git a/benchmarks/disagg_benchmarks/disagg_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_benchmark.sh index 96e3cf35d49a4..fe5afd9fa513b 100644 --- a/benchmarks/disagg_benchmarks/disagg_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_benchmark.sh @@ -53,10 +53,10 @@ benchmark() { model="meta-llama/Meta-Llama-3.1-70B-Instruct" dataset_name="sonnet" dataset_path="../sonnet_4x.txt" - num_prompts=100 + num_prompts=50 qps=$1 prefix_len=50 - input_len="100" + input_len=2048 output_len=$2 @@ -143,19 +143,19 @@ benchmark() { # large model -VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ +VLLM_RPC_PORT=5570 VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ -m vllm.entrypoints.openai.api_server \ --model $model \ --port 8100 \ -tp 4 \ - --max-model-len 10000 \ + --max-model-len 30000 \ --gpu-memory-utilization 0.8 & -VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ +VLLM_RPC_PORT=5580 VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ -m vllm.entrypoints.openai.api_server \ --model $model \ --port 8200 \ -tp 4 \ - --max-model-len 10000 \ + --max-model-len 30000 \ --gpu-memory-utilization 0.8 & # # Small Model @@ -181,8 +181,8 @@ VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ wait_for_server 8200 # launch a proxy server that listen from port 8000 - python3 disagg_prefill_proxy_server.py & - sleep 1 + # python3 disagg_prefill_proxy_server.py & + # sleep 1 python3 ../benchmark_serving.py \ --backend vllm \ @@ -193,7 +193,23 @@ VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ --sonnet-output-len $output_len \ --sonnet-prefix-len $prefix_len \ --num-prompts $num_prompts \ - --port 8000 \ + --port 8100 \ + --save-result \ + --result-dir $results_folder \ + --result-filename disagg_prefill_2xtp4.json \ + --request-rate $qps + + + python3 ../benchmark_serving.py \ + --backend vllm \ + --model $model \ + --dataset-name $dataset_name \ + --dataset-path $dataset_path \ + --sonnet-input-len $input_len \ + --sonnet-output-len $output_len \ + --sonnet-prefix-len $prefix_len \ + --num-prompts $num_prompts \ + --port 8200 \ --save-result \ --result-dir $results_folder \ --result-filename disagg_prefill_2xtp4.json \ @@ -230,8 +246,8 @@ main() { rm -rf results mkdir results - default_qps=10 - default_output_len=150 + default_qps=1 + default_output_len=1 # for target_qps in 2 4 8 16 # do diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 878996f6904d9..d4847542688c0 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -195,7 +195,7 @@ def __init__( # message. otherwise, we will only receive the first subscription # see http://api.zeromq.org/3-3:zmq-setsockopt for more details self.local_socket.setsockopt(XPUB_VERBOSE, True) - local_subscribe_port = get_open_port(is_for_dist_init = False) + local_subscribe_port = get_open_port() self.local_socket.bind(f"tcp://*:{local_subscribe_port}") self.current_idx = 0 @@ -211,7 +211,7 @@ def __init__( # create a publish-subscribe socket to communicate large data self.remote_socket = context.socket(XPUB) self.remote_socket.setsockopt(XPUB_VERBOSE, True) - remote_subscribe_port = get_open_port(is_for_dist_init = False) + remote_subscribe_port = get_open_port() self.remote_socket.bind(f"tcp://*:{remote_subscribe_port}") else: diff --git a/vllm/distributed/distributed_kv.py b/vllm/distributed/distributed_kv.py index 4d8c1b7d37565..6c73fc2dd5814 100644 --- a/vllm/distributed/distributed_kv.py +++ b/vllm/distributed/distributed_kv.py @@ -2,7 +2,7 @@ These APIs are used in `vllm/worker/model_runner.py`. """ from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING -from collections import defaultdict +from collections import defaultdict, deque from concurrent.futures import ThreadPoolExecutor import torch @@ -43,7 +43,7 @@ def __init__( use_custom_allreduce: bool = False, use_tpu_communicator: bool = True, use_message_queue_broadcaster: bool = False, - use_cpu_comm_for_sanity_check: bool = True, + use_cpu_comm_for_sanity_check: bool = False, ): super().__init__( @@ -61,12 +61,14 @@ def __init__( self.use_cpu_comm_for_sanity_check = use_cpu_comm_for_sanity_check # use a threadpool to buffer send request in disaggregated prefill - self.input_hash_to_kv_sending_requests = defaultdict(list) + self.input_hash_to_kv_sending_requests = defaultdict(deque) self.kv_sending_thread = None - - + self.target_rank_for_send = self.ranks[(self.rank_in_group + 1) % + self.world_size] + self.target_rank_for_recv = self.ranks[(self.rank_in_group - 1) % + self.world_size] + torch.set_default_device(self.device) - def debug_send(self, tensor: torch.Tensor, @@ -133,63 +135,87 @@ def kv_cache_recv(self, return tensor + def send_input_hash(self, input_hash: int) -> None: + + # KV cache send go through CPU, and the original `send` only use GPU. + # So create a new group for sending input hash. + input_hash_tensor = torch.tensor([input_hash], device="cpu").long() + torch.distributed.isend(input_hash_tensor, self.target_rank_for_send, + self.cpu_group) + + def recv_input_hash(self) -> int: + input_hash_tensor = torch.tensor([0], device="cpu").long() + torch.distributed.irecv(input_hash_tensor, self.target_rank_for_recv, + self.cpu_group).wait() + return input_hash_tensor.item() + def recv_input_hash_and_send_kv(self): try: # receive the input hash that the decode instance requires - logger.debug('Rank %d: Waiting for input hash from rank %d, my hashes are %s', - torch.distributed.get_rank(), - self.ranks[(self.rank_in_group - 1) % self.world_size], - list(self.input_hash_to_kv_sending_requests.keys())) - # FIXME(Kuntai): debug_recv guarantees correctness but hurts perf - input_hash_tensor = self.debug_recv(torch.Size([1]), torch.long) - input_hash = input_hash_tensor.item() - logger.debug('Successfully received input hash %d', input_hash) + logger.debug( + '[rank%d]: Waiting for input hash from rank %d', + torch.distributed.get_rank(), + self.target_rank_for_recv, + ) + input_hash = self.recv_input_hash() + logger.debug( + 'Successfully received input hash %d', + input_hash) assert input_hash in self.input_hash_to_kv_sending_requests, \ - f"The KV cache of {input_hash} does not exist." + f"The KV cache of {input_hash} does not exist. "\ + f"Existing input hash: {list(self.input_hash_to_kv_sending_requests.keys())}" logger.debug('Input hash %d exists, start sending', input_hash) # execute corresponding kv cache sending jobs in request queue - for idx, request in enumerate( - self.input_hash_to_kv_sending_requests[input_hash]): + while True: + request = self.input_hash_to_kv_sending_requests[ + input_hash].popleft() + # An empty request: the KV cahe of one request are all sent + if request == []: + break request[0](*request[1:]) - logger.debug('Finish input hash %d, free memory...' % input_hash) - # free GPU memory occupied by sending - del self.input_hash_to_kv_sending_requests[input_hash] + if len(self.input_hash_to_kv_sending_requests[input_hash]) == 0: + logger.debug('Finish input hash %d, free GPU memory...', + input_hash) + del self.input_hash_to_kv_sending_requests[input_hash] + else: + logger.debug( + 'The buffer for input hash %d is not empty, meaning that '\ + 'there are two jobs with identical input. Free GPU '\ + 'memory for one of the request.', + input_hash) except Exception as e: - import sys + # This function is executed in ThreadPoolExecutor + # and it will block all exceptions by default + # so log the potential error message here. import traceback - exc_info = traceback.format_exc() import time + exc_info = traceback.format_exc() + # avoid the output of different rank overlaps time.sleep(torch.distributed.get_rank()) logger.error("An error occured: %s, stack trace: %s", e, exc_info) - - def kv_cache_send_finish(self): + def kv_cache_send_finish(self, input_hash: int): if self.kv_sending_thread is None: self.kv_sending_thread = ThreadPoolExecutor(max_workers=1) + # append an empty job to signal that this is the end of a request + self.input_hash_to_kv_sending_requests[input_hash].append([]) job = self.kv_sending_thread.submit(self.recv_input_hash_and_send_kv) logger.debug(f'Submit job {job} into kv cache sending thread') def kv_cache_recv_start(self, input_hash: int): - logger.debug('Rank %d: Sending input hash %d to rank %d', - torch.distributed.get_rank(), - input_hash, self.ranks[(self.rank_in_group + 1) % self.world_size]) + logger.debug('[rank%d]: Sending input hash %d to rank %d', + torch.distributed.get_rank(), input_hash, + self.ranks[(self.rank_in_group + 1) % self.world_size]) - input_hash_tensor = torch.tensor([input_hash]).long().to(self.device) - logger.error("Rank %d: input hash tensor on device %s", - torch.distributed.get_rank(), - input_hash_tensor.device) # notify the kv cache sender with the input hash id - # FIXME(Kuntai): debug_send guarantees correctness but hurts perf. - self.debug_send(input_hash_tensor) - - + self.send_input_hash(input_hash) def buffer_kv_caches_send_and_listen_for_input_hash( @@ -198,17 +224,20 @@ def buffer_kv_caches_send_and_listen_for_input_hash( kv_caches: List[torch.Tensor], hidden_or_intermediate_states: torch.Tensor, ) -> None: - - _input_tokens_list = model_input.input_tokens.tolist() - seq_lens = model_input.seq_lens - query_lens = model_input.query_lens - seq_lens = ps.get_tp_group().broadcast_object(seq_lens) - query_lens = ps.get_tp_group().broadcast_object(query_lens) + + input_tokens_tuple = tuple(model_input.input_tokens.tolist()) + seq_query_obj = { + "seq_lens": model_input.seq_lens, + "query_lens": model_input.query_lens, + } + seq_query_obj = ps.get_tp_group().broadcast_object(seq_query_obj) + seq_lens = seq_query_obj["seq_lens"] + query_lens = seq_query_obj["query_lens"] slot_mapping = model_input.attn_metadata.slot_mapping.flatten() - - logger.info("KV cache shape is %s", kv_caches[0].shape) - + logger.debug("My query lens is %s, seq len is %s, rank is %s", + str(query_lens), str(seq_lens), torch.distributed.get_rank()) + # failed = False # reason = "" @@ -226,102 +255,93 @@ def buffer_kv_caches_send_and_listen_for_input_hash( # len(_input_tokens_list)) # failed=True # if failed: - # import subprocess + # import subprocess # subprocess.run("ps -e | grep pt_main_thread | awk '{print $1}' | xargs kill -9", shell=True) - - + # query_lens contains new KV caches that are added to vLLM. # so we will send them to decode instance - # FIXME(Kuntai): This assume that all requests are prefill. + # FIXME(Kuntai): This assume that all requests are prefill. for idx, qlen in enumerate(query_lens): - start_pos = sum(query_lens[:idx]) end_pos = start_pos + qlen - input_hash = hash(tuple(_input_tokens_list[start_pos:end_pos])) - + input_hash = hash(input_tokens_tuple[start_pos:end_pos]) + for i in range(model_executable.model.start_layer, - model_executable.model.end_layer): + model_executable.model.end_layer): kv_cache = kv_caches[i - model_executable.model.start_layer] - + _, _, num_heads, head_size = kv_cache[0].shape - + key_cache = kv_cache[0].reshape(-1, num_heads, head_size) value_cache = kv_cache[1].reshape(-1, num_heads, head_size) current_slot_mapping = slot_mapping[start_pos:end_pos] ps.get_disagg_group().kv_cache_send( - input_hash, - key_cache[current_slot_mapping]) + input_hash, key_cache[current_slot_mapping]) ps.get_disagg_group().kv_cache_send( - input_hash, - value_cache[current_slot_mapping]) - + input_hash, value_cache[current_slot_mapping]) ps.get_disagg_group().kv_cache_send( - input_hash, - hidden_or_intermediate_states[start_pos:end_pos]) - ps.get_disagg_group().kv_cache_send_finish() + input_hash, hidden_or_intermediate_states[start_pos:end_pos]) + ps.get_disagg_group().kv_cache_send_finish(input_hash) + + logger.error("\033[92mKV send DONE for rank %d\033[0m", + torch.distributed.get_rank()) - logger.error("\033[92mKV send DONE for rank %d\033[0m", torch.distributed.get_rank()) - - def send_input_hash_and_do_kv_caches_recv( - model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: List[torch.Tensor] -) -> torch.Tensor: - + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor]) -> torch.Tensor: + # This is disagg decode instance, during prefill state # Need to receive KV from the prefill instance - # FIXME(Kuntai): This impl assumes that all requests are prefill. - - _input_tokens_list = model_input.input_tokens.tolist() - seq_lens = model_input.seq_lens - query_lens = model_input.query_lens - seq_lens = ps.get_tp_group().broadcast_object(seq_lens) - query_lens = ps.get_tp_group().broadcast_object(query_lens) + # FIXME(Kuntai): This impl assumes that all requests are prefill. + input_tokens_tuple = tuple(model_input.input_tokens.tolist()) + seq_query_obj = { + "seq_lens": model_input.seq_lens, + "query_lens": model_input.query_lens, + } + seq_query_obj = ps.get_tp_group().broadcast_object(seq_query_obj) + seq_lens = seq_query_obj["seq_lens"] + query_lens = seq_query_obj["query_lens"] slot_mapping = model_input.attn_metadata.slot_mapping.flatten() - + hidden_or_intermediate_states_for_one_req = [] - + # enumerate different requests - logger.debug("My query lens is %s, seq len is %s, rank is %s", - str(query_lens), - str(seq_lens), - torch.distributed.get_rank()) + logger.debug("My query lens is %s, seq len is %s, rank is %s", + str(query_lens), str(seq_lens), torch.distributed.get_rank()) for idx, qlen in enumerate(query_lens): start_pos = sum(query_lens[:idx]) end_pos = start_pos + qlen - input_hash = hash(tuple(_input_tokens_list[start_pos:end_pos])) + input_hash = hash(input_tokens_tuple[start_pos:end_pos]) num_tokens = qlen - - # notify the prefill instance to start sending KVs associated with input_hash + + # notify the prefill instance to start sending KVs associated with input_hash ps.get_disagg_group().kv_cache_recv_start(input_hash) # receive KV cache from disaggregated prefill instance for i in range(model_executable.model.start_layer, - model_executable.model.end_layer): - + model_executable.model.end_layer): + # get kv cache kv_cache = kv_caches[i - model_executable.model.start_layer] # get corresponding layer layer = model_executable.model.layers[i] - + # get kv cache shape (after sliced by tp) _, _, num_heads, head_size = kv_cache[0].shape key = ps.get_disagg_group().kv_cache_recv( torch.Size([num_tokens, num_heads, head_size]), - kv_cache[0].dtype - ) + kv_cache[0].dtype) value = ps.get_disagg_group().kv_cache_recv( torch.Size([num_tokens, num_heads, head_size]), - kv_cache[0].dtype - ) - + kv_cache[0].dtype) + key_cache, value_cache = kv_cache[0], kv_cache[1] ops.reshape_and_cache_flash( key, @@ -334,16 +354,15 @@ def send_input_hash_and_do_kv_caches_recv( layer.self_attn.attn._v_scale, ) - hidden_or_intermediate_states_for_one_req.append( ps.get_disagg_group().kv_cache_recv( torch.Size([num_tokens, model_executable.config.hidden_size]), - kv_cache[0].dtype - ) - ) + kv_cache[0].dtype)) # concatenate hidden states from different requests - hidden_or_intermediate_states = torch.cat(hidden_or_intermediate_states_for_one_req, dim=0) + hidden_or_intermediate_states = torch.cat( + hidden_or_intermediate_states_for_one_req, dim=0) - logger.error("\033[92mKV receive DONE for rank %d\033[0m", torch.distributed.get_rank()) - return hidden_or_intermediate_states \ No newline at end of file + logger.error("\033[92mKV receive DONE for rank %d\033[0m", + torch.distributed.get_rank()) + return hidden_or_intermediate_states diff --git a/vllm/distributed/group_coordinator.py b/vllm/distributed/group_coordinator.py index 03b6a9e3b0565..bfa3c7f3c17cf 100644 --- a/vllm/distributed/group_coordinator.py +++ b/vllm/distributed/group_coordinator.py @@ -631,9 +631,12 @@ def recv_tensor_dict( tensor_dict: Dict[str, Any] = {} for key, value in recv_metadata_list: if isinstance(value, TensorMetadata): + target_device = value.device + if 'cuda' in target_device: + target_device = self.device tensor = torch.empty(value.size, dtype=value.dtype, - device=value.device) + device=target_device) if tensor.numel() == 0: # Skip broadcasting empty tensors. tensor_dict[key] = tensor @@ -653,11 +656,12 @@ def recv_tensor_dict( torch.distributed.recv(tensor, src=self.ranks[src], group=metadata_group) + else: # use group for GPU tensors torch.distributed.recv(tensor, - src=self.ranks[src], - group=group) + src=self.ranks[src], + group=group) if use_all_gather: # do the allgather tensor = all_gather_group.all_gather( # type: ignore @@ -719,4 +723,4 @@ def destroy(self): if self.ca_comm is not None: self.ca_comm = None if self.mq_broadcaster is not None: - self.mq_broadcaster = None \ No newline at end of file + self.mq_broadcaster = None diff --git a/vllm/envs.py b/vllm/envs.py index 4c9e43cfe1405..07a7b647f6bc5 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -145,7 +145,7 @@ def get_default_config_root(): # used when the frontend api server is running in multi-processing mode, # to communicate with the backend engine process over ZMQ. 'VLLM_RPC_PORT': - lambda: int(os.getenv('VLLM_PORT', '5570')), + lambda: int(os.getenv('VLLM_RPC_PORT', '5570')), # If true, will load models from ModelScope instead of Hugging Face Hub. # note that the value is true or false, not numbers diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 3e77af0e20323..300e9a33eba56 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -1,5 +1,6 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union +import vllm.distributed.distributed_kv as dist_kv from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -43,7 +44,7 @@ def _get_worker_kwargs( """Return worker init args for a given rank.""" if distributed_init_method is None: distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) + get_ip(), get_open_port(force=dist_kv.IS_DISTRIBUTED_KV_INSTANCE)) return dict( model_config=self.model_config, parallel_config=self.parallel_config, diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 08a35a074b37b..ba222f8b5e405 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -8,6 +8,7 @@ import torch +import vllm.distributed.distributed_kv as dist_kv from vllm.executor.distributed_gpu_executor import ( # yapf: disable DistributedGPUExecutor, DistributedGPUExecutorAsync) from vllm.executor.gpu_executor import create_worker @@ -82,7 +83,7 @@ def _init_executor(self) -> None: # Since it only works for single node, we can use the loopback address # 127.0.0.1 for communication. distributed_init_method = get_distributed_init_method( - "127.0.0.1", get_open_port()) + "127.0.0.1", get_open_port(force=dist_kv.IS_DISTRIBUTED_KV_INSTANCE)) self.workers: List[ProcessWorkerWrapper] = [] # This is the list of workers that are rank 0 of each TP group EXCEPT diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 4a6825c01fcf8..17f4d36338860 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import vllm.envs as envs +import vllm.distributed.distributed_kv as dist_kv from vllm.executor.distributed_gpu_executor import ( # yapf: disable DistributedGPUExecutor, DistributedGPUExecutorAsync) from vllm.executor.ray_utils import RayWorkerWrapper, ray @@ -226,8 +227,11 @@ def sort_by_driver_then_worker_ip(worker): # solves this issue, as it always works for communication inside # the node. driver_ip = "127.0.0.1" + # force vLLM to use the port specified by envs.VLLM_PORT + # this port will be binded by prefill instance + # but the decode instance must use that port to init torch.distributed distributed_init_method = get_distributed_init_method( - driver_ip, get_open_port()) + driver_ip, get_open_port(force=dist_kv.IS_DISTRIBUTED_KV_INSTANCE)) # Initialize the actual workers inside worker wrapper. init_worker_all_kwargs = [ diff --git a/vllm/utils.py b/vllm/utils.py index 0d4c22ab761b6..fa54523352645 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -388,18 +388,20 @@ def get_distributed_init_method(ip: str, port: int) -> str: return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}" -def get_open_port(port: Optional[int] = None, is_for_dist_init: bool = True) -> int: +def get_open_port(port: Optional[int] = None, force: bool = False) -> int: if port is None: # Default behavior here is to return a port for multi-gpu communication port = envs.VLLM_PORT if port is not None: - if envs.VLLM_DISAGG_PREFILL_ROLE is not None and is_for_dist_init: - # When initializing distributed environment for disagg prefill - # The prefill and decode instance may share the same port - # Skip the binding check as the port may be binded by prefill + if force and port is not None: + # force vLLM to use envs.VLLM_PORT for torch.distributed init + # This is because this port will binded by prefill instance + # But both prefill and decode instance need to use this port to + # initialize torch.distributed return port while True: try: + logger.error('Trying port %d', port) with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("", port)) return port From 34e6bb324587a02000c4ea38a21b48ffce08ebc5 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Mon, 5 Aug 2024 05:06:00 +0000 Subject: [PATCH 138/303] remove useless comments --- vllm/distributed/distributed_kv.py | 22 +--------------------- 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/vllm/distributed/distributed_kv.py b/vllm/distributed/distributed_kv.py index 6c73fc2dd5814..77c153d775c68 100644 --- a/vllm/distributed/distributed_kv.py +++ b/vllm/distributed/distributed_kv.py @@ -237,27 +237,7 @@ def buffer_kv_caches_send_and_listen_for_input_hash( logger.debug("My query lens is %s, seq len is %s, rank is %s", str(query_lens), str(seq_lens), torch.distributed.get_rank()) - - # failed = False - # reason = "" - - # if sum(query_lens) != sum(seq_lens): - # logger.error("Query len sum is %d but seq len sum is %d", sum(query_lens), sum(seq_lens)) - # failed=True - # if sum(query_lens) != len(_input_tokens_list): - # logger.error("Input tokens len is %d, doesn't match with query lens sum %d", - # sum(query_lens), - # len(_input_tokens_list)) - # failed=True - # if slot_mapping.shape[0] != len(_input_tokens_list): - # logger.error("Slot mapping shape is %s, mismatch with input shape %s", - # slot_mapping.shape, - # len(_input_tokens_list)) - # failed=True - # if failed: - # import subprocess - # subprocess.run("ps -e | grep pt_main_thread | awk '{print $1}' | xargs kill -9", shell=True) - + # query_lens contains new KV caches that are added to vLLM. # so we will send them to decode instance # FIXME(Kuntai): This assume that all requests are prefill. From 55bf3bfb29a89c3870df492ad2ed3542506bac58 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Mon, 5 Aug 2024 05:11:02 +0000 Subject: [PATCH 139/303] update disaggregated prefill example --- .../disagg_prefill/disagg_prefill_example.sh | 61 +++++++------------ 1 file changed, 23 insertions(+), 38 deletions(-) diff --git a/examples/disagg_prefill/disagg_prefill_example.sh b/examples/disagg_prefill/disagg_prefill_example.sh index 8cfe528ffb58c..a2a6dc0932fdf 100644 --- a/examples/disagg_prefill/disagg_prefill_example.sh +++ b/examples/disagg_prefill/disagg_prefill_example.sh @@ -3,10 +3,8 @@ # We will launch 2 vllm instances (1 for prefill and 1 for decode), # and then transfer the KV cache between them. -export VLLM_LOGGING_LEVEL=DEBUG export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') -# export NCCL_DEBUG=INFO -export NCCL_BUFFSIZE=67108864 +export VLLM_PORT=12345 # a function that waits vLLM server to start wait_for_server() { @@ -18,7 +16,7 @@ wait_for_server() { } # prefilling instance -VLLM_LOGGING_LEVEL=DEBUG VLLM_HOST_IP=$(hostname -I | awk '{print $1}') VLLM_PORT=2345 VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0 python3 \ +VLLM_RPC_PORT=5570 VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0 python3 \ -m vllm.entrypoints.openai.api_server \ --model meta-llama/Meta-Llama-3.1-8B-Instruct \ --port 8100 \ @@ -28,7 +26,7 @@ VLLM_LOGGING_LEVEL=DEBUG VLLM_HOST_IP=$(hostname -I | awk '{print $1}') VLLM_POR --max-model-len 10000 & # decoding instance -VLLM_LOGGING_LEVEL=DEBUG VLLM_HOST_IP=$(hostname -I | awk '{print $1}') VLLM_PORT=2345 VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=1 python3 \ +VLLM_RPC_PORT=5580 VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=1 python3 \ -m vllm.entrypoints.openai.api_server \ --model meta-llama/Meta-Llama-3.1-8B-Instruct \ --port 8200 \ @@ -46,39 +44,26 @@ wait_for_server 8200 # 1. send the request to prefill instance, with max_tokens set to 1 # 2. send the request again to decode instance, no modification +# send to prefill instance, let it only do prefill by setting max_token=1 +curl -m 5 http://localhost:8100/v1/completions \ +-H "Content-Type: application/json" \ +-d '{ +"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", +"prompt": "'$i' San Francisco is a", +"max_tokens": 1, +"temperature": 0 +}' -for i in {0..0} -do - # send to prefill instance - curl -m 5 http://localhost:8100/v1/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", - "prompt": "'$i' San Francisco is a", - "max_tokens": 1, - "temperature": 0 - }' +# send to decode instance +curl -m 5 http://localhost:8100/v1/completions \ +-H "Content-Type: application/json" \ +-d '{ +"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", +"prompt": "'$i' San Francisco is a", +"max_tokens": 50, +"temperature": 0 +}' - curl -m 5 http://localhost:8100/v1/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", - "prompt": "'$i' San Francisco is a", - "max_tokens": 1, - "temperature": 0 - }' - # # send to decode instance - # curl -m 60 http://localhost:8200/v1/completions \ - # -H "Content-Type: application/json" \ - # -d '{ - # "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", - # "prompt": "'$i' San Francisco is a", - # "max_tokens": 5, - # "temperature": 0 - # }' - -done - -# kill command: -# ps -e | grep pt_main_thread | awk '{print $1}' | xargs kill -9 \ No newline at end of file +# clean up +ps -e | grep pt_main_thread | awk '{print $1}' | xargs kill -9 \ No newline at end of file From b525510419103f967c9433bd4553fa80f81f4708 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Tue, 6 Aug 2024 19:01:22 +0000 Subject: [PATCH 140/303] add disaggregated prefill overhead benchmark --- .../disagg_overhead_benchmark.sh | 148 ++++++++++++++++++ 1 file changed, 148 insertions(+) create mode 100644 benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh diff --git a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh new file mode 100644 index 0000000000000..12f5150cadda3 --- /dev/null +++ b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh @@ -0,0 +1,148 @@ +#!/bin/bash + +# Requirement: 8x H100 GPUs. + + +# Model: neuralmagic/Meta-Llama-3-70B-Instruct-FP8-KV +# Query: 2048 input tokens, 11 output tokens, QPS 4, 500 requests +# Resource: 8x H100 +# Approaches: +# 1. Chunked prefill: 1 vllm instance with tp=8 +# 2. Chunked prefill: 2 vllm instance with tp=4, equivalent to 1 tp=4 instance with QPS 4 +# 3. Disaggregated prefill: 1 prefilling instance and 1 decoding instance +# Prefilling instance: max_output_token=1 +# Decoding instance: force the input tokens be the same across requests to bypass prefilling + +set -ex + +kill_gpu_processes() { + # kill all processes on GPU. + pkill pt_main_thread + sleep 10 + + # remove vllm config file + rm -rf ~/.config/vllm + + # Print the GPU memory usage + # so that we know if all GPU processes are killed. + gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0) + # The memory usage should be 0 MB. + echo "GPU 0 Memory Usage: $gpu_memory_usage MB" +} + +wait_for_server() { + # wait for vllm server to start + # return 1 if vllm server crashes + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + + +benchmark() { + + export VLLM_LOGGING_LEVEL=DEBUG + export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') + export VLLM_PORT=12345 + + # compare chunked prefill with disaggregated prefill + + results_folder="./results" + model="meta-llama/Meta-Llama-3.1-70B-Instruct" + dataset_name="sonnet" + dataset_path="../sonnet_4x.txt" + num_prompts=50 + qps=$1 + prefix_len=50 + input_len=2048 + output_len=$2 + + # large model + VLLM_RPC_PORT=5570 VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8100 \ + -tp 4 \ + --max-model-len 30000 \ + --gpu-memory-utilization 0.8 & + VLLM_RPC_PORT=5580 VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8200 \ + -tp 4 \ + --max-model-len 30000 \ + --gpu-memory-utilization 0.8 & + + wait_for_server 8100 + wait_for_server 8200 + + # let the prefill instance finish prefill + python3 ../benchmark_serving.py \ + --backend vllm \ + --model $model \ + --dataset-name $dataset_name \ + --dataset-path $dataset_path \ + --sonnet-input-len $input_len \ + --sonnet-output-len $output_len \ + --sonnet-prefix-len $prefix_len \ + --num-prompts $num_prompts \ + --port 8100 \ + --save-result \ + --result-dir $results_folder \ + --result-filename disagg_prefill_2xtp4.json \ + --request-rate $qps + + + # send the request to decode. + # The TTFT of this command will be the overhead of disagg prefill impl. + python3 ../benchmark_serving.py \ + --backend vllm \ + --model $model \ + --dataset-name $dataset_name \ + --dataset-path $dataset_path \ + --sonnet-input-len $input_len \ + --sonnet-output-len $output_len \ + --sonnet-prefix-len $prefix_len \ + --num-prompts $num_prompts \ + --port 8200 \ + --save-result \ + --result-dir $results_folder \ + --result-filename disagg_prefill_2xtp4.json \ + --request-rate $qps + kill_gpu_processes + +} + + +main() { + + (which wget && which curl) || (apt-get update && apt-get install -y wget curl) + (which jq) || (apt-get -y install jq) + (which socat) || (apt-get -y install socat) + + pip install quart httpx + + cd "$(dirname "$0")" + + cd .. + # create sonnet-4x.txt + echo "" > sonnet_4x.txt + for _ in {1..4} + do + cat sonnet.txt >> sonnet_4x.txt + done + cd disagg_benchmarks + + rm -rf results + mkdir results + + default_qps=1 + default_output_len=1 + benchmark $default_qps $default_output_len + +} + + +main "$@" From ee6a6ec5c12d5a1014850f8068b2857ed7540b8b Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Wed, 7 Aug 2024 02:56:58 +0000 Subject: [PATCH 141/303] change disagg prefill proxy server to support non-streaming case --- .../disagg_prefill_proxy_server.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py index 9028d9be86ec5..eb2f2a7149a37 100644 --- a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py +++ b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py @@ -1,7 +1,6 @@ from quart import Quart, request, Response, jsonify, make_response import aiohttp import sys -import httpx import traceback import os @@ -14,15 +13,17 @@ async def forward_request(url, data): headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" } - async with session.post(url=url, json=data, - headers=headers) as response: + async with session.post(url=url, json=data, headers=headers) as response: if response.status == 200: - async for chunk_bytes in response.content: - yield chunk_bytes - + if response.headers.get('Transfer-Encoding') == 'chunked': + async for chunk_bytes in response.content.iter_chunked(1024): + yield chunk_bytes + else: + content = await response.read() + yield content + @app.route('/v1/completions', methods=['POST']) async def handle_request(): - try: original_request_data = await request.get_json() @@ -30,7 +31,7 @@ async def handle_request(): prefill_request['max_tokens'] = 1 # finish prefill - async for data in forward_request('http://localhost:8100/v1/completions', prefill_request): + async for _ in forward_request('http://localhost:8100/v1/completions', prefill_request): continue print(f"Request {prefill_request} prefill done. proceeding to decode.") From f3cc91ddb496978cbd625f68ffa1a099780f0400 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Wed, 7 Aug 2024 03:01:59 +0000 Subject: [PATCH 142/303] avoid detokenizing the first token in prefill instance -- for shorter latency --- benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py index eb2f2a7149a37..bf5c136b9a2d9 100644 --- a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py +++ b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py @@ -28,7 +28,10 @@ async def handle_request(): original_request_data = await request.get_json() prefill_request = original_request_data.copy() + # change max_tokens = 1 to let it only do prefill prefill_request['max_tokens'] = 1 + # avoid sampling overhead by setting detokenize = False + prefill_request['detokenize'] = False # finish prefill async for _ in forward_request('http://localhost:8100/v1/completions', prefill_request): From 058226540bdf16e53c8cb4d79528b3e64ea2f785 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Wed, 7 Aug 2024 23:06:18 +0000 Subject: [PATCH 143/303] add failure test cases --- try switching to another machine --- .../disagg_benchmarks/disagg_benchmark.sh | 266 ------------------ .../disagg_performance_benchmark.sh | 172 +++++++++++ .../disagg_prefill_proxy_server.py | 7 +- 3 files changed, 175 insertions(+), 270 deletions(-) delete mode 100644 benchmarks/disagg_benchmarks/disagg_benchmark.sh create mode 100644 benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh diff --git a/benchmarks/disagg_benchmarks/disagg_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_benchmark.sh deleted file mode 100644 index fe5afd9fa513b..0000000000000 --- a/benchmarks/disagg_benchmarks/disagg_benchmark.sh +++ /dev/null @@ -1,266 +0,0 @@ -#!/bin/bash - -# Requirement: 8x H100 GPUs. - - -# Model: neuralmagic/Meta-Llama-3-70B-Instruct-FP8-KV -# Query: 2048 input tokens, 11 output tokens, QPS 4, 500 requests -# Resource: 8x H100 -# Approaches: -# 1. Chunked prefill: 1 vllm instance with tp=8 -# 2. Chunked prefill: 2 vllm instance with tp=4, equivalent to 1 tp=4 instance with QPS 4 -# 3. Disaggregated prefill: 1 prefilling instance and 1 decoding instance -# Prefilling instance: max_output_token=1 -# Decoding instance: force the input tokens be the same across requests to bypass prefilling - -set -ex - -kill_gpu_processes() { - # kill all processes on GPU. - pkill pt_main_thread - sleep 10 - - # remove vllm config file - rm -rf ~/.config/vllm - - # Print the GPU memory usage - # so that we know if all GPU processes are killed. - gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0) - # The memory usage should be 0 MB. - echo "GPU 0 Memory Usage: $gpu_memory_usage MB" -} - -wait_for_server() { - # wait for vllm server to start - # return 1 if vllm server crashes - local port=$1 - timeout 1200 bash -c " - until curl -s localhost:${port}/v1/completions > /dev/null; do - sleep 1 - done" && return 0 || return 1 -} - - -benchmark() { - - export VLLM_LOGGING_LEVEL=DEBUG - export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') - export VLLM_PORT=12345 - - # compare chunked prefill with disaggregated prefill - - results_folder="./results" - model="meta-llama/Meta-Llama-3.1-70B-Instruct" - dataset_name="sonnet" - dataset_path="../sonnet_4x.txt" - num_prompts=50 - qps=$1 - prefix_len=50 - input_len=2048 - output_len=$2 - - - # # chunked prefill with tp=4 - # CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ - # -m vllm.entrypoints.openai.api_server \ - # --model $model \ - # --port 8000 \ - # -tp 4 \ - # --disable-log-stats \ - # --disable-log-requests \ - # --enable-chunked-prefill & - # wait_for_server 8000 - - # python3 ../benchmark_serving.py \ - # --backend vllm \ - # --model $model \ - # --dataset-name $dataset_name \ - # --dataset-path $dataset_path \ - # --sonnet-input-len $input_len \ - # --sonnet-output-len $output_len \ - # --sonnet-prefix-len $prefix_len \ - # --num-prompts $((num_prompts / 2)) \ - # --port 8000 \ - # --save-result \ - # --result-dir $results_folder \ - # --result-filename chunked_prefill_tp4.json \ - # --request-rate $((qps / 2)) - # kill_gpu_processes - - - # # disaggregated prefill - # # prefill with tp=4 - # python3 -m vllm.entrypoints.openai.api_server \ - # --model $model \ - # --port 8000 \ - # -tp 4 \ - # --disable-log-stats \ - # --disable-log-requests & - # wait_for_server 8000 - # # set output-len to 1 so that it only do prefilling - # python3 ../benchmark_serving.py \ - # --backend vllm \ - # --model $model \ - # --dataset-name $dataset_name \ - # --dataset-path $dataset_path \ - # --sonnet-input-len $input_len \ - # --sonnet-output-len 1 \ - # --sonnet-prefix-len $prefix_len \ - # --num-prompts $num_prompts \ - # --port 8000 \ - # --save-result \ - # --result-dir $results_folder \ - # --result-filename disagg_prefill_tp4.json \ - # --request-rate $qps - # kill_gpu_processes - - # # decode with tp=4, enable APC - # python3 -m vllm.entrypoints.openai.api_server \ - # --model $model \ - # --port 8000 \ - # -tp 4 \ - # --enable-prefix-caching \ - # --disable-log-stats \ - # --disable-log-requests & - # wait_for_server 8000 - # # skip prefilling - # # by enabling APC and force the input tokens be the same - # python3 ../benchmark_serving.py \ - # --backend vllm \ - # --model $model \ - # --dataset-name $dataset_name \ - # --dataset-path $dataset_path \ - # --sonnet-input-len $input_len \ - # --sonnet-output-len $output_len \ - # --sonnet-prefix-len $input_len \ - # --num-prompts $num_prompts \ - # --port 8000 \ - # --save-result \ - # --result-dir $results_folder \ - # --result-filename disagg_decode_tp4.json \ - # --request-rate $qps - # kill_gpu_processes - - -# large model -VLLM_RPC_PORT=5570 VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ - -m vllm.entrypoints.openai.api_server \ - --model $model \ - --port 8100 \ - -tp 4 \ - --max-model-len 30000 \ - --gpu-memory-utilization 0.8 & -VLLM_RPC_PORT=5580 VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ - -m vllm.entrypoints.openai.api_server \ - --model $model \ - --port 8200 \ - -tp 4 \ - --max-model-len 30000 \ - --gpu-memory-utilization 0.8 & - -# # Small Model -# # prefilling instance -# VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0 python3 \ -# -m vllm.entrypoints.openai.api_server \ -# --model $model \ -# --port 8100 \ -# -tp 1 \ -# --gpu-memory-utilization 0.8 \ -# --max-model-len 10000 & - -# # decoding instance -# VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=1 python3 \ -# -m vllm.entrypoints.openai.api_server \ -# --model $model \ -# --port 8200 \ -# -tp 1 \ -# --gpu-memory-utilization 0.8 \ -# --max-model-len 10000 & - - wait_for_server 8100 - wait_for_server 8200 - - # launch a proxy server that listen from port 8000 - # python3 disagg_prefill_proxy_server.py & - # sleep 1 - - python3 ../benchmark_serving.py \ - --backend vllm \ - --model $model \ - --dataset-name $dataset_name \ - --dataset-path $dataset_path \ - --sonnet-input-len $input_len \ - --sonnet-output-len $output_len \ - --sonnet-prefix-len $prefix_len \ - --num-prompts $num_prompts \ - --port 8100 \ - --save-result \ - --result-dir $results_folder \ - --result-filename disagg_prefill_2xtp4.json \ - --request-rate $qps - - - python3 ../benchmark_serving.py \ - --backend vllm \ - --model $model \ - --dataset-name $dataset_name \ - --dataset-path $dataset_path \ - --sonnet-input-len $input_len \ - --sonnet-output-len $output_len \ - --sonnet-prefix-len $prefix_len \ - --num-prompts $num_prompts \ - --port 8200 \ - --save-result \ - --result-dir $results_folder \ - --result-filename disagg_prefill_2xtp4.json \ - --request-rate $qps - kill_gpu_processes - - # python3 analyze_benchmark_results.py \ - # --results-folder $results_folder \ - # --output-len $output_len \ - # --qps $qps - -} - - -main() { - - (which wget && which curl) || (apt-get update && apt-get install -y wget curl) - (which jq) || (apt-get -y install jq) - (which socat) || (apt-get -y install socat) - - pip install quart httpx - - cd "$(dirname "$0")" - - cd .. - # create sonnet-4x.txt - echo "" > sonnet_4x.txt - for _ in {1..4} - do - cat sonnet.txt >> sonnet_4x.txt - done - cd disagg_benchmarks - - rm -rf results - mkdir results - - default_qps=1 - default_output_len=1 - - # for target_qps in 2 4 8 16 - # do - # benchmark $target_qps $default_output_len - # done - benchmark $default_qps $default_output_len - - # for target_output_len in 5 10 20 40 80 - # do - # benchmark $default_qps $target_output_len - # done - -} - - -main "$@" diff --git a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh new file mode 100644 index 0000000000000..38c0fcdaefc85 --- /dev/null +++ b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh @@ -0,0 +1,172 @@ +#!/bin/bash + +# Requirement: 8x H100 GPUs. + + +# Model: neuralmagic/Meta-Llama-3-70B-Instruct-FP8-KV +# Query: 2048 input tokens, 11 output tokens, QPS 4, 500 requests +# Resource: 8x H100 +# Approaches: +# 1. Chunked prefill: 1 vllm instance with tp=8 +# 2. Chunked prefill: 2 vllm instance with tp=4, equivalent to 1 tp=4 instance with QPS 4 +# 3. Disaggregated prefill: 1 prefilling instance and 1 decoding instance +# Prefilling instance: max_output_token=1 +# Decoding instance: force the input tokens be the same across requests to bypass prefilling + +set -ex + +kill_gpu_processes() { + # kill all processes on GPU. + pkill pt_main_thread + sleep 10 + + # remove vllm config file + rm -rf ~/.config/vllm + + # Print the GPU memory usage + # so that we know if all GPU processes are killed. + gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0) + # The memory usage should be 0 MB. + echo "GPU 0 Memory Usage: $gpu_memory_usage MB" +} + +wait_for_server() { + # wait for vllm server to start + # return 1 if vllm server crashes + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + + +benchmark() { + + export VLLM_LOGGING_LEVEL=DEBUG + export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') + export VLLM_PORT=12345 + # export VLLM_TRACE_FUNCTION=1 + + # compare chunked prefill with disaggregated prefill + + results_folder="./results" + model="meta-llama/Meta-Llama-3.1-70B-Instruct" + dataset_name="sonnet" + dataset_path="../sonnet_4x.txt" + num_prompts=10 + qps=$1 + prefix_len=50 + input_len=2048 + output_len=$2 + + + # baseline: chunked prefill + # CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ + # -m vllm.entrypoints.openai.api_server \ + # --model $model \ + # --port 8000 \ + # -tp 4 \ + # --disable-log-stats \ + # --disable-log-requests \ + # --enable-chunked-prefill & + # wait_for_server 8000 + + # python3 ../benchmark_serving.py \ + # --backend vllm \ + # --model $model \ + # --dataset-name $dataset_name \ + # --dataset-path $dataset_path \ + # --sonnet-input-len $input_len \ + # --sonnet-output-len $output_len \ + # --sonnet-prefix-len $prefix_len \ + # --num-prompts $((num_prompts / 2)) \ + # --port 8000 \ + # --save-result \ + # --result-dir $results_folder \ + # --result-filename chunked_prefill_tp4.json \ + # --request-rate $((qps / 2)) + # kill_gpu_processes + + + + # disagg prefill + VLLM_RPC_PORT=5570 VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8100 \ + -tp 4 \ + --max-model-len 30000 \ + --disable-log-stats \ + --disable-log-requests \ + --gpu-memory-utilization 0.8 & + VLLM_RPC_PORT=5580 VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8200 \ + -tp 4 \ + --max-model-len 30000 \ + --disable-log-stats \ + --disable-log-requests \ + --gpu-memory-utilization 0.8 & + + wait_for_server 8100 + wait_for_server 8200 + + # launch a proxy server that listen from port 8000 + python3 disagg_prefill_proxy_server.py & + sleep 1 + + python3 ../benchmark_serving.py \ + --backend vllm \ + --model $model \ + --dataset-name $dataset_name \ + --dataset-path $dataset_path \ + --sonnet-input-len $input_len \ + --sonnet-output-len $output_len \ + --sonnet-prefix-len $prefix_len \ + --num-prompts $num_prompts \ + --port 8000 \ + --save-result \ + --result-dir $results_folder \ + --result-filename disagg_prefill_2xtp4.json \ + --request-rate $qps + + + kill_gpu_processes + + +} + + +main() { + + (which wget && which curl) || (apt-get update && apt-get install -y wget curl) + (which jq) || (apt-get -y install jq) + (which socat) || (apt-get -y install socat) + + pip install quart httpx + + cd "$(dirname "$0")" + + cd .. + # create sonnet-4x.txt so that we can sample 2048 tokens for input + echo "" > sonnet_4x.txt + for _ in {1..4} + do + cat sonnet.txt >> sonnet_4x.txt + done + cd disagg_benchmarks + + rm -rf results + mkdir results + + default_qps=10 + default_output_len=10 + + benchmark $default_qps $default_output_len + +} + + +main "$@" diff --git a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py index bf5c136b9a2d9..8d9f699d45321 100644 --- a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py +++ b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py @@ -15,7 +15,8 @@ async def forward_request(url, data): } async with session.post(url=url, json=data, headers=headers) as response: if response.status == 200: - if response.headers.get('Transfer-Encoding') == 'chunked': + # if response.headers.get('Transfer-Encoding') == 'chunked': + if True: async for chunk_bytes in response.content.iter_chunked(1024): yield chunk_bytes else: @@ -30,9 +31,7 @@ async def handle_request(): prefill_request = original_request_data.copy() # change max_tokens = 1 to let it only do prefill prefill_request['max_tokens'] = 1 - # avoid sampling overhead by setting detokenize = False - prefill_request['detokenize'] = False - + # finish prefill async for _ in forward_request('http://localhost:8100/v1/completions', prefill_request): continue From 89d4ca42168a11b73b69ae3a0969ad4a8fbef310 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Wed, 7 Aug 2024 23:08:49 +0000 Subject: [PATCH 144/303] update --- vllm/worker/model_runner.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 120b6406692e9..5c285fd0b933b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1370,6 +1370,8 @@ def execute_model( is_profile_run = (kv_caches is None) or (kv_caches[0] is None) # check if the current run is prefill is_prefill_run = prefill_meta is not None + + logger.debug('Into disagg prefill') # check if we can skip prefilling # We can only skip during prefill phase in disaggregated decode instance @@ -1414,6 +1416,8 @@ def execute_model( model_input, kv_caches, ) + + logger.debug("Out from disagg prefill.") # Compute the logits in the last pipeline stage. From 9f4dba236da05c3af67337c32a21c1bbc531c659 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Thu, 8 Aug 2024 06:12:25 +0000 Subject: [PATCH 145/303] remove debugging information --- vllm/worker/model_runner.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 5c285fd0b933b..cc18f4b3cd621 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1371,8 +1371,6 @@ def execute_model( # check if the current run is prefill is_prefill_run = prefill_meta is not None - logger.debug('Into disagg prefill') - # check if we can skip prefilling # We can only skip during prefill phase in disaggregated decode instance if any([ @@ -1417,8 +1415,6 @@ def execute_model( kv_caches, ) - logger.debug("Out from disagg prefill.") - # Compute the logits in the last pipeline stage. if not get_pp_group().is_last_rank: From aa55883e755153e7a98c32fa4ad48679743d2829 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Fri, 9 Aug 2024 03:30:39 +0000 Subject: [PATCH 146/303] avoid broadcast by finding seqlen inside the attn metadata --- vllm/distributed/distributed_kv.py | 61 ++++++++++++++---------------- 1 file changed, 28 insertions(+), 33 deletions(-) diff --git a/vllm/distributed/distributed_kv.py b/vllm/distributed/distributed_kv.py index 77c153d775c68..3d6dfc93acfe9 100644 --- a/vllm/distributed/distributed_kv.py +++ b/vllm/distributed/distributed_kv.py @@ -21,6 +21,9 @@ IS_KV_PREFILL_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE == "prefill") IS_KV_DECODE_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE == "decode") +# a magic number +DISTRIBUTED_KV_GLOO_TAG = 24857323 + logger = init_logger(__name__) @@ -140,13 +143,17 @@ def send_input_hash(self, input_hash: int) -> None: # KV cache send go through CPU, and the original `send` only use GPU. # So create a new group for sending input hash. input_hash_tensor = torch.tensor([input_hash], device="cpu").long() - torch.distributed.isend(input_hash_tensor, self.target_rank_for_send, - self.cpu_group) + torch.distributed.isend(input_hash_tensor, + self.target_rank_for_send, + self.cpu_group, + tag=DISTRIBUTED_KV_GLOO_TAG) def recv_input_hash(self) -> int: input_hash_tensor = torch.tensor([0], device="cpu").long() - torch.distributed.irecv(input_hash_tensor, self.target_rank_for_recv, - self.cpu_group).wait() + torch.distributed.irecv(input_hash_tensor, + self.target_rank_for_recv, + self.cpu_group, + tag=DISTRIBUTED_KV_GLOO_TAG).wait() return input_hash_tensor.item() def recv_input_hash_and_send_kv(self): @@ -155,9 +162,10 @@ def recv_input_hash_and_send_kv(self): # receive the input hash that the decode instance requires logger.debug( - '[rank%d]: Waiting for input hash from rank %d', + '[rank%d]: Waiting for input hash from rank %d, my keys are %s', torch.distributed.get_rank(), self.target_rank_for_recv, + list(self.input_hash_to_kv_sending_requests.keys()), ) input_hash = self.recv_input_hash() logger.debug( @@ -226,25 +234,18 @@ def buffer_kv_caches_send_and_listen_for_input_hash( ) -> None: input_tokens_tuple = tuple(model_input.input_tokens.tolist()) - seq_query_obj = { - "seq_lens": model_input.seq_lens, - "query_lens": model_input.query_lens, - } - seq_query_obj = ps.get_tp_group().broadcast_object(seq_query_obj) - seq_lens = seq_query_obj["seq_lens"] - query_lens = seq_query_obj["query_lens"] + seq_lens = model_input.attn_metadata.seq_lens slot_mapping = model_input.attn_metadata.slot_mapping.flatten() - logger.debug("My query lens is %s, seq len is %s, rank is %s", - str(query_lens), str(seq_lens), torch.distributed.get_rank()) + logger.debug("My seq len is %s, rank is %s", str(seq_lens), torch.distributed.get_rank()) # query_lens contains new KV caches that are added to vLLM. # so we will send them to decode instance # FIXME(Kuntai): This assume that all requests are prefill. - for idx, qlen in enumerate(query_lens): + for idx, slen in enumerate(seq_lens): - start_pos = sum(query_lens[:idx]) - end_pos = start_pos + qlen + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen input_hash = hash(input_tokens_tuple[start_pos:end_pos]) for i in range(model_executable.model.start_layer, @@ -267,7 +268,7 @@ def buffer_kv_caches_send_and_listen_for_input_hash( input_hash, hidden_or_intermediate_states[start_pos:end_pos]) ps.get_disagg_group().kv_cache_send_finish(input_hash) - logger.error("\033[92mKV send DONE for rank %d\033[0m", + logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) @@ -278,28 +279,22 @@ def send_input_hash_and_do_kv_caches_recv( # This is disagg decode instance, during prefill state # Need to receive KV from the prefill instance - # FIXME(Kuntai): This impl assumes that all requests are prefill. input_tokens_tuple = tuple(model_input.input_tokens.tolist()) - seq_query_obj = { - "seq_lens": model_input.seq_lens, - "query_lens": model_input.query_lens, - } - seq_query_obj = ps.get_tp_group().broadcast_object(seq_query_obj) - seq_lens = seq_query_obj["seq_lens"] - query_lens = seq_query_obj["query_lens"] + seq_lens = model_input.attn_metadata.seq_lens slot_mapping = model_input.attn_metadata.slot_mapping.flatten() + logger.debug("My seq len is %s, rank is %s", str(seq_lens), torch.distributed.get_rank()) + hidden_or_intermediate_states_for_one_req = [] # enumerate different requests - logger.debug("My query lens is %s, seq len is %s, rank is %s", - str(query_lens), str(seq_lens), torch.distributed.get_rank()) - for idx, qlen in enumerate(query_lens): + # FIXME(Kuntai): This impl assumes that all requests are prefill. + for idx, slen in enumerate(seq_lens): - start_pos = sum(query_lens[:idx]) - end_pos = start_pos + qlen + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen input_hash = hash(input_tokens_tuple[start_pos:end_pos]) - num_tokens = qlen + num_tokens = slen # notify the prefill instance to start sending KVs associated with input_hash ps.get_disagg_group().kv_cache_recv_start(input_hash) @@ -343,6 +338,6 @@ def send_input_hash_and_do_kv_caches_recv( hidden_or_intermediate_states = torch.cat( hidden_or_intermediate_states_for_one_req, dim=0) - logger.error("\033[92mKV receive DONE for rank %d\033[0m", + logger.error("[rank%d]: KV recv DONE.", torch.distributed.get_rank()) return hidden_or_intermediate_states From 95df02349644a8807b4c646b75d0164fb4b93c99 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Fri, 9 Aug 2024 06:33:58 +0000 Subject: [PATCH 147/303] update examples --- .../disagg_prefill/disagg_prefill_example.sh | 38 ++++++------------- 1 file changed, 11 insertions(+), 27 deletions(-) diff --git a/examples/disagg_prefill/disagg_prefill_example.sh b/examples/disagg_prefill/disagg_prefill_example.sh index a2a6dc0932fdf..f57f5fd86d89c 100644 --- a/examples/disagg_prefill/disagg_prefill_example.sh +++ b/examples/disagg_prefill/disagg_prefill_example.sh @@ -20,50 +20,34 @@ VLLM_RPC_PORT=5570 VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0 pytho -m vllm.entrypoints.openai.api_server \ --model meta-llama/Meta-Llama-3.1-8B-Instruct \ --port 8100 \ - -tp 1 \ - --enable-prefix-caching \ - --gpu-memory-utilization 0.8 \ - --max-model-len 10000 & + --max-model-len 10000 \ + --gpu-memory-utilization 0.8 & # decoding instance VLLM_RPC_PORT=5580 VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=1 python3 \ -m vllm.entrypoints.openai.api_server \ --model meta-llama/Meta-Llama-3.1-8B-Instruct \ --port 8200 \ - -tp 1 \ - --enable-prefix-caching \ - --gpu-memory-utilization 0.8 \ - --max-model-len 10000 & + --max-model-len 10000 \ + --gpu-memory-utilization 0.8 & # wait until prefill and decode instances are ready wait_for_server 8100 wait_for_server 8200 -# sending an example request -# in disaggregated prefilling, there are two steps of sending a request: -# 1. send the request to prefill instance, with max_tokens set to 1 -# 2. send the request again to decode instance, no modification +# launch a proxy server that opens the service at port 8000 +python3 ../../benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py & +sleep 1 -# send to prefill instance, let it only do prefill by setting max_token=1 -curl -m 5 http://localhost:8100/v1/completions \ +# serve an example request +curl http://localhost:8000/v1/completions \ -H "Content-Type: application/json" \ -d '{ "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", -"prompt": "'$i' San Francisco is a", -"max_tokens": 1, +"prompt": "San Francisco is a", +"max_tokens": 10, "temperature": 0 }' -# send to decode instance -curl -m 5 http://localhost:8100/v1/completions \ --H "Content-Type: application/json" \ --d '{ -"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", -"prompt": "'$i' San Francisco is a", -"max_tokens": 50, -"temperature": 0 -}' - - # clean up ps -e | grep pt_main_thread | awk '{print $1}' | xargs kill -9 \ No newline at end of file From d92223ad176dae9264555fbacd3fc63056ebedd8 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Fri, 9 Aug 2024 06:34:25 +0000 Subject: [PATCH 148/303] support pipeline parallel --- vllm/distributed/distributed_kv.py | 132 ++++++++++++++++++++--------- 1 file changed, 90 insertions(+), 42 deletions(-) diff --git a/vllm/distributed/distributed_kv.py b/vllm/distributed/distributed_kv.py index 3d6dfc93acfe9..5404ea91e2bc4 100644 --- a/vllm/distributed/distributed_kv.py +++ b/vllm/distributed/distributed_kv.py @@ -4,6 +4,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING from collections import defaultdict, deque from concurrent.futures import ThreadPoolExecutor +from threading import Lock +from copy import deepcopy import torch from torch.distributed import Backend, ProcessGroup @@ -13,15 +15,16 @@ from vllm.logger import init_logger import vllm.distributed.parallel_state as ps from vllm import _custom_ops as ops +from vllm.sequence import IntermediateTensors assert envs.VLLM_DISAGG_PREFILL_ROLE in [None, "prefill", "decode"], \ "VLLM_DISAGG_PREFILL_ROLE can only be prefill or decode." -IS_DISTRIBUTED_KV_INSTANCE = (envs.VLLM_DISAGG_PREFILL_ROLE is not None) +IS_DISTRIBUTED_KV_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE is not None) IS_KV_PREFILL_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE == "prefill") IS_KV_DECODE_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE == "decode") -# a magic number +# add a tag when sending/recving input hash DISTRIBUTED_KV_GLOO_TAG = 24857323 logger = init_logger(__name__) @@ -66,6 +69,7 @@ def __init__( # use a threadpool to buffer send request in disaggregated prefill self.input_hash_to_kv_sending_requests = defaultdict(deque) self.kv_sending_thread = None + self.input_hash_to_kv_sending_requests_lock = Lock() self.target_rank_for_send = self.ranks[(self.rank_in_group + 1) % self.world_size] self.target_rank_for_recv = self.ranks[(self.rank_in_group - 1) % @@ -104,7 +108,8 @@ def debug_recv(self, def kv_cache_send(self, input_hash: int, - tensor: torch.Tensor, + tensor: Union[torch.Tensor, IntermediateTensors], + is_hidden: bool = False, dst: Optional[int] = None) -> None: """Push the KV cache send request into the send buffer""" """NOTE: `dst` is the local rank of the destination rank.""" @@ -114,17 +119,35 @@ def kv_cache_send(self, else: send_func = self.send - self.input_hash_to_kv_sending_requests[input_hash].append([ - send_func, - # tensor needs to be cloned, if not the tensor may be freed - tensor.clone(), - dst - ]) - - def kv_cache_recv(self, - size: torch.Size, - dtype: torch.dtype, - src: Optional[int] = None) -> torch.Tensor: + if is_hidden and not ps.get_pp_group().is_last_rank: + + assert isinstance(tensor, IntermediateTensors) + + output = deepcopy(tensor.tensors) + for key in output: + output[key] = output[key].contiguous() + + self.input_hash_to_kv_sending_requests[input_hash].append( + [self.send_tensor_dict, output, dst]) + + else: + + assert isinstance(tensor, torch.Tensor) + + self.input_hash_to_kv_sending_requests[input_hash].append([ + send_func, + # tensor needs to be cloned, if not the tensor may be freed + tensor.clone(), + dst + ]) + + def kv_cache_recv( + self, + size: torch.Size, + dtype: torch.dtype, + is_hidden: bool = False, + src: Optional[int] = None + ) -> Union[torch.Tensor, IntermediateTensors]: """Receives a tensor from the src rank (blocking).""" """This API should be used together with `push`""" """NOTE: `src` is the local rank of the destination rank.""" @@ -134,7 +157,10 @@ def kv_cache_recv(self, else: recv_func = self.recv - tensor = recv_func(size, dtype, src) + if is_hidden and not ps.get_pp_group().is_last_rank: + tensor = IntermediateTensors(self.recv_tensor_dict(src)) + else: + tensor = recv_func(size, dtype, src) return tensor @@ -143,34 +169,33 @@ def send_input_hash(self, input_hash: int) -> None: # KV cache send go through CPU, and the original `send` only use GPU. # So create a new group for sending input hash. input_hash_tensor = torch.tensor([input_hash], device="cpu").long() - torch.distributed.isend(input_hash_tensor, - self.target_rank_for_send, - self.cpu_group, - tag=DISTRIBUTED_KV_GLOO_TAG) + torch.distributed.send(input_hash_tensor, + self.target_rank_for_send, + self.cpu_group, + tag=DISTRIBUTED_KV_GLOO_TAG) def recv_input_hash(self) -> int: input_hash_tensor = torch.tensor([0], device="cpu").long() - torch.distributed.irecv(input_hash_tensor, - self.target_rank_for_recv, - self.cpu_group, - tag=DISTRIBUTED_KV_GLOO_TAG).wait() + torch.distributed.recv(input_hash_tensor, + self.target_rank_for_recv, + self.cpu_group, + tag=DISTRIBUTED_KV_GLOO_TAG) return input_hash_tensor.item() def recv_input_hash_and_send_kv(self): try: - - # receive the input hash that the decode instance requires logger.debug( '[rank%d]: Waiting for input hash from rank %d, my keys are %s', torch.distributed.get_rank(), self.target_rank_for_recv, list(self.input_hash_to_kv_sending_requests.keys()), ) + # block the ThreadPoolExecutor, until a new input hash is received input_hash = self.recv_input_hash() - logger.debug( - 'Successfully received input hash %d', - input_hash) + + self.input_hash_to_kv_sending_requests_lock.acquire() + logger.debug('Successfully received input hash %d', input_hash) assert input_hash in self.input_hash_to_kv_sending_requests, \ f"The KV cache of {input_hash} does not exist. "\ f"Existing input hash: {list(self.input_hash_to_kv_sending_requests.keys())}" @@ -184,6 +209,7 @@ def recv_input_hash_and_send_kv(self): if request == []: break request[0](*request[1:]) + if len(self.input_hash_to_kv_sending_requests[input_hash]) == 0: logger.debug('Finish input hash %d, free GPU memory...', input_hash) @@ -195,6 +221,8 @@ def recv_input_hash_and_send_kv(self): 'memory for one of the request.', input_hash) + self.input_hash_to_kv_sending_requests_lock.release() + except Exception as e: # This function is executed in ThreadPoolExecutor # and it will block all exceptions by default @@ -237,8 +265,11 @@ def buffer_kv_caches_send_and_listen_for_input_hash( seq_lens = model_input.attn_metadata.seq_lens slot_mapping = model_input.attn_metadata.slot_mapping.flatten() - logger.debug("My seq len is %s, rank is %s", str(seq_lens), torch.distributed.get_rank()) - + logger.debug("My seq len is %s, rank is %s", str(seq_lens), + torch.distributed.get_rank()) + + ps.get_disagg_group().input_hash_to_kv_sending_requests_lock.acquire() + # query_lens contains new KV caches that are added to vLLM. # so we will send them to decode instance # FIXME(Kuntai): This assume that all requests are prefill. @@ -265,11 +296,14 @@ def buffer_kv_caches_send_and_listen_for_input_hash( input_hash, value_cache[current_slot_mapping]) ps.get_disagg_group().kv_cache_send( - input_hash, hidden_or_intermediate_states[start_pos:end_pos]) + input_hash, + hidden_or_intermediate_states[start_pos:end_pos], + is_hidden=True) ps.get_disagg_group().kv_cache_send_finish(input_hash) - logger.debug("[rank%d]: KV send DONE.", - torch.distributed.get_rank()) + ps.get_disagg_group().input_hash_to_kv_sending_requests_lock.release() + + logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) def send_input_hash_and_do_kv_caches_recv( @@ -283,7 +317,8 @@ def send_input_hash_and_do_kv_caches_recv( seq_lens = model_input.attn_metadata.seq_lens slot_mapping = model_input.attn_metadata.slot_mapping.flatten() - logger.debug("My seq len is %s, rank is %s", str(seq_lens), torch.distributed.get_rank()) + logger.debug("My seq len is %s, rank is %s", str(seq_lens), + torch.distributed.get_rank()) hidden_or_intermediate_states_for_one_req = [] @@ -330,14 +365,27 @@ def send_input_hash_and_do_kv_caches_recv( ) hidden_or_intermediate_states_for_one_req.append( - ps.get_disagg_group().kv_cache_recv( - torch.Size([num_tokens, model_executable.config.hidden_size]), - kv_cache[0].dtype)) + ps.get_disagg_group().kv_cache_recv(torch.Size( + [num_tokens, model_executable.config.hidden_size]), + kv_cache[0].dtype, + is_hidden=True)) # concatenate hidden states from different requests - hidden_or_intermediate_states = torch.cat( - hidden_or_intermediate_states_for_one_req, dim=0) - - logger.error("[rank%d]: KV recv DONE.", - torch.distributed.get_rank()) + if isinstance(hidden_or_intermediate_states_for_one_req[0], torch.Tensor): + hidden_or_intermediate_states = torch.cat( + hidden_or_intermediate_states_for_one_req, dim=0) + else: + # concat the IntermediateTensors + keys = list(hidden_or_intermediate_states_for_one_req[0].tensors.keys()) + result_its = {} + + for key in keys: + result_its[key] = [] + for its in hidden_or_intermediate_states_for_one_req: + result_its[key].append(its[key]) + result_its[key] = torch.cat(result_its[key], dim=0) + + hidden_or_intermediate_states = IntermediateTensors(result_its) + + logger.error("[rank%d]: KV recv DONE.", torch.distributed.get_rank()) return hidden_or_intermediate_states From a8c202c130f622c2ba37cf10869eb287f96e2ac1 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sat, 10 Aug 2024 09:19:49 +0000 Subject: [PATCH 149/303] update benchmark --- compare chunked prefill w.r.t. disagg prefill --- .../disagg_performance_benchmark.sh | 136 +++++++++--------- 1 file changed, 68 insertions(+), 68 deletions(-) diff --git a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh index 38c0fcdaefc85..dde9a80b59b37 100644 --- a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh @@ -17,17 +17,12 @@ set -ex kill_gpu_processes() { # kill all processes on GPU. - pkill pt_main_thread - sleep 10 - - # remove vllm config file - rm -rf ~/.config/vllm - - # Print the GPU memory usage - # so that we know if all GPU processes are killed. - gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0) - # The memory usage should be 0 MB. - echo "GPU 0 Memory Usage: $gpu_memory_usage MB" + pkill -f pt_main_thread + pkill -f python3 + pkill -f round_robin_proxy.sh + ps -e | grep pt_main_thread | awk '{print $1}' | xargs kill -9 + for port in 8000 8100 8200; do lsof -t -i:$port | xargs -r kill -9; done + sleep 1 } wait_for_server() { @@ -41,57 +36,40 @@ wait_for_server() { } -benchmark() { - - export VLLM_LOGGING_LEVEL=DEBUG - export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') - export VLLM_PORT=12345 - # export VLLM_TRACE_FUNCTION=1 - - # compare chunked prefill with disaggregated prefill - - results_folder="./results" +launch_chunked_prefill() { model="meta-llama/Meta-Llama-3.1-70B-Instruct" - dataset_name="sonnet" - dataset_path="../sonnet_4x.txt" - num_prompts=10 - qps=$1 - prefix_len=50 - input_len=2048 - output_len=$2 - - - # baseline: chunked prefill - # CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ - # -m vllm.entrypoints.openai.api_server \ - # --model $model \ - # --port 8000 \ - # -tp 4 \ - # --disable-log-stats \ - # --disable-log-requests \ - # --enable-chunked-prefill & - # wait_for_server 8000 - - # python3 ../benchmark_serving.py \ - # --backend vllm \ - # --model $model \ - # --dataset-name $dataset_name \ - # --dataset-path $dataset_path \ - # --sonnet-input-len $input_len \ - # --sonnet-output-len $output_len \ - # --sonnet-prefix-len $prefix_len \ - # --num-prompts $((num_prompts / 2)) \ - # --port 8000 \ - # --save-result \ - # --result-dir $results_folder \ - # --result-filename chunked_prefill_tp4.json \ - # --request-rate $((qps / 2)) - # kill_gpu_processes - + # disagg prefill + VLLM_RPC_PORT=5570 CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8100 \ + -tp 4 \ + --max-model-len 30000 \ + --disable-log-stats \ + --disable-log-requests \ + --enable-chunked-prefill \ + --gpu-memory-utilization 0.8 & + VLLM_RPC_PORT=5580 CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8200 \ + -tp 4 \ + --max-model-len 30000 \ + --disable-log-stats \ + --disable-log-requests \ + --enable-chunked-prefill \ + --gpu-memory-utilization 0.8 & + wait_for_server 8100 + wait_for_server 8200 + bash round_robin_proxy.sh & + sleep 1 +} +launch_disagg_prefill() { + model="meta-llama/Meta-Llama-3.1-70B-Instruct" # disagg prefill - VLLM_RPC_PORT=5570 VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ + VLLM_PORT=12345 VLLM_RPC_PORT=5570 VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ -m vllm.entrypoints.openai.api_server \ --model $model \ --port 8100 \ @@ -100,7 +78,7 @@ benchmark() { --disable-log-stats \ --disable-log-requests \ --gpu-memory-utilization 0.8 & - VLLM_RPC_PORT=5580 VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ + VLLM_PORT=12345 VLLM_RPC_PORT=5580 VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ -m vllm.entrypoints.openai.api_server \ --model $model \ --port 8200 \ @@ -109,13 +87,24 @@ benchmark() { --disable-log-stats \ --disable-log-requests \ --gpu-memory-utilization 0.8 & - wait_for_server 8100 wait_for_server 8200 - - # launch a proxy server that listen from port 8000 python3 disagg_prefill_proxy_server.py & sleep 1 +} + + +benchmark() { + results_folder="./results" + model="meta-llama/Meta-Llama-3.1-70B-Instruct" + dataset_name="sonnet" + dataset_path="../sonnet_4x.txt" + num_prompts=400 + qps=$1 + prefix_len=50 + input_len=2048 + output_len=$2 + tag=$3 python3 ../benchmark_serving.py \ --backend vllm \ @@ -129,12 +118,10 @@ benchmark() { --port 8000 \ --save-result \ --result-dir $results_folder \ - --result-filename disagg_prefill_2xtp4.json \ + --result-filename $tag-qps-$qps.json \ --request-rate $qps - - kill_gpu_processes - + sleep 2 } @@ -162,9 +149,22 @@ main() { mkdir results default_qps=10 - default_output_len=10 + default_output_len=150 - benchmark $default_qps $default_output_len + export VLLM_LOGGING_LEVEL=DEBUG + export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') + + launch_chunked_prefill + for qps in 2 4 6 8; do + benchmark $qps $default_output_len chunked_prefill + done + kill_gpu_processes + + launch_disagg_prefill + for qps in 2 4 6 8; do + benchmark $qps $default_output_len disagg_prefill + done + kill_gpu_processes } From 310f3a3214a86f8411a60d9d59c07d4686244c94 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sat, 10 Aug 2024 09:20:16 +0000 Subject: [PATCH 150/303] mute round_robin_proxy -- too loud --- benchmarks/disagg_benchmarks/round_robin_proxy.sh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/benchmarks/disagg_benchmarks/round_robin_proxy.sh b/benchmarks/disagg_benchmarks/round_robin_proxy.sh index e996756bc89d6..375bf9e422371 100644 --- a/benchmarks/disagg_benchmarks/round_robin_proxy.sh +++ b/benchmarks/disagg_benchmarks/round_robin_proxy.sh @@ -15,6 +15,5 @@ get_next_port() { # Start the proxy while true; do NEXT_PORT=$(get_next_port) - echo "Forwarding to port $NEXT_PORT" - socat TCP4-LISTEN:8000,reuseaddr,fork TCP4:localhost:$NEXT_PORT + socat TCP4-LISTEN:8000,reuseaddr,fork TCP4:localhost:$NEXT_PORT 2>/dev/null done \ No newline at end of file From 118aab18bda3f07df1af7643bc6208657825287b Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sat, 10 Aug 2024 09:21:54 +0000 Subject: [PATCH 151/303] bug fix: racing conditions, and rare cases where input hash is not cached --- .../disagg_prefill_proxy_server.py | 9 +- vllm/distributed/distributed_kv.py | 200 +++++++++++++----- vllm/worker/model_runner.py | 67 +++--- 3 files changed, 180 insertions(+), 96 deletions(-) diff --git a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py index 8d9f699d45321..5750df7735ad1 100644 --- a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py +++ b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py @@ -36,7 +36,7 @@ async def handle_request(): async for _ in forward_request('http://localhost:8100/v1/completions', prefill_request): continue - print(f"Request {prefill_request} prefill done. proceeding to decode.") + print(f"Prefill done. proceeding to decode.") # return decode generator = forward_request('http://localhost:8200/v1/completions', original_request_data) @@ -46,9 +46,10 @@ async def handle_request(): return response except Exception as e: - exc_info = sys.exc_info() - print(e) - print("".join(traceback.format_exception(*exc_info))) + pass + # exc_info = sys.exc_info() + # print(e) + # print("".join(traceback.format_exception(*exc_info))) if __name__ == '__main__': app.run(port=8000) diff --git a/vllm/distributed/distributed_kv.py b/vllm/distributed/distributed_kv.py index 5404ea91e2bc4..fd28a82c294a5 100644 --- a/vllm/distributed/distributed_kv.py +++ b/vllm/distributed/distributed_kv.py @@ -1,11 +1,22 @@ """vLLM distributed KV cache transfer API. These APIs are used in `vllm/worker/model_runner.py`. + +Currently supporting TP and PP. + +Workflow: +- In prefill instance, KV cache sender *buffers* the KV cache send requests +- In decode instance + - KV cache receiver sends the hash of input tokens to sender + - KV cache sender executes send request + - KV cache receiver receives the KV cache """ from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING from collections import defaultdict, deque from concurrent.futures import ThreadPoolExecutor from threading import Lock from copy import deepcopy +import time +import threading import torch from torch.distributed import Backend, ProcessGroup @@ -29,6 +40,24 @@ logger = init_logger(__name__) +import logging + + +class RankFilter(logging.Filter): + + def filter(self, record): + # Only log if rank is 4 + rank = 1 + try: + rank = torch.distributed.get_rank() + except Exception: + pass + return rank % 4 == 0 + + +for handler in logger.handlers: + handler.addFilter(RankFilter()) + class DistributedKVCoordinator(GroupCoordinator): """ @@ -45,7 +74,11 @@ def __init__( group_ranks: List[List[int]], local_rank: int, torch_distributed_backend: Union[str, Backend], - use_pynccl: bool = True, + # DO NOT use pynccl here + # Pynccl send is non-blocking + # and it's possible that the memory is freed before the data being sent + # which may happen at high qps + use_pynccl: bool = False, use_custom_allreduce: bool = False, use_tpu_communicator: bool = True, use_message_queue_broadcaster: bool = False, @@ -136,7 +169,7 @@ def kv_cache_send(self, self.input_hash_to_kv_sending_requests[input_hash].append([ send_func, - # tensor needs to be cloned, if not the tensor may be freed + # use clone to make sure the tensor is contiguous tensor.clone(), dst ]) @@ -164,7 +197,11 @@ def kv_cache_recv( return tensor - def send_input_hash(self, input_hash: int) -> None: + def send_input_hash(self, input_hash: int) -> int: + + logger.debug('[rank%d]: Sending input hash %d to rank %d', + torch.distributed.get_rank(), input_hash, + self.target_rank_for_send) # KV cache send go through CPU, and the original `send` only use GPU. # So create a new group for sending input hash. @@ -173,33 +210,64 @@ def send_input_hash(self, input_hash: int) -> None: self.target_rank_for_send, self.cpu_group, tag=DISTRIBUTED_KV_GLOO_TAG) + return_tensor = torch.tensor([0], device="cpu").long() + torch.distributed.recv(return_tensor, + self.target_rank_for_recv, + self.cpu_group, + tag=DISTRIBUTED_KV_GLOO_TAG) + return return_tensor.item() - def recv_input_hash(self) -> int: + def recv_input_hash(self) -> Optional[int]: + ''' + Receive an input hash, and check if it is already cached + ''' input_hash_tensor = torch.tensor([0], device="cpu").long() torch.distributed.recv(input_hash_tensor, self.target_rank_for_recv, self.cpu_group, tag=DISTRIBUTED_KV_GLOO_TAG) - return input_hash_tensor.item() - - def recv_input_hash_and_send_kv(self): - - try: + input_hash = input_hash_tensor.item() + # a new input hash comes in, see if it is already cached + self.input_hash_to_kv_sending_requests_lock.acquire() + logger.debug('Successfully received input hash %d', input_hash) + if input_hash not in self.input_hash_to_kv_sending_requests: + logger.warning( + f"The KV cache of {input_hash} does not exist. "\ + f"Existing input hash: {list(self.input_hash_to_kv_sending_requests.keys())}") + + # 0 for fail + x = torch.tensor([0], device="cpu").long() + torch.distributed.send(x, + self.target_rank_for_send, + self.cpu_group, + tag=DISTRIBUTED_KV_GLOO_TAG) + return None + else: + logger.debug('Input hash %d exists, start sending', input_hash) + + # 1 for success + x = torch.tensor([1], device="cpu").long() + torch.distributed.send(x, + self.target_rank_for_send, + self.cpu_group, + tag=DISTRIBUTED_KV_GLOO_TAG) + return input_hash + + def kv_cache_send_loop(self): + + while True: logger.debug( '[rank%d]: Waiting for input hash from rank %d, my keys are %s', torch.distributed.get_rank(), self.target_rank_for_recv, list(self.input_hash_to_kv_sending_requests.keys()), ) - # block the ThreadPoolExecutor, until a new input hash is received + # wait for a new input hash + # this function will acquire the lock input_hash = self.recv_input_hash() - - self.input_hash_to_kv_sending_requests_lock.acquire() - logger.debug('Successfully received input hash %d', input_hash) - assert input_hash in self.input_hash_to_kv_sending_requests, \ - f"The KV cache of {input_hash} does not exist. "\ - f"Existing input hash: {list(self.input_hash_to_kv_sending_requests.keys())}" - logger.debug('Input hash %d exists, start sending', input_hash) + if input_hash is None: + self.input_hash_to_kv_sending_requests_lock.release() + continue # execute corresponding kv cache sending jobs in request queue while True: @@ -208,6 +276,7 @@ def recv_input_hash_and_send_kv(self): # An empty request: the KV cahe of one request are all sent if request == []: break + request[0](*request[1:]) if len(self.input_hash_to_kv_sending_requests[input_hash]) == 0: @@ -217,56 +286,57 @@ def recv_input_hash_and_send_kv(self): else: logger.debug( 'The buffer for input hash %d is not empty, meaning that '\ - 'there are two jobs with identical input. Free GPU '\ - 'memory for one of the request.', + 'there are two jobs with identical input.', input_hash) self.input_hash_to_kv_sending_requests_lock.release() - except Exception as e: - # This function is executed in ThreadPoolExecutor - # and it will block all exceptions by default - # so log the potential error message here. - import traceback - import time - exc_info = traceback.format_exc() - # avoid the output of different rank overlaps - time.sleep(torch.distributed.get_rank()) - logger.error("An error occured: %s, stack trace: %s", e, exc_info) - def kv_cache_send_finish(self, input_hash: int): + def kv_cache_send_ready(self, input_hash: int): if self.kv_sending_thread is None: - self.kv_sending_thread = ThreadPoolExecutor(max_workers=1) - - # append an empty job to signal that this is the end of a request + self.kv_sending_thread = threading.Thread( + target=self.kv_cache_send_loop) + self.kv_sending_thread.start() + + # append an empty list to separate requests + # as there might be identical requests, that has the same input hash self.input_hash_to_kv_sending_requests[input_hash].append([]) - job = self.kv_sending_thread.submit(self.recv_input_hash_and_send_kv) - logger.debug(f'Submit job {job} into kv cache sending thread') + logger.debug(f'Buffered input hash {input_hash}') def kv_cache_recv_start(self, input_hash: int): - - logger.debug('[rank%d]: Sending input hash %d to rank %d', - torch.distributed.get_rank(), input_hash, - self.ranks[(self.rank_in_group + 1) % self.world_size]) - # notify the kv cache sender with the input hash id - self.send_input_hash(input_hash) + return self.send_input_hash(input_hash) + + def block_if_buffer_full(self): + + # block vLLM if the KV cache sending buffer is full + # TODO: allow using other policies to handle buffer full + while True: + self.input_hash_to_kv_sending_requests_lock.acquire() + if len(self.input_hash_to_kv_sending_requests.keys()) > 55: + self.input_hash_to_kv_sending_requests_lock.release() + time.sleep(0.1) + else: + self.input_hash_to_kv_sending_requests_lock.release() + break -def buffer_kv_caches_send_and_listen_for_input_hash( +def send_kv_caches_and_hidden_states( model_executable: torch.nn.Module, model_input: "ModelInputForGPUWithSamplingMetadata", kv_caches: List[torch.Tensor], - hidden_or_intermediate_states: torch.Tensor, + hidden_or_intermediate_states: Union[torch.Tensor, IntermediateTensors], ) -> None: input_tokens_tuple = tuple(model_input.input_tokens.tolist()) seq_lens = model_input.attn_metadata.seq_lens slot_mapping = model_input.attn_metadata.slot_mapping.flatten() - logger.debug("My seq len is %s, rank is %s", str(seq_lens), - torch.distributed.get_rank()) + # Assumption: current batch is all-prefill requests + assert torch.allclose(model_input.attn_metadata.query_start_loc, + model_input.attn_metadata.seq_start_loc) + assert torch.all(model_input.attn_metadata.context_lens_tensor == 0) ps.get_disagg_group().input_hash_to_kv_sending_requests_lock.acquire() @@ -299,17 +369,22 @@ def buffer_kv_caches_send_and_listen_for_input_hash( input_hash, hidden_or_intermediate_states[start_pos:end_pos], is_hidden=True) - ps.get_disagg_group().kv_cache_send_finish(input_hash) + ps.get_disagg_group().kv_cache_send_ready(input_hash) ps.get_disagg_group().input_hash_to_kv_sending_requests_lock.release() + ps.get_disagg_group().block_if_buffer_full() + logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) -def send_input_hash_and_do_kv_caches_recv( - model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: List[torch.Tensor]) -> torch.Tensor: +def recv_kv_caches_and_hidden_states( + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor] +) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool]: + + bypass_model_exec = True # This is disagg decode instance, during prefill state # Need to receive KV from the prefill instance @@ -317,8 +392,10 @@ def send_input_hash_and_do_kv_caches_recv( seq_lens = model_input.attn_metadata.seq_lens slot_mapping = model_input.attn_metadata.slot_mapping.flatten() - logger.debug("My seq len is %s, rank is %s", str(seq_lens), - torch.distributed.get_rank()) + # Assumption: current batch is all-prefill requests + assert torch.allclose(model_input.attn_metadata.query_start_loc, + model_input.attn_metadata.seq_start_loc) + assert torch.all(model_input.attn_metadata.context_lens_tensor == 0) hidden_or_intermediate_states_for_one_req = [] @@ -332,7 +409,13 @@ def send_input_hash_and_do_kv_caches_recv( num_tokens = slen # notify the prefill instance to start sending KVs associated with input_hash - ps.get_disagg_group().kv_cache_recv_start(input_hash) + contain = ps.get_disagg_group().kv_cache_recv_start(input_hash) + + # fail to find input_hash in prefill instance + # this can occur but idk why... + if contain == 0: + bypass_model_exec = False + continue # receive KV cache from disaggregated prefill instance for i in range(model_executable.model.start_layer, @@ -376,16 +459,17 @@ def send_input_hash_and_do_kv_caches_recv( hidden_or_intermediate_states_for_one_req, dim=0) else: # concat the IntermediateTensors - keys = list(hidden_or_intermediate_states_for_one_req[0].tensors.keys()) + keys = list( + hidden_or_intermediate_states_for_one_req[0].tensors.keys()) result_its = {} - + for key in keys: result_its[key] = [] for its in hidden_or_intermediate_states_for_one_req: result_its[key].append(its[key]) result_its[key] = torch.cat(result_its[key], dim=0) - + hidden_or_intermediate_states = IntermediateTensors(result_its) - logger.error("[rank%d]: KV recv DONE.", torch.distributed.get_rank()) - return hidden_or_intermediate_states + logger.debug("[rank%d]: KV recv DONE.", torch.distributed.get_rank()) + return hidden_or_intermediate_states, bypass_model_exec diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index cc18f4b3cd621..4d8105bde2c4a 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1371,16 +1371,26 @@ def execute_model( # check if the current run is prefill is_prefill_run = prefill_meta is not None - # check if we can skip prefilling - # We can only skip during prefill phase in disaggregated decode instance - if any([ - not is_prefill_run, - not dist_kv.IS_KV_DECODE_INSTANCE, - is_profile_run]): - - # model forwarding - # during forwarding the KV cache will be sent in prefill instance - # see vllm/attention/backends/flash_attn.py for sending impl + # for disaggregated prefilling: allow bypassing model execution + bypass_model_exec = False + + # Recv kv cache for disaggregated prefill + # Skip model execution if all required KV cache are received + if all([ + is_prefill_run, + dist_kv.IS_KV_DECODE_INSTANCE, + not is_profile_run]): + + hidden_or_intermediate_states, bypass = \ + dist_kv.recv_kv_caches_and_hidden_states( + model_executable, + model_input, + kv_caches, + ) + if bypass: + bypass_model_exec = True + + if not bypass_model_exec: hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, @@ -1388,32 +1398,21 @@ def execute_model( attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, **MultiModalInputs.as_kwargs(multi_modal_kwargs, - device=self.device), + device=self.device), **seqlen_agnostic_kwargs) + + # Send KV cache for disaggregated prefill + if all([ + is_prefill_run, + dist_kv.IS_KV_PREFILL_INSTANCE, + not is_profile_run]): - - if all([ - is_prefill_run, - dist_kv.IS_KV_PREFILL_INSTANCE, - not is_profile_run]): - - # transfer KV cache and hidden state - dist_kv.buffer_kv_caches_send_and_listen_for_input_hash( - model_executable, - model_input, - kv_caches, - hidden_or_intermediate_states, - ) - - else: - - # skip prefill, receive KV cache and hidden state - hidden_or_intermediate_states = \ - dist_kv.send_input_hash_and_do_kv_caches_recv( - model_executable, - model_input, - kv_caches, - ) + dist_kv.send_kv_caches_and_hidden_states( + model_executable, + model_input, + kv_caches, + hidden_or_intermediate_states, + ) # Compute the logits in the last pipeline stage. From 96d38b49b1f3811e06728491bb2d33ac71298c12 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 11 Aug 2024 05:29:31 +0000 Subject: [PATCH 152/303] add visualization script --- .../visualize_benchmark_results.py | 90 +++++++------------ 1 file changed, 32 insertions(+), 58 deletions(-) diff --git a/benchmarks/disagg_benchmarks/visualize_benchmark_results.py b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py index 8686fb2abf9b9..192f26a1e3cd2 100644 --- a/benchmarks/disagg_benchmarks/visualize_benchmark_results.py +++ b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py @@ -2,72 +2,46 @@ import matplotlib.pyplot as plt import yaml import pandas as pd -from tabulate import tabulate +import json -def stringify(x): - return [str(i) for i in x] - if __name__ == "__main__": - - with open("results/chunk_vs_disagg.yaml", "r") as f: - data = yaml.load(f, Loader=yaml.FullLoader) - df = pd.DataFrame.from_dict(data) - - print_df = df.copy() - print_df.drop(columns=[ - "ttft_ratio", - "itl_ratio", - "prefill_decode_ratio", - ], inplace=True) - print_df.to_csv('results/chunk_vs_disagg.csv', index=False) - df["chunk_e2e"] = df["chunk_ttft"] + df["chunk_itl"] * df["output_len"] - df["disagg_e2e"] = df["disagg_ttft"] + df["disagg_itl"] * df["output_len"] - df["e2e_ratio"] = df["chunk_e2e"] / df["disagg_e2e"] + data = [] + for name in ['disagg_prefill', 'chunked_prefill']: + for qps in [2,4,6,8]: + with open(f"results/{name}-qps-{qps}.json", "r") as f: + x = json.load(f) + x['name'] = name + x['qps'] = qps + data.append(x) + + df = pd.DataFrame.from_dict(data) + dis_df = df[df['name'] == 'disagg_prefill'] + chu_df = df[df['name'] == 'chunked_prefill'] + plt.style.use('bmh') plt.rcParams['font.size'] = 20 - - - # qps vs performance - qps_df = df[df["output_len"] == 150].copy() - qps_df.drop(columns=[ - "chunk_itl", - "chunk_ttft", - "disagg_itl", - "disagg_ttft", - "output_len", - "prefill_decode_ratio", - ], inplace=True) - fig, ax = plt.subplots(figsize=(10, 7)) - qps_df.plot( - ax=ax, - kind="bar", - x="qps", - y=["ttft_ratio", "itl_ratio", "e2e_ratio"], - ylabel="$T_{chunked}~/~T_{disagg}$", - rot=0, - ) - ax.hlines(1, -1, 5, color='black') - fig.savefig('results/qps.png') - plt.close(fig) - # prefill decode ratio vs performance - tokens_df = df[df["output_len"] != 12] - fig, ax = plt.subplots(figsize=(10, 7)) - tokens_df.plot( - ax=ax, - kind="bar", - x="output_len", - xlabel="# of output tokens", - y=["ttft_ratio", "itl_ratio", "e2e_ratio", "prefill_decode_ratio"], - ylabel="$T_{chunked}~/~T_{disagg}$", - rot=0, - ) - ax.hlines(1, -1, 5, color='black') - fig.savefig('results/tokens.png') - plt.close(fig) + for key in ['mean_ttft_ms', + 'median_ttft_ms', + 'p99_ttft_ms', + 'mean_itl_ms', + 'median_itl_ms', + 'p99_itl_ms']: + + fig, ax = plt.subplots(figsize=(11, 7)) + plt.plot(dis_df['qps'], dis_df[key], label='disagg_prefill', marker='o', linewidth=4) + plt.plot(chu_df['qps'], chu_df[key], label='chunked_prefill', marker='o', linewidth=4) + ax.legend() + + ax.set_xlabel('QPS') + ax.set_ylabel(key) + ax.set_ylim(bottom=0) + fig.savefig(f'results/{key}.png') + plt.close(fig) + \ No newline at end of file From 3fc0c5cbca1d53c23a39bab8b26bfe56fe95b684 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 11 Aug 2024 05:30:04 +0000 Subject: [PATCH 153/303] fix bug: when KV transfer fails, do not return hidden state --- vllm/distributed/distributed_kv.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/distributed_kv.py b/vllm/distributed/distributed_kv.py index fd28a82c294a5..9005a325d6bff 100644 --- a/vllm/distributed/distributed_kv.py +++ b/vllm/distributed/distributed_kv.py @@ -314,7 +314,7 @@ def block_if_buffer_full(self): # TODO: allow using other policies to handle buffer full while True: self.input_hash_to_kv_sending_requests_lock.acquire() - if len(self.input_hash_to_kv_sending_requests.keys()) > 55: + if len(self.input_hash_to_kv_sending_requests.keys()) > 40: self.input_hash_to_kv_sending_requests_lock.release() time.sleep(0.1) else: @@ -452,6 +452,11 @@ def recv_kv_caches_and_hidden_states( [num_tokens, model_executable.config.hidden_size]), kv_cache[0].dtype, is_hidden=True)) + + if not bypass_model_exec: + # Some of the KV cache is not retrieved + # so we need to recompute the hidden state + return [], bypass_model_exec # concatenate hidden states from different requests if isinstance(hidden_or_intermediate_states_for_one_req[0], torch.Tensor): From f9aadd87de1bb284d42692b8968b76dfaa775606 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 26 Aug 2024 10:55:10 -0700 Subject: [PATCH 154/303] add new abstractions --- vllm/distributed/kv_transfer/__init__.py | 0 .../kv_transfer/kv_database/__init__.py | 0 .../kv_transfer/kv_database/base.py | 16 + .../kv_transfer/kv_pipe/__init__.py | 0 vllm/distributed/kv_transfer/kv_pipe/base.py | 13 + .../kv_pipe/torch_distributed_pipe.py | 88 +++++ .../kv_transfer/kv_serde/__init__.py | 0 vllm/distributed/kv_transfer/kv_serde/base.py | 13 + vllm/distributed/kv_transfer/vllm_adapter.py | 341 ++++++++++++++++++ 9 files changed, 471 insertions(+) create mode 100644 vllm/distributed/kv_transfer/__init__.py create mode 100644 vllm/distributed/kv_transfer/kv_database/__init__.py create mode 100644 vllm/distributed/kv_transfer/kv_database/base.py create mode 100644 vllm/distributed/kv_transfer/kv_pipe/__init__.py create mode 100644 vllm/distributed/kv_transfer/kv_pipe/base.py create mode 100644 vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py create mode 100644 vllm/distributed/kv_transfer/kv_serde/__init__.py create mode 100644 vllm/distributed/kv_transfer/kv_serde/base.py create mode 100644 vllm/distributed/kv_transfer/vllm_adapter.py diff --git a/vllm/distributed/kv_transfer/__init__.py b/vllm/distributed/kv_transfer/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/distributed/kv_transfer/kv_database/__init__.py b/vllm/distributed/kv_transfer/kv_database/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/distributed/kv_transfer/kv_database/base.py b/vllm/distributed/kv_transfer/kv_database/base.py new file mode 100644 index 0000000000000..ae17650754bf1 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_database/base.py @@ -0,0 +1,16 @@ + +from abc import ABC, abstractmethod +from typing import Optional +import torch + + +class KV_Database(ABC): + + @abstractmethod + def insert(self, input_tokens, kv, roi): + raise NotImplementedError + + @abstractmethod + def drop_select(self, input_tokens, roi) -> Optional[torch.Tensor]: + raise NotImplementedError + \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/kv_pipe/__init__.py b/vllm/distributed/kv_transfer/kv_pipe/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/distributed/kv_transfer/kv_pipe/base.py b/vllm/distributed/kv_transfer/kv_pipe/base.py new file mode 100644 index 0000000000000..625656adc2664 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_pipe/base.py @@ -0,0 +1,13 @@ + +from abc import ABC, abstractmethod + + +class KVPipeBase(ABC): + + @abstractmethod + def send_tensor(self, tensor): + raise NotImplementedError + + @abstractmethod + def recv_tensor(self): + raise NotImplementedError \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py new file mode 100644 index 0000000000000..97fab48171983 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py @@ -0,0 +1,88 @@ + +from vllm.distributed.group_coordinator import GroupCoordinator +from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase +from torch.distributed import Backend, ProcessGroup +import torch +from typing import List, Union, Optional + +class TorchDistributedPipe(KVPipeBase, GroupCoordinator): + +class DistributedKVCoordinator(GroupCoordinator): + """ + A class designated for distributed KV transfer + + Target use cases: + 1. Disaggregated prefill + 2. Remote KV cache storage + + """ + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + # DO NOT use pynccl here + # Pynccl send is non-blocking + # and it's possible that the memory is freed before the data being sent + # which may happen at high qps + use_pynccl: bool = False, + use_custom_allreduce: bool = False, + use_tpu_communicator: bool = True, + use_message_queue_broadcaster: bool = False, + blocking_send_recv: bool = False, + ): + + super().__init__( + group_ranks, + local_rank, + torch_distributed_backend, + use_pynccl, + use_custom_allreduce, + use_tpu_communicator, + use_message_queue_broadcaster, + ) + + # if turned on, will use CPU-based communication to perform a series of sanity check. + # but it adds ~5ms delay, so please turn it off in performance-demanding usecases (e.g. disaggregated prefill) + self.blocking_send_recv = blocking_send_recv + self.target_rank_for_send = self.ranks[(self.rank_in_group + 1) % + self.world_size] + self.target_rank_for_recv = self.ranks[(self.rank_in_group - 1) % + self.world_size] + torch.set_default_device(self.device) + + def send_tensor(self, + tensor: torch.Tensor) -> None: + """ + Sends a tensor to the destination rank in a non-blocking way. + Flow: send tensor dim -- send tensor shape -- send tensor data + """ + + dim_tensor = torch.tensor([len(tensor.shape)], dtype=torch.int).to(self.device, non_blocking=True) + shape_tensor = torch.tensor(tensor.shape, dtype=torch.int).to(self.device, non_blocking=True) + + torch.distributed.isend(dim_tensor, self.target_rank_for_send, self.device_group) + torch.distributed.isend(shape_tensor, self.target_rank_for_send, self.device_group) + torch.distributed.isend(tensor, self.target_rank_for_send, self.device_group) + + def recv_tensor(self) -> torch.Tensor: + """Receives a tensor from the src rank. Blocking.""" + + # FIXME(Kuntai): this incurs frequent data moving between CPU and GPU + # can be optimized by pre-allocating tensors on GPU. + dim_tensor = torch.tensor([0], dtype=torch.int).to(self.device) + torch.distributed.irecv(dim_tensor, self.target_rank_for_recv, self.device_group) + dim = dim_tensor.item() + shape_tensor = torch.zeros(dim, dtype=torch.int).to(self.device) + torch.distributed.irecv(shape_tensor, self.target_rank_for_recv, self.device_group) + return_tensor = torch.zeros(shape_tensor, dtype=torch.float32).to(self.device) + torch.distributed.irecv(return_tensor, self.target_rank_for_recv, self.device_group) + + result = self.recv_tensor_dict(src) + tensor = result["tensor"] + assert torch.allclose(result["mean"], tensor.float().mean()) + assert result["shape"] == tensor.shape + assert result[ + "shape"] == size, f"The shape sent by sender is {result['shape']} but trying to receive {size}" + return tensor diff --git a/vllm/distributed/kv_transfer/kv_serde/__init__.py b/vllm/distributed/kv_transfer/kv_serde/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/distributed/kv_transfer/kv_serde/base.py b/vllm/distributed/kv_transfer/kv_serde/base.py new file mode 100644 index 0000000000000..64168553ff15c --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_serde/base.py @@ -0,0 +1,13 @@ + +import torch +from abc import ABC, abstractmethod + +class KV_serde(ABC): + + @abstractmethod + def serialize(self, tensor: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + @abstractmethod + def deserialize(self, data: torch.Tensor) -> torch.Tensor: + raise NotImplementedError \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py new file mode 100644 index 0000000000000..683185df1ab5d --- /dev/null +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -0,0 +1,341 @@ +"""vLLM distributed KV cache transfer API. +These APIs are used in `vllm/worker/model_runner.py`. + +Currently supporting TP and PP. + +Workflow: +- In prefill instance, KV cache sender *buffers* the KV cache send requests +- In decode instance + - KV cache receiver sends the hash of input tokens to sender + - KV cache sender executes send request + - KV cache receiver receives the KV cache +""" +from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING +from collections import defaultdict, deque +from concurrent.futures import ThreadPoolExecutor +from threading import Lock +from copy import deepcopy +import time +import threading + +import torch +from torch.distributed import Backend, ProcessGroup + +import vllm.envs as envs +from vllm.distributed.group_coordinator import GroupCoordinator +from vllm.logger import init_logger +import vllm.distributed.parallel_state as ps +from vllm import _custom_ops as ops +from vllm.sequence import IntermediateTensors + +assert envs.VLLM_DISAGG_PREFILL_ROLE in [None, "prefill", "decode"], \ + "VLLM_DISAGG_PREFILL_ROLE can only be prefill or decode." + +IS_DISTRIBUTED_KV_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE is not None) +IS_KV_PREFILL_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE == "prefill") +IS_KV_DECODE_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE == "decode") + +# add a tag when sending/recving input hash +DISTRIBUTED_KV_GLOO_TAG = 24857323 + +logger = init_logger(__name__) + +import logging + + +class RankFilter(logging.Filter): + + def filter(self, record): + # Only log if rank is 4 + rank = 1 + try: + rank = torch.distributed.get_rank() + except Exception: + pass + return rank % 4 == 0 + + +for handler in logger.handlers: + handler.addFilter(RankFilter()) + + +class DistributedKVCoordinator(GroupCoordinator): + """ + A class designated for distributed KV transfer + + Target use cases: + 1. Disaggregated prefill + 2. Remote KV cache storage + + """ + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + # DO NOT use pynccl here + # Pynccl send is non-blocking + # and it's possible that the memory is freed before the data being sent + # which may happen at high qps + use_pynccl: bool = False, + use_custom_allreduce: bool = False, + use_tpu_communicator: bool = True, + use_message_queue_broadcaster: bool = False, + use_cpu_comm_for_sanity_check: bool = False, + ): + + super().__init__( + group_ranks, + local_rank, + torch_distributed_backend, + use_pynccl, + use_custom_allreduce, + use_tpu_communicator, + use_message_queue_broadcaster, + ) + + # if turned on, will use CPU-based communication to perform a series of sanity check. + # but it adds ~5ms delay, so please turn it off in performance-demanding usecases (e.g. disaggregated prefill) + self.use_cpu_comm_for_sanity_check = use_cpu_comm_for_sanity_check + + # use a threadpool to buffer send request in disaggregated prefill + self.input_hash_to_kv_sending_requests = defaultdict(deque) + self.kv_sending_thread = None + self.input_hash_to_kv_sending_requests_lock = Lock() + self.target_rank_for_send = self.ranks[(self.rank_in_group + 1) % + self.world_size] + self.target_rank_for_recv = self.ranks[(self.rank_in_group - 1) % + self.world_size] + + torch.set_default_device(self.device) + + def debug_send(self, + tensor: torch.Tensor, + dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """Will send several metadata. Useful for debugging.""" + """NOTE: `dst` is the local rank of the destination rank.""" + + self.send_tensor_dict( + { + "tensor": tensor, + "mean": tensor.float().mean(), + "shape": tensor.shape + }, dst) + + def debug_recv(self, + size: torch.Size, + dtype: torch.dtype, + src: Optional[int] = None) -> torch.Tensor: + """Receives a tensor from the src rank.""" + """NOTE: `src` is the local rank of the destination rank.""" + + result = self.recv_tensor_dict(src) + tensor = result["tensor"] + assert torch.allclose(result["mean"], tensor.float().mean()) + assert result["shape"] == tensor.shape + assert result[ + "shape"] == size, f"The shape sent by sender is {result['shape']} but trying to receive {size}" + return tensor + + def kv_cache_send(self, + input_hash: int, + tensor: Union[torch.Tensor, IntermediateTensors], + is_hidden: bool = False, + dst: Optional[int] = None) -> None: + """Push the KV cache send request into the send buffer""" + """NOTE: `dst` is the local rank of the destination rank.""" + + if self.use_cpu_comm_for_sanity_check: + send_func = self.debug_send + else: + send_func = self.send + + if is_hidden and not ps.get_pp_group().is_last_rank: + + assert isinstance(tensor, IntermediateTensors) + + output = deepcopy(tensor.tensors) + for key in output: + output[key] = output[key].contiguous() + + self.input_hash_to_kv_sending_requests[input_hash].append( + [self.send_tensor_dict, output, dst]) + + else: + + assert isinstance(tensor, torch.Tensor) + + self.input_hash_to_kv_sending_requests[input_hash].append([ + send_func, + # use clone to make sure the tensor is contiguous + tensor.clone(), + dst + ]) + + def kv_cache_recv( + self, + size: torch.Size, + dtype: torch.dtype, + is_hidden: bool = False, + src: Optional[int] = None + ) -> Union[torch.Tensor, IntermediateTensors]: + """Receives a tensor from the src rank (blocking).""" + """This API should be used together with `push`""" + """NOTE: `src` is the local rank of the destination rank.""" + + if self.use_cpu_comm_for_sanity_check: + recv_func = self.debug_recv + else: + recv_func = self.recv + + if is_hidden and not ps.get_pp_group().is_last_rank: + tensor = IntermediateTensors(self.recv_tensor_dict(src)) + else: + tensor = recv_func(size, dtype, src) + + return tensor + + def send_input_hash(self, input_hash: int) -> int: + + logger.debug('[rank%d]: Sending input hash %d to rank %d', + torch.distributed.get_rank(), input_hash, + self.target_rank_for_send) + + # KV cache send go through CPU, and the original `send` only use GPU. + # So create a new group for sending input hash. + input_hash_tensor = torch.tensor([input_hash], device="cpu").long() + torch.distributed.send(input_hash_tensor, + self.target_rank_for_send, + self.cpu_group, + tag=DISTRIBUTED_KV_GLOO_TAG) + return_tensor = torch.tensor([0], device="cpu").long() + torch.distributed.recv(return_tensor, + self.target_rank_for_recv, + self.cpu_group, + tag=DISTRIBUTED_KV_GLOO_TAG) + return return_tensor.item() + + def recv_input_hash(self) -> Optional[int]: + ''' + Receive an input hash, and check if it is already cached + ''' + input_hash_tensor = torch.tensor([0], device="cpu").long() + torch.distributed.recv(input_hash_tensor, + self.target_rank_for_recv, + self.cpu_group, + tag=DISTRIBUTED_KV_GLOO_TAG) + input_hash = input_hash_tensor.item() + # a new input hash comes in, see if it is already cached + self.input_hash_to_kv_sending_requests_lock.acquire() + logger.debug('Successfully received input hash %d', input_hash) + if input_hash not in self.input_hash_to_kv_sending_requests: + logger.warning( + f"The KV cache of {input_hash} does not exist. "\ + f"Existing input hash: {list(self.input_hash_to_kv_sending_requests.keys())}") + + # 0 for fail + x = torch.tensor([0], device="cpu").long() + torch.distributed.send(x, + self.target_rank_for_send, + self.cpu_group, + tag=DISTRIBUTED_KV_GLOO_TAG) + return None + else: + logger.debug('Input hash %d exists, start sending', input_hash) + + # 1 for success + x = torch.tensor([1], device="cpu").long() + torch.distributed.send(x, + self.target_rank_for_send, + self.cpu_group, + tag=DISTRIBUTED_KV_GLOO_TAG) + return input_hash + + def kv_cache_send_loop(self): + + while True: + logger.debug( + '[rank%d]: Waiting for input hash from rank %d, my keys are %s', + torch.distributed.get_rank(), + self.target_rank_for_recv, + list(self.input_hash_to_kv_sending_requests.keys()), + ) + # wait for a new input hash + # this function will acquire the lock + input_hash = self.recv_input_hash() + if input_hash is None: + self.input_hash_to_kv_sending_requests_lock.release() + continue + + # execute corresponding kv cache sending jobs in request queue + while True: + request = self.input_hash_to_kv_sending_requests[ + input_hash].popleft() + # An empty request: the KV cahe of one request are all sent + if request == []: + break + + request[0](*request[1:]) + + if len(self.input_hash_to_kv_sending_requests[input_hash]) == 0: + logger.debug('Finish input hash %d, free GPU memory...', + input_hash) + del self.input_hash_to_kv_sending_requests[input_hash] + else: + logger.debug( + 'The buffer for input hash %d is not empty, meaning that '\ + 'there are two jobs with identical input.', + input_hash) + + self.input_hash_to_kv_sending_requests_lock.release() + + + def kv_cache_send_ready(self, input_hash: int): + + if self.kv_sending_thread is None: + self.kv_sending_thread = threading.Thread( + target=self.kv_cache_send_loop) + self.kv_sending_thread.start() + + # append an empty list to separate requests + # as there might be identical requests, that has the same input hash + self.input_hash_to_kv_sending_requests[input_hash].append([]) + logger.debug(f'Buffered input hash {input_hash}') + + def kv_cache_recv_start(self, input_hash: int): + # notify the kv cache sender with the input hash id + return self.send_input_hash(input_hash) + + def block_if_buffer_full(self): + + # block vLLM if the KV cache sending buffer is full + # TODO: allow using other policies to handle buffer full + while True: + self.input_hash_to_kv_sending_requests_lock.acquire() + if len(self.input_hash_to_kv_sending_requests.keys()) > 40: + self.input_hash_to_kv_sending_requests_lock.release() + time.sleep(0.1) + else: + self.input_hash_to_kv_sending_requests_lock.release() + break + + +def buffer_kv( + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, IntermediateTensors], +) -> None: + + pass + + +def recv_kv( + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor] +) -> Tuple[List[torch.Tensor], Union[torch.Tensor, IntermediateTensors]]: + + pass \ No newline at end of file From db66a1e383bff0c2c0d7656289f6d7ac46b1f6ea Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 27 Aug 2024 22:15:53 -0700 Subject: [PATCH 155/303] major revision: add 3-layer abstractions. Transport, lookup buffer, and adapter. --- .../__init__.py | 0 .../{kv_database => kv_lookup_buffer}/base.py | 6 +- .../simple_kv_lookup_buffer.py | 122 ++++++ .../kv_pipe/torch_distributed_pipe.py | 91 ++-- .../kv_transfer/kv_serde/__init__.py | 0 vllm/distributed/kv_transfer/kv_serde/base.py | 13 - vllm/distributed/kv_transfer/vllm_adapter.py | 396 ++++++------------ 7 files changed, 311 insertions(+), 317 deletions(-) rename vllm/distributed/kv_transfer/{kv_database => kv_lookup_buffer}/__init__.py (100%) rename vllm/distributed/kv_transfer/{kv_database => kv_lookup_buffer}/base.py (65%) create mode 100644 vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py delete mode 100644 vllm/distributed/kv_transfer/kv_serde/__init__.py delete mode 100644 vllm/distributed/kv_transfer/kv_serde/base.py diff --git a/vllm/distributed/kv_transfer/kv_database/__init__.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/__init__.py similarity index 100% rename from vllm/distributed/kv_transfer/kv_database/__init__.py rename to vllm/distributed/kv_transfer/kv_lookup_buffer/__init__.py diff --git a/vllm/distributed/kv_transfer/kv_database/base.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py similarity index 65% rename from vllm/distributed/kv_transfer/kv_database/base.py rename to vllm/distributed/kv_transfer/kv_lookup_buffer/base.py index ae17650754bf1..5ac8fbb244446 100644 --- a/vllm/distributed/kv_transfer/kv_database/base.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py @@ -4,10 +4,12 @@ import torch -class KV_Database(ABC): +class KVLookupBufferBase(ABC): @abstractmethod - def insert(self, input_tokens, kv, roi): + def insert(self, + input_tokens: torch.Tensor, + kv: torch.Tensor, roi) -> None: raise NotImplementedError @abstractmethod diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py new file mode 100644 index 0000000000000..c43a41575aee0 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py @@ -0,0 +1,122 @@ + +from vllm.distributed.kv_transfer.kv_lookup_buffer.base import \ + KVLookupBufferBase +from typing import Dict, Tuple, List, Optional +import threading +import torch +from collections import deque + +class SimpleKVLookupBuffer(KVLookupBufferBase): + + def __init__(self, pipe): + + self.tokens_roi_kv_buffer = deque() + + self.buffer_size = 0 + self.buffer_lock = threading.Lock() + self.pipe = pipe + self.request_handling_thread = None + + + def _matches(self, tokens_roi_sender, tokens_roi_recver): + + # tokens_roi_sender: tokens and roi of the producer (in the buffer) + # tokens_roi_recver: tokens and roi of the consumer (query) + + tokens_sender = tokens_roi_sender[0] + tokens_recver = tokens_roi_recver[0] + roi_sender = tokens_roi_sender[1] + roi_recver = tokens_roi_recver[1] + + if tokens_recver is None: + # consumer sends an empty request + # semantics: DROP SELECT * LIMIT 1 + # so any of the data in the buffer can be drop-selected + return True + + + if tokens_sender == tokens_recver[:tokens_sender.shape[0]]: + # drastically simplified + # accept a match as long as + + return True + + + def _send_tensor_and_dec_size(self, tensor: Optional[torch.Tensor]) -> None: + + assert tensor is not None, "Use self.pipe.send(None) instead" + self.buffer_size -= tensor.element_size() * tensor.numel() + tensor = tensor.clone() + self.pipe.send_tensor(tensor) + + def _add_to_buffer(self, input_tokens, roi, kv): + + self.buffer_size += input_tokens.element_size() * input_tokens.numel() + self.buffer_size += roi.element_size() * roi.numel() + self.buffer_size += kv.element_size() * kv.numel() + self.tokens_roi_kv_buffer.append((input_tokens, roi, kv)) + + + + + + def drop_select_handler(self): + + while True: + input_tokens = self.pipe.recv_tensor() + roi = self.pipe.recv_tensor() + tokens_roi = [input_tokens, roi] + + matched_idx = None + + # perform input tokens and roi matching + with self.buffer_lock: + + for idx, tokens_roi_kv in enumerate(self.tokens_roi_kv_buffer): + if self._matches(tokens_roi_kv, tokens_roi): + matched_idx = idx + break + + if matched_idx is not None: + # need to clone the tensor + # in case the tensor is freed before sending finishes + matched_item = self.tokens_roi_kv_buffer[matched_idx] + self._send_tensor_and_dec_size(matched_item[0].clone()) + self._send_tensor_and_dec_size(matched_item[1].clone()) + self._send_tensor_and_dec_size(matched_item[2].clone()) + del self.tokens_roi_kv_buffer[matched_idx] + + else: + # no match, just send None + self.pipe.send_tensor(None) + self.pipe.send_tensor(None) + self.pipe.send_tensor(None) + + + def drop_select(self, input_tokens, roi): + + assert self.request_handling_thread is None, \ + "drop_select should be called by the receiver" + + self.pipe.send_tensor(input_tokens.clone()) + self.pipe.send_tensor(roi.clone()) + + input_tokens = self.pipe.recv_tensor() + roi = self.pipe.recv_tensor() + kv = self.pipe.recv_tensor() + + return [input_tokens, roi, kv] + + + def insert(self, input_tokens, roi, kv) -> None: + + # when calling the insert, the current process is a sender + # need to launch the request handler and start listening to request. + if self.request_handling_thread is None: + self.request_handling_thread = threading.Thread( + target=self.drop_select_handler) + self.request_handling_thread.start() + + with self.buffer_lock: + self._add_to_buffer(input_tokens, roi, kv) + \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py index 97fab48171983..4a16f5908679b 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py @@ -4,19 +4,19 @@ from torch.distributed import Backend, ProcessGroup import torch from typing import List, Union, Optional +import threading +from concurrent.futures import ThreadPoolExecutor +import time +import threading + + +class BrokenPipeException(Exception): + def __init__(self, message): + self.message = message + super().__init__(self.message) class TorchDistributedPipe(KVPipeBase, GroupCoordinator): -class DistributedKVCoordinator(GroupCoordinator): - """ - A class designated for distributed KV transfer - - Target use cases: - 1. Disaggregated prefill - 2. Remote KV cache storage - - """ - def __init__( self, group_ranks: List[List[int]], @@ -30,7 +30,6 @@ def __init__( use_custom_allreduce: bool = False, use_tpu_communicator: bool = True, use_message_queue_broadcaster: bool = False, - blocking_send_recv: bool = False, ): super().__init__( @@ -45,44 +44,64 @@ def __init__( # if turned on, will use CPU-based communication to perform a series of sanity check. # but it adds ~5ms delay, so please turn it off in performance-demanding usecases (e.g. disaggregated prefill) - self.blocking_send_recv = blocking_send_recv self.target_rank_for_send = self.ranks[(self.rank_in_group + 1) % self.world_size] self.target_rank_for_recv = self.ranks[(self.rank_in_group - 1) % self.world_size] torch.set_default_device(self.device) + self.kv_sending_thread = None + self.buffer_size = 0 + self.buffer_size_lock = threading.lock() + + self.nan_tensor = torch.tensor(['nan']) + self.broken = False + + + def send_tensor_wrapper(self, tensor: torch.Tensor) -> None: + """Wrapper for send_tensor_dict""" + tensor_size = tensor['tensor'].element_size() * tensor['tensor'].numel() + self.send_tensor_dict({'tensor': tensor}, self.target_rank_for_send) + + with self.buffer_size_lock: + self.buffer_size = self.buffer_size - tensor_size + + def block_if_full(self): + + while self.buffer_size > 1e9: + time.sleep(0.05) + def send_tensor(self, - tensor: torch.Tensor) -> None: + tensor: Optional[torch.Tensor]) -> None: """ Sends a tensor to the destination rank in a non-blocking way. Flow: send tensor dim -- send tensor shape -- send tensor data """ - dim_tensor = torch.tensor([len(tensor.shape)], dtype=torch.int).to(self.device, non_blocking=True) - shape_tensor = torch.tensor(tensor.shape, dtype=torch.int).to(self.device, non_blocking=True) + if self.kv_sending_thread is None: + self.kv_sending_thread = ThreadPoolExecutor(max_workers=1) + + if tensor is None: + tensor = self.nan_tensor + tensor_size = 0 + else: + tensor_size = tensor.element_size() * tensor.numel() + + self.block_if_full() + + with self.buffer_size_lock: + self.buffer_size = self.buffer_size + tensor_size + + self.kv_sending_thread.submit(self.send_tensor_wrapper, tensor) + - torch.distributed.isend(dim_tensor, self.target_rank_for_send, self.device_group) - torch.distributed.isend(shape_tensor, self.target_rank_for_send, self.device_group) - torch.distributed.isend(tensor, self.target_rank_for_send, self.device_group) - def recv_tensor(self) -> torch.Tensor: + def recv_tensor(self) -> Optional[torch.Tensor]: """Receives a tensor from the src rank. Blocking.""" - # FIXME(Kuntai): this incurs frequent data moving between CPU and GPU - # can be optimized by pre-allocating tensors on GPU. - dim_tensor = torch.tensor([0], dtype=torch.int).to(self.device) - torch.distributed.irecv(dim_tensor, self.target_rank_for_recv, self.device_group) - dim = dim_tensor.item() - shape_tensor = torch.zeros(dim, dtype=torch.int).to(self.device) - torch.distributed.irecv(shape_tensor, self.target_rank_for_recv, self.device_group) - return_tensor = torch.zeros(shape_tensor, dtype=torch.float32).to(self.device) - torch.distributed.irecv(return_tensor, self.target_rank_for_recv, self.device_group) - - result = self.recv_tensor_dict(src) - tensor = result["tensor"] - assert torch.allclose(result["mean"], tensor.float().mean()) - assert result["shape"] == tensor.shape - assert result[ - "shape"] == size, f"The shape sent by sender is {result['shape']} but trying to receive {size}" - return tensor + tensor = self.recv_tensor_dict(self.target_rank_for_recv)['tensor'] + if tensor.isnan().item(): + return None + else: + return tensor + \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/kv_serde/__init__.py b/vllm/distributed/kv_transfer/kv_serde/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/vllm/distributed/kv_transfer/kv_serde/base.py b/vllm/distributed/kv_transfer/kv_serde/base.py deleted file mode 100644 index 64168553ff15c..0000000000000 --- a/vllm/distributed/kv_transfer/kv_serde/base.py +++ /dev/null @@ -1,13 +0,0 @@ - -import torch -from abc import ABC, abstractmethod - -class KV_serde(ABC): - - @abstractmethod - def serialize(self, tensor: torch.Tensor) -> torch.Tensor: - raise NotImplementedError - - @abstractmethod - def deserialize(self, data: torch.Tensor) -> torch.Tensor: - raise NotImplementedError \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py index 683185df1ab5d..07358567e783f 100644 --- a/vllm/distributed/kv_transfer/vllm_adapter.py +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -27,14 +27,20 @@ import vllm.distributed.parallel_state as ps from vllm import _custom_ops as ops from vllm.sequence import IntermediateTensors +from vllm.distributed.kv_transfer.kv_pipe.torch_distributed_pipe import TorchDistributedPipe +from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_kv_lookup_buffer import SimpleKVLookupBuffer -assert envs.VLLM_DISAGG_PREFILL_ROLE in [None, "prefill", "decode"], \ +assert envs.VLLM_DISAGG_PREFILL_ROLE in [None, "prefill", "decode", "lmc"], \ "VLLM_DISAGG_PREFILL_ROLE can only be prefill or decode." -IS_DISTRIBUTED_KV_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE is not None) +IS_DISTRIBUTED_KV_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE in ["prefill", "decode"]) IS_KV_PREFILL_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE == "prefill") IS_KV_DECODE_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE == "decode") +'''Jiayi starts here''' +IS_LMC_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE == "lmc") +'''Jiayi ends here''' + # add a tag when sending/recving input hash DISTRIBUTED_KV_GLOO_TAG = 24857323 @@ -43,30 +49,13 @@ import logging -class RankFilter(logging.Filter): - - def filter(self, record): - # Only log if rank is 4 - rank = 1 - try: - rank = torch.distributed.get_rank() - except Exception: - pass - return rank % 4 == 0 - - -for handler in logger.handlers: - handler.addFilter(RankFilter()) - - -class DistributedKVCoordinator(GroupCoordinator): +class KV_transfer_agent: """ A class designated for distributed KV transfer Target use cases: 1. Disaggregated prefill 2. Remote KV cache storage - """ def __init__( @@ -74,18 +63,14 @@ def __init__( group_ranks: List[List[int]], local_rank: int, torch_distributed_backend: Union[str, Backend], - # DO NOT use pynccl here - # Pynccl send is non-blocking - # and it's possible that the memory is freed before the data being sent - # which may happen at high qps use_pynccl: bool = False, use_custom_allreduce: bool = False, use_tpu_communicator: bool = True, - use_message_queue_broadcaster: bool = False, - use_cpu_comm_for_sanity_check: bool = False, + use_message_queue_broadcaster: bool = False ): - - super().__init__( + + # init pipe + self.pipe = TorchDistributedPipe( group_ranks, local_rank, torch_distributed_backend, @@ -94,248 +79,127 @@ def __init__( use_tpu_communicator, use_message_queue_broadcaster, ) + # init lookup buffer + self.buffer = SimpleKVLookupBuffer(self.pipe) - # if turned on, will use CPU-based communication to perform a series of sanity check. - # but it adds ~5ms delay, so please turn it off in performance-demanding usecases (e.g. disaggregated prefill) - self.use_cpu_comm_for_sanity_check = use_cpu_comm_for_sanity_check - - # use a threadpool to buffer send request in disaggregated prefill - self.input_hash_to_kv_sending_requests = defaultdict(deque) - self.kv_sending_thread = None - self.input_hash_to_kv_sending_requests_lock = Lock() - self.target_rank_for_send = self.ranks[(self.rank_in_group + 1) % - self.world_size] - self.target_rank_for_recv = self.ranks[(self.rank_in_group - 1) % - self.world_size] - - torch.set_default_device(self.device) - - def debug_send(self, - tensor: torch.Tensor, - dst: Optional[int] = None) -> None: - """Sends a tensor to the destination rank in a non-blocking way""" - """Will send several metadata. Useful for debugging.""" - """NOTE: `dst` is the local rank of the destination rank.""" - - self.send_tensor_dict( - { - "tensor": tensor, - "mean": tensor.float().mean(), - "shape": tensor.shape - }, dst) - - def debug_recv(self, - size: torch.Size, - dtype: torch.dtype, - src: Optional[int] = None) -> torch.Tensor: - """Receives a tensor from the src rank.""" - """NOTE: `src` is the local rank of the destination rank.""" - - result = self.recv_tensor_dict(src) - tensor = result["tensor"] - assert torch.allclose(result["mean"], tensor.float().mean()) - assert result["shape"] == tensor.shape - assert result[ - "shape"] == size, f"The shape sent by sender is {result['shape']} but trying to receive {size}" - return tensor - - def kv_cache_send(self, - input_hash: int, - tensor: Union[torch.Tensor, IntermediateTensors], - is_hidden: bool = False, - dst: Optional[int] = None) -> None: - """Push the KV cache send request into the send buffer""" - """NOTE: `dst` is the local rank of the destination rank.""" - - if self.use_cpu_comm_for_sanity_check: - send_func = self.debug_send - else: - send_func = self.send - - if is_hidden and not ps.get_pp_group().is_last_rank: - - assert isinstance(tensor, IntermediateTensors) - - output = deepcopy(tensor.tensors) - for key in output: - output[key] = output[key].contiguous() - - self.input_hash_to_kv_sending_requests[input_hash].append( - [self.send_tensor_dict, output, dst]) - - else: - - assert isinstance(tensor, torch.Tensor) - - self.input_hash_to_kv_sending_requests[input_hash].append([ - send_func, - # use clone to make sure the tensor is contiguous - tensor.clone(), - dst - ]) - - def kv_cache_recv( - self, - size: torch.Size, - dtype: torch.dtype, - is_hidden: bool = False, - src: Optional[int] = None - ) -> Union[torch.Tensor, IntermediateTensors]: - """Receives a tensor from the src rank (blocking).""" - """This API should be used together with `push`""" - """NOTE: `src` is the local rank of the destination rank.""" - - if self.use_cpu_comm_for_sanity_check: - recv_func = self.debug_recv - else: - recv_func = self.recv - - if is_hidden and not ps.get_pp_group().is_last_rank: - tensor = IntermediateTensors(self.recv_tensor_dict(src)) - else: - tensor = recv_func(size, dtype, src) - - return tensor - - def send_input_hash(self, input_hash: int) -> int: - - logger.debug('[rank%d]: Sending input hash %d to rank %d', - torch.distributed.get_rank(), input_hash, - self.target_rank_for_send) - - # KV cache send go through CPU, and the original `send` only use GPU. - # So create a new group for sending input hash. - input_hash_tensor = torch.tensor([input_hash], device="cpu").long() - torch.distributed.send(input_hash_tensor, - self.target_rank_for_send, - self.cpu_group, - tag=DISTRIBUTED_KV_GLOO_TAG) - return_tensor = torch.tensor([0], device="cpu").long() - torch.distributed.recv(return_tensor, - self.target_rank_for_recv, - self.cpu_group, - tag=DISTRIBUTED_KV_GLOO_TAG) - return return_tensor.item() - - def recv_input_hash(self) -> Optional[int]: - ''' - Receive an input hash, and check if it is already cached - ''' - input_hash_tensor = torch.tensor([0], device="cpu").long() - torch.distributed.recv(input_hash_tensor, - self.target_rank_for_recv, - self.cpu_group, - tag=DISTRIBUTED_KV_GLOO_TAG) - input_hash = input_hash_tensor.item() - # a new input hash comes in, see if it is already cached - self.input_hash_to_kv_sending_requests_lock.acquire() - logger.debug('Successfully received input hash %d', input_hash) - if input_hash not in self.input_hash_to_kv_sending_requests: - logger.warning( - f"The KV cache of {input_hash} does not exist. "\ - f"Existing input hash: {list(self.input_hash_to_kv_sending_requests.keys())}") + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, IntermediateTensors], + ) -> None: + + #input_tokens_tuple = tuple(model_input.input_tokens.tolist()) + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() + + # query_lens contains new KV caches that are added to vLLM. + # so we will send them to decode instance + # FIXME(Kuntai): This assume that all requests are prefill. + for idx, slen in enumerate(seq_lens): + logger.debug(f"sending request {idx}") + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen - # 0 for fail - x = torch.tensor([0], device="cpu").long() - torch.distributed.send(x, - self.target_rank_for_send, - self.cpu_group, - tag=DISTRIBUTED_KV_GLOO_TAG) - return None - else: - logger.debug('Input hash %d exists, start sending', input_hash) + keys, values = [], [] - # 1 for success - x = torch.tensor([1], device="cpu").long() - torch.distributed.send(x, - self.target_rank_for_send, - self.cpu_group, - tag=DISTRIBUTED_KV_GLOO_TAG) - return input_hash - - def kv_cache_send_loop(self): - - while True: - logger.debug( - '[rank%d]: Waiting for input hash from rank %d, my keys are %s', - torch.distributed.get_rank(), - self.target_rank_for_recv, - list(self.input_hash_to_kv_sending_requests.keys()), + + for l in range(model_executable.model.start_layer, + model_executable.model.end_layer): + logger.debug(f"sending layer {l}") + kv_cache = kv_caches[l - model_executable.model.start_layer] + + _, _, num_heads, head_size = kv_cache[0].shape + + key_cache = kv_cache[0].reshape(-1, num_heads, head_size) + value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + + current_slot_mapping = slot_mapping_flat[start_pos:end_pos] + + keys.append(key_cache[current_slot_mapping].unsqeeze(0)) + values.append(value_cache[current_slot_mapping].unsqueeze(0)) + + keys = torch.cat(keys, dim=0) + values = torch.cat(values, dim=0) + self.buffer.insert( + input_tokens_tensor[start_pos:end_pos], + None, + keys, + values, + hidden_or_intermediate_states[start_pos:end_pos] ) - # wait for a new input hash - # this function will acquire the lock - input_hash = self.recv_input_hash() - if input_hash is None: - self.input_hash_to_kv_sending_requests_lock.release() - continue - - # execute corresponding kv cache sending jobs in request queue - while True: - request = self.input_hash_to_kv_sending_requests[ - input_hash].popleft() - # An empty request: the KV cahe of one request are all sent - if request == []: - break - - request[0](*request[1:]) - - if len(self.input_hash_to_kv_sending_requests[input_hash]) == 0: - logger.debug('Finish input hash %d, free GPU memory...', - input_hash) - del self.input_hash_to_kv_sending_requests[input_hash] - else: - logger.debug( - 'The buffer for input hash %d is not empty, meaning that '\ - 'there are two jobs with identical input.', - input_hash) - - self.input_hash_to_kv_sending_requests_lock.release() - - - def kv_cache_send_ready(self, input_hash: int): - - if self.kv_sending_thread is None: - self.kv_sending_thread = threading.Thread( - target=self.kv_cache_send_loop) - self.kv_sending_thread.start() - - # append an empty list to separate requests - # as there might be identical requests, that has the same input hash - self.input_hash_to_kv_sending_requests[input_hash].append([]) - logger.debug(f'Buffered input hash {input_hash}') - - def kv_cache_recv_start(self, input_hash: int): - # notify the kv cache sender with the input hash id - return self.send_input_hash(input_hash) + - def block_if_buffer_full(self): - - # block vLLM if the KV cache sending buffer is full - # TODO: allow using other policies to handle buffer full - while True: - self.input_hash_to_kv_sending_requests_lock.acquire() - if len(self.input_hash_to_kv_sending_requests.keys()) > 40: - self.input_hash_to_kv_sending_requests_lock.release() - time.sleep(0.1) - else: - self.input_hash_to_kv_sending_requests_lock.release() - break - - -def buffer_kv( - model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: List[torch.Tensor], - hidden_or_intermediate_states: Union[torch.Tensor, IntermediateTensors], -) -> None: - - pass + logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) -def recv_kv( - model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: List[torch.Tensor] -) -> Tuple[List[torch.Tensor], Union[torch.Tensor, IntermediateTensors]]: - - pass \ No newline at end of file + def recv_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor] + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool]: + + bypass_model_exec = True + + # This is disagg decode instance, during prefill state + # Need to receive KV from the prefill instance + input_tokens_tuple = tuple(model_input.input_tokens.tolist()) + seq_lens = model_input.attn_metadata.seq_lens + slot_mapping = model_input.attn_metadata.slot_mapping.flatten() + + hidden_or_intermediate_states_for_one_req = [] + + # enumerate different requests + # FIXME(Kuntai): This impl assumes that all requests are prefill. + for idx, slen in enumerate(seq_lens): + + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + current_input_tokens = input_tokens_tuple[start_pos:end_pos] + num_tokens = slen + + ret = self.buffer.drop_select(current_input_tokens, None) + if ret[0] is None: + # didn't find any match. + self.bypass_model_exec = False + continue + + _, _, keys, values, hidden = ret + + # receive KV cache from disaggregated prefill instance + for i in range(model_executable.model.start_layer, + model_executable.model.end_layer): + + # get kv cache + kv_cache = kv_caches[i - model_executable.model.start_layer] + # get corresponding layer + layer = model_executable.model.layers[i] + + key_cache, value_cache = kv_cache[0], kv_cache[1] + ops.reshape_and_cache_flash( + keys[i], + values[i], + key_cache, + value_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + layer.self_attn.attn._v_scale, + ) + + hidden_or_intermediate_states_for_one_req.append(hidden) + + if not bypass_model_exec: + # Some of the KV cache is not retrieved + # so we need to recompute the hidden state + return [], bypass_model_exec + + # concatenate hidden states from different requests + hidden_or_intermediate_states = torch.cat( + hidden_or_intermediate_states_for_one_req, dim=0) + + logger.debug("[rank%d]: KV recv DONE.", torch.distributed.get_rank()) + return hidden_or_intermediate_states, bypass_model_exec \ No newline at end of file From e04430c574f89590778f3d2aea55518c1e648a50 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 27 Aug 2024 22:20:06 -0700 Subject: [PATCH 156/303] add kv transfer test --- tests/random_send_recv.py | 0 tests/test_send_recv.sh | 0 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/random_send_recv.py create mode 100644 tests/test_send_recv.sh diff --git a/tests/random_send_recv.py b/tests/random_send_recv.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/test_send_recv.sh b/tests/test_send_recv.sh new file mode 100644 index 0000000000000..e69de29bb2d1d From 30f9bb670d8e4110de5ec906c3695b438fbe390d Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Wed, 28 Aug 2024 07:25:58 +0000 Subject: [PATCH 157/303] add test cases for pipe --- tests/kv_transfer/test_send_recv.py | 81 +++++++++++++++++++++++++++++ tests/kv_transfer/test_send_recv.sh | 3 ++ 2 files changed, 84 insertions(+) create mode 100644 tests/kv_transfer/test_send_recv.py create mode 100644 tests/kv_transfer/test_send_recv.sh diff --git a/tests/kv_transfer/test_send_recv.py b/tests/kv_transfer/test_send_recv.py new file mode 100644 index 0000000000000..c8be888ef6aa3 --- /dev/null +++ b/tests/kv_transfer/test_send_recv.py @@ -0,0 +1,81 @@ + +import vllm.distributed.kv_transfer.kv_pipe.torch_distributed_pipe as tdp +import torch +import os +import random +from tqdm import tqdm + +my_rank = int(os.environ['RANK']) + + +torch.distributed.init_process_group( + init_method="tcp://127.0.0.1:23456", + world_size=2, + rank=my_rank) + +print("initialized! My rank is %d" % my_rank) + + +pipe = tdp.TorchDistributedPipe([[0,1]], my_rank, "nccl") + +print('My device is ', pipe.device, ' default: ', torch.cuda.current_device()) +print(pipe.target_rank_for_send, pipe.target_rank_for_recv) + +# test run + +if my_rank == 0: + x = torch.tensor([1]).to(pipe.device) + pipe.send(x, 1) + + +else: + y = torch.tensor([0]).to(pipe.device) + y = pipe.recv(y.shape, y.dtype) + + assert y.item() == 1 + +# if my_rank == 0: +# x = torch.tensor([1]).to(pipe.device) +# torch.distributed.send(x, dst=1, group=pipe.device_group) +# else: +# x = torch.tensor([0]).to(pipe.device) +# torch.distributed.recv(x, src=0, group=pipe.device_group) +# assert x.item() == 1 + +print(my_rank, 'Test run successed! ') + +if my_rank == 0: + # send a tensor 1000 times + for i in range(3): + + mean = random.randint(10, 100) + std = random.randint(10, 100) + size = [random.randint(10, 100), random.randint(10, 100)] + x = torch.normal(mean, std, size=size).to(pipe.device) + + if i % 10 == 0: + pipe.send_tensor(None) + pipe.send_tensor(None) + pipe.send_tensor(None) + else: + pipe.send_tensor(x) + pipe.send_tensor(x.mean()) + pipe.send_tensor(x.std()) + +else: + # recv a tensor 1000 times + for i in tqdm(range(2)): + + x = pipe.recv_tensor() + mean = pipe.recv_tensor() + std = pipe.recv_tensor() + + if x is None: + assert mean is None, std is None + else: + assert x.mean() == mean + assert x.std() == std + + + + \ No newline at end of file diff --git a/tests/kv_transfer/test_send_recv.sh b/tests/kv_transfer/test_send_recv.sh new file mode 100644 index 0000000000000..2a478871bd0e7 --- /dev/null +++ b/tests/kv_transfer/test_send_recv.sh @@ -0,0 +1,3 @@ + +RANK=0 python3 test_send_recv.py & +RANK=1 python3 test_send_recv.py & From bbce62ea8cf0f207461f0776839a838623465325 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Wed, 28 Aug 2024 07:26:20 +0000 Subject: [PATCH 158/303] bug fix --- .../kv_transfer/kv_pipe/torch_distributed_pipe.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py index 4a16f5908679b..dd540ead9440e 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py @@ -10,6 +10,11 @@ import threading +# if the tensor is only one-element and only contains this number +# this means that the sended object is None. +NONE_INT = -150886311 + + class BrokenPipeException(Exception): def __init__(self, message): self.message = message @@ -52,13 +57,14 @@ def __init__( self.kv_sending_thread = None self.buffer_size = 0 - self.buffer_size_lock = threading.lock() + self.buffer_size_lock = threading.Lock() - self.nan_tensor = torch.tensor(['nan']) + self.none_tensor = torch.tensor([NONE_INT]).to(self.device) self.broken = False def send_tensor_wrapper(self, tensor: torch.Tensor) -> None: + print('Sending ', tensor) """Wrapper for send_tensor_dict""" tensor_size = tensor['tensor'].element_size() * tensor['tensor'].numel() self.send_tensor_dict({'tensor': tensor}, self.target_rank_for_send) @@ -82,7 +88,7 @@ def send_tensor(self, self.kv_sending_thread = ThreadPoolExecutor(max_workers=1) if tensor is None: - tensor = self.nan_tensor + tensor = self.none_tensor tensor_size = 0 else: tensor_size = tensor.element_size() * tensor.numel() @@ -100,7 +106,7 @@ def recv_tensor(self) -> Optional[torch.Tensor]: """Receives a tensor from the src rank. Blocking.""" tensor = self.recv_tensor_dict(self.target_rank_for_recv)['tensor'] - if tensor.isnan().item(): + if tensor.numel() == 1 and tensor.item() == 150886311: return None else: return tensor From 927800d9aa0b61f17edbc333b25bafd197ccf4d5 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Thu, 29 Aug 2024 07:02:04 +0000 Subject: [PATCH 159/303] finalize send-recv test --- tests/kv_transfer/test_send_recv.py | 154 +++++++++++------- .../kv_pipe/torch_distributed_pipe.py | 5 +- 2 files changed, 97 insertions(+), 62 deletions(-) diff --git a/tests/kv_transfer/test_send_recv.py b/tests/kv_transfer/test_send_recv.py index c8be888ef6aa3..5d59c38c4f276 100644 --- a/tests/kv_transfer/test_send_recv.py +++ b/tests/kv_transfer/test_send_recv.py @@ -4,78 +4,114 @@ import os import random from tqdm import tqdm +import time -my_rank = int(os.environ['RANK']) +def test_run(my_rank, pipe): + # test run + if my_rank == 0: + x = torch.tensor([1]).to(pipe.device) + pipe.send_tensor(x) + else: + y = pipe.recv_tensor() + assert y.item() == 1 -torch.distributed.init_process_group( - init_method="tcp://127.0.0.1:23456", - world_size=2, - rank=my_rank) -print("initialized! My rank is %d" % my_rank) - - -pipe = tdp.TorchDistributedPipe([[0,1]], my_rank, "nccl") - -print('My device is ', pipe.device, ' default: ', torch.cuda.current_device()) -print(pipe.target_rank_for_send, pipe.target_rank_for_recv) - -# test run - -if my_rank == 0: - x = torch.tensor([1]).to(pipe.device) - pipe.send(x, 1) +def stress_test(my_rank, pipe): + torch.distributed.barrier() -else: - y = torch.tensor([0]).to(pipe.device) - y = pipe.recv(y.shape, y.dtype) + tensors = [] - assert y.item() == 1 - -# if my_rank == 0: -# x = torch.tensor([1]).to(pipe.device) -# torch.distributed.send(x, dst=1, group=pipe.device_group) -# else: -# x = torch.tensor([0]).to(pipe.device) -# torch.distributed.recv(x, src=0, group=pipe.device_group) -# assert x.item() == 1 + if my_rank == 0: + for i in tqdm(range(2000)): + mean = random.randint(1, 10) + std = random.randint(1, 10) + size = [random.randint(900, 1000), random.randint(900, 1000)] + x = torch.normal(mean * 1.0, std * 1.0, size=size).to(pipe.device) + + # 5% probability of sending a None + if random.randint(1, 100) < 5: + tensors.append(None) + tensors.append(None) + tensors.append(None) + else: + tensors.append(x) + tensors.append(x.mean()) + tensors.append(x.std()) + + torch.distributed.barrier() + + for i in tqdm(range(2000)): + if my_rank == 0: + pipe.send_tensor(tensors[3*i]) + pipe.send_tensor(tensors[3*i+1]) + pipe.send_tensor(tensors[3*i+2]) + else: + x = pipe.recv_tensor() + mean = pipe.recv_tensor() + std = pipe.recv_tensor() + if x is None: + assert mean is None + assert std is None + else: + assert x.mean() == mean + assert x.std() == std -print(my_rank, 'Test run successed! ') + torch.distributed.barrier() -if my_rank == 0: - # send a tensor 1000 times - for i in range(3): + print("Stress test passed.") + + + +def latency_test(my_rank, pipe, nelement, ntensor): + + latencies = [] + + torch.distributed.barrier() + + for i in tqdm(range(1000)): - mean = random.randint(10, 100) - std = random.randint(10, 100) - size = [random.randint(10, 100), random.randint(10, 100)] - x = torch.normal(mean, std, size=size).to(pipe.device) + tensors = [] - if i % 10 == 0: - pipe.send_tensor(None) - pipe.send_tensor(None) - pipe.send_tensor(None) - else: - pipe.send_tensor(x) - pipe.send_tensor(x.mean()) - pipe.send_tensor(x.std()) - -else: - # recv a tensor 1000 times - for i in tqdm(range(2)): + if my_rank == 0: + # create tensor + tensors = [torch.rand(nelement).to(pipe.device) for _ in range(ntensor)] - x = pipe.recv_tensor() - mean = pipe.recv_tensor() - std = pipe.recv_tensor() + torch.distributed.barrier() - if x is None: - assert mean is None, std is None + if my_rank == 0: + t = torch.tensor(time.time(), dtype=torch.float64).to(pipe.device) + for tensor in tensors: + pipe.send_tensor(tensor) + pipe.send_tensor(t) else: - assert x.mean() == mean - assert x.std() == std + for _ in range(ntensor): + pipe.recv_tensor() + t = pipe.recv_tensor() + latencies.append(time.time() - t.item()) + + torch.distributed.barrier() + + print('Latency test passed.') + print('Latency:', torch.tensor(latencies).mean().item() * 1000, 'ms') - - \ No newline at end of file +if __name__ == "__main__": + + my_rank = int(os.environ['RANK']) + + + torch.distributed.init_process_group( + init_method="tcp://127.0.0.1:23456", + world_size=2, + rank=my_rank) + + print("initialized! My rank is %d" % my_rank) + + + pipe = tdp.TorchDistributedPipe([[0,1]], my_rank, "nccl") + + test_run(my_rank, pipe) + stress_test(my_rank, pipe) + latency_test(my_rank, pipe, 1024*8*128, 80) diff --git a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py index dd540ead9440e..d3663ac7667dd 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py @@ -64,9 +64,8 @@ def __init__( def send_tensor_wrapper(self, tensor: torch.Tensor) -> None: - print('Sending ', tensor) """Wrapper for send_tensor_dict""" - tensor_size = tensor['tensor'].element_size() * tensor['tensor'].numel() + tensor_size = tensor.element_size() * tensor.numel() self.send_tensor_dict({'tensor': tensor}, self.target_rank_for_send) with self.buffer_size_lock: @@ -106,7 +105,7 @@ def recv_tensor(self) -> Optional[torch.Tensor]: """Receives a tensor from the src rank. Blocking.""" tensor = self.recv_tensor_dict(self.target_rank_for_recv)['tensor'] - if tensor.numel() == 1 and tensor.item() == 150886311: + if tensor.numel() == 1 and tensor.item() == NONE_INT: return None else: return tensor From 6680ea780d57fed1cb162c7efef3e677992d3bd2 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Thu, 29 Aug 2024 07:48:27 +0000 Subject: [PATCH 160/303] update test case so that there are both send and recv --- tests/kv_transfer/test_send_recv.py | 33 ++++++++++++++--------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/tests/kv_transfer/test_send_recv.py b/tests/kv_transfer/test_send_recv.py index 5d59c38c4f276..5a2a72cf177fa 100644 --- a/tests/kv_transfer/test_send_recv.py +++ b/tests/kv_transfer/test_send_recv.py @@ -23,27 +23,26 @@ def stress_test(my_rank, pipe): tensors = [] - if my_rank == 0: - for i in tqdm(range(2000)): - mean = random.randint(1, 10) - std = random.randint(1, 10) - size = [random.randint(900, 1000), random.randint(900, 1000)] - x = torch.normal(mean * 1.0, std * 1.0, size=size).to(pipe.device) - - # 5% probability of sending a None - if random.randint(1, 100) < 5: - tensors.append(None) - tensors.append(None) - tensors.append(None) - else: - tensors.append(x) - tensors.append(x.mean()) - tensors.append(x.std()) + for i in tqdm(range(2000)): + mean = random.randint(1, 10) + std = random.randint(1, 10) + size = [random.randint(900, 1000), random.randint(900, 1000)] + x = torch.normal(mean * 1.0, std * 1.0, size=size).to(pipe.device) + + # 5% probability of sending a None + if random.randint(1, 100) < 5: + tensors.append(None) + tensors.append(None) + tensors.append(None) + else: + tensors.append(x) + tensors.append(x.mean()) + tensors.append(x.std()) torch.distributed.barrier() for i in tqdm(range(2000)): - if my_rank == 0: + if my_rank == int((i % 10) > 3): pipe.send_tensor(tensors[3*i]) pipe.send_tensor(tensors[3*i+1]) pipe.send_tensor(tensors[3*i+2]) From dfbfe80ba5a26a02011d651b0b78c8c38f3a3b86 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Thu, 29 Aug 2024 08:08:27 +0000 Subject: [PATCH 161/303] update kv lookup buffer --- I am TOOOOOOO sleepy --- .../simple_kv_lookup_buffer.py | 46 ++++++++++--------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py index c43a41575aee0..9864e5f2d42bf 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py @@ -10,7 +10,7 @@ class SimpleKVLookupBuffer(KVLookupBufferBase): def __init__(self, pipe): - self.tokens_roi_kv_buffer = deque() + self.buffer = deque() self.buffer_size = 0 self.buffer_lock = threading.Lock() @@ -35,11 +35,13 @@ def _matches(self, tokens_roi_sender, tokens_roi_recver): return True - if tokens_sender == tokens_recver[:tokens_sender.shape[0]]: + min_length = min(len(tokens_sender), len(tokens_receiver)) + if tokens_sender[:min_length] == tokens_recver[:min_length]: # drastically simplified - # accept a match as long as - + # common prefix matching return True + + return None def _send_tensor_and_dec_size(self, tensor: Optional[torch.Tensor]) -> None: @@ -49,15 +51,14 @@ def _send_tensor_and_dec_size(self, tensor: Optional[torch.Tensor]) -> None: tensor = tensor.clone() self.pipe.send_tensor(tensor) - def _add_to_buffer(self, input_tokens, roi, kv): + def _add_to_buffer(self, input_tokens, roi, key, value, hidden): self.buffer_size += input_tokens.element_size() * input_tokens.numel() self.buffer_size += roi.element_size() * roi.numel() - self.buffer_size += kv.element_size() * kv.numel() - self.tokens_roi_kv_buffer.append((input_tokens, roi, kv)) - - - + self.buffer_size += key.element_size() * key.numel() + self.buffer_size += value.element_size() * value.numel() + self.buffer_size += hidden.element_size() * hidden.numel() + self.buffer.append([input_tokens, roi, kv, hidden]) def drop_select_handler(self): @@ -81,17 +82,15 @@ def drop_select_handler(self): # need to clone the tensor # in case the tensor is freed before sending finishes matched_item = self.tokens_roi_kv_buffer[matched_idx] - self._send_tensor_and_dec_size(matched_item[0].clone()) - self._send_tensor_and_dec_size(matched_item[1].clone()) - self._send_tensor_and_dec_size(matched_item[2].clone()) + for tensor in matched_item: + self._send_tensor_and_dec_size(tensor.clone()) del self.tokens_roi_kv_buffer[matched_idx] else: # no match, just send None - self.pipe.send_tensor(None) - self.pipe.send_tensor(None) - self.pipe.send_tensor(None) - + for _ in range(5): + self.pipe.send_tensor(None) + def drop_select(self, input_tokens, roi): @@ -103,12 +102,17 @@ def drop_select(self, input_tokens, roi): input_tokens = self.pipe.recv_tensor() roi = self.pipe.recv_tensor() - kv = self.pipe.recv_tensor() + key = self.pipe.recv_tensor() + value = self.pipe.recv_tensor() + hidden = self.pipe.recv_tensor() - return [input_tokens, roi, kv] + return [input_tokens, roi, key, value, hidden] - def insert(self, input_tokens, roi, kv) -> None: + def insert(self, input_tokens, roi, key, value, hidden) -> None: + + with self.buffer_lock: + self._add_to_buffer(input_tokens, roi, key, value, hidden) # when calling the insert, the current process is a sender # need to launch the request handler and start listening to request. @@ -117,6 +121,4 @@ def insert(self, input_tokens, roi, kv) -> None: target=self.drop_select_handler) self.request_handling_thread.start() - with self.buffer_lock: - self._add_to_buffer(input_tokens, roi, kv) \ No newline at end of file From b566b18e4101190787077f7a889f10c2229e951a Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Wed, 4 Sep 2024 06:15:39 +0000 Subject: [PATCH 162/303] add lookup buffer test --- tests/kv_transfer/test_lookup_buffer.py | 132 ++++++++++++++++++++++++ tests/kv_transfer/test_lookup_buffer.sh | 3 + 2 files changed, 135 insertions(+) create mode 100644 tests/kv_transfer/test_lookup_buffer.py create mode 100644 tests/kv_transfer/test_lookup_buffer.sh diff --git a/tests/kv_transfer/test_lookup_buffer.py b/tests/kv_transfer/test_lookup_buffer.py new file mode 100644 index 0000000000000..477eafbfafc25 --- /dev/null +++ b/tests/kv_transfer/test_lookup_buffer.py @@ -0,0 +1,132 @@ + +import vllm.distributed.kv_transfer.kv_pipe.torch_distributed_pipe as tdp +import vllm.distributed.kv_transfer.kv_lookup_buffer.simple_kv_lookup_buffer as sklb +import torch +import os +import random +from tqdm import tqdm +import time + + +def test_run(my_rank, buffer): + # test run + tokens = torch.tensor([1,2,3]).to(buffer.pipe.device) + + if my_rank == 0: + key = 2.0 * torch.ones([5, 6]).to(buffer.pipe.device) + value = 3.0 * torch.ones([5, 6]).to(buffer.pipe.device) + + placeholder = torch.tensor([1]).to(buffer.pipe.device) + + buffer.insert(tokens, placeholder, key, value, placeholder) + + else: + placeholder = torch.tensor([1]).to(buffer.pipe.device) + tok, roi, key, value, hidden = buffer.drop_select(tokens, placeholder) + assert torch.allclose(tokens, tok) + assert torch.allclose(key, 2.0 * torch.ones([5, 6])) + assert torch.allclose(value, 3.0 * torch.ones([5, 6])) + + torch.distributed.barrier() + + if my_rank == 0: + assert buffer.buffer_size == 0 + assert len(buffer.buffer) == 0 + + +def stress_test(my_rank, pipe): + + torch.distributed.barrier() + + tensors = [] + + for i in tqdm(range(2000)): + mean = random.randint(1, 10) + std = random.randint(1, 10) + size = [random.randint(900, 1000), random.randint(900, 1000)] + x = torch.normal(mean * 1.0, std * 1.0, size=size).to(pipe.device) + + # 5% probability of sending a None + if random.randint(1, 100) < 5: + tensors.append(None) + tensors.append(None) + tensors.append(None) + else: + tensors.append(x) + tensors.append(x.mean()) + tensors.append(x.std()) + + torch.distributed.barrier() + + for i in tqdm(range(2000)): + if my_rank == int((i % 10) > 3): + pipe.send_tensor(tensors[3*i]) + pipe.send_tensor(tensors[3*i+1]) + pipe.send_tensor(tensors[3*i+2]) + else: + x = pipe.recv_tensor() + mean = pipe.recv_tensor() + std = pipe.recv_tensor() + if x is None: + assert mean is None + assert std is None + else: + assert x.mean() == mean + assert x.std() == std + + torch.distributed.barrier() + + print("Stress test passed.") + + + +def latency_test(my_rank, pipe, nelement, ntensor): + + latencies = [] + + torch.distributed.barrier() + + for i in tqdm(range(1000)): + + tensors = [] + + if my_rank == 0: + # create tensor + tensors = [torch.rand(nelement).to(pipe.device) for _ in range(ntensor)] + + torch.distributed.barrier() + + if my_rank == 0: + t = torch.tensor(time.time(), dtype=torch.float64).to(pipe.device) + for tensor in tensors: + pipe.send_tensor(tensor) + pipe.send_tensor(t) + else: + for _ in range(ntensor): + pipe.recv_tensor() + t = pipe.recv_tensor() + latencies.append(time.time() - t.item()) + + torch.distributed.barrier() + + print('Latency test passed.') + print('Latency:', torch.tensor(latencies).mean().item() * 1000, 'ms') + + +if __name__ == "__main__": + + my_rank = int(os.environ['RANK']) + + + torch.distributed.init_process_group( + init_method="tcp://127.0.0.1:23456", + world_size=2, + rank=my_rank) + + print("initialized! My rank is %d" % my_rank) + + + pipe = tdp.TorchDistributedPipe([[0,1]], my_rank, "nccl") + buffer = sklb.SimpleKVLookupBuffer(pipe) + + test_run(my_rank, buffer) diff --git a/tests/kv_transfer/test_lookup_buffer.sh b/tests/kv_transfer/test_lookup_buffer.sh new file mode 100644 index 0000000000000..336b540e70542 --- /dev/null +++ b/tests/kv_transfer/test_lookup_buffer.sh @@ -0,0 +1,3 @@ + +RANK=0 python3 test_lookup_buffer.py & +RANK=1 python3 test_lookup_buffer.py & From fc2c972bdf1beb203c7711f40a803795b2e1502d Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Wed, 4 Sep 2024 06:16:10 +0000 Subject: [PATCH 163/303] update lookup buffer --- .../simple_kv_lookup_buffer.py | 69 ++++++++++++++----- 1 file changed, 50 insertions(+), 19 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py index 9864e5f2d42bf..a5b0ee4c3c722 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py @@ -35,30 +35,51 @@ def _matches(self, tokens_roi_sender, tokens_roi_recver): return True - min_length = min(len(tokens_sender), len(tokens_receiver)) - if tokens_sender[:min_length] == tokens_recver[:min_length]: + min_length = min(len(tokens_sender), len(tokens_recver)) + if torch.allclose(tokens_sender[:min_length], tokens_recver[:min_length]): # drastically simplified # common prefix matching - return True + print("min length is ", min_length) + return min_length - return None + return 0 def _send_tensor_and_dec_size(self, tensor: Optional[torch.Tensor]) -> None: assert tensor is not None, "Use self.pipe.send(None) instead" self.buffer_size -= tensor.element_size() * tensor.numel() - tensor = tensor.clone() self.pipe.send_tensor(tensor) + + def _get_element_size(self, data): + + if data == [] or data is None: + return 0 + if isinstance(data, torch.Tensor): + return data.element_size() * data.numel() + + assert False, "Unknown data type %s" % type(data) def _add_to_buffer(self, input_tokens, roi, key, value, hidden): + + if isinstance(input_tokens, torch.Tensor): + input_tokens = input_tokens.clone() + if isinstance(roi, torch.Tensor): + roi = roi.clone() + if isinstance(key, torch.Tensor): + key = key.clone() + if isinstance(value, torch.Tensor): + value = value.clone() + if isinstance(hidden, torch.Tensor): + hidden = hidden.clone() + - self.buffer_size += input_tokens.element_size() * input_tokens.numel() - self.buffer_size += roi.element_size() * roi.numel() - self.buffer_size += key.element_size() * key.numel() - self.buffer_size += value.element_size() * value.numel() - self.buffer_size += hidden.element_size() * hidden.numel() - self.buffer.append([input_tokens, roi, kv, hidden]) + buffer_item = [input_tokens, roi, key, value, hidden] + + with self.buffer_lock: + for data in buffer_item: + self.buffer_size += self._get_element_size(data) + self.buffer.append(buffer_item) def drop_select_handler(self): @@ -66,25 +87,29 @@ def drop_select_handler(self): while True: input_tokens = self.pipe.recv_tensor() roi = self.pipe.recv_tensor() - tokens_roi = [input_tokens, roi] + tokens_roi_recver = [input_tokens, roi] matched_idx = None # perform input tokens and roi matching with self.buffer_lock: - for idx, tokens_roi_kv in enumerate(self.tokens_roi_kv_buffer): - if self._matches(tokens_roi_kv, tokens_roi): + for idx, tokens_roi_sender in enumerate(self.buffer): + if self._matches(tokens_roi_sender, tokens_roi_recver) > 0: matched_idx = idx break + + + print("Got a match ", matched_idx) if matched_idx is not None: # need to clone the tensor # in case the tensor is freed before sending finishes - matched_item = self.tokens_roi_kv_buffer[matched_idx] + matched_item = self.buffer[matched_idx] + print(matched_item) for tensor in matched_item: - self._send_tensor_and_dec_size(tensor.clone()) - del self.tokens_roi_kv_buffer[matched_idx] + self._send_tensor_and_dec_size(tensor) + del self.buffer[matched_idx] else: # no match, just send None @@ -96,9 +121,15 @@ def drop_select(self, input_tokens, roi): assert self.request_handling_thread is None, \ "drop_select should be called by the receiver" + + + if isinstance(input_tokens, torch.Tensor): + input_tokens = input_tokens.clone() + if isinstance(roi, torch.Tensor): + roi = roi.clone() - self.pipe.send_tensor(input_tokens.clone()) - self.pipe.send_tensor(roi.clone()) + self.pipe.send_tensor(input_tokens) + self.pipe.send_tensor(roi) input_tokens = self.pipe.recv_tensor() roi = self.pipe.recv_tensor() From b2c765c9444b3687aa8baeb098e2b6adc6089fc9 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Fri, 6 Sep 2024 06:37:09 +0000 Subject: [PATCH 164/303] finish lookup buffer test --- tests/kv_transfer/test_lookup_buffer.py | 142 +++++++++--------- .../simple_kv_lookup_buffer.py | 82 ++++++---- 2 files changed, 120 insertions(+), 104 deletions(-) diff --git a/tests/kv_transfer/test_lookup_buffer.py b/tests/kv_transfer/test_lookup_buffer.py index 477eafbfafc25..aa98a7804ecde 100644 --- a/tests/kv_transfer/test_lookup_buffer.py +++ b/tests/kv_transfer/test_lookup_buffer.py @@ -9,24 +9,32 @@ def test_run(my_rank, buffer): - # test run - tokens = torch.tensor([1,2,3]).to(buffer.pipe.device) + # buffer should be empty in the beginning + if my_rank == 0: + assert buffer.buffer_size == 0 + assert len(buffer.buffer) == 0 + + + # insert + tokens = torch.tensor([1,2,3]).to(buffer.pipe.device) + roi = (tokens > 0) if my_rank == 0: key = 2.0 * torch.ones([5, 6]).to(buffer.pipe.device) value = 3.0 * torch.ones([5, 6]).to(buffer.pipe.device) placeholder = torch.tensor([1]).to(buffer.pipe.device) - buffer.insert(tokens, placeholder, key, value, placeholder) + buffer.insert(tokens, roi, key, value, placeholder) + torch.distributed.barrier() - else: - placeholder = torch.tensor([1]).to(buffer.pipe.device) - tok, roi, key, value, hidden = buffer.drop_select(tokens, placeholder) + # drop_select + if my_rank == 1: + tok, roi_, key, value, hidden = buffer.drop_select(tokens, roi) assert torch.allclose(tokens, tok) + assert torch.allclose(roi, roi_) assert torch.allclose(key, 2.0 * torch.ones([5, 6])) assert torch.allclose(value, 3.0 * torch.ones([5, 6])) - torch.distributed.barrier() if my_rank == 0: @@ -34,84 +42,72 @@ def test_run(my_rank, buffer): assert len(buffer.buffer) == 0 -def stress_test(my_rank, pipe): +def stress_test(my_rank, buf): torch.distributed.barrier() + torch.manual_seed(100) + + device = buf.pipe.device - tensors = [] + reqs = [ + ( + torch.rand(100).to(device), # tokens + torch.ones(100).bool().to(device), # roi + torch.rand(100).to(device), # key + torch.rand(100).to(device), # value + torch.rand(100).to(device), # hidden + ) for i in range(200)] + + random.seed(my_rank) + random.shuffle(reqs) - for i in tqdm(range(2000)): - mean = random.randint(1, 10) - std = random.randint(1, 10) - size = [random.randint(900, 1000), random.randint(900, 1000)] - x = torch.normal(mean * 1.0, std * 1.0, size=size).to(pipe.device) - - # 5% probability of sending a None - if random.randint(1, 100) < 5: - tensors.append(None) - tensors.append(None) - tensors.append(None) - else: - tensors.append(x) - tensors.append(x.mean()) - tensors.append(x.std()) - torch.distributed.barrier() - for i in tqdm(range(2000)): - if my_rank == int((i % 10) > 3): - pipe.send_tensor(tensors[3*i]) - pipe.send_tensor(tensors[3*i+1]) - pipe.send_tensor(tensors[3*i+2]) + n = 0 + + # the buffer size can only store 100 reqs + # so the sender will occasionally block.needs to wait for the receiver. + for req in tqdm(reqs): + if my_rank == 0: + buf.insert(*req) else: - x = pipe.recv_tensor() - mean = pipe.recv_tensor() - std = pipe.recv_tensor() - if x is None: - assert mean is None - assert std is None + tok, roi, k, v, h = req + tok_, roi_, k_, v_, h_ = buf.drop_select(tok, roi) + + if tok_ is None: + assert roi_ is None + assert k_ is None + assert v_ is None + assert h_ is None + n += 1 else: - assert x.mean() == mean - assert x.std() == std - + assert torch.allclose(tok, tok_) + assert torch.allclose(roi, roi_) + assert torch.allclose(k, k_) + assert torch.allclose(v, v_) + assert torch.allclose(h, h_) + print('Rand %d done' % my_rank) torch.distributed.barrier() - - print("Stress test passed.") + if my_rank == 0: + x = torch.tensor([0]) + torch.distributed.recv(x, 1) + # the # of None received is the kv that are not selected + assert x.item() == len(buf.buffer) + # and the size of the buffer should be 2000 * buffer len + print(buf.buffer_size) + assert buf.buffer_size == 1700 * len(buf.buffer) + else: + torch.distributed.send(torch.tensor([n]), 0) + + -def latency_test(my_rank, pipe, nelement, ntensor): + + - latencies = [] - torch.distributed.barrier() - for i in tqdm(range(1000)): - - tensors = [] - - if my_rank == 0: - # create tensor - tensors = [torch.rand(nelement).to(pipe.device) for _ in range(ntensor)] - - torch.distributed.barrier() - - if my_rank == 0: - t = torch.tensor(time.time(), dtype=torch.float64).to(pipe.device) - for tensor in tensors: - pipe.send_tensor(tensor) - pipe.send_tensor(t) - else: - for _ in range(ntensor): - pipe.recv_tensor() - t = pipe.recv_tensor() - latencies.append(time.time() - t.item()) - - torch.distributed.barrier() - - print('Latency test passed.') - print('Latency:', torch.tensor(latencies).mean().item() * 1000, 'ms') - if __name__ == "__main__": @@ -127,6 +123,10 @@ def latency_test(my_rank, pipe, nelement, ntensor): pipe = tdp.TorchDistributedPipe([[0,1]], my_rank, "nccl") - buffer = sklb.SimpleKVLookupBuffer(pipe) + buffer = sklb.SimpleKVLookupBuffer(pipe, 170000) test_run(my_rank, buffer) + + stress_test(my_rank, buffer) + + print('Done') diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py index a5b0ee4c3c722..84566789f6965 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py @@ -5,14 +5,16 @@ import threading import torch from collections import deque +import time class SimpleKVLookupBuffer(KVLookupBufferBase): - def __init__(self, pipe): + def __init__(self, pipe, buffer_size_thresh): self.buffer = deque() self.buffer_size = 0 + self.buffer_size_threshold = buffer_size_thresh self.buffer_lock = threading.Lock() self.pipe = pipe self.request_handling_thread = None @@ -33,13 +35,17 @@ def _matches(self, tokens_roi_sender, tokens_roi_recver): # semantics: DROP SELECT * LIMIT 1 # so any of the data in the buffer can be drop-selected return True + + + # I am assuming that roi is a mask on tokens + tokens_sender = tokens_sender[roi_sender] + tokens_recver = tokens_recver[roi_recver] min_length = min(len(tokens_sender), len(tokens_recver)) if torch.allclose(tokens_sender[:min_length], tokens_recver[:min_length]): # drastically simplified # common prefix matching - print("min length is ", min_length) return min_length return 0 @@ -75,7 +81,7 @@ def _add_to_buffer(self, input_tokens, roi, key, value, hidden): buffer_item = [input_tokens, roi, key, value, hidden] - + with self.buffer_lock: for data in buffer_item: self.buffer_size += self._get_element_size(data) @@ -83,38 +89,42 @@ def _add_to_buffer(self, input_tokens, roi, key, value, hidden): def drop_select_handler(self): + + try: - while True: - input_tokens = self.pipe.recv_tensor() - roi = self.pipe.recv_tensor() - tokens_roi_recver = [input_tokens, roi] - - matched_idx = None - - # perform input tokens and roi matching - with self.buffer_lock: + while True: + input_tokens = self.pipe.recv_tensor() + roi = self.pipe.recv_tensor() + tokens_roi_recver = [input_tokens, roi] - for idx, tokens_roi_sender in enumerate(self.buffer): - if self._matches(tokens_roi_sender, tokens_roi_recver) > 0: - matched_idx = idx - break + matched_length = 0 + + # perform input tokens and roi matching + with self.buffer_lock: + for _ in range(len(self.buffer)): + + temp_length = self._matches(self.buffer[0], tokens_roi_recver) + if temp_length > 0: + matched_length = temp_length + break + # rotate the element we just accessed to the end + self.buffer.rotate(-1) - print("Got a match ", matched_idx) - - if matched_idx is not None: - # need to clone the tensor - # in case the tensor is freed before sending finishes - matched_item = self.buffer[matched_idx] - print(matched_item) - for tensor in matched_item: - self._send_tensor_and_dec_size(tensor) - del self.buffer[matched_idx] - - else: - # no match, just send None - for _ in range(5): - self.pipe.send_tensor(None) + if matched_length > 0: + # need to clone the tensor + # in case the tensor is freed before sending finishes + matched_item = self.buffer.popleft() + for tensor in matched_item: + self._send_tensor_and_dec_size(tensor) + + else: + # no match, just send None + for _ in range(5): + self.pipe.send_tensor(None) + except RuntimeError as e: + if 'Connection closed by peer' not in str(e): + raise e def drop_select(self, input_tokens, roi): @@ -138,12 +148,18 @@ def drop_select(self, input_tokens, roi): hidden = self.pipe.recv_tensor() return [input_tokens, roi, key, value, hidden] + + + def full_handler(self): + time.sleep(0.001) def insert(self, input_tokens, roi, key, value, hidden) -> None: + + while self.buffer_size > self.buffer_size_threshold: + self.full_handler() - with self.buffer_lock: - self._add_to_buffer(input_tokens, roi, key, value, hidden) + self._add_to_buffer(input_tokens, roi, key, value, hidden) # when calling the insert, the current process is a sender # need to launch the request handler and start listening to request. From 8aef9dcf1e701b96f2956e9e432ef94f6ace7563 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sat, 7 Sep 2024 22:27:15 +0000 Subject: [PATCH 165/303] update parallel state to use the new class method --- vllm/distributed/distributed_kv.py | 480 ----------------------------- vllm/distributed/parallel_state.py | 12 +- 2 files changed, 2 insertions(+), 490 deletions(-) delete mode 100644 vllm/distributed/distributed_kv.py diff --git a/vllm/distributed/distributed_kv.py b/vllm/distributed/distributed_kv.py deleted file mode 100644 index 9005a325d6bff..0000000000000 --- a/vllm/distributed/distributed_kv.py +++ /dev/null @@ -1,480 +0,0 @@ -"""vLLM distributed KV cache transfer API. -These APIs are used in `vllm/worker/model_runner.py`. - -Currently supporting TP and PP. - -Workflow: -- In prefill instance, KV cache sender *buffers* the KV cache send requests -- In decode instance - - KV cache receiver sends the hash of input tokens to sender - - KV cache sender executes send request - - KV cache receiver receives the KV cache -""" -from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING -from collections import defaultdict, deque -from concurrent.futures import ThreadPoolExecutor -from threading import Lock -from copy import deepcopy -import time -import threading - -import torch -from torch.distributed import Backend, ProcessGroup - -import vllm.envs as envs -from vllm.distributed.group_coordinator import GroupCoordinator -from vllm.logger import init_logger -import vllm.distributed.parallel_state as ps -from vllm import _custom_ops as ops -from vllm.sequence import IntermediateTensors - -assert envs.VLLM_DISAGG_PREFILL_ROLE in [None, "prefill", "decode"], \ - "VLLM_DISAGG_PREFILL_ROLE can only be prefill or decode." - -IS_DISTRIBUTED_KV_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE is not None) -IS_KV_PREFILL_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE == "prefill") -IS_KV_DECODE_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE == "decode") - -# add a tag when sending/recving input hash -DISTRIBUTED_KV_GLOO_TAG = 24857323 - -logger = init_logger(__name__) - -import logging - - -class RankFilter(logging.Filter): - - def filter(self, record): - # Only log if rank is 4 - rank = 1 - try: - rank = torch.distributed.get_rank() - except Exception: - pass - return rank % 4 == 0 - - -for handler in logger.handlers: - handler.addFilter(RankFilter()) - - -class DistributedKVCoordinator(GroupCoordinator): - """ - A class designated for distributed KV transfer - - Target use cases: - 1. Disaggregated prefill - 2. Remote KV cache storage - - """ - - def __init__( - self, - group_ranks: List[List[int]], - local_rank: int, - torch_distributed_backend: Union[str, Backend], - # DO NOT use pynccl here - # Pynccl send is non-blocking - # and it's possible that the memory is freed before the data being sent - # which may happen at high qps - use_pynccl: bool = False, - use_custom_allreduce: bool = False, - use_tpu_communicator: bool = True, - use_message_queue_broadcaster: bool = False, - use_cpu_comm_for_sanity_check: bool = False, - ): - - super().__init__( - group_ranks, - local_rank, - torch_distributed_backend, - use_pynccl, - use_custom_allreduce, - use_tpu_communicator, - use_message_queue_broadcaster, - ) - - # if turned on, will use CPU-based communication to perform a series of sanity check. - # but it adds ~5ms delay, so please turn it off in performance-demanding usecases (e.g. disaggregated prefill) - self.use_cpu_comm_for_sanity_check = use_cpu_comm_for_sanity_check - - # use a threadpool to buffer send request in disaggregated prefill - self.input_hash_to_kv_sending_requests = defaultdict(deque) - self.kv_sending_thread = None - self.input_hash_to_kv_sending_requests_lock = Lock() - self.target_rank_for_send = self.ranks[(self.rank_in_group + 1) % - self.world_size] - self.target_rank_for_recv = self.ranks[(self.rank_in_group - 1) % - self.world_size] - - torch.set_default_device(self.device) - - def debug_send(self, - tensor: torch.Tensor, - dst: Optional[int] = None) -> None: - """Sends a tensor to the destination rank in a non-blocking way""" - """Will send several metadata. Useful for debugging.""" - """NOTE: `dst` is the local rank of the destination rank.""" - - self.send_tensor_dict( - { - "tensor": tensor, - "mean": tensor.float().mean(), - "shape": tensor.shape - }, dst) - - def debug_recv(self, - size: torch.Size, - dtype: torch.dtype, - src: Optional[int] = None) -> torch.Tensor: - """Receives a tensor from the src rank.""" - """NOTE: `src` is the local rank of the destination rank.""" - - result = self.recv_tensor_dict(src) - tensor = result["tensor"] - assert torch.allclose(result["mean"], tensor.float().mean()) - assert result["shape"] == tensor.shape - assert result[ - "shape"] == size, f"The shape sent by sender is {result['shape']} but trying to receive {size}" - return tensor - - def kv_cache_send(self, - input_hash: int, - tensor: Union[torch.Tensor, IntermediateTensors], - is_hidden: bool = False, - dst: Optional[int] = None) -> None: - """Push the KV cache send request into the send buffer""" - """NOTE: `dst` is the local rank of the destination rank.""" - - if self.use_cpu_comm_for_sanity_check: - send_func = self.debug_send - else: - send_func = self.send - - if is_hidden and not ps.get_pp_group().is_last_rank: - - assert isinstance(tensor, IntermediateTensors) - - output = deepcopy(tensor.tensors) - for key in output: - output[key] = output[key].contiguous() - - self.input_hash_to_kv_sending_requests[input_hash].append( - [self.send_tensor_dict, output, dst]) - - else: - - assert isinstance(tensor, torch.Tensor) - - self.input_hash_to_kv_sending_requests[input_hash].append([ - send_func, - # use clone to make sure the tensor is contiguous - tensor.clone(), - dst - ]) - - def kv_cache_recv( - self, - size: torch.Size, - dtype: torch.dtype, - is_hidden: bool = False, - src: Optional[int] = None - ) -> Union[torch.Tensor, IntermediateTensors]: - """Receives a tensor from the src rank (blocking).""" - """This API should be used together with `push`""" - """NOTE: `src` is the local rank of the destination rank.""" - - if self.use_cpu_comm_for_sanity_check: - recv_func = self.debug_recv - else: - recv_func = self.recv - - if is_hidden and not ps.get_pp_group().is_last_rank: - tensor = IntermediateTensors(self.recv_tensor_dict(src)) - else: - tensor = recv_func(size, dtype, src) - - return tensor - - def send_input_hash(self, input_hash: int) -> int: - - logger.debug('[rank%d]: Sending input hash %d to rank %d', - torch.distributed.get_rank(), input_hash, - self.target_rank_for_send) - - # KV cache send go through CPU, and the original `send` only use GPU. - # So create a new group for sending input hash. - input_hash_tensor = torch.tensor([input_hash], device="cpu").long() - torch.distributed.send(input_hash_tensor, - self.target_rank_for_send, - self.cpu_group, - tag=DISTRIBUTED_KV_GLOO_TAG) - return_tensor = torch.tensor([0], device="cpu").long() - torch.distributed.recv(return_tensor, - self.target_rank_for_recv, - self.cpu_group, - tag=DISTRIBUTED_KV_GLOO_TAG) - return return_tensor.item() - - def recv_input_hash(self) -> Optional[int]: - ''' - Receive an input hash, and check if it is already cached - ''' - input_hash_tensor = torch.tensor([0], device="cpu").long() - torch.distributed.recv(input_hash_tensor, - self.target_rank_for_recv, - self.cpu_group, - tag=DISTRIBUTED_KV_GLOO_TAG) - input_hash = input_hash_tensor.item() - # a new input hash comes in, see if it is already cached - self.input_hash_to_kv_sending_requests_lock.acquire() - logger.debug('Successfully received input hash %d', input_hash) - if input_hash not in self.input_hash_to_kv_sending_requests: - logger.warning( - f"The KV cache of {input_hash} does not exist. "\ - f"Existing input hash: {list(self.input_hash_to_kv_sending_requests.keys())}") - - # 0 for fail - x = torch.tensor([0], device="cpu").long() - torch.distributed.send(x, - self.target_rank_for_send, - self.cpu_group, - tag=DISTRIBUTED_KV_GLOO_TAG) - return None - else: - logger.debug('Input hash %d exists, start sending', input_hash) - - # 1 for success - x = torch.tensor([1], device="cpu").long() - torch.distributed.send(x, - self.target_rank_for_send, - self.cpu_group, - tag=DISTRIBUTED_KV_GLOO_TAG) - return input_hash - - def kv_cache_send_loop(self): - - while True: - logger.debug( - '[rank%d]: Waiting for input hash from rank %d, my keys are %s', - torch.distributed.get_rank(), - self.target_rank_for_recv, - list(self.input_hash_to_kv_sending_requests.keys()), - ) - # wait for a new input hash - # this function will acquire the lock - input_hash = self.recv_input_hash() - if input_hash is None: - self.input_hash_to_kv_sending_requests_lock.release() - continue - - # execute corresponding kv cache sending jobs in request queue - while True: - request = self.input_hash_to_kv_sending_requests[ - input_hash].popleft() - # An empty request: the KV cahe of one request are all sent - if request == []: - break - - request[0](*request[1:]) - - if len(self.input_hash_to_kv_sending_requests[input_hash]) == 0: - logger.debug('Finish input hash %d, free GPU memory...', - input_hash) - del self.input_hash_to_kv_sending_requests[input_hash] - else: - logger.debug( - 'The buffer for input hash %d is not empty, meaning that '\ - 'there are two jobs with identical input.', - input_hash) - - self.input_hash_to_kv_sending_requests_lock.release() - - - def kv_cache_send_ready(self, input_hash: int): - - if self.kv_sending_thread is None: - self.kv_sending_thread = threading.Thread( - target=self.kv_cache_send_loop) - self.kv_sending_thread.start() - - # append an empty list to separate requests - # as there might be identical requests, that has the same input hash - self.input_hash_to_kv_sending_requests[input_hash].append([]) - logger.debug(f'Buffered input hash {input_hash}') - - def kv_cache_recv_start(self, input_hash: int): - # notify the kv cache sender with the input hash id - return self.send_input_hash(input_hash) - - def block_if_buffer_full(self): - - # block vLLM if the KV cache sending buffer is full - # TODO: allow using other policies to handle buffer full - while True: - self.input_hash_to_kv_sending_requests_lock.acquire() - if len(self.input_hash_to_kv_sending_requests.keys()) > 40: - self.input_hash_to_kv_sending_requests_lock.release() - time.sleep(0.1) - else: - self.input_hash_to_kv_sending_requests_lock.release() - break - - -def send_kv_caches_and_hidden_states( - model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: List[torch.Tensor], - hidden_or_intermediate_states: Union[torch.Tensor, IntermediateTensors], -) -> None: - - input_tokens_tuple = tuple(model_input.input_tokens.tolist()) - seq_lens = model_input.attn_metadata.seq_lens - slot_mapping = model_input.attn_metadata.slot_mapping.flatten() - - # Assumption: current batch is all-prefill requests - assert torch.allclose(model_input.attn_metadata.query_start_loc, - model_input.attn_metadata.seq_start_loc) - assert torch.all(model_input.attn_metadata.context_lens_tensor == 0) - - ps.get_disagg_group().input_hash_to_kv_sending_requests_lock.acquire() - - # query_lens contains new KV caches that are added to vLLM. - # so we will send them to decode instance - # FIXME(Kuntai): This assume that all requests are prefill. - for idx, slen in enumerate(seq_lens): - - start_pos = sum(seq_lens[:idx]) - end_pos = start_pos + slen - input_hash = hash(input_tokens_tuple[start_pos:end_pos]) - - for i in range(model_executable.model.start_layer, - model_executable.model.end_layer): - kv_cache = kv_caches[i - model_executable.model.start_layer] - - _, _, num_heads, head_size = kv_cache[0].shape - - key_cache = kv_cache[0].reshape(-1, num_heads, head_size) - value_cache = kv_cache[1].reshape(-1, num_heads, head_size) - - current_slot_mapping = slot_mapping[start_pos:end_pos] - - ps.get_disagg_group().kv_cache_send( - input_hash, key_cache[current_slot_mapping]) - ps.get_disagg_group().kv_cache_send( - input_hash, value_cache[current_slot_mapping]) - - ps.get_disagg_group().kv_cache_send( - input_hash, - hidden_or_intermediate_states[start_pos:end_pos], - is_hidden=True) - ps.get_disagg_group().kv_cache_send_ready(input_hash) - - ps.get_disagg_group().input_hash_to_kv_sending_requests_lock.release() - - ps.get_disagg_group().block_if_buffer_full() - - logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) - - -def recv_kv_caches_and_hidden_states( - model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: List[torch.Tensor] -) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool]: - - bypass_model_exec = True - - # This is disagg decode instance, during prefill state - # Need to receive KV from the prefill instance - input_tokens_tuple = tuple(model_input.input_tokens.tolist()) - seq_lens = model_input.attn_metadata.seq_lens - slot_mapping = model_input.attn_metadata.slot_mapping.flatten() - - # Assumption: current batch is all-prefill requests - assert torch.allclose(model_input.attn_metadata.query_start_loc, - model_input.attn_metadata.seq_start_loc) - assert torch.all(model_input.attn_metadata.context_lens_tensor == 0) - - hidden_or_intermediate_states_for_one_req = [] - - # enumerate different requests - # FIXME(Kuntai): This impl assumes that all requests are prefill. - for idx, slen in enumerate(seq_lens): - - start_pos = sum(seq_lens[:idx]) - end_pos = start_pos + slen - input_hash = hash(input_tokens_tuple[start_pos:end_pos]) - num_tokens = slen - - # notify the prefill instance to start sending KVs associated with input_hash - contain = ps.get_disagg_group().kv_cache_recv_start(input_hash) - - # fail to find input_hash in prefill instance - # this can occur but idk why... - if contain == 0: - bypass_model_exec = False - continue - - # receive KV cache from disaggregated prefill instance - for i in range(model_executable.model.start_layer, - model_executable.model.end_layer): - - # get kv cache - kv_cache = kv_caches[i - model_executable.model.start_layer] - # get corresponding layer - layer = model_executable.model.layers[i] - - # get kv cache shape (after sliced by tp) - _, _, num_heads, head_size = kv_cache[0].shape - key = ps.get_disagg_group().kv_cache_recv( - torch.Size([num_tokens, num_heads, head_size]), - kv_cache[0].dtype) - value = ps.get_disagg_group().kv_cache_recv( - torch.Size([num_tokens, num_heads, head_size]), - kv_cache[0].dtype) - - key_cache, value_cache = kv_cache[0], kv_cache[1] - ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - slot_mapping[start_pos:end_pos], - layer.self_attn.attn.kv_cache_dtype, - layer.self_attn.attn._k_scale, - layer.self_attn.attn._v_scale, - ) - - hidden_or_intermediate_states_for_one_req.append( - ps.get_disagg_group().kv_cache_recv(torch.Size( - [num_tokens, model_executable.config.hidden_size]), - kv_cache[0].dtype, - is_hidden=True)) - - if not bypass_model_exec: - # Some of the KV cache is not retrieved - # so we need to recompute the hidden state - return [], bypass_model_exec - - # concatenate hidden states from different requests - if isinstance(hidden_or_intermediate_states_for_one_req[0], torch.Tensor): - hidden_or_intermediate_states = torch.cat( - hidden_or_intermediate_states_for_one_req, dim=0) - else: - # concat the IntermediateTensors - keys = list( - hidden_or_intermediate_states_for_one_req[0].tensors.keys()) - result_its = {} - - for key in keys: - result_its[key] = [] - for its in hidden_or_intermediate_states_for_one_req: - result_its[key].append(its[key]) - result_its[key] = torch.cat(result_its[key], dim=0) - - hidden_or_intermediate_states = IntermediateTensors(result_its) - - logger.debug("[rank%d]: KV recv DONE.", torch.distributed.get_rank()) - return hidden_or_intermediate_states, bypass_model_exec diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 86e26b46c9e92..22bc109833c1c 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -39,7 +39,7 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.distributed.group_coordinator import GroupCoordinator -import vllm.distributed.distributed_kv as dist_kv +import vllm.distributed.kv_transfer.vllm_adapter as dist_kv @@ -350,19 +350,11 @@ def initialize_model_parallel( # decode global rank: i + world_size group_ranks.append([i, i + world_size]) logger.debug("Distributed group is %s", str(group_ranks)) - _DISAGG = dist_kv.DistributedKVCoordinator( + _DISAGG = dist_kv.KV_transfer_agent( group_ranks=group_ranks, local_rank=get_world_group().local_rank, torch_distributed_backend=backend, ) - # follow by a warmup, to warmup nccl - # necessary, as NCCL may not be warmed up when tp and pp are both 1. - temp_tensor = torch.tensor([1.]).to(_DISAGG.device) - if dist_kv.IS_KV_PREFILL_INSTANCE: - _DISAGG.send(temp_tensor) - else: - recv_tensor = _DISAGG.recv(temp_tensor.shape, temp_tensor.dtype) - assert torch.allclose(temp_tensor, recv_tensor) logger.debug("_DISAGG initialized for rank %d", torch.distributed.get_rank()) From 1b6125d0af553379ed6bd5f9f64a86cb6c425a55 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 8 Sep 2024 01:13:29 +0000 Subject: [PATCH 166/303] move the implementatio to worker_base.py --- vllm/distributed/group_coordinator.py | 5 -- vllm/distributed/kv_transfer/vllm_adapter.py | 38 ++++----- vllm/distributed/parallel_state.py | 4 +- vllm/executor/gpu_executor.py | 2 +- vllm/executor/multiproc_gpu_executor.py | 2 +- vllm/executor/ray_gpu_executor.py | 2 +- vllm/worker/model_runner.py | 72 +++++----------- vllm/worker/worker_base.py | 89 ++++++++++++++++++-- 8 files changed, 129 insertions(+), 85 deletions(-) diff --git a/vllm/distributed/group_coordinator.py b/vllm/distributed/group_coordinator.py index bfa3c7f3c17cf..e043beeb9969a 100644 --- a/vllm/distributed/group_coordinator.py +++ b/vllm/distributed/group_coordinator.py @@ -387,7 +387,6 @@ def broadcast_object_list(self, """Broadcast the input object list. NOTE: `src` is the local rank of the source rank. """ - assert src < self.world_size, f"Invalid src rank ({src})" # Bypass the function if we are using only 1 GPU. if self.world_size == 1: @@ -402,7 +401,6 @@ def send_object(self, obj: Any, dst: int) -> None: """Send the input object list to the destination rank.""" """NOTE: `dst` is the local rank of the destination rank.""" - assert dst < self.world_size, f"Invalid dst rank ({dst})" assert dst != self.rank_in_group, ( "Invalid destination rank. Destination rank is the same " @@ -432,7 +430,6 @@ def recv_object(self, src: int) -> Any: """Receive the input object list from the source rank.""" """NOTE: `src` is the local rank of the source rank.""" - assert src < self.world_size, f"Invalid src rank ({src})" assert src != self.rank_in_group, ( "Invalid source rank. Source rank is the same as the current rank." @@ -570,7 +567,6 @@ def send_tensor_dict( if dst is None: dst = (self.rank_in_group + 1) % self.world_size - assert dst < self.world_size, f"Invalid dst rank ({dst})" metadata_list: List[Tuple[Any, Any]] = [] assert isinstance( @@ -625,7 +621,6 @@ def recv_tensor_dict( if src is None: src = (self.rank_in_group - 1) % self.world_size - assert src < self.world_size, f"Invalid src rank ({src})" recv_metadata_list = self.recv_object(src=src) tensor_dict: Dict[str, Any] = {} diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py index 07358567e783f..c5747330d1ac9 100644 --- a/vllm/distributed/kv_transfer/vllm_adapter.py +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -30,19 +30,19 @@ from vllm.distributed.kv_transfer.kv_pipe.torch_distributed_pipe import TorchDistributedPipe from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_kv_lookup_buffer import SimpleKVLookupBuffer -assert envs.VLLM_DISAGG_PREFILL_ROLE in [None, "prefill", "decode", "lmc"], \ - "VLLM_DISAGG_PREFILL_ROLE can only be prefill or decode." +assert envs.VLLM_DISAGG_PREFILL_ROLE in [None, "prefill", "decode", "lmcache"], \ + "VLLM_DISAGG_PREFILL_ROLE can only be prefill, decode or lmcache." + +# currently the connections are hard-coded. +# we only handle 2 cases: +# - prefill vLLM --> decode vLLM +# - vLLM --> LMCache IS_DISTRIBUTED_KV_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE in ["prefill", "decode"]) IS_KV_PREFILL_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE == "prefill") IS_KV_DECODE_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE == "decode") +IS_LMCACHE_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE == "lmcache") -'''Jiayi starts here''' -IS_LMC_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE == "lmc") -'''Jiayi ends here''' - -# add a tag when sending/recving input hash -DISTRIBUTED_KV_GLOO_TAG = 24857323 logger = init_logger(__name__) @@ -80,7 +80,7 @@ def __init__( use_message_queue_broadcaster, ) # init lookup buffer - self.buffer = SimpleKVLookupBuffer(self.pipe) + self.buffer = SimpleKVLookupBuffer(self.pipe, 1000**3) def send_kv_caches_and_hidden_states( self, @@ -90,7 +90,6 @@ def send_kv_caches_and_hidden_states( hidden_or_intermediate_states: Union[torch.Tensor, IntermediateTensors], ) -> None: - #input_tokens_tuple = tuple(model_input.input_tokens.tolist()) input_tokens_tensor = model_input.input_tokens seq_lens = model_input.attn_metadata.seq_lens slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() @@ -99,16 +98,15 @@ def send_kv_caches_and_hidden_states( # so we will send them to decode instance # FIXME(Kuntai): This assume that all requests are prefill. for idx, slen in enumerate(seq_lens): - logger.debug(f"sending request {idx}") start_pos = sum(seq_lens[:idx]) end_pos = start_pos + slen + current_tokens = input_tokens_tensor[start_pos:end_pos] keys, values = [], [] for l in range(model_executable.model.start_layer, model_executable.model.end_layer): - logger.debug(f"sending layer {l}") kv_cache = kv_caches[l - model_executable.model.start_layer] _, _, num_heads, head_size = kv_cache[0].shape @@ -118,14 +116,14 @@ def send_kv_caches_and_hidden_states( current_slot_mapping = slot_mapping_flat[start_pos:end_pos] - keys.append(key_cache[current_slot_mapping].unsqeeze(0)) + keys.append(key_cache[current_slot_mapping].unsqueeze(0)) values.append(value_cache[current_slot_mapping].unsqueeze(0)) keys = torch.cat(keys, dim=0) values = torch.cat(values, dim=0) self.buffer.insert( - input_tokens_tensor[start_pos:end_pos], - None, + current_tokens, + torch.ones_like(current_tokens, dtype=bool), keys, values, hidden_or_intermediate_states[start_pos:end_pos] @@ -146,7 +144,7 @@ def recv_kv_caches_and_hidden_states( # This is disagg decode instance, during prefill state # Need to receive KV from the prefill instance - input_tokens_tuple = tuple(model_input.input_tokens.tolist()) + input_tokens_tensor = model_input.input_tokens seq_lens = model_input.attn_metadata.seq_lens slot_mapping = model_input.attn_metadata.slot_mapping.flatten() @@ -158,10 +156,12 @@ def recv_kv_caches_and_hidden_states( start_pos = sum(seq_lens[:idx]) end_pos = start_pos + slen - current_input_tokens = input_tokens_tuple[start_pos:end_pos] + current_tokens = input_tokens_tensor[start_pos:end_pos] num_tokens = slen - ret = self.buffer.drop_select(current_input_tokens, None) + ret = self.buffer.drop_select( + current_tokens, + torch.ones_like(current_tokens, dtype=bool)) if ret[0] is None: # didn't find any match. self.bypass_model_exec = False @@ -202,4 +202,4 @@ def recv_kv_caches_and_hidden_states( hidden_or_intermediate_states_for_one_req, dim=0) logger.debug("[rank%d]: KV recv DONE.", torch.distributed.get_rank()) - return hidden_or_intermediate_states, bypass_model_exec \ No newline at end of file + return hidden_or_intermediate_states, bypass_model_exec, model_input \ No newline at end of file diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 22bc109833c1c..c48b113de9705 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -107,10 +107,10 @@ def get_pp_group() -> GroupCoordinator: # kept for backward compatibility get_pipeline_model_parallel_group = get_pp_group -_DISAGG: Optional[dist_kv.DistributedKVCoordinator] = None +_DISAGG: Optional[dist_kv.KV_transfer_agent] = None -def get_disagg_group() -> dist_kv.DistributedKVCoordinator: +def get_disagg_group() -> dist_kv.KV_transfer_agent: assert _DISAGG is not None, ( "disaggregated prefill parallel group is not initialized") return _DISAGG diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 300e9a33eba56..c52c21600d030 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union -import vllm.distributed.distributed_kv as dist_kv +import vllm.distributed.kv_transfer.vllm_adapter as dist_kv from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index ba222f8b5e405..7857eab2e551b 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -8,7 +8,7 @@ import torch -import vllm.distributed.distributed_kv as dist_kv +import vllm.distributed.kv_transfer.vllm_adapter as dist_kv from vllm.executor.distributed_gpu_executor import ( # yapf: disable DistributedGPUExecutor, DistributedGPUExecutorAsync) from vllm.executor.gpu_executor import create_worker diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 17f4d36338860..0d14d65d26caa 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import vllm.envs as envs -import vllm.distributed.distributed_kv as dist_kv +import vllm.distributed.kv_transfer.vllm_adapter as dist_kv from vllm.executor.distributed_gpu_executor import ( # yapf: disable DistributedGPUExecutor, DistributedGPUExecutorAsync) from vllm.executor.ray_utils import RayWorkerWrapper, ray diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 4d8105bde2c4a..6c2982f6f781f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -12,7 +12,6 @@ import torch.distributed import torch.nn as nn -import vllm.distributed.distributed_kv as dist_kv try: from flashinfer import BatchDecodeWithPagedKVCacheWrapper @@ -1366,56 +1365,25 @@ def execute_model( "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, } if self.has_seqlen_agnostic else {} - # check if the current run is profiling - is_profile_run = (kv_caches is None) or (kv_caches[0] is None) - # check if the current run is prefill - is_prefill_run = prefill_meta is not None + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), + **seqlen_agnostic_kwargs) - # for disaggregated prefilling: allow bypassing model execution - bypass_model_exec = False - - # Recv kv cache for disaggregated prefill - # Skip model execution if all required KV cache are received - if all([ - is_prefill_run, - dist_kv.IS_KV_DECODE_INSTANCE, - not is_profile_run]): - - hidden_or_intermediate_states, bypass = \ - dist_kv.recv_kv_caches_and_hidden_states( - model_executable, - model_input, - kv_caches, - ) - if bypass: - bypass_model_exec = True - - if not bypass_model_exec: - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(multi_modal_kwargs, - device=self.device), - **seqlen_agnostic_kwargs) - - # Send KV cache for disaggregated prefill - if all([ - is_prefill_run, - dist_kv.IS_KV_PREFILL_INSTANCE, - not is_profile_run]): - - dist_kv.send_kv_caches_and_hidden_states( - model_executable, - model_input, - kv_caches, - hidden_or_intermediate_states, - ) - + return hidden_or_intermediate_states - # Compute the logits in the last pipeline stage. + @torch.inference_mode() + def postprocess_model( + self, + model_input, + hidden_or_intermediate_states, + + ): if not get_pp_group().is_last_rank: return hidden_or_intermediate_states @@ -1431,7 +1399,7 @@ def execute_model( sampling_metadata=model_input.sampling_metadata, ) - + decode_meta = model_input.attn_metadata.decode_metadata if self.return_hidden_states: # we only need to pass hidden states of most recent token assert model_input.sampling_metadata is not None @@ -1447,7 +1415,9 @@ def execute_model( output.hidden_states = hidden_states return [output] - + + + class CUDAGraphRunner: diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index e56440693b895..6fd94312483e5 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -16,6 +16,9 @@ update_environment_variables) from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase +import vllm.distributed.kv_transfer.vllm_adapter as dist_kv +import vllm.distributed.parallel_state as ps + logger = init_logger(__name__) @@ -212,6 +215,7 @@ def execute_worker(self, worker_input: WorkerInput) -> None: Process an execution request. """ raise NotImplementedError + def execute_model( self, @@ -269,11 +273,52 @@ def execute_model( intermediate_tensors = IntermediateTensors( get_pp_group().recv_tensor_dict( all_gather_group=get_tp_group())) - - output = self.model_runner.execute_model( - model_input, self.kv_cache[worker_input.virtual_engine] - if self.kv_cache is not None else None, intermediate_tensors, - num_steps) + + + # for disaggregated prefilling: allow bypassing model execution + bypass_model_exec = False + + + # receive KV cache. + # NOTE(kuntai): + # If only a part of KV cache is received, we will adjust model_input + # to avoid prefill on the part of KV caches that are already received. + # This will not happen for disaggregated prefill, but will happen + # when connecting to a KV cache database (like LMCache). + if self.need_recv_kv(model_input, worker_input): + hidden_or_intermediate_states, bypass_model_exec, model_input = \ + ps.get_disagg_group().recv_kv_caches_and_hidden_states( + # model is used to know which layer the current worker + # is working on, so that we can receive KV for only those + # layers. + self.model_runner.model, + model_input, + self.kv_cache[worker_input.virtual_engine], + ) + + if not bypass_model_exec: + hidden_or_intermediate_states = self.model_runner.execute_model( + model_input, self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None, intermediate_tensors, + num_steps) + + # sending out KV cache + if self.need_send_kv(model_input, worker_input): + ps.get_disagg_group().send_kv_caches_and_hidden_states( + # model is used to know which layer the current worker + # is working on, so that we can send KV for only those + # layers. + self.model_runner.model, + model_input, + self.kv_cache[worker_input.virtual_engine], + hidden_or_intermediate_states, + ) + + # Get model output based on hidden state. + output = self.model_runner.postprocess_model( + model_input, + hidden_or_intermediate_states, + ) if not get_pp_group().is_last_rank: # output is IntermediateTensors @@ -284,6 +329,40 @@ def execute_model( # output is List[SamplerOutput] return output + def need_recv_kv(self, model_input, worker_input) -> bool: + + kv_caches = self.kv_cache[worker_input.virtual_engine] + prefill_meta = model_input.attn_metadata.prefill_metadata + + # check if the current run is profiling + is_profile_run = (kv_caches is None) or (kv_caches[0] is None) + # check if the current run is prefill + is_prefill_run = prefill_meta is not None + # for disaggregated prefilling: allow bypassing model execution + + return all([ + is_prefill_run, + dist_kv.IS_KV_DECODE_INSTANCE or dist_kv.IS_LMCACHE_INSTANCE, + not is_profile_run]) + + + def need_send_kv(self, model_input, worker_input) -> bool: + + kv_caches = self.kv_cache[worker_input.virtual_engine] + prefill_meta = model_input.attn_metadata.prefill_metadata + model_executable = self.model_runner.model + + # check if the current run is profiling + is_profile_run = (kv_caches is None) or (kv_caches[0] is None) + # check if the current run is prefill + is_prefill_run = prefill_meta is not None + + return all([ + is_prefill_run, + dist_kv.IS_KV_PREFILL_INSTANCE or dist_kv.IS_LMCACHE_INSTANCE, + not is_profile_run]) + + def _execute_model_spmd( self, execute_model_req: ExecuteModelRequest, From c4102ef057210f4a5a8ac99cb125cf6c95043167 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 8 Sep 2024 08:14:15 +0000 Subject: [PATCH 167/303] update test --- tests/kv_transfer/test_send_recv.py | 36 ++++++++++++++++++----------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/tests/kv_transfer/test_send_recv.py b/tests/kv_transfer/test_send_recv.py index 5a2a72cf177fa..95a7528f0f7a8 100644 --- a/tests/kv_transfer/test_send_recv.py +++ b/tests/kv_transfer/test_send_recv.py @@ -9,12 +9,15 @@ def test_run(my_rank, pipe): # test run + x = torch.tensor([1]).to(pipe.device) + y = torch.tensor([[2., 3., 4., 8.]]).to(pipe.device) if my_rank == 0: - x = torch.tensor([1]).to(pipe.device) + pipe.send_tensor(x) + pipe.send_tensor(y) else: - y = pipe.recv_tensor() - assert y.item() == 1 + assert torch.allclose(x, pipe.recv_tensor()) + assert torch.allclose(y, pipe.recv_tensor()) def stress_test(my_rank, pipe): @@ -23,21 +26,24 @@ def stress_test(my_rank, pipe): tensors = [] + for i in tqdm(range(2000)): - mean = random.randint(1, 10) - std = random.randint(1, 10) - size = [random.randint(900, 1000), random.randint(900, 1000)] - x = torch.normal(mean * 1.0, std * 1.0, size=size).to(pipe.device) + mean = torch.rand(1).item() + std = torch.rand(1).item() + size = torch.randint(900, 1000, (2,)) + x = torch.normal(mean * 1.0, std * 1.0, size=size.tolist()).to(pipe.device) # 5% probability of sending a None - if random.randint(1, 100) < 5: + if torch.rand(1).item() < 0.05: tensors.append(None) tensors.append(None) tensors.append(None) else: tensors.append(x) - tensors.append(x.mean()) - tensors.append(x.std()) + tensors.append(x.mean().unsqueeze(0)) + tensors.append(x.std().unsqueeze(0)) + + torch.distributed.barrier() @@ -54,8 +60,9 @@ def stress_test(my_rank, pipe): assert mean is None assert std is None else: - assert x.mean() == mean - assert x.std() == std + assert torch.allclose(x, tensors[3*i]) + assert x.mean() == mean[0] + assert x.std() == std[0] torch.distributed.barrier() @@ -80,7 +87,7 @@ def latency_test(my_rank, pipe, nelement, ntensor): torch.distributed.barrier() if my_rank == 0: - t = torch.tensor(time.time(), dtype=torch.float64).to(pipe.device) + t = torch.tensor([time.time()], dtype=torch.float64).to(pipe.device) for tensor in tensors: pipe.send_tensor(tensor) pipe.send_tensor(t) @@ -110,7 +117,8 @@ def latency_test(my_rank, pipe, nelement, ntensor): pipe = tdp.TorchDistributedPipe([[0,1]], my_rank, "nccl") - + + torch.manual_seed(0) test_run(my_rank, pipe) stress_test(my_rank, pipe) latency_test(my_rank, pipe, 1024*8*128, 80) From a576532438e54bc545b86dc19d4292a32539492e Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 8 Sep 2024 08:14:39 +0000 Subject: [PATCH 168/303] update a new implementation for distributed pipe. Much less CPU communication --- .../simple_kv_lookup_buffer.py | 5 + .../kv_pipe/torch_distributed_pipe.py | 251 +++++++++++++++--- vllm/distributed/kv_transfer/vllm_adapter.py | 10 +- 3 files changed, 225 insertions(+), 41 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py index 84566789f6965..407ac7c9bcfc1 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py @@ -7,6 +7,10 @@ from collections import deque import time +from vllm.logger import init_logger + +logger = init_logger(__name__) + class SimpleKVLookupBuffer(KVLookupBufferBase): def __init__(self, pipe, buffer_size_thresh): @@ -157,6 +161,7 @@ def full_handler(self): def insert(self, input_tokens, roi, key, value, hidden) -> None: while self.buffer_size > self.buffer_size_threshold: + logger.debug("KV transfer buffer is full. Handling...") self.full_handler() self._add_to_buffer(input_tokens, roi, key, value, hidden) diff --git a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py index d3663ac7667dd..1bf54badac255 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py @@ -3,16 +3,76 @@ from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase from torch.distributed import Backend, ProcessGroup import torch -from typing import List, Union, Optional +from typing import Any, Dict, List, Optional, Tuple, Union import threading from concurrent.futures import ThreadPoolExecutor import time import threading +from collections import namedtuple +from typing import Dict, Any, Tuple, List +import pickle + +from vllm.logger import init_logger + + +logger = init_logger(__name__) + + +# auxilary function to send tensordict +TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) + +def _split_tensor_dict( + tensor_dict: Dict[str, Union[torch.Tensor, Any]] +) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: + """Split the tensor dictionary into two parts: + 1. A list of (key, value) pairs. If the value is a tensor, it is replaced + by its metadata. + 2. A list of tensors. + """ + metadata_list: List[Tuple[str, Any]] = [] + tensor_list: List[torch.Tensor] = [] + for key, value in tensor_dict.items(): + if isinstance(value, torch.Tensor): + # Note: we cannot use `value.device` here, + # because it contains not only the device type but also the device + # index (e.g. "cuda:0"). We only need the device type. + # receiving side will set the device index. + device = value.device.type + metadata_list.append( + (key, TensorMetadata(device, value.dtype, value.size()))) + tensor_list.append(value) + else: + metadata_list.append((key, value)) + return metadata_list, tensor_list # if the tensor is only one-element and only contains this number # this means that the sended object is None. NONE_INT = -150886311 +FLOAT16_INT = -543205003776624 +INT64_INT = -375623078607432 +BOOL_INT = -28035262008646 +BFLOAT16_INT = -452084912267662 +FLOAT32_INT = -1049557997456592 +FLOAT64_INT = -452201007054137 + +DTYPE2INT = { + torch.float16: FLOAT16_INT, + torch.int64: INT64_INT, + torch.bool: BOOL_INT, + torch.bfloat16: BFLOAT16_INT, + torch.float32: FLOAT32_INT, + torch.float64: FLOAT64_INT, +} + +INT2DTYPE = { + FLOAT16_INT: torch.float16, + INT64_INT: torch.int64, + BOOL_INT: torch.bool, + BFLOAT16_INT: torch.bfloat16, + FLOAT32_INT: torch.float32, + FLOAT64_INT: torch.float64, +} class BrokenPipeException(Exception): @@ -20,32 +80,41 @@ def __init__(self, message): self.message = message super().__init__(self.message) -class TorchDistributedPipe(KVPipeBase, GroupCoordinator): +class TorchDistributedPipe(KVPipeBase): def __init__( self, group_ranks: List[List[int]], local_rank: int, - torch_distributed_backend: Union[str, Backend], - # DO NOT use pynccl here - # Pynccl send is non-blocking - # and it's possible that the memory is freed before the data being sent - # which may happen at high qps - use_pynccl: bool = False, - use_custom_allreduce: bool = False, - use_tpu_communicator: bool = True, - use_message_queue_broadcaster: bool = False, + torch_distributed_backend: Union[str, Backend] ): - super().__init__( - group_ranks, - local_rank, - torch_distributed_backend, - use_pynccl, - use_custom_allreduce, - use_tpu_communicator, - use_message_queue_broadcaster, - ) + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + self.cpu_group = None + + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + self.cpu_group = cpu_group + + assert self.cpu_group is not None + assert self.device_group is not None + assert self.rank_in_group <= 1 + + if torch.cuda.is_available(): + self.device = torch.device(f"cuda:{local_rank}") + else: + self.device = torch.device("cpu") # if turned on, will use CPU-based communication to perform a series of sanity check. # but it adds ~5ms delay, so please turn it off in performance-demanding usecases (e.g. disaggregated prefill) @@ -61,19 +130,131 @@ def __init__( self.none_tensor = torch.tensor([NONE_INT]).to(self.device) self.broken = False + + # create a dummy tensor + # this tensor is used + self.dummy_cpu_tensor_for_send = torch.tensor([1],device='cpu') + self.dummy_cpu_tensor_for_recv = torch.tensor([1],device='cpu') + + self.dtype_tensor_for_recv = torch.tensor([0]).to(self.device) + self.numdim_tensor_for_recv = torch.tensor([-1]).to(self.device) + self.dims_tensor_for_recv = torch.ones([100], dtype=int).to(self.device) + + def quick_send(self, tensor, prep): + + group = self.device_group + + # NCCL is NOT fully duplex + # need to explicitly sync using CPU + # to guarantee that there is only 1-directional data happening now + torch.distributed.send( + self.dummy_cpu_tensor_for_send, + dst=self.target_rank_for_send, + group=self.cpu_group + ) + + torch.distributed.send( + prep['dtype'], + dst=self.target_rank_for_send, + group=group + ) + torch.distributed.send( + prep['numdim'], + dst=self.target_rank_for_send, + group=group + ) + torch.distributed.send( + prep['dims'], + dst=self.target_rank_for_send, + group=group + ) + torch.distributed.send( + tensor, + dst=self.target_rank_for_send, + group=group + ) + + + def quick_recv(self): + + # receive is sequential, so we can reuse the GPU buffer + group = self.device_group + + # NCCL is NOT fully duplex + # need to explicitly sync using CPU + # to guarantee that there is only 1-directional data happening now + torch.distributed.recv( + self.dummy_cpu_tensor_for_recv, + src=self.target_rank_for_recv, + group=self.cpu_group + ) + + torch.distributed.recv( + self.dtype_tensor_for_recv, + src=self.target_rank_for_recv, + group=group + ) + torch.distributed.recv( + self.numdim_tensor_for_recv, + src=self.target_rank_for_recv, + group=group + ) + + numdim = self.numdim_tensor_for_recv.item() + torch.distributed.recv( + self.dims_tensor_for_recv[:numdim], + src=self.target_rank_for_recv, + group=group + ) + + dtype = INT2DTYPE[self.dtype_tensor_for_recv.item()] + shape = self.dims_tensor_for_recv[:numdim].tolist() + + buffer = torch.zeros(shape, dtype=dtype).to(self.device) - def send_tensor_wrapper(self, tensor: torch.Tensor) -> None: - """Wrapper for send_tensor_dict""" - tensor_size = tensor.element_size() * tensor.numel() - self.send_tensor_dict({'tensor': tensor}, self.target_rank_for_send) + torch.distributed.recv( + buffer, + src=self.target_rank_for_recv, + group=group + ) + + return buffer - with self.buffer_size_lock: - self.buffer_size = self.buffer_size - tensor_size + + + def prep_send(self, tensor): + + # prepare a series of tensor before send + dtype_tensor = torch.tensor([DTYPE2INT[tensor.dtype]]).to(self.device, non_blocking=True) + numdim_tensor = torch.tensor(len(tensor.shape)).to(self.device, non_blocking=True) + dims_tensor = torch.tensor(tensor.shape).to(self.device, non_blocking=True) + + return { + 'dtype': dtype_tensor, + 'numdim': numdim_tensor, + 'dims': dims_tensor + } + + + def send_tensor_wrapper(self, tensor, prep) -> None: + + try: + """Wrapper for send_tensor_dict""" + tensor_size = tensor.element_size() * tensor.numel() + # self.send_tensor_dict({'tensor': tensor}) + self.quick_send(tensor, prep) + + with self.buffer_size_lock: + self.buffer_size = self.buffer_size - tensor_size + except Exception as e: + logger.error("Encountering exception in KV sending thread") + logger.error("%s", e) def block_if_full(self): while self.buffer_size > 1e9: + logger.debug("KV cache transfer pipe is full. Waiting...") time.sleep(0.05) def send_tensor(self, @@ -92,21 +273,27 @@ def send_tensor(self, else: tensor_size = tensor.element_size() * tensor.numel() + assert 0 < len(tensor.shape) < 100, "Send tensor does not support tensor with 0 dim or >=100 dim. Got %d" % len(tensor.shape) + self.block_if_full() with self.buffer_size_lock: self.buffer_size = self.buffer_size + tensor_size - self.kv_sending_thread.submit(self.send_tensor_wrapper, tensor) - - - + # self.kv_sending_thread.submit(self.send_tensor_wrapper, tensor) + prep = self.prep_send(tensor) + self.kv_sending_thread.submit( + self.send_tensor_wrapper, + tensor, prep) + def recv_tensor(self) -> Optional[torch.Tensor]: """Receives a tensor from the src rank. Blocking.""" - tensor = self.recv_tensor_dict(self.target_rank_for_recv)['tensor'] + tensor = self.quick_recv() if tensor.numel() == 1 and tensor.item() == NONE_INT: return None else: return tensor - \ No newline at end of file + + + diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py index c5747330d1ac9..d13d132f5dfee 100644 --- a/vllm/distributed/kv_transfer/vllm_adapter.py +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -63,10 +63,6 @@ def __init__( group_ranks: List[List[int]], local_rank: int, torch_distributed_backend: Union[str, Backend], - use_pynccl: bool = False, - use_custom_allreduce: bool = False, - use_tpu_communicator: bool = True, - use_message_queue_broadcaster: bool = False ): # init pipe @@ -74,13 +70,9 @@ def __init__( group_ranks, local_rank, torch_distributed_backend, - use_pynccl, - use_custom_allreduce, - use_tpu_communicator, - use_message_queue_broadcaster, ) # init lookup buffer - self.buffer = SimpleKVLookupBuffer(self.pipe, 1000**3) + self.buffer = SimpleKVLookupBuffer(self.pipe, 1000**3 * 10) def send_kv_caches_and_hidden_states( self, From 24a231eae92dab83565d17275d49e37557ecd2dc Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 8 Sep 2024 08:31:26 +0000 Subject: [PATCH 169/303] update tensor sending and receiving. Use CPU to transfer metadata instead. --- .../kv_pipe/torch_distributed_pipe.py | 130 +++--------------- 1 file changed, 20 insertions(+), 110 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py index 1bf54badac255..760a9662d5155 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py @@ -18,34 +18,6 @@ logger = init_logger(__name__) -# auxilary function to send tensordict -TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) - -def _split_tensor_dict( - tensor_dict: Dict[str, Union[torch.Tensor, Any]] -) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: - """Split the tensor dictionary into two parts: - 1. A list of (key, value) pairs. If the value is a tensor, it is replaced - by its metadata. - 2. A list of tensors. - """ - metadata_list: List[Tuple[str, Any]] = [] - tensor_list: List[torch.Tensor] = [] - for key, value in tensor_dict.items(): - if isinstance(value, torch.Tensor): - # Note: we cannot use `value.device` here, - # because it contains not only the device type but also the device - # index (e.g. "cuda:0"). We only need the device type. - # receiving side will set the device index. - device = value.device.type - metadata_list.append( - (key, TensorMetadata(device, value.dtype, value.size()))) - tensor_list.append(value) - else: - metadata_list.append((key, value)) - return metadata_list, tensor_list - - # if the tensor is only one-element and only contains this number # this means that the sended object is None. NONE_INT = -150886311 @@ -131,119 +103,58 @@ def __init__( self.none_tensor = torch.tensor([NONE_INT]).to(self.device) self.broken = False - # create a dummy tensor - # this tensor is used - self.dummy_cpu_tensor_for_send = torch.tensor([1],device='cpu') - self.dummy_cpu_tensor_for_recv = torch.tensor([1],device='cpu') - - self.dtype_tensor_for_recv = torch.tensor([0]).to(self.device) - self.numdim_tensor_for_recv = torch.tensor([-1]).to(self.device) - self.dims_tensor_for_recv = torch.ones([100], dtype=int).to(self.device) - - def quick_send(self, tensor, prep): + def quick_send(self, tensor): group = self.device_group # NCCL is NOT fully duplex - # need to explicitly sync using CPU - # to guarantee that there is only 1-directional data happening now - torch.distributed.send( - self.dummy_cpu_tensor_for_send, + # so CPU communication is ALWAYS necessary + torch.distributed.send_object_list( + [tensor.dtype, tensor.shape, str(tensor.device)], dst=self.target_rank_for_send, group=self.cpu_group ) - torch.distributed.send( - prep['dtype'], - dst=self.target_rank_for_send, - group=group - ) - torch.distributed.send( - prep['numdim'], - dst=self.target_rank_for_send, - group=group - ) - torch.distributed.send( - prep['dims'], - dst=self.target_rank_for_send, - group=group - ) torch.distributed.send( tensor, dst=self.target_rank_for_send, - group=group + group=self.device_group ) def quick_recv(self): - # receive is sequential, so we can reuse the GPU buffer - group = self.device_group - # NCCL is NOT fully duplex - # need to explicitly sync using CPU - # to guarantee that there is only 1-directional data happening now - torch.distributed.recv( - self.dummy_cpu_tensor_for_recv, + # so CPU communication is necessary + metadata = [None, None, None] + torch.distributed.recv_object_list( + metadata, src=self.target_rank_for_recv, group=self.cpu_group ) - torch.distributed.recv( - self.dtype_tensor_for_recv, - src=self.target_rank_for_recv, - group=group - ) - torch.distributed.recv( - self.numdim_tensor_for_recv, - src=self.target_rank_for_recv, - group=group - ) - - numdim = self.numdim_tensor_for_recv.item() - torch.distributed.recv( - self.dims_tensor_for_recv[:numdim], - src=self.target_rank_for_recv, - group=group - ) - - dtype = INT2DTYPE[self.dtype_tensor_for_recv.item()] - shape = self.dims_tensor_for_recv[:numdim].tolist() - - buffer = torch.zeros(shape, dtype=dtype).to(self.device) + dtype, shape, device = metadata + if 'cuda' in device: + device = self.device + else: + device = 'cpu' + buffer = torch.zeros(shape, dtype=dtype).to(device, non_blocking=True) torch.distributed.recv( buffer, src=self.target_rank_for_recv, - group=group + group=self.device_group ) - return buffer - - - def prep_send(self, tensor): - - # prepare a series of tensor before send - dtype_tensor = torch.tensor([DTYPE2INT[tensor.dtype]]).to(self.device, non_blocking=True) - numdim_tensor = torch.tensor(len(tensor.shape)).to(self.device, non_blocking=True) - dims_tensor = torch.tensor(tensor.shape).to(self.device, non_blocking=True) - - return { - 'dtype': dtype_tensor, - 'numdim': numdim_tensor, - 'dims': dims_tensor - } - def send_tensor_wrapper(self, tensor, prep) -> None: + def send_tensor_wrapper(self, tensor) -> None: try: - """Wrapper for send_tensor_dict""" tensor_size = tensor.element_size() * tensor.numel() - # self.send_tensor_dict({'tensor': tensor}) - self.quick_send(tensor, prep) + self.quick_send(tensor) with self.buffer_size_lock: self.buffer_size = self.buffer_size - tensor_size @@ -280,11 +191,10 @@ def send_tensor(self, with self.buffer_size_lock: self.buffer_size = self.buffer_size + tensor_size - # self.kv_sending_thread.submit(self.send_tensor_wrapper, tensor) - prep = self.prep_send(tensor) + # prepare the metadata before sending the tensor. self.kv_sending_thread.submit( self.send_tensor_wrapper, - tensor, prep) + tensor) def recv_tensor(self) -> Optional[torch.Tensor]: """Receives a tensor from the src rank. Blocking.""" From dca877ab7def76ada27324d4729eb8a3832a24e8 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 8 Sep 2024 08:52:07 +0000 Subject: [PATCH 170/303] update benchmark: use small model for quick iteration --- .../disagg_overhead_benchmark.sh | 43 +++++++++++++------ 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh index 12f5150cadda3..d264f18156438 100644 --- a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh @@ -50,29 +50,44 @@ benchmark() { # compare chunked prefill with disaggregated prefill results_folder="./results" - model="meta-llama/Meta-Llama-3.1-70B-Instruct" + model="meta-llama/Meta-Llama-3.1-8B-Instruct" dataset_name="sonnet" dataset_path="../sonnet_4x.txt" - num_prompts=50 + num_prompts=20 qps=$1 prefix_len=50 input_len=2048 output_len=$2 # large model - VLLM_RPC_PORT=5570 VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ - -m vllm.entrypoints.openai.api_server \ - --model $model \ - --port 8100 \ - -tp 4 \ - --max-model-len 30000 \ - --gpu-memory-utilization 0.8 & - VLLM_RPC_PORT=5580 VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ + # VLLM_RPC_PORT=5570 VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ + # -m vllm.entrypoints.openai.api_server \ + # --model $model \ + # --port 8100 \ + # -tp 4 \ + # --max-model-len 30000 \ + # --gpu-memory-utilization 0.8 & + # VLLM_RPC_PORT=5580 VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ + # -m vllm.entrypoints.openai.api_server \ + # --model $model \ + # --port 8200 \ + # -tp 4 \ + # --max-model-len 30000 \ + # --gpu-memory-utilization 0.8 & + + VLLM_RPC_PORT=5570 VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0 python3 \ -m vllm.entrypoints.openai.api_server \ - --model $model \ + --model meta-llama/Meta-Llama-3.1-8B-Instruct \ + --port 8100 \ + --max-model-len 10000 \ + --gpu-memory-utilization 0.8 & + +# decoding instance +VLLM_RPC_PORT=5580 VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=1 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model meta-llama/Meta-Llama-3.1-8B-Instruct \ --port 8200 \ - -tp 4 \ - --max-model-len 30000 \ + --max-model-len 10000 \ --gpu-memory-utilization 0.8 & wait_for_server 8100 @@ -92,7 +107,7 @@ benchmark() { --save-result \ --result-dir $results_folder \ --result-filename disagg_prefill_2xtp4.json \ - --request-rate $qps + --request-rate "inf" # send the request to decode. From 9f81f41f813333bf077f792ee509f0ba7fe6eac8 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 8 Sep 2024 08:57:40 +0000 Subject: [PATCH 171/303] update implementation --- benchmarks/sonnet_4x.txt | 2070 +++++++++++++++++ .../kv_pipe/torch_distributed_pipe.py | 4 +- vllm/worker/worker_base.py | 1 + 3 files changed, 2072 insertions(+), 3 deletions(-) create mode 100644 benchmarks/sonnet_4x.txt diff --git a/benchmarks/sonnet_4x.txt b/benchmarks/sonnet_4x.txt new file mode 100644 index 0000000000000..02f39a9fb14fb --- /dev/null +++ b/benchmarks/sonnet_4x.txt @@ -0,0 +1,2070 @@ + +FROM fairest creatures we desire increase, +That thereby beauty's rose might never die, +But as the riper should by time decease, +His tender heir might bear his memory: +But thou, contracted to thine own bright eyes, +Feed'st thy light'st flame with self-substantial fuel, +Making a famine where abundance lies, +Thyself thy foe, to thy sweet self too cruel. +Thou that art now the world's fresh ornament +And only herald to the gaudy spring, +Within thine own bud buriest thy content +And, tender churl, makest waste in niggarding. +Pity the world, or else this glutton be, +To eat the world's due, by the grave and thee. +When forty winters shall beseige thy brow, +And dig deep trenches in thy beauty's field, +Thy youth's proud livery, so gazed on now, +Will be a tatter'd weed, of small worth held: +Then being ask'd where all thy beauty lies, +Where all the treasure of thy lusty days, +To say, within thine own deep-sunken eyes, +Were an all-eating shame and thriftless praise. +How much more praise deserved thy beauty's use, +If thou couldst answer 'This fair child of mine +Shall sum my count and make my old excuse,' +Proving his beauty by succession thine! +This were to be new made when thou art old, +And see thy blood warm when thou feel'st it cold. +Look in thy glass, and tell the face thou viewest +Now is the time that face should form another; +Whose fresh repair if now thou not renewest, +Thou dost beguile the world, unbless some mother. +For where is she so fair whose unear'd womb +Disdains the tillage of thy husbandry? +Or who is he so fond will be the tomb +Of his self-love, to stop posterity? +Thou art thy mother's glass, and she in thee +Calls back the lovely April of her prime: +So thou through windows of thine age shall see +Despite of wrinkles this thy golden time. +But if thou live, remember'd not to be, +Die single, and thine image dies with thee. +Unthrifty loveliness, why dost thou spend +Upon thyself thy beauty's legacy? +Nature's bequest gives nothing but doth lend, +And being frank she lends to those are free. +Then, beauteous niggard, why dost thou abuse +The bounteous largess given thee to give? +Profitless usurer, why dost thou use +So great a sum of sums, yet canst not live? +For having traffic with thyself alone, +Thou of thyself thy sweet self dost deceive. +Then how, when nature calls thee to be gone, +What acceptable audit canst thou leave? +Thy unused beauty must be tomb'd with thee, +Which, used, lives th' executor to be. +Those hours, that with gentle work did frame +The lovely gaze where every eye doth dwell, +Will play the tyrants to the very same +And that unfair which fairly doth excel: +For never-resting time leads summer on +To hideous winter and confounds him there; +Sap cheque'd with frost and lusty leaves quite gone, +Beauty o'ersnow'd and bareness every where: +Then, were not summer's distillation left, +A liquid prisoner pent in walls of glass, +Beauty's effect with beauty were bereft, +Nor it nor no remembrance what it was: +But flowers distill'd though they with winter meet, +Leese but their show; their substance still lives sweet. +Then let not winter's ragged hand deface +In thee thy summer, ere thou be distill'd: +Make sweet some vial; treasure thou some place +With beauty's treasure, ere it be self-kill'd. +That use is not forbidden usury, +Which happies those that pay the willing loan; +That's for thyself to breed another thee, +Or ten times happier, be it ten for one; +Ten times thyself were happier than thou art, +If ten of thine ten times refigured thee: +Then what could death do, if thou shouldst depart, +Leaving thee living in posterity? +Be not self-will'd, for thou art much too fair +To be death's conquest and make worms thine heir. +Lo! in the orient when the gracious light +Lifts up his burning head, each under eye +Doth homage to his new-appearing sight, +Serving with looks his sacred majesty; +And having climb'd the steep-up heavenly hill, +Resembling strong youth in his middle age, +yet mortal looks adore his beauty still, +Attending on his golden pilgrimage; +But when from highmost pitch, with weary car, +Like feeble age, he reeleth from the day, +The eyes, 'fore duteous, now converted are +From his low tract and look another way: +So thou, thyself out-going in thy noon, +Unlook'd on diest, unless thou get a son. +Music to hear, why hear'st thou music sadly? +Sweets with sweets war not, joy delights in joy. +Why lovest thou that which thou receivest not gladly, +Or else receivest with pleasure thine annoy? +If the true concord of well-tuned sounds, +By unions married, do offend thine ear, +They do but sweetly chide thee, who confounds +In singleness the parts that thou shouldst bear. +Mark how one string, sweet husband to another, +Strikes each in each by mutual ordering, +Resembling sire and child and happy mother +Who all in one, one pleasing note do sing: +Whose speechless song, being many, seeming one, +Sings this to thee: 'thou single wilt prove none.' +Is it for fear to wet a widow's eye +That thou consumest thyself in single life? +Ah! if thou issueless shalt hap to die. +The world will wail thee, like a makeless wife; +The world will be thy widow and still weep +That thou no form of thee hast left behind, +When every private widow well may keep +By children's eyes her husband's shape in mind. +Look, what an unthrift in the world doth spend +Shifts but his place, for still the world enjoys it; +But beauty's waste hath in the world an end, +And kept unused, the user so destroys it. +No love toward others in that bosom sits +That on himself such murderous shame commits. +For shame! deny that thou bear'st love to any, +Who for thyself art so unprovident. +Grant, if thou wilt, thou art beloved of many, +But that thou none lovest is most evident; +For thou art so possess'd with murderous hate +That 'gainst thyself thou stick'st not to conspire. +Seeking that beauteous roof to ruinate +Which to repair should be thy chief desire. +O, change thy thought, that I may change my mind! +Shall hate be fairer lodged than gentle love? +Be, as thy presence is, gracious and kind, +Or to thyself at least kind-hearted prove: +Make thee another self, for love of me, +That beauty still may live in thine or thee. +As fast as thou shalt wane, so fast thou growest +In one of thine, from that which thou departest; +And that fresh blood which youngly thou bestowest +Thou mayst call thine when thou from youth convertest. +Herein lives wisdom, beauty and increase: +Without this, folly, age and cold decay: +If all were minded so, the times should cease +And threescore year would make the world away. +Let those whom Nature hath not made for store, +Harsh featureless and rude, barrenly perish: +Look, whom she best endow'd she gave the more; +Which bounteous gift thou shouldst in bounty cherish: +She carved thee for her seal, and meant thereby +Thou shouldst print more, not let that copy die. +When I do count the clock that tells the time, +And see the brave day sunk in hideous night; +When I behold the violet past prime, +And sable curls all silver'd o'er with white; +When lofty trees I see barren of leaves +Which erst from heat did canopy the herd, +And summer's green all girded up in sheaves +Borne on the bier with white and bristly beard, +Then of thy beauty do I question make, +That thou among the wastes of time must go, +Since sweets and beauties do themselves forsake +And die as fast as they see others grow; +And nothing 'gainst Time's scythe can make defence +Save breed, to brave him when he takes thee hence. +O, that you were yourself! but, love, you are +No longer yours than you yourself here live: +Against this coming end you should prepare, +And your sweet semblance to some other give. +So should that beauty which you hold in lease +Find no determination: then you were +Yourself again after yourself's decease, +When your sweet issue your sweet form should bear. +Who lets so fair a house fall to decay, +Which husbandry in honour might uphold +Against the stormy gusts of winter's day +And barren rage of death's eternal cold? +O, none but unthrifts! Dear my love, you know +You had a father: let your son say so. +Not from the stars do I my judgment pluck; +And yet methinks I have astronomy, +But not to tell of good or evil luck, +Of plagues, of dearths, or seasons' quality; +Nor can I fortune to brief minutes tell, +Pointing to each his thunder, rain and wind, +Or say with princes if it shall go well, +By oft predict that I in heaven find: +But from thine eyes my knowledge I derive, +And, constant stars, in them I read such art +As truth and beauty shall together thrive, +If from thyself to store thou wouldst convert; +Or else of thee this I prognosticate: +Thy end is truth's and beauty's doom and date. +When I consider every thing that grows +Holds in perfection but a little moment, +That this huge stage presenteth nought but shows +Whereon the stars in secret influence comment; +When I perceive that men as plants increase, +Cheered and cheque'd even by the self-same sky, +Vaunt in their youthful sap, at height decrease, +And wear their brave state out of memory; +Then the conceit of this inconstant stay +Sets you most rich in youth before my sight, +Where wasteful Time debateth with Decay, +To change your day of youth to sullied night; +And all in war with Time for love of you, +As he takes from you, I engraft you new. +But wherefore do not you a mightier way +Make war upon this bloody tyrant, Time? +And fortify yourself in your decay +With means more blessed than my barren rhyme? +Now stand you on the top of happy hours, +And many maiden gardens yet unset +With virtuous wish would bear your living flowers, +Much liker than your painted counterfeit: +So should the lines of life that life repair, +Which this, Time's pencil, or my pupil pen, +Neither in inward worth nor outward fair, +Can make you live yourself in eyes of men. +To give away yourself keeps yourself still, +And you must live, drawn by your own sweet skill. +Who will believe my verse in time to come, +If it were fill'd with your most high deserts? +Though yet, heaven knows, it is but as a tomb +Which hides your life and shows not half your parts. +If I could write the beauty of your eyes +And in fresh numbers number all your graces, +The age to come would say 'This poet lies: +Such heavenly touches ne'er touch'd earthly faces.' +So should my papers yellow'd with their age +Be scorn'd like old men of less truth than tongue, +And your true rights be term'd a poet's rage +And stretched metre of an antique song: +But were some child of yours alive that time, +You should live twice; in it and in my rhyme. +Shall I compare thee to a summer's day? +Thou art more lovely and more temperate: +Rough winds do shake the darling buds of May, +And summer's lease hath all too short a date: +Sometime too hot the eye of heaven shines, +And often is his gold complexion dimm'd; +And every fair from fair sometime declines, +By chance or nature's changing course untrimm'd; +But thy eternal summer shall not fade +Nor lose possession of that fair thou owest; +Nor shall Death brag thou wander'st in his shade, +When in eternal lines to time thou growest: +So long as men can breathe or eyes can see, +So long lives this and this gives life to thee. +Devouring Time, blunt thou the lion's paws, +And make the earth devour her own sweet brood; +Pluck the keen teeth from the fierce tiger's jaws, +And burn the long-lived phoenix in her blood; +Make glad and sorry seasons as thou fleets, +And do whate'er thou wilt, swift-footed Time, +To the wide world and all her fading sweets; +But I forbid thee one most heinous crime: +O, carve not with thy hours my love's fair brow, +Nor draw no lines there with thine antique pen; +Him in thy course untainted do allow +For beauty's pattern to succeeding men. +Yet, do thy worst, old Time: despite thy wrong, +My love shall in my verse ever live young. +A woman's face with Nature's own hand painted +Hast thou, the master-mistress of my passion; +A woman's gentle heart, but not acquainted +With shifting change, as is false women's fashion; +An eye more bright than theirs, less false in rolling, +Gilding the object whereupon it gazeth; +A man in hue, all 'hues' in his controlling, +Much steals men's eyes and women's souls amazeth. +And for a woman wert thou first created; +Till Nature, as she wrought thee, fell a-doting, +And by addition me of thee defeated, +By adding one thing to my purpose nothing. +But since she prick'd thee out for women's pleasure, +Mine be thy love and thy love's use their treasure. +So is it not with me as with that Muse +Stirr'd by a painted beauty to his verse, +Who heaven itself for ornament doth use +And every fair with his fair doth rehearse +Making a couplement of proud compare, +With sun and moon, with earth and sea's rich gems, +With April's first-born flowers, and all things rare +That heaven's air in this huge rondure hems. +O' let me, true in love, but truly write, +And then believe me, my love is as fair +As any mother's child, though not so bright +As those gold candles fix'd in heaven's air: +Let them say more than like of hearsay well; +I will not praise that purpose not to sell. +My glass shall not persuade me I am old, +So long as youth and thou are of one date; +But when in thee time's furrows I behold, +Then look I death my days should expiate. +For all that beauty that doth cover thee +Is but the seemly raiment of my heart, +Which in thy breast doth live, as thine in me: +How can I then be elder than thou art? +O, therefore, love, be of thyself so wary +As I, not for myself, but for thee will; +Bearing thy heart, which I will keep so chary +As tender nurse her babe from faring ill. +Presume not on thy heart when mine is slain; +Thou gavest me thine, not to give back again. +As an unperfect actor on the stage +Who with his fear is put besides his part, +Or some fierce thing replete with too much rage, +Whose strength's abundance weakens his own heart. +So I, for fear of trust, forget to say +The perfect ceremony of love's rite, +And in mine own love's strength seem to decay, +O'ercharged with burden of mine own love's might. +O, let my books be then the eloquence +And dumb presagers of my speaking breast, +Who plead for love and look for recompense +More than that tongue that more hath more express'd. +O, learn to read what silent love hath writ: +To hear with eyes belongs to love's fine wit. +Mine eye hath play'd the painter and hath stell'd +Thy beauty's form in table of my heart; +My body is the frame wherein 'tis held, +And perspective it is the painter's art. +For through the painter must you see his skill, +To find where your true image pictured lies; +Which in my bosom's shop is hanging still, +That hath his windows glazed with thine eyes. +Now see what good turns eyes for eyes have done: +Mine eyes have drawn thy shape, and thine for me +Are windows to my breast, where-through the sun +Delights to peep, to gaze therein on thee; +Yet eyes this cunning want to grace their art; +They draw but what they see, know not the heart. +Let those who are in favour with their stars +Of public honour and proud titles boast, +Whilst I, whom fortune of such triumph bars, +Unlook'd for joy in that I honour most. +Great princes' favourites their fair leaves spread +But as the marigold at the sun's eye, +And in themselves their pride lies buried, +For at a frown they in their glory die. +The painful warrior famoused for fight, +After a thousand victories once foil'd, +Is from the book of honour razed quite, +And all the rest forgot for which he toil'd: +Then happy I, that love and am beloved +Where I may not remove nor be removed. +Lord of my love, to whom in vassalage +Thy merit hath my duty strongly knit, +To thee I send this written embassage, +To witness duty, not to show my wit: +Duty so great, which wit so poor as mine +May make seem bare, in wanting words to show it, +But that I hope some good conceit of thine +In thy soul's thought, all naked, will bestow it; +Till whatsoever star that guides my moving +Points on me graciously with fair aspect +And puts apparel on my tatter'd loving, +To show me worthy of thy sweet respect: +Then may I dare to boast how I do love thee; +Till then not show my head where thou mayst prove me. +Weary with toil, I haste me to my bed, +The dear repose for limbs with travel tired; +But then begins a journey in my head, +To work my mind, when body's work's expired: +For then my thoughts, from far where I abide, +Intend a zealous pilgrimage to thee, +And keep my drooping eyelids open wide, +Looking on darkness which the blind do see +Save that my soul's imaginary sight +Presents thy shadow to my sightless view, +Which, like a jewel hung in ghastly night, +Makes black night beauteous and her old face new. +Lo! thus, by day my limbs, by night my mind, +For thee and for myself no quiet find. +How can I then return in happy plight, +That am debarr'd the benefit of rest? +When day's oppression is not eased by night, +But day by night, and night by day, oppress'd? +And each, though enemies to either's reign, +Do in consent shake hands to torture me; +The one by toil, the other to complain +How far I toil, still farther off from thee. +I tell the day, to please them thou art bright +And dost him grace when clouds do blot the heaven: +So flatter I the swart-complexion'd night, +When sparkling stars twire not thou gild'st the even. +But day doth daily draw my sorrows longer +And night doth nightly make grief's strength seem stronger. +When, in disgrace with fortune and men's eyes, +I all alone beweep my outcast state +And trouble deal heaven with my bootless cries +And look upon myself and curse my fate, +Wishing me like to one more rich in hope, +Featured like him, like him with friends possess'd, +Desiring this man's art and that man's scope, +With what I most enjoy contented least; +Yet in these thoughts myself almost despising, +Haply I think on thee, and then my state, +Like to the lark at break of day arising +From sullen earth, sings hymns at heaven's gate; +For thy sweet love remember'd such wealth brings +That then I scorn to change my state with kings. +When to the sessions of sweet silent thought +I summon up remembrance of things past, +I sigh the lack of many a thing I sought, +And with old woes new wail my dear time's waste: +Then can I drown an eye, unused to flow, +For precious friends hid in death's dateless night, +And weep afresh love's long since cancell'd woe, +And moan the expense of many a vanish'd sight: +Then can I grieve at grievances foregone, +And heavily from woe to woe tell o'er +The sad account of fore-bemoaned moan, +Which I new pay as if not paid before. +But if the while I think on thee, dear friend, +All losses are restored and sorrows end. +Thy bosom is endeared with all hearts, +Which I by lacking have supposed dead, +And there reigns love and all love's loving parts, +And all those friends which I thought buried. +How many a holy and obsequious tear +Hath dear religious love stol'n from mine eye +As interest of the dead, which now appear +But things removed that hidden in thee lie! +Thou art the grave where buried love doth live, +Hung with the trophies of my lovers gone, +Who all their parts of me to thee did give; +That due of many now is thine alone: +Their images I loved I view in thee, +And thou, all they, hast all the all of me. +If thou survive my well-contented day, +When that churl Death my bones with dust shall cover, +And shalt by fortune once more re-survey +These poor rude lines of thy deceased lover, +Compare them with the bettering of the time, +And though they be outstripp'd by every pen, +Reserve them for my love, not for their rhyme, +Exceeded by the height of happier men. +O, then vouchsafe me but this loving thought: +'Had my friend's Muse grown with this growing age, +A dearer birth than this his love had brought, +To march in ranks of better equipage: +But since he died and poets better prove, +Theirs for their style I'll read, his for his love.' +Full many a glorious morning have I seen +Flatter the mountain-tops with sovereign eye, +Kissing with golden face the meadows green, +Gilding pale streams with heavenly alchemy; +Anon permit the basest clouds to ride +With ugly rack on his celestial face, +And from the forlorn world his visage hide, +Stealing unseen to west with this disgrace: +Even so my sun one early morn did shine +With all triumphant splendor on my brow; +But out, alack! he was but one hour mine; +The region cloud hath mask'd him from me now. +Yet him for this my love no whit disdaineth; +Suns of the world may stain when heaven's sun staineth. +Why didst thou promise such a beauteous day, +And make me travel forth without my cloak, +To let base clouds o'ertake me in my way, +Hiding thy bravery in their rotten smoke? +'Tis not enough that through the cloud thou break, +To dry the rain on my storm-beaten face, +For no man well of such a salve can speak +That heals the wound and cures not the disgrace: +Nor can thy shame give physic to my grief; +Though thou repent, yet I have still the loss: +The offender's sorrow lends but weak relief +To him that bears the strong offence's cross. +Ah! but those tears are pearl which thy love sheds, +And they are rich and ransom all ill deeds. +No more be grieved at that which thou hast done: +Roses have thorns, and silver fountains mud; +Clouds and eclipses stain both moon and sun, +And loathsome canker lives in sweetest bud. +All men make faults, and even I in this, +Authorizing thy trespass with compare, +Myself corrupting, salving thy amiss, +Excusing thy sins more than thy sins are; +For to thy sensual fault I bring in sense-- +Thy adverse party is thy advocate-- +And 'gainst myself a lawful plea commence: +Such civil war is in my love and hate +That I an accessary needs must be +To that sweet thief which sourly robs from me. +Let me confess that we two must be twain, +Although our undivided loves are one: +So shall those blots that do with me remain +Without thy help by me be borne alone. +In our two loves there is but one respect, +Though in our lives a separable spite, +Which though it alter not love's sole effect, +Yet doth it steal sweet hours from love's delight. +I may not evermore acknowledge thee, +Lest my bewailed guilt should do thee shame, +Nor thou with public kindness honour me, +Unless thou take that honour from thy name: +But do not so; I love thee in such sort +As, thou being mine, mine is thy good report. +As a decrepit father takes delight +To see his active child do deeds of youth, +So I, made lame by fortune's dearest spite, +Take all my comfort of thy worth and truth. +For whether beauty, birth, or wealth, or wit, +Or any of these all, or all, or more, +Entitled in thy parts do crowned sit, +I make my love engrafted to this store: +So then I am not lame, poor, nor despised, +Whilst that this shadow doth such substance give +That I in thy abundance am sufficed +And by a part of all thy glory live. +Look, what is best, that best I wish in thee: +This wish I have; then ten times happy me!FROM fairest creatures we desire increase, +That thereby beauty's rose might never die, +But as the riper should by time decease, +His tender heir might bear his memory: +But thou, contracted to thine own bright eyes, +Feed'st thy light'st flame with self-substantial fuel, +Making a famine where abundance lies, +Thyself thy foe, to thy sweet self too cruel. +Thou that art now the world's fresh ornament +And only herald to the gaudy spring, +Within thine own bud buriest thy content +And, tender churl, makest waste in niggarding. +Pity the world, or else this glutton be, +To eat the world's due, by the grave and thee. +When forty winters shall beseige thy brow, +And dig deep trenches in thy beauty's field, +Thy youth's proud livery, so gazed on now, +Will be a tatter'd weed, of small worth held: +Then being ask'd where all thy beauty lies, +Where all the treasure of thy lusty days, +To say, within thine own deep-sunken eyes, +Were an all-eating shame and thriftless praise. +How much more praise deserved thy beauty's use, +If thou couldst answer 'This fair child of mine +Shall sum my count and make my old excuse,' +Proving his beauty by succession thine! +This were to be new made when thou art old, +And see thy blood warm when thou feel'st it cold. +Look in thy glass, and tell the face thou viewest +Now is the time that face should form another; +Whose fresh repair if now thou not renewest, +Thou dost beguile the world, unbless some mother. +For where is she so fair whose unear'd womb +Disdains the tillage of thy husbandry? +Or who is he so fond will be the tomb +Of his self-love, to stop posterity? +Thou art thy mother's glass, and she in thee +Calls back the lovely April of her prime: +So thou through windows of thine age shall see +Despite of wrinkles this thy golden time. +But if thou live, remember'd not to be, +Die single, and thine image dies with thee. +Unthrifty loveliness, why dost thou spend +Upon thyself thy beauty's legacy? +Nature's bequest gives nothing but doth lend, +And being frank she lends to those are free. +Then, beauteous niggard, why dost thou abuse +The bounteous largess given thee to give? +Profitless usurer, why dost thou use +So great a sum of sums, yet canst not live? +For having traffic with thyself alone, +Thou of thyself thy sweet self dost deceive. +Then how, when nature calls thee to be gone, +What acceptable audit canst thou leave? +Thy unused beauty must be tomb'd with thee, +Which, used, lives th' executor to be. +Those hours, that with gentle work did frame +The lovely gaze where every eye doth dwell, +Will play the tyrants to the very same +And that unfair which fairly doth excel: +For never-resting time leads summer on +To hideous winter and confounds him there; +Sap cheque'd with frost and lusty leaves quite gone, +Beauty o'ersnow'd and bareness every where: +Then, were not summer's distillation left, +A liquid prisoner pent in walls of glass, +Beauty's effect with beauty were bereft, +Nor it nor no remembrance what it was: +But flowers distill'd though they with winter meet, +Leese but their show; their substance still lives sweet. +Then let not winter's ragged hand deface +In thee thy summer, ere thou be distill'd: +Make sweet some vial; treasure thou some place +With beauty's treasure, ere it be self-kill'd. +That use is not forbidden usury, +Which happies those that pay the willing loan; +That's for thyself to breed another thee, +Or ten times happier, be it ten for one; +Ten times thyself were happier than thou art, +If ten of thine ten times refigured thee: +Then what could death do, if thou shouldst depart, +Leaving thee living in posterity? +Be not self-will'd, for thou art much too fair +To be death's conquest and make worms thine heir. +Lo! in the orient when the gracious light +Lifts up his burning head, each under eye +Doth homage to his new-appearing sight, +Serving with looks his sacred majesty; +And having climb'd the steep-up heavenly hill, +Resembling strong youth in his middle age, +yet mortal looks adore his beauty still, +Attending on his golden pilgrimage; +But when from highmost pitch, with weary car, +Like feeble age, he reeleth from the day, +The eyes, 'fore duteous, now converted are +From his low tract and look another way: +So thou, thyself out-going in thy noon, +Unlook'd on diest, unless thou get a son. +Music to hear, why hear'st thou music sadly? +Sweets with sweets war not, joy delights in joy. +Why lovest thou that which thou receivest not gladly, +Or else receivest with pleasure thine annoy? +If the true concord of well-tuned sounds, +By unions married, do offend thine ear, +They do but sweetly chide thee, who confounds +In singleness the parts that thou shouldst bear. +Mark how one string, sweet husband to another, +Strikes each in each by mutual ordering, +Resembling sire and child and happy mother +Who all in one, one pleasing note do sing: +Whose speechless song, being many, seeming one, +Sings this to thee: 'thou single wilt prove none.' +Is it for fear to wet a widow's eye +That thou consumest thyself in single life? +Ah! if thou issueless shalt hap to die. +The world will wail thee, like a makeless wife; +The world will be thy widow and still weep +That thou no form of thee hast left behind, +When every private widow well may keep +By children's eyes her husband's shape in mind. +Look, what an unthrift in the world doth spend +Shifts but his place, for still the world enjoys it; +But beauty's waste hath in the world an end, +And kept unused, the user so destroys it. +No love toward others in that bosom sits +That on himself such murderous shame commits. +For shame! deny that thou bear'st love to any, +Who for thyself art so unprovident. +Grant, if thou wilt, thou art beloved of many, +But that thou none lovest is most evident; +For thou art so possess'd with murderous hate +That 'gainst thyself thou stick'st not to conspire. +Seeking that beauteous roof to ruinate +Which to repair should be thy chief desire. +O, change thy thought, that I may change my mind! +Shall hate be fairer lodged than gentle love? +Be, as thy presence is, gracious and kind, +Or to thyself at least kind-hearted prove: +Make thee another self, for love of me, +That beauty still may live in thine or thee. +As fast as thou shalt wane, so fast thou growest +In one of thine, from that which thou departest; +And that fresh blood which youngly thou bestowest +Thou mayst call thine when thou from youth convertest. +Herein lives wisdom, beauty and increase: +Without this, folly, age and cold decay: +If all were minded so, the times should cease +And threescore year would make the world away. +Let those whom Nature hath not made for store, +Harsh featureless and rude, barrenly perish: +Look, whom she best endow'd she gave the more; +Which bounteous gift thou shouldst in bounty cherish: +She carved thee for her seal, and meant thereby +Thou shouldst print more, not let that copy die. +When I do count the clock that tells the time, +And see the brave day sunk in hideous night; +When I behold the violet past prime, +And sable curls all silver'd o'er with white; +When lofty trees I see barren of leaves +Which erst from heat did canopy the herd, +And summer's green all girded up in sheaves +Borne on the bier with white and bristly beard, +Then of thy beauty do I question make, +That thou among the wastes of time must go, +Since sweets and beauties do themselves forsake +And die as fast as they see others grow; +And nothing 'gainst Time's scythe can make defence +Save breed, to brave him when he takes thee hence. +O, that you were yourself! but, love, you are +No longer yours than you yourself here live: +Against this coming end you should prepare, +And your sweet semblance to some other give. +So should that beauty which you hold in lease +Find no determination: then you were +Yourself again after yourself's decease, +When your sweet issue your sweet form should bear. +Who lets so fair a house fall to decay, +Which husbandry in honour might uphold +Against the stormy gusts of winter's day +And barren rage of death's eternal cold? +O, none but unthrifts! Dear my love, you know +You had a father: let your son say so. +Not from the stars do I my judgment pluck; +And yet methinks I have astronomy, +But not to tell of good or evil luck, +Of plagues, of dearths, or seasons' quality; +Nor can I fortune to brief minutes tell, +Pointing to each his thunder, rain and wind, +Or say with princes if it shall go well, +By oft predict that I in heaven find: +But from thine eyes my knowledge I derive, +And, constant stars, in them I read such art +As truth and beauty shall together thrive, +If from thyself to store thou wouldst convert; +Or else of thee this I prognosticate: +Thy end is truth's and beauty's doom and date. +When I consider every thing that grows +Holds in perfection but a little moment, +That this huge stage presenteth nought but shows +Whereon the stars in secret influence comment; +When I perceive that men as plants increase, +Cheered and cheque'd even by the self-same sky, +Vaunt in their youthful sap, at height decrease, +And wear their brave state out of memory; +Then the conceit of this inconstant stay +Sets you most rich in youth before my sight, +Where wasteful Time debateth with Decay, +To change your day of youth to sullied night; +And all in war with Time for love of you, +As he takes from you, I engraft you new. +But wherefore do not you a mightier way +Make war upon this bloody tyrant, Time? +And fortify yourself in your decay +With means more blessed than my barren rhyme? +Now stand you on the top of happy hours, +And many maiden gardens yet unset +With virtuous wish would bear your living flowers, +Much liker than your painted counterfeit: +So should the lines of life that life repair, +Which this, Time's pencil, or my pupil pen, +Neither in inward worth nor outward fair, +Can make you live yourself in eyes of men. +To give away yourself keeps yourself still, +And you must live, drawn by your own sweet skill. +Who will believe my verse in time to come, +If it were fill'd with your most high deserts? +Though yet, heaven knows, it is but as a tomb +Which hides your life and shows not half your parts. +If I could write the beauty of your eyes +And in fresh numbers number all your graces, +The age to come would say 'This poet lies: +Such heavenly touches ne'er touch'd earthly faces.' +So should my papers yellow'd with their age +Be scorn'd like old men of less truth than tongue, +And your true rights be term'd a poet's rage +And stretched metre of an antique song: +But were some child of yours alive that time, +You should live twice; in it and in my rhyme. +Shall I compare thee to a summer's day? +Thou art more lovely and more temperate: +Rough winds do shake the darling buds of May, +And summer's lease hath all too short a date: +Sometime too hot the eye of heaven shines, +And often is his gold complexion dimm'd; +And every fair from fair sometime declines, +By chance or nature's changing course untrimm'd; +But thy eternal summer shall not fade +Nor lose possession of that fair thou owest; +Nor shall Death brag thou wander'st in his shade, +When in eternal lines to time thou growest: +So long as men can breathe or eyes can see, +So long lives this and this gives life to thee. +Devouring Time, blunt thou the lion's paws, +And make the earth devour her own sweet brood; +Pluck the keen teeth from the fierce tiger's jaws, +And burn the long-lived phoenix in her blood; +Make glad and sorry seasons as thou fleets, +And do whate'er thou wilt, swift-footed Time, +To the wide world and all her fading sweets; +But I forbid thee one most heinous crime: +O, carve not with thy hours my love's fair brow, +Nor draw no lines there with thine antique pen; +Him in thy course untainted do allow +For beauty's pattern to succeeding men. +Yet, do thy worst, old Time: despite thy wrong, +My love shall in my verse ever live young. +A woman's face with Nature's own hand painted +Hast thou, the master-mistress of my passion; +A woman's gentle heart, but not acquainted +With shifting change, as is false women's fashion; +An eye more bright than theirs, less false in rolling, +Gilding the object whereupon it gazeth; +A man in hue, all 'hues' in his controlling, +Much steals men's eyes and women's souls amazeth. +And for a woman wert thou first created; +Till Nature, as she wrought thee, fell a-doting, +And by addition me of thee defeated, +By adding one thing to my purpose nothing. +But since she prick'd thee out for women's pleasure, +Mine be thy love and thy love's use their treasure. +So is it not with me as with that Muse +Stirr'd by a painted beauty to his verse, +Who heaven itself for ornament doth use +And every fair with his fair doth rehearse +Making a couplement of proud compare, +With sun and moon, with earth and sea's rich gems, +With April's first-born flowers, and all things rare +That heaven's air in this huge rondure hems. +O' let me, true in love, but truly write, +And then believe me, my love is as fair +As any mother's child, though not so bright +As those gold candles fix'd in heaven's air: +Let them say more than like of hearsay well; +I will not praise that purpose not to sell. +My glass shall not persuade me I am old, +So long as youth and thou are of one date; +But when in thee time's furrows I behold, +Then look I death my days should expiate. +For all that beauty that doth cover thee +Is but the seemly raiment of my heart, +Which in thy breast doth live, as thine in me: +How can I then be elder than thou art? +O, therefore, love, be of thyself so wary +As I, not for myself, but for thee will; +Bearing thy heart, which I will keep so chary +As tender nurse her babe from faring ill. +Presume not on thy heart when mine is slain; +Thou gavest me thine, not to give back again. +As an unperfect actor on the stage +Who with his fear is put besides his part, +Or some fierce thing replete with too much rage, +Whose strength's abundance weakens his own heart. +So I, for fear of trust, forget to say +The perfect ceremony of love's rite, +And in mine own love's strength seem to decay, +O'ercharged with burden of mine own love's might. +O, let my books be then the eloquence +And dumb presagers of my speaking breast, +Who plead for love and look for recompense +More than that tongue that more hath more express'd. +O, learn to read what silent love hath writ: +To hear with eyes belongs to love's fine wit. +Mine eye hath play'd the painter and hath stell'd +Thy beauty's form in table of my heart; +My body is the frame wherein 'tis held, +And perspective it is the painter's art. +For through the painter must you see his skill, +To find where your true image pictured lies; +Which in my bosom's shop is hanging still, +That hath his windows glazed with thine eyes. +Now see what good turns eyes for eyes have done: +Mine eyes have drawn thy shape, and thine for me +Are windows to my breast, where-through the sun +Delights to peep, to gaze therein on thee; +Yet eyes this cunning want to grace their art; +They draw but what they see, know not the heart. +Let those who are in favour with their stars +Of public honour and proud titles boast, +Whilst I, whom fortune of such triumph bars, +Unlook'd for joy in that I honour most. +Great princes' favourites their fair leaves spread +But as the marigold at the sun's eye, +And in themselves their pride lies buried, +For at a frown they in their glory die. +The painful warrior famoused for fight, +After a thousand victories once foil'd, +Is from the book of honour razed quite, +And all the rest forgot for which he toil'd: +Then happy I, that love and am beloved +Where I may not remove nor be removed. +Lord of my love, to whom in vassalage +Thy merit hath my duty strongly knit, +To thee I send this written embassage, +To witness duty, not to show my wit: +Duty so great, which wit so poor as mine +May make seem bare, in wanting words to show it, +But that I hope some good conceit of thine +In thy soul's thought, all naked, will bestow it; +Till whatsoever star that guides my moving +Points on me graciously with fair aspect +And puts apparel on my tatter'd loving, +To show me worthy of thy sweet respect: +Then may I dare to boast how I do love thee; +Till then not show my head where thou mayst prove me. +Weary with toil, I haste me to my bed, +The dear repose for limbs with travel tired; +But then begins a journey in my head, +To work my mind, when body's work's expired: +For then my thoughts, from far where I abide, +Intend a zealous pilgrimage to thee, +And keep my drooping eyelids open wide, +Looking on darkness which the blind do see +Save that my soul's imaginary sight +Presents thy shadow to my sightless view, +Which, like a jewel hung in ghastly night, +Makes black night beauteous and her old face new. +Lo! thus, by day my limbs, by night my mind, +For thee and for myself no quiet find. +How can I then return in happy plight, +That am debarr'd the benefit of rest? +When day's oppression is not eased by night, +But day by night, and night by day, oppress'd? +And each, though enemies to either's reign, +Do in consent shake hands to torture me; +The one by toil, the other to complain +How far I toil, still farther off from thee. +I tell the day, to please them thou art bright +And dost him grace when clouds do blot the heaven: +So flatter I the swart-complexion'd night, +When sparkling stars twire not thou gild'st the even. +But day doth daily draw my sorrows longer +And night doth nightly make grief's strength seem stronger. +When, in disgrace with fortune and men's eyes, +I all alone beweep my outcast state +And trouble deal heaven with my bootless cries +And look upon myself and curse my fate, +Wishing me like to one more rich in hope, +Featured like him, like him with friends possess'd, +Desiring this man's art and that man's scope, +With what I most enjoy contented least; +Yet in these thoughts myself almost despising, +Haply I think on thee, and then my state, +Like to the lark at break of day arising +From sullen earth, sings hymns at heaven's gate; +For thy sweet love remember'd such wealth brings +That then I scorn to change my state with kings. +When to the sessions of sweet silent thought +I summon up remembrance of things past, +I sigh the lack of many a thing I sought, +And with old woes new wail my dear time's waste: +Then can I drown an eye, unused to flow, +For precious friends hid in death's dateless night, +And weep afresh love's long since cancell'd woe, +And moan the expense of many a vanish'd sight: +Then can I grieve at grievances foregone, +And heavily from woe to woe tell o'er +The sad account of fore-bemoaned moan, +Which I new pay as if not paid before. +But if the while I think on thee, dear friend, +All losses are restored and sorrows end. +Thy bosom is endeared with all hearts, +Which I by lacking have supposed dead, +And there reigns love and all love's loving parts, +And all those friends which I thought buried. +How many a holy and obsequious tear +Hath dear religious love stol'n from mine eye +As interest of the dead, which now appear +But things removed that hidden in thee lie! +Thou art the grave where buried love doth live, +Hung with the trophies of my lovers gone, +Who all their parts of me to thee did give; +That due of many now is thine alone: +Their images I loved I view in thee, +And thou, all they, hast all the all of me. +If thou survive my well-contented day, +When that churl Death my bones with dust shall cover, +And shalt by fortune once more re-survey +These poor rude lines of thy deceased lover, +Compare them with the bettering of the time, +And though they be outstripp'd by every pen, +Reserve them for my love, not for their rhyme, +Exceeded by the height of happier men. +O, then vouchsafe me but this loving thought: +'Had my friend's Muse grown with this growing age, +A dearer birth than this his love had brought, +To march in ranks of better equipage: +But since he died and poets better prove, +Theirs for their style I'll read, his for his love.' +Full many a glorious morning have I seen +Flatter the mountain-tops with sovereign eye, +Kissing with golden face the meadows green, +Gilding pale streams with heavenly alchemy; +Anon permit the basest clouds to ride +With ugly rack on his celestial face, +And from the forlorn world his visage hide, +Stealing unseen to west with this disgrace: +Even so my sun one early morn did shine +With all triumphant splendor on my brow; +But out, alack! he was but one hour mine; +The region cloud hath mask'd him from me now. +Yet him for this my love no whit disdaineth; +Suns of the world may stain when heaven's sun staineth. +Why didst thou promise such a beauteous day, +And make me travel forth without my cloak, +To let base clouds o'ertake me in my way, +Hiding thy bravery in their rotten smoke? +'Tis not enough that through the cloud thou break, +To dry the rain on my storm-beaten face, +For no man well of such a salve can speak +That heals the wound and cures not the disgrace: +Nor can thy shame give physic to my grief; +Though thou repent, yet I have still the loss: +The offender's sorrow lends but weak relief +To him that bears the strong offence's cross. +Ah! but those tears are pearl which thy love sheds, +And they are rich and ransom all ill deeds. +No more be grieved at that which thou hast done: +Roses have thorns, and silver fountains mud; +Clouds and eclipses stain both moon and sun, +And loathsome canker lives in sweetest bud. +All men make faults, and even I in this, +Authorizing thy trespass with compare, +Myself corrupting, salving thy amiss, +Excusing thy sins more than thy sins are; +For to thy sensual fault I bring in sense-- +Thy adverse party is thy advocate-- +And 'gainst myself a lawful plea commence: +Such civil war is in my love and hate +That I an accessary needs must be +To that sweet thief which sourly robs from me. +Let me confess that we two must be twain, +Although our undivided loves are one: +So shall those blots that do with me remain +Without thy help by me be borne alone. +In our two loves there is but one respect, +Though in our lives a separable spite, +Which though it alter not love's sole effect, +Yet doth it steal sweet hours from love's delight. +I may not evermore acknowledge thee, +Lest my bewailed guilt should do thee shame, +Nor thou with public kindness honour me, +Unless thou take that honour from thy name: +But do not so; I love thee in such sort +As, thou being mine, mine is thy good report. +As a decrepit father takes delight +To see his active child do deeds of youth, +So I, made lame by fortune's dearest spite, +Take all my comfort of thy worth and truth. +For whether beauty, birth, or wealth, or wit, +Or any of these all, or all, or more, +Entitled in thy parts do crowned sit, +I make my love engrafted to this store: +So then I am not lame, poor, nor despised, +Whilst that this shadow doth such substance give +That I in thy abundance am sufficed +And by a part of all thy glory live. +Look, what is best, that best I wish in thee: +This wish I have; then ten times happy me!FROM fairest creatures we desire increase, +That thereby beauty's rose might never die, +But as the riper should by time decease, +His tender heir might bear his memory: +But thou, contracted to thine own bright eyes, +Feed'st thy light'st flame with self-substantial fuel, +Making a famine where abundance lies, +Thyself thy foe, to thy sweet self too cruel. +Thou that art now the world's fresh ornament +And only herald to the gaudy spring, +Within thine own bud buriest thy content +And, tender churl, makest waste in niggarding. +Pity the world, or else this glutton be, +To eat the world's due, by the grave and thee. +When forty winters shall beseige thy brow, +And dig deep trenches in thy beauty's field, +Thy youth's proud livery, so gazed on now, +Will be a tatter'd weed, of small worth held: +Then being ask'd where all thy beauty lies, +Where all the treasure of thy lusty days, +To say, within thine own deep-sunken eyes, +Were an all-eating shame and thriftless praise. +How much more praise deserved thy beauty's use, +If thou couldst answer 'This fair child of mine +Shall sum my count and make my old excuse,' +Proving his beauty by succession thine! +This were to be new made when thou art old, +And see thy blood warm when thou feel'st it cold. +Look in thy glass, and tell the face thou viewest +Now is the time that face should form another; +Whose fresh repair if now thou not renewest, +Thou dost beguile the world, unbless some mother. +For where is she so fair whose unear'd womb +Disdains the tillage of thy husbandry? +Or who is he so fond will be the tomb +Of his self-love, to stop posterity? +Thou art thy mother's glass, and she in thee +Calls back the lovely April of her prime: +So thou through windows of thine age shall see +Despite of wrinkles this thy golden time. +But if thou live, remember'd not to be, +Die single, and thine image dies with thee. +Unthrifty loveliness, why dost thou spend +Upon thyself thy beauty's legacy? +Nature's bequest gives nothing but doth lend, +And being frank she lends to those are free. +Then, beauteous niggard, why dost thou abuse +The bounteous largess given thee to give? +Profitless usurer, why dost thou use +So great a sum of sums, yet canst not live? +For having traffic with thyself alone, +Thou of thyself thy sweet self dost deceive. +Then how, when nature calls thee to be gone, +What acceptable audit canst thou leave? +Thy unused beauty must be tomb'd with thee, +Which, used, lives th' executor to be. +Those hours, that with gentle work did frame +The lovely gaze where every eye doth dwell, +Will play the tyrants to the very same +And that unfair which fairly doth excel: +For never-resting time leads summer on +To hideous winter and confounds him there; +Sap cheque'd with frost and lusty leaves quite gone, +Beauty o'ersnow'd and bareness every where: +Then, were not summer's distillation left, +A liquid prisoner pent in walls of glass, +Beauty's effect with beauty were bereft, +Nor it nor no remembrance what it was: +But flowers distill'd though they with winter meet, +Leese but their show; their substance still lives sweet. +Then let not winter's ragged hand deface +In thee thy summer, ere thou be distill'd: +Make sweet some vial; treasure thou some place +With beauty's treasure, ere it be self-kill'd. +That use is not forbidden usury, +Which happies those that pay the willing loan; +That's for thyself to breed another thee, +Or ten times happier, be it ten for one; +Ten times thyself were happier than thou art, +If ten of thine ten times refigured thee: +Then what could death do, if thou shouldst depart, +Leaving thee living in posterity? +Be not self-will'd, for thou art much too fair +To be death's conquest and make worms thine heir. +Lo! in the orient when the gracious light +Lifts up his burning head, each under eye +Doth homage to his new-appearing sight, +Serving with looks his sacred majesty; +And having climb'd the steep-up heavenly hill, +Resembling strong youth in his middle age, +yet mortal looks adore his beauty still, +Attending on his golden pilgrimage; +But when from highmost pitch, with weary car, +Like feeble age, he reeleth from the day, +The eyes, 'fore duteous, now converted are +From his low tract and look another way: +So thou, thyself out-going in thy noon, +Unlook'd on diest, unless thou get a son. +Music to hear, why hear'st thou music sadly? +Sweets with sweets war not, joy delights in joy. +Why lovest thou that which thou receivest not gladly, +Or else receivest with pleasure thine annoy? +If the true concord of well-tuned sounds, +By unions married, do offend thine ear, +They do but sweetly chide thee, who confounds +In singleness the parts that thou shouldst bear. +Mark how one string, sweet husband to another, +Strikes each in each by mutual ordering, +Resembling sire and child and happy mother +Who all in one, one pleasing note do sing: +Whose speechless song, being many, seeming one, +Sings this to thee: 'thou single wilt prove none.' +Is it for fear to wet a widow's eye +That thou consumest thyself in single life? +Ah! if thou issueless shalt hap to die. +The world will wail thee, like a makeless wife; +The world will be thy widow and still weep +That thou no form of thee hast left behind, +When every private widow well may keep +By children's eyes her husband's shape in mind. +Look, what an unthrift in the world doth spend +Shifts but his place, for still the world enjoys it; +But beauty's waste hath in the world an end, +And kept unused, the user so destroys it. +No love toward others in that bosom sits +That on himself such murderous shame commits. +For shame! deny that thou bear'st love to any, +Who for thyself art so unprovident. +Grant, if thou wilt, thou art beloved of many, +But that thou none lovest is most evident; +For thou art so possess'd with murderous hate +That 'gainst thyself thou stick'st not to conspire. +Seeking that beauteous roof to ruinate +Which to repair should be thy chief desire. +O, change thy thought, that I may change my mind! +Shall hate be fairer lodged than gentle love? +Be, as thy presence is, gracious and kind, +Or to thyself at least kind-hearted prove: +Make thee another self, for love of me, +That beauty still may live in thine or thee. +As fast as thou shalt wane, so fast thou growest +In one of thine, from that which thou departest; +And that fresh blood which youngly thou bestowest +Thou mayst call thine when thou from youth convertest. +Herein lives wisdom, beauty and increase: +Without this, folly, age and cold decay: +If all were minded so, the times should cease +And threescore year would make the world away. +Let those whom Nature hath not made for store, +Harsh featureless and rude, barrenly perish: +Look, whom she best endow'd she gave the more; +Which bounteous gift thou shouldst in bounty cherish: +She carved thee for her seal, and meant thereby +Thou shouldst print more, not let that copy die. +When I do count the clock that tells the time, +And see the brave day sunk in hideous night; +When I behold the violet past prime, +And sable curls all silver'd o'er with white; +When lofty trees I see barren of leaves +Which erst from heat did canopy the herd, +And summer's green all girded up in sheaves +Borne on the bier with white and bristly beard, +Then of thy beauty do I question make, +That thou among the wastes of time must go, +Since sweets and beauties do themselves forsake +And die as fast as they see others grow; +And nothing 'gainst Time's scythe can make defence +Save breed, to brave him when he takes thee hence. +O, that you were yourself! but, love, you are +No longer yours than you yourself here live: +Against this coming end you should prepare, +And your sweet semblance to some other give. +So should that beauty which you hold in lease +Find no determination: then you were +Yourself again after yourself's decease, +When your sweet issue your sweet form should bear. +Who lets so fair a house fall to decay, +Which husbandry in honour might uphold +Against the stormy gusts of winter's day +And barren rage of death's eternal cold? +O, none but unthrifts! Dear my love, you know +You had a father: let your son say so. +Not from the stars do I my judgment pluck; +And yet methinks I have astronomy, +But not to tell of good or evil luck, +Of plagues, of dearths, or seasons' quality; +Nor can I fortune to brief minutes tell, +Pointing to each his thunder, rain and wind, +Or say with princes if it shall go well, +By oft predict that I in heaven find: +But from thine eyes my knowledge I derive, +And, constant stars, in them I read such art +As truth and beauty shall together thrive, +If from thyself to store thou wouldst convert; +Or else of thee this I prognosticate: +Thy end is truth's and beauty's doom and date. +When I consider every thing that grows +Holds in perfection but a little moment, +That this huge stage presenteth nought but shows +Whereon the stars in secret influence comment; +When I perceive that men as plants increase, +Cheered and cheque'd even by the self-same sky, +Vaunt in their youthful sap, at height decrease, +And wear their brave state out of memory; +Then the conceit of this inconstant stay +Sets you most rich in youth before my sight, +Where wasteful Time debateth with Decay, +To change your day of youth to sullied night; +And all in war with Time for love of you, +As he takes from you, I engraft you new. +But wherefore do not you a mightier way +Make war upon this bloody tyrant, Time? +And fortify yourself in your decay +With means more blessed than my barren rhyme? +Now stand you on the top of happy hours, +And many maiden gardens yet unset +With virtuous wish would bear your living flowers, +Much liker than your painted counterfeit: +So should the lines of life that life repair, +Which this, Time's pencil, or my pupil pen, +Neither in inward worth nor outward fair, +Can make you live yourself in eyes of men. +To give away yourself keeps yourself still, +And you must live, drawn by your own sweet skill. +Who will believe my verse in time to come, +If it were fill'd with your most high deserts? +Though yet, heaven knows, it is but as a tomb +Which hides your life and shows not half your parts. +If I could write the beauty of your eyes +And in fresh numbers number all your graces, +The age to come would say 'This poet lies: +Such heavenly touches ne'er touch'd earthly faces.' +So should my papers yellow'd with their age +Be scorn'd like old men of less truth than tongue, +And your true rights be term'd a poet's rage +And stretched metre of an antique song: +But were some child of yours alive that time, +You should live twice; in it and in my rhyme. +Shall I compare thee to a summer's day? +Thou art more lovely and more temperate: +Rough winds do shake the darling buds of May, +And summer's lease hath all too short a date: +Sometime too hot the eye of heaven shines, +And often is his gold complexion dimm'd; +And every fair from fair sometime declines, +By chance or nature's changing course untrimm'd; +But thy eternal summer shall not fade +Nor lose possession of that fair thou owest; +Nor shall Death brag thou wander'st in his shade, +When in eternal lines to time thou growest: +So long as men can breathe or eyes can see, +So long lives this and this gives life to thee. +Devouring Time, blunt thou the lion's paws, +And make the earth devour her own sweet brood; +Pluck the keen teeth from the fierce tiger's jaws, +And burn the long-lived phoenix in her blood; +Make glad and sorry seasons as thou fleets, +And do whate'er thou wilt, swift-footed Time, +To the wide world and all her fading sweets; +But I forbid thee one most heinous crime: +O, carve not with thy hours my love's fair brow, +Nor draw no lines there with thine antique pen; +Him in thy course untainted do allow +For beauty's pattern to succeeding men. +Yet, do thy worst, old Time: despite thy wrong, +My love shall in my verse ever live young. +A woman's face with Nature's own hand painted +Hast thou, the master-mistress of my passion; +A woman's gentle heart, but not acquainted +With shifting change, as is false women's fashion; +An eye more bright than theirs, less false in rolling, +Gilding the object whereupon it gazeth; +A man in hue, all 'hues' in his controlling, +Much steals men's eyes and women's souls amazeth. +And for a woman wert thou first created; +Till Nature, as she wrought thee, fell a-doting, +And by addition me of thee defeated, +By adding one thing to my purpose nothing. +But since she prick'd thee out for women's pleasure, +Mine be thy love and thy love's use their treasure. +So is it not with me as with that Muse +Stirr'd by a painted beauty to his verse, +Who heaven itself for ornament doth use +And every fair with his fair doth rehearse +Making a couplement of proud compare, +With sun and moon, with earth and sea's rich gems, +With April's first-born flowers, and all things rare +That heaven's air in this huge rondure hems. +O' let me, true in love, but truly write, +And then believe me, my love is as fair +As any mother's child, though not so bright +As those gold candles fix'd in heaven's air: +Let them say more than like of hearsay well; +I will not praise that purpose not to sell. +My glass shall not persuade me I am old, +So long as youth and thou are of one date; +But when in thee time's furrows I behold, +Then look I death my days should expiate. +For all that beauty that doth cover thee +Is but the seemly raiment of my heart, +Which in thy breast doth live, as thine in me: +How can I then be elder than thou art? +O, therefore, love, be of thyself so wary +As I, not for myself, but for thee will; +Bearing thy heart, which I will keep so chary +As tender nurse her babe from faring ill. +Presume not on thy heart when mine is slain; +Thou gavest me thine, not to give back again. +As an unperfect actor on the stage +Who with his fear is put besides his part, +Or some fierce thing replete with too much rage, +Whose strength's abundance weakens his own heart. +So I, for fear of trust, forget to say +The perfect ceremony of love's rite, +And in mine own love's strength seem to decay, +O'ercharged with burden of mine own love's might. +O, let my books be then the eloquence +And dumb presagers of my speaking breast, +Who plead for love and look for recompense +More than that tongue that more hath more express'd. +O, learn to read what silent love hath writ: +To hear with eyes belongs to love's fine wit. +Mine eye hath play'd the painter and hath stell'd +Thy beauty's form in table of my heart; +My body is the frame wherein 'tis held, +And perspective it is the painter's art. +For through the painter must you see his skill, +To find where your true image pictured lies; +Which in my bosom's shop is hanging still, +That hath his windows glazed with thine eyes. +Now see what good turns eyes for eyes have done: +Mine eyes have drawn thy shape, and thine for me +Are windows to my breast, where-through the sun +Delights to peep, to gaze therein on thee; +Yet eyes this cunning want to grace their art; +They draw but what they see, know not the heart. +Let those who are in favour with their stars +Of public honour and proud titles boast, +Whilst I, whom fortune of such triumph bars, +Unlook'd for joy in that I honour most. +Great princes' favourites their fair leaves spread +But as the marigold at the sun's eye, +And in themselves their pride lies buried, +For at a frown they in their glory die. +The painful warrior famoused for fight, +After a thousand victories once foil'd, +Is from the book of honour razed quite, +And all the rest forgot for which he toil'd: +Then happy I, that love and am beloved +Where I may not remove nor be removed. +Lord of my love, to whom in vassalage +Thy merit hath my duty strongly knit, +To thee I send this written embassage, +To witness duty, not to show my wit: +Duty so great, which wit so poor as mine +May make seem bare, in wanting words to show it, +But that I hope some good conceit of thine +In thy soul's thought, all naked, will bestow it; +Till whatsoever star that guides my moving +Points on me graciously with fair aspect +And puts apparel on my tatter'd loving, +To show me worthy of thy sweet respect: +Then may I dare to boast how I do love thee; +Till then not show my head where thou mayst prove me. +Weary with toil, I haste me to my bed, +The dear repose for limbs with travel tired; +But then begins a journey in my head, +To work my mind, when body's work's expired: +For then my thoughts, from far where I abide, +Intend a zealous pilgrimage to thee, +And keep my drooping eyelids open wide, +Looking on darkness which the blind do see +Save that my soul's imaginary sight +Presents thy shadow to my sightless view, +Which, like a jewel hung in ghastly night, +Makes black night beauteous and her old face new. +Lo! thus, by day my limbs, by night my mind, +For thee and for myself no quiet find. +How can I then return in happy plight, +That am debarr'd the benefit of rest? +When day's oppression is not eased by night, +But day by night, and night by day, oppress'd? +And each, though enemies to either's reign, +Do in consent shake hands to torture me; +The one by toil, the other to complain +How far I toil, still farther off from thee. +I tell the day, to please them thou art bright +And dost him grace when clouds do blot the heaven: +So flatter I the swart-complexion'd night, +When sparkling stars twire not thou gild'st the even. +But day doth daily draw my sorrows longer +And night doth nightly make grief's strength seem stronger. +When, in disgrace with fortune and men's eyes, +I all alone beweep my outcast state +And trouble deal heaven with my bootless cries +And look upon myself and curse my fate, +Wishing me like to one more rich in hope, +Featured like him, like him with friends possess'd, +Desiring this man's art and that man's scope, +With what I most enjoy contented least; +Yet in these thoughts myself almost despising, +Haply I think on thee, and then my state, +Like to the lark at break of day arising +From sullen earth, sings hymns at heaven's gate; +For thy sweet love remember'd such wealth brings +That then I scorn to change my state with kings. +When to the sessions of sweet silent thought +I summon up remembrance of things past, +I sigh the lack of many a thing I sought, +And with old woes new wail my dear time's waste: +Then can I drown an eye, unused to flow, +For precious friends hid in death's dateless night, +And weep afresh love's long since cancell'd woe, +And moan the expense of many a vanish'd sight: +Then can I grieve at grievances foregone, +And heavily from woe to woe tell o'er +The sad account of fore-bemoaned moan, +Which I new pay as if not paid before. +But if the while I think on thee, dear friend, +All losses are restored and sorrows end. +Thy bosom is endeared with all hearts, +Which I by lacking have supposed dead, +And there reigns love and all love's loving parts, +And all those friends which I thought buried. +How many a holy and obsequious tear +Hath dear religious love stol'n from mine eye +As interest of the dead, which now appear +But things removed that hidden in thee lie! +Thou art the grave where buried love doth live, +Hung with the trophies of my lovers gone, +Who all their parts of me to thee did give; +That due of many now is thine alone: +Their images I loved I view in thee, +And thou, all they, hast all the all of me. +If thou survive my well-contented day, +When that churl Death my bones with dust shall cover, +And shalt by fortune once more re-survey +These poor rude lines of thy deceased lover, +Compare them with the bettering of the time, +And though they be outstripp'd by every pen, +Reserve them for my love, not for their rhyme, +Exceeded by the height of happier men. +O, then vouchsafe me but this loving thought: +'Had my friend's Muse grown with this growing age, +A dearer birth than this his love had brought, +To march in ranks of better equipage: +But since he died and poets better prove, +Theirs for their style I'll read, his for his love.' +Full many a glorious morning have I seen +Flatter the mountain-tops with sovereign eye, +Kissing with golden face the meadows green, +Gilding pale streams with heavenly alchemy; +Anon permit the basest clouds to ride +With ugly rack on his celestial face, +And from the forlorn world his visage hide, +Stealing unseen to west with this disgrace: +Even so my sun one early morn did shine +With all triumphant splendor on my brow; +But out, alack! he was but one hour mine; +The region cloud hath mask'd him from me now. +Yet him for this my love no whit disdaineth; +Suns of the world may stain when heaven's sun staineth. +Why didst thou promise such a beauteous day, +And make me travel forth without my cloak, +To let base clouds o'ertake me in my way, +Hiding thy bravery in their rotten smoke? +'Tis not enough that through the cloud thou break, +To dry the rain on my storm-beaten face, +For no man well of such a salve can speak +That heals the wound and cures not the disgrace: +Nor can thy shame give physic to my grief; +Though thou repent, yet I have still the loss: +The offender's sorrow lends but weak relief +To him that bears the strong offence's cross. +Ah! but those tears are pearl which thy love sheds, +And they are rich and ransom all ill deeds. +No more be grieved at that which thou hast done: +Roses have thorns, and silver fountains mud; +Clouds and eclipses stain both moon and sun, +And loathsome canker lives in sweetest bud. +All men make faults, and even I in this, +Authorizing thy trespass with compare, +Myself corrupting, salving thy amiss, +Excusing thy sins more than thy sins are; +For to thy sensual fault I bring in sense-- +Thy adverse party is thy advocate-- +And 'gainst myself a lawful plea commence: +Such civil war is in my love and hate +That I an accessary needs must be +To that sweet thief which sourly robs from me. +Let me confess that we two must be twain, +Although our undivided loves are one: +So shall those blots that do with me remain +Without thy help by me be borne alone. +In our two loves there is but one respect, +Though in our lives a separable spite, +Which though it alter not love's sole effect, +Yet doth it steal sweet hours from love's delight. +I may not evermore acknowledge thee, +Lest my bewailed guilt should do thee shame, +Nor thou with public kindness honour me, +Unless thou take that honour from thy name: +But do not so; I love thee in such sort +As, thou being mine, mine is thy good report. +As a decrepit father takes delight +To see his active child do deeds of youth, +So I, made lame by fortune's dearest spite, +Take all my comfort of thy worth and truth. +For whether beauty, birth, or wealth, or wit, +Or any of these all, or all, or more, +Entitled in thy parts do crowned sit, +I make my love engrafted to this store: +So then I am not lame, poor, nor despised, +Whilst that this shadow doth such substance give +That I in thy abundance am sufficed +And by a part of all thy glory live. +Look, what is best, that best I wish in thee: +This wish I have; then ten times happy me!FROM fairest creatures we desire increase, +That thereby beauty's rose might never die, +But as the riper should by time decease, +His tender heir might bear his memory: +But thou, contracted to thine own bright eyes, +Feed'st thy light'st flame with self-substantial fuel, +Making a famine where abundance lies, +Thyself thy foe, to thy sweet self too cruel. +Thou that art now the world's fresh ornament +And only herald to the gaudy spring, +Within thine own bud buriest thy content +And, tender churl, makest waste in niggarding. +Pity the world, or else this glutton be, +To eat the world's due, by the grave and thee. +When forty winters shall beseige thy brow, +And dig deep trenches in thy beauty's field, +Thy youth's proud livery, so gazed on now, +Will be a tatter'd weed, of small worth held: +Then being ask'd where all thy beauty lies, +Where all the treasure of thy lusty days, +To say, within thine own deep-sunken eyes, +Were an all-eating shame and thriftless praise. +How much more praise deserved thy beauty's use, +If thou couldst answer 'This fair child of mine +Shall sum my count and make my old excuse,' +Proving his beauty by succession thine! +This were to be new made when thou art old, +And see thy blood warm when thou feel'st it cold. +Look in thy glass, and tell the face thou viewest +Now is the time that face should form another; +Whose fresh repair if now thou not renewest, +Thou dost beguile the world, unbless some mother. +For where is she so fair whose unear'd womb +Disdains the tillage of thy husbandry? +Or who is he so fond will be the tomb +Of his self-love, to stop posterity? +Thou art thy mother's glass, and she in thee +Calls back the lovely April of her prime: +So thou through windows of thine age shall see +Despite of wrinkles this thy golden time. +But if thou live, remember'd not to be, +Die single, and thine image dies with thee. +Unthrifty loveliness, why dost thou spend +Upon thyself thy beauty's legacy? +Nature's bequest gives nothing but doth lend, +And being frank she lends to those are free. +Then, beauteous niggard, why dost thou abuse +The bounteous largess given thee to give? +Profitless usurer, why dost thou use +So great a sum of sums, yet canst not live? +For having traffic with thyself alone, +Thou of thyself thy sweet self dost deceive. +Then how, when nature calls thee to be gone, +What acceptable audit canst thou leave? +Thy unused beauty must be tomb'd with thee, +Which, used, lives th' executor to be. +Those hours, that with gentle work did frame +The lovely gaze where every eye doth dwell, +Will play the tyrants to the very same +And that unfair which fairly doth excel: +For never-resting time leads summer on +To hideous winter and confounds him there; +Sap cheque'd with frost and lusty leaves quite gone, +Beauty o'ersnow'd and bareness every where: +Then, were not summer's distillation left, +A liquid prisoner pent in walls of glass, +Beauty's effect with beauty were bereft, +Nor it nor no remembrance what it was: +But flowers distill'd though they with winter meet, +Leese but their show; their substance still lives sweet. +Then let not winter's ragged hand deface +In thee thy summer, ere thou be distill'd: +Make sweet some vial; treasure thou some place +With beauty's treasure, ere it be self-kill'd. +That use is not forbidden usury, +Which happies those that pay the willing loan; +That's for thyself to breed another thee, +Or ten times happier, be it ten for one; +Ten times thyself were happier than thou art, +If ten of thine ten times refigured thee: +Then what could death do, if thou shouldst depart, +Leaving thee living in posterity? +Be not self-will'd, for thou art much too fair +To be death's conquest and make worms thine heir. +Lo! in the orient when the gracious light +Lifts up his burning head, each under eye +Doth homage to his new-appearing sight, +Serving with looks his sacred majesty; +And having climb'd the steep-up heavenly hill, +Resembling strong youth in his middle age, +yet mortal looks adore his beauty still, +Attending on his golden pilgrimage; +But when from highmost pitch, with weary car, +Like feeble age, he reeleth from the day, +The eyes, 'fore duteous, now converted are +From his low tract and look another way: +So thou, thyself out-going in thy noon, +Unlook'd on diest, unless thou get a son. +Music to hear, why hear'st thou music sadly? +Sweets with sweets war not, joy delights in joy. +Why lovest thou that which thou receivest not gladly, +Or else receivest with pleasure thine annoy? +If the true concord of well-tuned sounds, +By unions married, do offend thine ear, +They do but sweetly chide thee, who confounds +In singleness the parts that thou shouldst bear. +Mark how one string, sweet husband to another, +Strikes each in each by mutual ordering, +Resembling sire and child and happy mother +Who all in one, one pleasing note do sing: +Whose speechless song, being many, seeming one, +Sings this to thee: 'thou single wilt prove none.' +Is it for fear to wet a widow's eye +That thou consumest thyself in single life? +Ah! if thou issueless shalt hap to die. +The world will wail thee, like a makeless wife; +The world will be thy widow and still weep +That thou no form of thee hast left behind, +When every private widow well may keep +By children's eyes her husband's shape in mind. +Look, what an unthrift in the world doth spend +Shifts but his place, for still the world enjoys it; +But beauty's waste hath in the world an end, +And kept unused, the user so destroys it. +No love toward others in that bosom sits +That on himself such murderous shame commits. +For shame! deny that thou bear'st love to any, +Who for thyself art so unprovident. +Grant, if thou wilt, thou art beloved of many, +But that thou none lovest is most evident; +For thou art so possess'd with murderous hate +That 'gainst thyself thou stick'st not to conspire. +Seeking that beauteous roof to ruinate +Which to repair should be thy chief desire. +O, change thy thought, that I may change my mind! +Shall hate be fairer lodged than gentle love? +Be, as thy presence is, gracious and kind, +Or to thyself at least kind-hearted prove: +Make thee another self, for love of me, +That beauty still may live in thine or thee. +As fast as thou shalt wane, so fast thou growest +In one of thine, from that which thou departest; +And that fresh blood which youngly thou bestowest +Thou mayst call thine when thou from youth convertest. +Herein lives wisdom, beauty and increase: +Without this, folly, age and cold decay: +If all were minded so, the times should cease +And threescore year would make the world away. +Let those whom Nature hath not made for store, +Harsh featureless and rude, barrenly perish: +Look, whom she best endow'd she gave the more; +Which bounteous gift thou shouldst in bounty cherish: +She carved thee for her seal, and meant thereby +Thou shouldst print more, not let that copy die. +When I do count the clock that tells the time, +And see the brave day sunk in hideous night; +When I behold the violet past prime, +And sable curls all silver'd o'er with white; +When lofty trees I see barren of leaves +Which erst from heat did canopy the herd, +And summer's green all girded up in sheaves +Borne on the bier with white and bristly beard, +Then of thy beauty do I question make, +That thou among the wastes of time must go, +Since sweets and beauties do themselves forsake +And die as fast as they see others grow; +And nothing 'gainst Time's scythe can make defence +Save breed, to brave him when he takes thee hence. +O, that you were yourself! but, love, you are +No longer yours than you yourself here live: +Against this coming end you should prepare, +And your sweet semblance to some other give. +So should that beauty which you hold in lease +Find no determination: then you were +Yourself again after yourself's decease, +When your sweet issue your sweet form should bear. +Who lets so fair a house fall to decay, +Which husbandry in honour might uphold +Against the stormy gusts of winter's day +And barren rage of death's eternal cold? +O, none but unthrifts! Dear my love, you know +You had a father: let your son say so. +Not from the stars do I my judgment pluck; +And yet methinks I have astronomy, +But not to tell of good or evil luck, +Of plagues, of dearths, or seasons' quality; +Nor can I fortune to brief minutes tell, +Pointing to each his thunder, rain and wind, +Or say with princes if it shall go well, +By oft predict that I in heaven find: +But from thine eyes my knowledge I derive, +And, constant stars, in them I read such art +As truth and beauty shall together thrive, +If from thyself to store thou wouldst convert; +Or else of thee this I prognosticate: +Thy end is truth's and beauty's doom and date. +When I consider every thing that grows +Holds in perfection but a little moment, +That this huge stage presenteth nought but shows +Whereon the stars in secret influence comment; +When I perceive that men as plants increase, +Cheered and cheque'd even by the self-same sky, +Vaunt in their youthful sap, at height decrease, +And wear their brave state out of memory; +Then the conceit of this inconstant stay +Sets you most rich in youth before my sight, +Where wasteful Time debateth with Decay, +To change your day of youth to sullied night; +And all in war with Time for love of you, +As he takes from you, I engraft you new. +But wherefore do not you a mightier way +Make war upon this bloody tyrant, Time? +And fortify yourself in your decay +With means more blessed than my barren rhyme? +Now stand you on the top of happy hours, +And many maiden gardens yet unset +With virtuous wish would bear your living flowers, +Much liker than your painted counterfeit: +So should the lines of life that life repair, +Which this, Time's pencil, or my pupil pen, +Neither in inward worth nor outward fair, +Can make you live yourself in eyes of men. +To give away yourself keeps yourself still, +And you must live, drawn by your own sweet skill. +Who will believe my verse in time to come, +If it were fill'd with your most high deserts? +Though yet, heaven knows, it is but as a tomb +Which hides your life and shows not half your parts. +If I could write the beauty of your eyes +And in fresh numbers number all your graces, +The age to come would say 'This poet lies: +Such heavenly touches ne'er touch'd earthly faces.' +So should my papers yellow'd with their age +Be scorn'd like old men of less truth than tongue, +And your true rights be term'd a poet's rage +And stretched metre of an antique song: +But were some child of yours alive that time, +You should live twice; in it and in my rhyme. +Shall I compare thee to a summer's day? +Thou art more lovely and more temperate: +Rough winds do shake the darling buds of May, +And summer's lease hath all too short a date: +Sometime too hot the eye of heaven shines, +And often is his gold complexion dimm'd; +And every fair from fair sometime declines, +By chance or nature's changing course untrimm'd; +But thy eternal summer shall not fade +Nor lose possession of that fair thou owest; +Nor shall Death brag thou wander'st in his shade, +When in eternal lines to time thou growest: +So long as men can breathe or eyes can see, +So long lives this and this gives life to thee. +Devouring Time, blunt thou the lion's paws, +And make the earth devour her own sweet brood; +Pluck the keen teeth from the fierce tiger's jaws, +And burn the long-lived phoenix in her blood; +Make glad and sorry seasons as thou fleets, +And do whate'er thou wilt, swift-footed Time, +To the wide world and all her fading sweets; +But I forbid thee one most heinous crime: +O, carve not with thy hours my love's fair brow, +Nor draw no lines there with thine antique pen; +Him in thy course untainted do allow +For beauty's pattern to succeeding men. +Yet, do thy worst, old Time: despite thy wrong, +My love shall in my verse ever live young. +A woman's face with Nature's own hand painted +Hast thou, the master-mistress of my passion; +A woman's gentle heart, but not acquainted +With shifting change, as is false women's fashion; +An eye more bright than theirs, less false in rolling, +Gilding the object whereupon it gazeth; +A man in hue, all 'hues' in his controlling, +Much steals men's eyes and women's souls amazeth. +And for a woman wert thou first created; +Till Nature, as she wrought thee, fell a-doting, +And by addition me of thee defeated, +By adding one thing to my purpose nothing. +But since she prick'd thee out for women's pleasure, +Mine be thy love and thy love's use their treasure. +So is it not with me as with that Muse +Stirr'd by a painted beauty to his verse, +Who heaven itself for ornament doth use +And every fair with his fair doth rehearse +Making a couplement of proud compare, +With sun and moon, with earth and sea's rich gems, +With April's first-born flowers, and all things rare +That heaven's air in this huge rondure hems. +O' let me, true in love, but truly write, +And then believe me, my love is as fair +As any mother's child, though not so bright +As those gold candles fix'd in heaven's air: +Let them say more than like of hearsay well; +I will not praise that purpose not to sell. +My glass shall not persuade me I am old, +So long as youth and thou are of one date; +But when in thee time's furrows I behold, +Then look I death my days should expiate. +For all that beauty that doth cover thee +Is but the seemly raiment of my heart, +Which in thy breast doth live, as thine in me: +How can I then be elder than thou art? +O, therefore, love, be of thyself so wary +As I, not for myself, but for thee will; +Bearing thy heart, which I will keep so chary +As tender nurse her babe from faring ill. +Presume not on thy heart when mine is slain; +Thou gavest me thine, not to give back again. +As an unperfect actor on the stage +Who with his fear is put besides his part, +Or some fierce thing replete with too much rage, +Whose strength's abundance weakens his own heart. +So I, for fear of trust, forget to say +The perfect ceremony of love's rite, +And in mine own love's strength seem to decay, +O'ercharged with burden of mine own love's might. +O, let my books be then the eloquence +And dumb presagers of my speaking breast, +Who plead for love and look for recompense +More than that tongue that more hath more express'd. +O, learn to read what silent love hath writ: +To hear with eyes belongs to love's fine wit. +Mine eye hath play'd the painter and hath stell'd +Thy beauty's form in table of my heart; +My body is the frame wherein 'tis held, +And perspective it is the painter's art. +For through the painter must you see his skill, +To find where your true image pictured lies; +Which in my bosom's shop is hanging still, +That hath his windows glazed with thine eyes. +Now see what good turns eyes for eyes have done: +Mine eyes have drawn thy shape, and thine for me +Are windows to my breast, where-through the sun +Delights to peep, to gaze therein on thee; +Yet eyes this cunning want to grace their art; +They draw but what they see, know not the heart. +Let those who are in favour with their stars +Of public honour and proud titles boast, +Whilst I, whom fortune of such triumph bars, +Unlook'd for joy in that I honour most. +Great princes' favourites their fair leaves spread +But as the marigold at the sun's eye, +And in themselves their pride lies buried, +For at a frown they in their glory die. +The painful warrior famoused for fight, +After a thousand victories once foil'd, +Is from the book of honour razed quite, +And all the rest forgot for which he toil'd: +Then happy I, that love and am beloved +Where I may not remove nor be removed. +Lord of my love, to whom in vassalage +Thy merit hath my duty strongly knit, +To thee I send this written embassage, +To witness duty, not to show my wit: +Duty so great, which wit so poor as mine +May make seem bare, in wanting words to show it, +But that I hope some good conceit of thine +In thy soul's thought, all naked, will bestow it; +Till whatsoever star that guides my moving +Points on me graciously with fair aspect +And puts apparel on my tatter'd loving, +To show me worthy of thy sweet respect: +Then may I dare to boast how I do love thee; +Till then not show my head where thou mayst prove me. +Weary with toil, I haste me to my bed, +The dear repose for limbs with travel tired; +But then begins a journey in my head, +To work my mind, when body's work's expired: +For then my thoughts, from far where I abide, +Intend a zealous pilgrimage to thee, +And keep my drooping eyelids open wide, +Looking on darkness which the blind do see +Save that my soul's imaginary sight +Presents thy shadow to my sightless view, +Which, like a jewel hung in ghastly night, +Makes black night beauteous and her old face new. +Lo! thus, by day my limbs, by night my mind, +For thee and for myself no quiet find. +How can I then return in happy plight, +That am debarr'd the benefit of rest? +When day's oppression is not eased by night, +But day by night, and night by day, oppress'd? +And each, though enemies to either's reign, +Do in consent shake hands to torture me; +The one by toil, the other to complain +How far I toil, still farther off from thee. +I tell the day, to please them thou art bright +And dost him grace when clouds do blot the heaven: +So flatter I the swart-complexion'd night, +When sparkling stars twire not thou gild'st the even. +But day doth daily draw my sorrows longer +And night doth nightly make grief's strength seem stronger. +When, in disgrace with fortune and men's eyes, +I all alone beweep my outcast state +And trouble deal heaven with my bootless cries +And look upon myself and curse my fate, +Wishing me like to one more rich in hope, +Featured like him, like him with friends possess'd, +Desiring this man's art and that man's scope, +With what I most enjoy contented least; +Yet in these thoughts myself almost despising, +Haply I think on thee, and then my state, +Like to the lark at break of day arising +From sullen earth, sings hymns at heaven's gate; +For thy sweet love remember'd such wealth brings +That then I scorn to change my state with kings. +When to the sessions of sweet silent thought +I summon up remembrance of things past, +I sigh the lack of many a thing I sought, +And with old woes new wail my dear time's waste: +Then can I drown an eye, unused to flow, +For precious friends hid in death's dateless night, +And weep afresh love's long since cancell'd woe, +And moan the expense of many a vanish'd sight: +Then can I grieve at grievances foregone, +And heavily from woe to woe tell o'er +The sad account of fore-bemoaned moan, +Which I new pay as if not paid before. +But if the while I think on thee, dear friend, +All losses are restored and sorrows end. +Thy bosom is endeared with all hearts, +Which I by lacking have supposed dead, +And there reigns love and all love's loving parts, +And all those friends which I thought buried. +How many a holy and obsequious tear +Hath dear religious love stol'n from mine eye +As interest of the dead, which now appear +But things removed that hidden in thee lie! +Thou art the grave where buried love doth live, +Hung with the trophies of my lovers gone, +Who all their parts of me to thee did give; +That due of many now is thine alone: +Their images I loved I view in thee, +And thou, all they, hast all the all of me. +If thou survive my well-contented day, +When that churl Death my bones with dust shall cover, +And shalt by fortune once more re-survey +These poor rude lines of thy deceased lover, +Compare them with the bettering of the time, +And though they be outstripp'd by every pen, +Reserve them for my love, not for their rhyme, +Exceeded by the height of happier men. +O, then vouchsafe me but this loving thought: +'Had my friend's Muse grown with this growing age, +A dearer birth than this his love had brought, +To march in ranks of better equipage: +But since he died and poets better prove, +Theirs for their style I'll read, his for his love.' +Full many a glorious morning have I seen +Flatter the mountain-tops with sovereign eye, +Kissing with golden face the meadows green, +Gilding pale streams with heavenly alchemy; +Anon permit the basest clouds to ride +With ugly rack on his celestial face, +And from the forlorn world his visage hide, +Stealing unseen to west with this disgrace: +Even so my sun one early morn did shine +With all triumphant splendor on my brow; +But out, alack! he was but one hour mine; +The region cloud hath mask'd him from me now. +Yet him for this my love no whit disdaineth; +Suns of the world may stain when heaven's sun staineth. +Why didst thou promise such a beauteous day, +And make me travel forth without my cloak, +To let base clouds o'ertake me in my way, +Hiding thy bravery in their rotten smoke? +'Tis not enough that through the cloud thou break, +To dry the rain on my storm-beaten face, +For no man well of such a salve can speak +That heals the wound and cures not the disgrace: +Nor can thy shame give physic to my grief; +Though thou repent, yet I have still the loss: +The offender's sorrow lends but weak relief +To him that bears the strong offence's cross. +Ah! but those tears are pearl which thy love sheds, +And they are rich and ransom all ill deeds. +No more be grieved at that which thou hast done: +Roses have thorns, and silver fountains mud; +Clouds and eclipses stain both moon and sun, +And loathsome canker lives in sweetest bud. +All men make faults, and even I in this, +Authorizing thy trespass with compare, +Myself corrupting, salving thy amiss, +Excusing thy sins more than thy sins are; +For to thy sensual fault I bring in sense-- +Thy adverse party is thy advocate-- +And 'gainst myself a lawful plea commence: +Such civil war is in my love and hate +That I an accessary needs must be +To that sweet thief which sourly robs from me. +Let me confess that we two must be twain, +Although our undivided loves are one: +So shall those blots that do with me remain +Without thy help by me be borne alone. +In our two loves there is but one respect, +Though in our lives a separable spite, +Which though it alter not love's sole effect, +Yet doth it steal sweet hours from love's delight. +I may not evermore acknowledge thee, +Lest my bewailed guilt should do thee shame, +Nor thou with public kindness honour me, +Unless thou take that honour from thy name: +But do not so; I love thee in such sort +As, thou being mine, mine is thy good report. +As a decrepit father takes delight +To see his active child do deeds of youth, +So I, made lame by fortune's dearest spite, +Take all my comfort of thy worth and truth. +For whether beauty, birth, or wealth, or wit, +Or any of these all, or all, or more, +Entitled in thy parts do crowned sit, +I make my love engrafted to this store: +So then I am not lame, poor, nor despised, +Whilst that this shadow doth such substance give +That I in thy abundance am sufficed +And by a part of all thy glory live. +Look, what is best, that best I wish in thee: +This wish I have; then ten times happy me! \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py index 760a9662d5155..53d97fe295292 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py @@ -139,7 +139,7 @@ def quick_recv(self): device = self.device else: device = 'cpu' - buffer = torch.zeros(shape, dtype=dtype).to(device, non_blocking=True) + buffer = torch.zeros(shape, dtype=dtype).to(device) torch.distributed.recv( buffer, @@ -184,8 +184,6 @@ def send_tensor(self, else: tensor_size = tensor.element_size() * tensor.numel() - assert 0 < len(tensor.shape) < 100, "Send tensor does not support tensor with 0 dim or >=100 dim. Got %d" % len(tensor.shape) - self.block_if_full() with self.buffer_size_lock: diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 6fd94312483e5..1b60b444709e6 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -295,6 +295,7 @@ def execute_model( model_input, self.kv_cache[worker_input.virtual_engine], ) + assert bypass_model_exec if not bypass_model_exec: hidden_or_intermediate_states = self.model_runner.execute_model( From bb865884e086a7f3617e36e8b4428565e80a2ecb Mon Sep 17 00:00:00 2001 From: ApostaC Date: Tue, 10 Sep 2024 18:32:53 -0500 Subject: [PATCH 172/303] [Add] optimized implementation for KV transfer pipe --- .../kv_pipe/torch_distributed_pipe.py | 247 ++++++++++++------ 1 file changed, 166 insertions(+), 81 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py index 53d97fe295292..994afc82ffa4c 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py @@ -1,20 +1,12 @@ - -from vllm.distributed.group_coordinator import GroupCoordinator -from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase -from torch.distributed import Backend, ProcessGroup +from torch.distributed import Backend import torch -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import List, Optional, Union import threading from concurrent.futures import ThreadPoolExecutor import time -import threading -from collections import namedtuple -from typing import Dict, Any, Tuple, List -import pickle from vllm.logger import init_logger - logger = init_logger(__name__) @@ -52,34 +44,32 @@ def __init__(self, message): self.message = message super().__init__(self.message) -class TorchDistributedPipe(KVPipeBase): - + +class TorchDistributedPipe: + METADATA_LENGTH = 16 + MAX_TENSOR_DIMENSIONS = 14 + METADATA_DTYPE = torch.int64 + def __init__( self, group_ranks: List[List[int]], local_rank: int, - torch_distributed_backend: Union[str, Backend] + torch_distributed_backend: Union[str, Backend], ): - self.rank = torch.distributed.get_rank() self.local_rank = local_rank self.device_group = None - self.cpu_group = None for ranks in group_ranks: device_group = torch.distributed.new_group( - ranks, backend=torch_distributed_backend) - # a group with `gloo` backend, to allow direct coordination between - # processes through the CPU. - cpu_group = torch.distributed.new_group(ranks, backend="gloo") + ranks, backend=torch_distributed_backend + ) if self.rank in ranks: self.ranks = ranks self.world_size = len(ranks) self.rank_in_group = ranks.index(self.rank) self.device_group = device_group - self.cpu_group = cpu_group - assert self.cpu_group is not None assert self.device_group is not None assert self.rank_in_group <= 1 @@ -88,95 +78,167 @@ def __init__( else: self.device = torch.device("cpu") - # if turned on, will use CPU-based communication to perform a series of sanity check. - # but it adds ~5ms delay, so please turn it off in performance-demanding usecases (e.g. disaggregated prefill) - self.target_rank_for_send = self.ranks[(self.rank_in_group + 1) % - self.world_size] - self.target_rank_for_recv = self.ranks[(self.rank_in_group - 1) % - self.world_size] + self.target_rank_for_send = self.ranks[ + (self.rank_in_group + 1) % self.world_size + ] + self.target_rank_for_recv = self.ranks[ + (self.rank_in_group - 1) % self.world_size + ] + + # FIXME: why we need this? torch.set_default_device(self.device) - self.kv_sending_thread = None + self.transport_thread = None self.buffer_size = 0 self.buffer_size_lock = threading.Lock() - self.none_tensor = torch.tensor([NONE_INT]).to(self.device) - self.broken = False + self.none_tensor = torch.tensor([NONE_INT], device=self.device) + + # On-device tensors to be reused for recv + self.rcv_metadata_buffer = torch.zeros( + self.METADATA_LENGTH, dtype=self.METADATA_DTYPE, device=self.device + ) + + def _make_metadata(self, tensor: torch.Tensor) -> torch.Tensor: + """ + Create the metadata on based on the input tensor, and move it to GPU. + The metadata's length is `TorchDistributedPipe.METADATA_LENGTH`. - - def quick_send(self, tensor): + Currently, the metadata is a int64 tensor and it includes dtype, number + of dimensions, and the shape information of the input tensor. - group = self.device_group - # NCCL is NOT fully duplex - # so CPU communication is ALWAYS necessary - torch.distributed.send_object_list( - [tensor.dtype, tensor.shape, str(tensor.device)], - dst=self.target_rank_for_send, - group=self.cpu_group + The information follows the layout below: + - metadata[0] -- dtype + - metadata[1] -- number of dimensions + - metadata[2 : 2+ndims] -- the shape of the input tensor + + Parameters: + - tensor: the input tensor + + Returns: + - metadata: the metadata tensor, on self.device + """ + buffer = torch.empty(self.METADATA_LENGTH, dtype=self.METADATA_DTYPE) + buffer[0] = DTYPE2INT[tensor.dtype] + ndims = len(tensor.shape) + buffer[1] = len(tensor.shape) + buffer[2 : 2 + ndims] = torch.tensor( + tensor.shape, dtype=self.METADATA_DTYPE ) + return buffer.to(self.device) + + def _prepare_recv_buffer( + self, d_metadata_buffer: torch.Tensor + ) -> torch.Tensor: + """ + Create a buffer to receive the tensor based on the metadata. + + Parameters: + - d_metadata_buffer: the metadata tensor on self.device + + Returns: + - buffer: the buffer tensor to receive the tensor, on self.device + """ + h_buffer = d_metadata_buffer.cpu().numpy() + dtype = INT2DTYPE[h_buffer[0]] + ndims = h_buffer[1] + shape = tuple(h_buffer[2 : 2 + ndims]) + return torch.empty(shape, dtype=dtype, device=self.device) + def _send_metadata(self, d_metadata_buffer: torch.Tensor): + """ + Send the metadata buffer to the target rank. + """ torch.distributed.send( - tensor, + d_metadata_buffer, dst=self.target_rank_for_send, - group=self.device_group + group=self.device_group, ) + def _recv_metadata(self) -> torch.Tensor: + """ + Receive the metadata buffer from the target rank. - def quick_recv(self): + Returns: + - metadata_buffer: the metadata buffer tensor, on self.device - # NCCL is NOT fully duplex - # so CPU communication is necessary - metadata = [None, None, None] - torch.distributed.recv_object_list( - metadata, + Note: + The current implementation uses the assumption that there is no + race conditions during sending/receiving. Therefore, the metadata + buffer can be reused + """ + torch.distributed.recv( + self.rcv_metadata_buffer, src=self.target_rank_for_recv, - group=self.cpu_group + group=self.device_group, ) - - dtype, shape, device = metadata - if 'cuda' in device: - device = self.device - else: - device = 'cpu' - buffer = torch.zeros(shape, dtype=dtype).to(device) - + return self.rcv_metadata_buffer + + def _send_impl(self, tensor): + """ + The actual implementation of sending the tensor to the target rank. + This function will first send the metadata, and then send the tensor. + + Parameters: + - tensor: the input tensor to be sent + """ + + metadata = self._make_metadata(tensor) + self._send_metadata(metadata) + + torch.distributed.send( + tensor, dst=self.target_rank_for_send, group=self.device_group + ) + + def _recv_impl(self) -> torch.Tensor: + """ + The actual implementation of receiving the tensor from the target rank. + This function will first receive the metadata, then receive the tensor. + + This function will block if there is no tensor to receive. + + Returns: + - buffer: the received tensor, on self.device + """ + d_metadata = self._recv_metadata() + buffer = self._prepare_recv_buffer(d_metadata) + torch.distributed.recv( - buffer, - src=self.target_rank_for_recv, - group=self.device_group + buffer, src=self.target_rank_for_recv, group=self.device_group ) - return buffer - - - def send_tensor_wrapper(self, tensor) -> None: + return buffer + def send_tensor_wrapper(self, tensor): try: + """Wrapper for send_tensor_dict""" tensor_size = tensor.element_size() * tensor.numel() - self.quick_send(tensor) - + self._send_impl(tensor) + with self.buffer_size_lock: self.buffer_size = self.buffer_size - tensor_size except Exception as e: logger.error("Encountering exception in KV sending thread") logger.error("%s", e) - + def block_if_full(self): - + """ + Block the current thread if the buffer size is larger than 1e9. + """ + # TODO: replace this 1e9 with a configurable parameter or a constant while self.buffer_size > 1e9: logger.debug("KV cache transfer pipe is full. Waiting...") time.sleep(0.05) - def send_tensor(self, - tensor: Optional[torch.Tensor]) -> None: + def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: """ Sends a tensor to the destination rank in a non-blocking way. Flow: send tensor dim -- send tensor shape -- send tensor data """ - - if self.kv_sending_thread is None: - self.kv_sending_thread = ThreadPoolExecutor(max_workers=1) + + if self.transport_thread is None: + self.transport_thread = ThreadPoolExecutor(max_workers=1) if tensor is None: tensor = self.none_tensor @@ -184,24 +246,47 @@ def send_tensor(self, else: tensor_size = tensor.element_size() * tensor.numel() + assert ( + 0 < len(tensor.shape) < self.MAX_TENSOR_DIMENSIONS + ), f"Only support dimensions within 1-{self.MAX_TENSOR_DIMENSIONS}" + self.block_if_full() with self.buffer_size_lock: + # print("Remaining size:", self.buffer_size) self.buffer_size = self.buffer_size + tensor_size - + # prepare the metadata before sending the tensor. - self.kv_sending_thread.submit( - self.send_tensor_wrapper, - tensor) - + self.transport_thread.submit( + self.send_tensor_wrapper, + tensor, + ) + def recv_tensor(self) -> Optional[torch.Tensor]: """Receives a tensor from the src rank. Blocking.""" - - tensor = self.quick_recv() + + if self.transport_thread is None: + self.transport_thread = ThreadPoolExecutor(max_workers=1) + + future = self.transport_thread.submit(self._recv_impl) + + try: + tensor = future.result() + except Exception as e: + logger.error("Encountering exception in KV receiving thread") + logger.error("%s", e) + if tensor.numel() == 1 and tensor.item() == NONE_INT: return None else: return tensor - - + def close(self): + """ + Close the pipe and release the resources. + """ + if ( + hasattr(self, "transport_thread") + and self.transport_thread is not None + ): + self.transport_thread.shutdown() From ffb792bd7bfde2b0d85d0b963b146e9827a1400f Mon Sep 17 00:00:00 2001 From: ApostaC Date: Fri, 13 Sep 2024 01:37:27 +0000 Subject: [PATCH 173/303] [Fix] the implementation of KV lookup buffer --- tests/kv_transfer/test_lookup_buffer.py | 46 +-- tests/kv_transfer/test_send_recv.py | 22 +- .../kv_transfer/kv_lookup_buffer/base.py | 8 +- .../simple_kv_lookup_buffer.py | 59 +++- vllm/distributed/kv_transfer/kv_pipe/base.py | 6 +- .../kv_pipe/torch_distributed_pipe.py | 22 +- .../kv_pipe/torch_distributed_pipe.py.bkup | 300 ++++++++++++++++++ 7 files changed, 415 insertions(+), 48 deletions(-) create mode 100644 vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py.bkup diff --git a/tests/kv_transfer/test_lookup_buffer.py b/tests/kv_transfer/test_lookup_buffer.py index aa98a7804ecde..5041bf0264839 100644 --- a/tests/kv_transfer/test_lookup_buffer.py +++ b/tests/kv_transfer/test_lookup_buffer.py @@ -7,8 +7,9 @@ from tqdm import tqdm import time +# TODO: the test depends on a lot of fields in the current implementation. We should have standard interface instead direct field access -def test_run(my_rank, buffer): +def test_run(my_rank, buffer, device): # buffer should be empty in the beginning if my_rank == 0: @@ -17,15 +18,19 @@ def test_run(my_rank, buffer): # insert - tokens = torch.tensor([1,2,3]).to(buffer.pipe.device) + tokens = torch.tensor([1,2,3]).to(device) roi = (tokens > 0) if my_rank == 0: - key = 2.0 * torch.ones([5, 6]).to(buffer.pipe.device) - value = 3.0 * torch.ones([5, 6]).to(buffer.pipe.device) + key = 2.0 * torch.ones([5, 6]).to(device) + value = 3.0 * torch.ones([5, 6]).to(device) - placeholder = torch.tensor([1]).to(buffer.pipe.device) + placeholder = torch.tensor([1]).to(device) buffer.insert(tokens, roi, key, value, placeholder) + + #for i in range(2000): + # print("Here:", i) + # time.sleep(0.01) torch.distributed.barrier() # drop_select @@ -33,22 +38,21 @@ def test_run(my_rank, buffer): tok, roi_, key, value, hidden = buffer.drop_select(tokens, roi) assert torch.allclose(tokens, tok) assert torch.allclose(roi, roi_) - assert torch.allclose(key, 2.0 * torch.ones([5, 6])) - assert torch.allclose(value, 3.0 * torch.ones([5, 6])) + assert torch.allclose(key, 2.0 * torch.ones([5, 6], device = device)) + assert torch.allclose(value, 3.0 * torch.ones([5, 6], device = device)) torch.distributed.barrier() if my_rank == 0: assert buffer.buffer_size == 0 assert len(buffer.buffer) == 0 + + print("Test run passed!") - -def stress_test(my_rank, buf): +def stress_test(my_rank, buf, device): torch.distributed.barrier() torch.manual_seed(100) - device = buf.pipe.device - reqs = [ ( torch.rand(100).to(device), # tokens @@ -56,7 +60,7 @@ def stress_test(my_rank, buf): torch.rand(100).to(device), # key torch.rand(100).to(device), # value torch.rand(100).to(device), # hidden - ) for i in range(200)] + ) for i in tqdm(range(200))] random.seed(my_rank) random.shuffle(reqs) @@ -86,7 +90,7 @@ def stress_test(my_rank, buf): assert torch.allclose(k, k_) assert torch.allclose(v, v_) assert torch.allclose(h, h_) - print('Rand %d done' % my_rank) + print('Rank %d done' % my_rank) torch.distributed.barrier() @@ -101,13 +105,9 @@ def stress_test(my_rank, buf): else: torch.distributed.send(torch.tensor([n]), 0) + print("Passed stress test!") - - - - - if __name__ == "__main__": @@ -123,10 +123,14 @@ def stress_test(my_rank, buf): pipe = tdp.TorchDistributedPipe([[0,1]], my_rank, "nccl") - buffer = sklb.SimpleKVLookupBuffer(pipe, 170000) + cpu_pipe = tdp.TorchDistributedPipe([[0,1]], my_rank, "gloo") + buffer = sklb.SimpleKVLookupBuffer(cpu_pipe, pipe, 170000) - test_run(my_rank, buffer) + test_run(my_rank, buffer, pipe.device) - stress_test(my_rank, buffer) + stress_test(my_rank, buffer, pipe.device) + buffer.close() + pipe.close() + cpu_pipe.close() print('Done') diff --git a/tests/kv_transfer/test_send_recv.py b/tests/kv_transfer/test_send_recv.py index 95a7528f0f7a8..4bf757d7c8492 100644 --- a/tests/kv_transfer/test_send_recv.py +++ b/tests/kv_transfer/test_send_recv.py @@ -12,12 +12,28 @@ def test_run(my_rank, pipe): x = torch.tensor([1]).to(pipe.device) y = torch.tensor([[2., 3., 4., 8.]]).to(pipe.device) if my_rank == 0: - pipe.send_tensor(x) + print("sent tensor x") pipe.send_tensor(y) + print("sent tensor y") + x2 = pipe.recv_tensor() + print("received x2 = ", x2) + y2 = pipe.recv_tensor() + print("received y2 = ", x2) + else: - assert torch.allclose(x, pipe.recv_tensor()) - assert torch.allclose(y, pipe.recv_tensor()) + x2 = pipe.recv_tensor() + print("received x2 = ", x2) + y2 = pipe.recv_tensor() + print("received y2 = ", x2) + pipe.send_tensor(x) + print("sent tensor x") + pipe.send_tensor(y) + print("sent tensor y") + + assert torch.allclose(x, x2) + assert torch.allclose(y, y2) + def stress_test(my_rank, pipe): diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py index 5ac8fbb244446..733bc82bf53f9 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py @@ -15,4 +15,10 @@ def insert(self, @abstractmethod def drop_select(self, input_tokens, roi) -> Optional[torch.Tensor]: raise NotImplementedError - \ No newline at end of file + + @abstractmethod + def close(self): + """ + Close the buffer, release resources. + """ + raise NotImplementedError diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py index 407ac7c9bcfc1..df52dd65692e4 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py @@ -1,6 +1,7 @@ from vllm.distributed.kv_transfer.kv_lookup_buffer.base import \ KVLookupBufferBase +from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase from typing import Dict, Tuple, List, Optional import threading import torch @@ -13,16 +14,24 @@ class SimpleKVLookupBuffer(KVLookupBufferBase): - def __init__(self, pipe, buffer_size_thresh): + def __init__(self, signal_pipe, data_pipe, buffer_size_thresh): + """ + signal_pipe: on CPU -- avoid recv() stops the python intepreter + data_pipe: on GPU + """ self.buffer = deque() self.buffer_size = 0 self.buffer_size_threshold = buffer_size_thresh self.buffer_lock = threading.Lock() - self.pipe = pipe + self.signal_pipe = signal_pipe + self.data_pipe = data_pipe self.request_handling_thread = None + self.normal_signal = torch.tensor([0]) + self.end_signal = None + def _matches(self, tokens_roi_sender, tokens_roi_recver): @@ -57,9 +66,9 @@ def _matches(self, tokens_roi_sender, tokens_roi_recver): def _send_tensor_and_dec_size(self, tensor: Optional[torch.Tensor]) -> None: - assert tensor is not None, "Use self.pipe.send(None) instead" + assert tensor is not None, "Use self.data_pipe.send(None) instead" self.buffer_size -= tensor.element_size() * tensor.numel() - self.pipe.send_tensor(tensor) + self.data_pipe.send_tensor(tensor) def _get_element_size(self, data): @@ -91,14 +100,22 @@ def _add_to_buffer(self, input_tokens, roi, key, value, hidden): self.buffer_size += self._get_element_size(data) self.buffer.append(buffer_item) + def _is_end_signal(self, signal): + return signal is None def drop_select_handler(self): try: while True: - input_tokens = self.pipe.recv_tensor() - roi = self.pipe.recv_tensor() + signal = self.signal_pipe.recv_tensor() + if self._is_end_signal(signal): + logger.info("Received end signal!") + break + + input_tokens = self.data_pipe.recv_tensor() + + roi = self.data_pipe.recv_tensor() tokens_roi_recver = [input_tokens, roi] matched_length = 0 @@ -125,10 +142,13 @@ def drop_select_handler(self): else: # no match, just send None for _ in range(5): - self.pipe.send_tensor(None) + self.data_pipe.send_tensor(None) + except RuntimeError as e: if 'Connection closed by peer' not in str(e): raise e + + logger.debug("closing drop_select_handler") def drop_select(self, input_tokens, roi): @@ -142,14 +162,15 @@ def drop_select(self, input_tokens, roi): if isinstance(roi, torch.Tensor): roi = roi.clone() - self.pipe.send_tensor(input_tokens) - self.pipe.send_tensor(roi) + self.signal_pipe.send_tensor(self.normal_signal) + self.data_pipe.send_tensor(input_tokens) + self.data_pipe.send_tensor(roi) - input_tokens = self.pipe.recv_tensor() - roi = self.pipe.recv_tensor() - key = self.pipe.recv_tensor() - value = self.pipe.recv_tensor() - hidden = self.pipe.recv_tensor() + input_tokens = self.data_pipe.recv_tensor() + roi = self.data_pipe.recv_tensor() + key = self.data_pipe.recv_tensor() + value = self.data_pipe.recv_tensor() + hidden = self.data_pipe.recv_tensor() return [input_tokens, roi, key, value, hidden] @@ -173,4 +194,12 @@ def insert(self, input_tokens, roi, key, value, hidden) -> None: target=self.drop_select_handler) self.request_handling_thread.start() - \ No newline at end of file + + def close(self): + + if hasattr(self, "request_handling_thread") and self.request_handling_thread is not None: + self.request_handling_thread.join() + + else: + # TODO: have a explicit close signal and have a explicit way to check if it's requester + self.signal_pipe.send_tensor(self.end_signal) diff --git a/vllm/distributed/kv_transfer/kv_pipe/base.py b/vllm/distributed/kv_transfer/kv_pipe/base.py index 625656adc2664..7662a5893ceb2 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/base.py +++ b/vllm/distributed/kv_transfer/kv_pipe/base.py @@ -10,4 +10,8 @@ def send_tensor(self, tensor): @abstractmethod def recv_tensor(self): - raise NotImplementedError \ No newline at end of file + raise NotImplementedError + + @abstractmethod + def close(self): + raise NotImplementedError diff --git a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py index 994afc82ffa4c..caa9e6aabd935 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py @@ -50,6 +50,7 @@ class TorchDistributedPipe: MAX_TENSOR_DIMENSIONS = 14 METADATA_DTYPE = torch.int64 + def __init__( self, group_ranks: List[List[int]], @@ -73,10 +74,7 @@ def __init__( assert self.device_group is not None assert self.rank_in_group <= 1 - if torch.cuda.is_available(): - self.device = torch.device(f"cuda:{local_rank}") - else: - self.device = torch.device("cpu") + self.device = self._select_device(torch_distributed_backend) self.target_rank_for_send = self.ranks[ (self.rank_in_group + 1) % self.world_size @@ -99,6 +97,12 @@ def __init__( self.METADATA_LENGTH, dtype=self.METADATA_DTYPE, device=self.device ) + def _select_device(self, backend: Union[str, Backend]): + if torch.cuda.is_available() and backend == Backend.NCCL: + return torch.device(f"cuda:{self.local_rank}") + else: + return "cpu" + def _make_metadata(self, tensor: torch.Tensor) -> torch.Tensor: """ Create the metadata on based on the input tensor, and move it to GPU. @@ -168,11 +172,12 @@ def _recv_metadata(self) -> torch.Tensor: race conditions during sending/receiving. Therefore, the metadata buffer can be reused """ - torch.distributed.recv( + task = torch.distributed.recv( self.rcv_metadata_buffer, src=self.target_rank_for_recv, group=self.device_group, ) + return self.rcv_metadata_buffer def _send_impl(self, tensor): @@ -256,15 +261,16 @@ def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: # print("Remaining size:", self.buffer_size) self.buffer_size = self.buffer_size + tensor_size - # prepare the metadata before sending the tensor. + + #self.send_tensor_wrapper(tensor) self.transport_thread.submit( self.send_tensor_wrapper, tensor, ) + def recv_tensor(self) -> Optional[torch.Tensor]: """Receives a tensor from the src rank. Blocking.""" - if self.transport_thread is None: self.transport_thread = ThreadPoolExecutor(max_workers=1) @@ -276,6 +282,8 @@ def recv_tensor(self) -> Optional[torch.Tensor]: logger.error("Encountering exception in KV receiving thread") logger.error("%s", e) + #tensor = self._recv_impl() + if tensor.numel() == 1 and tensor.item() == NONE_INT: return None else: diff --git a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py.bkup b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py.bkup new file mode 100644 index 0000000000000..489052285475d --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py.bkup @@ -0,0 +1,300 @@ +from torch.distributed import Backend +import torch +from typing import List, Optional, Union +import threading +from concurrent.futures import ThreadPoolExecutor +import time + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +# if the tensor is only one-element and only contains this number +# this means that the sended object is None. +NONE_INT = -150886311 +FLOAT16_INT = -543205003776624 +INT64_INT = -375623078607432 +BOOL_INT = -28035262008646 +BFLOAT16_INT = -452084912267662 +FLOAT32_INT = -1049557997456592 +FLOAT64_INT = -452201007054137 + +DTYPE2INT = { + torch.float16: FLOAT16_INT, + torch.int64: INT64_INT, + torch.bool: BOOL_INT, + torch.bfloat16: BFLOAT16_INT, + torch.float32: FLOAT32_INT, + torch.float64: FLOAT64_INT, +} + +INT2DTYPE = { + FLOAT16_INT: torch.float16, + INT64_INT: torch.int64, + BOOL_INT: torch.bool, + BFLOAT16_INT: torch.bfloat16, + FLOAT32_INT: torch.float32, + FLOAT64_INT: torch.float64, +} + + +class BrokenPipeException(Exception): + def __init__(self, message): + self.message = message + super().__init__(self.message) + + +class TorchDistributedPipe: + METADATA_LENGTH = 16 + MAX_TENSOR_DIMENSIONS = 14 + METADATA_DTYPE = torch.int64 + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + ): + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + + assert self.device_group is not None + assert self.rank_in_group <= 1 + + if torch.cuda.is_available(): + self.device = torch.device(f"cuda:{local_rank}") + else: + self.device = torch.device("cpu") + + self.target_rank_for_send = self.ranks[ + (self.rank_in_group + 1) % self.world_size + ] + self.target_rank_for_recv = self.ranks[ + (self.rank_in_group - 1) % self.world_size + ] + + # FIXME: why we need this? + torch.set_default_device(self.device) + + self.transport_thread = None + self.buffer_size = 0 + self.buffer_size_lock = threading.Lock() + + self.none_tensor = torch.tensor([NONE_INT], device=self.device) + + # On-device tensors to be reused for recv + self.rcv_metadata_buffer = torch.zeros( + self.METADATA_LENGTH, dtype=self.METADATA_DTYPE, device=self.device + ) + + self.pending_recv = None + + def _make_metadata(self, tensor: torch.Tensor) -> torch.Tensor: + """ + Create the metadata on based on the input tensor, and move it to GPU. + The metadata's length is `TorchDistributedPipe.METADATA_LENGTH`. + + Currently, the metadata is a int64 tensor and it includes dtype, number + of dimensions, and the shape information of the input tensor. + + + The information follows the layout below: + - metadata[0] -- dtype + - metadata[1] -- number of dimensions + - metadata[2 : 2+ndims] -- the shape of the input tensor + + Parameters: + - tensor: the input tensor + + Returns: + - metadata: the metadata tensor, on self.device + """ + buffer = torch.empty(self.METADATA_LENGTH, dtype=self.METADATA_DTYPE) + buffer[0] = DTYPE2INT[tensor.dtype] + ndims = len(tensor.shape) + buffer[1] = len(tensor.shape) + buffer[2 : 2 + ndims] = torch.tensor( + tensor.shape, dtype=self.METADATA_DTYPE + ) + return buffer.to(self.device) + + def _prepare_recv_buffer( + self, d_metadata_buffer: torch.Tensor + ) -> torch.Tensor: + """ + Create a buffer to receive the tensor based on the metadata. + + Parameters: + - d_metadata_buffer: the metadata tensor on self.device + + Returns: + - buffer: the buffer tensor to receive the tensor, on self.device + """ + h_buffer = d_metadata_buffer.cpu().numpy() + dtype = INT2DTYPE[h_buffer[0]] + ndims = h_buffer[1] + shape = tuple(h_buffer[2 : 2 + ndims]) + return torch.empty(shape, dtype=dtype, device=self.device) + + def _send_metadata(self, d_metadata_buffer: torch.Tensor): + """ + Send the metadata buffer to the target rank. + """ + torch.distributed.send( + d_metadata_buffer, + dst=self.target_rank_for_send, + group=self.device_group, + ) + + def _recv_metadata(self) -> torch.Tensor: + """ + Receive the metadata buffer from the target rank. + + Returns: + - metadata_buffer: the metadata buffer tensor, on self.device + + Note: + The current implementation uses the assumption that there is no + race conditions during sending/receiving. Therefore, the metadata + buffer can be reused + """ + torch.distributed.recv( + self.rcv_metadata_buffer, + src=self.target_rank_for_recv, + group=self.device_group, + ) + return self.rcv_metadata_buffer + + def _send_impl(self, tensor): + """ + The actual implementation of sending the tensor to the target rank. + This function will first send the metadata, and then send the tensor. + + Parameters: + - tensor: the input tensor to be sent + """ + + metadata = self._make_metadata(tensor) + self._send_metadata(metadata) + + torch.distributed.send( + tensor, dst=self.target_rank_for_send, group=self.device_group + ) + + def _recv_impl(self) -> torch.Tensor: + """ + The actual implementation of receiving the tensor from the target rank. + This function will first receive the metadata, then receive the tensor. + + This function will block if there is no tensor to receive. + + Returns: + - buffer: the received tensor, on self.device + """ + d_metadata = self._recv_metadata() + buffer = self._prepare_recv_buffer(d_metadata) + + torch.distributed.recv( + buffer, src=self.target_rank_for_recv, group=self.device_group + ) + + return buffer + + def send_tensor_wrapper(self, tensor): + try: + """Wrapper for send_tensor_dict""" + tensor_size = tensor.element_size() * tensor.numel() + self._send_impl(tensor) + + with self.buffer_size_lock: + self.buffer_size = self.buffer_size - tensor_size + except Exception as e: + logger.error("Encountering exception in KV sending thread") + logger.error("%s", e) + + def block_if_full(self): + """ + Block the current thread if the buffer size is larger than 1e9. + """ + # TODO: replace this 1e9 with a configurable parameter or a constant + while self.buffer_size > 1e9: + logger.debug("KV cache transfer pipe is full. Waiting...") + time.sleep(0.05) + + def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + """ + Sends a tensor to the destination rank in a non-blocking way. + Flow: send tensor dim -- send tensor shape -- send tensor data + """ + + if self.transport_thread is None: + self.transport_thread = ThreadPoolExecutor(max_workers=1) + + if tensor is None: + tensor = self.none_tensor + tensor_size = 0 + else: + tensor_size = tensor.element_size() * tensor.numel() + + assert ( + 0 < len(tensor.shape) < self.MAX_TENSOR_DIMENSIONS + ), f"Only support dimensions within 1-{self.MAX_TENSOR_DIMENSIONS}" + + self.block_if_full() + + with self.buffer_size_lock: + # print("Remaining size:", self.buffer_size) + self.buffer_size = self.buffer_size + tensor_size + + # prepare the metadata before sending the tensor. + self.transport_thread.submit( + self.send_tensor_wrapper, + tensor, + ) + + def recv_tensor(self, timeout: float = None) -> Optional[torch.Tensor]: + """Receives a tensor from the src rank. Blocking.""" + + if self.transport_thread is None: + self.transport_thread = ThreadPoolExecutor(max_workers=1) + + if self.pending_recv is None: + self.pending_recv = self.transport_thread.submit(self._recv_impl) + + try: + tensor = self.pending_recv.result(timeout=timeout) + self.pending_recv = None + + except TimeoutError as e: + raise e + + except Exception as e: + logger.error("Encountering exception in KV receiving thread") + logger.error("%s", e) + + if tensor.numel() == 1 and tensor.item() == NONE_INT: + return None + else: + return tensor + + def close(self): + """ + Close the pipe and release the resources. + """ + if ( + hasattr(self, "transport_thread") + and self.transport_thread is not None + ): + self.transport_thread.shutdown() From d7d32c1e378ee8cf519eb9f538fd3fcd717034f2 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Fri, 13 Sep 2024 01:41:08 +0000 Subject: [PATCH 174/303] remove unused file --- .../kv_pipe/torch_distributed_pipe.py.bkup | 300 ------------------ 1 file changed, 300 deletions(-) delete mode 100644 vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py.bkup diff --git a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py.bkup b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py.bkup deleted file mode 100644 index 489052285475d..0000000000000 --- a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py.bkup +++ /dev/null @@ -1,300 +0,0 @@ -from torch.distributed import Backend -import torch -from typing import List, Optional, Union -import threading -from concurrent.futures import ThreadPoolExecutor -import time - -from vllm.logger import init_logger - -logger = init_logger(__name__) - - -# if the tensor is only one-element and only contains this number -# this means that the sended object is None. -NONE_INT = -150886311 -FLOAT16_INT = -543205003776624 -INT64_INT = -375623078607432 -BOOL_INT = -28035262008646 -BFLOAT16_INT = -452084912267662 -FLOAT32_INT = -1049557997456592 -FLOAT64_INT = -452201007054137 - -DTYPE2INT = { - torch.float16: FLOAT16_INT, - torch.int64: INT64_INT, - torch.bool: BOOL_INT, - torch.bfloat16: BFLOAT16_INT, - torch.float32: FLOAT32_INT, - torch.float64: FLOAT64_INT, -} - -INT2DTYPE = { - FLOAT16_INT: torch.float16, - INT64_INT: torch.int64, - BOOL_INT: torch.bool, - BFLOAT16_INT: torch.bfloat16, - FLOAT32_INT: torch.float32, - FLOAT64_INT: torch.float64, -} - - -class BrokenPipeException(Exception): - def __init__(self, message): - self.message = message - super().__init__(self.message) - - -class TorchDistributedPipe: - METADATA_LENGTH = 16 - MAX_TENSOR_DIMENSIONS = 14 - METADATA_DTYPE = torch.int64 - - def __init__( - self, - group_ranks: List[List[int]], - local_rank: int, - torch_distributed_backend: Union[str, Backend], - ): - self.rank = torch.distributed.get_rank() - self.local_rank = local_rank - self.device_group = None - - for ranks in group_ranks: - device_group = torch.distributed.new_group( - ranks, backend=torch_distributed_backend - ) - if self.rank in ranks: - self.ranks = ranks - self.world_size = len(ranks) - self.rank_in_group = ranks.index(self.rank) - self.device_group = device_group - - assert self.device_group is not None - assert self.rank_in_group <= 1 - - if torch.cuda.is_available(): - self.device = torch.device(f"cuda:{local_rank}") - else: - self.device = torch.device("cpu") - - self.target_rank_for_send = self.ranks[ - (self.rank_in_group + 1) % self.world_size - ] - self.target_rank_for_recv = self.ranks[ - (self.rank_in_group - 1) % self.world_size - ] - - # FIXME: why we need this? - torch.set_default_device(self.device) - - self.transport_thread = None - self.buffer_size = 0 - self.buffer_size_lock = threading.Lock() - - self.none_tensor = torch.tensor([NONE_INT], device=self.device) - - # On-device tensors to be reused for recv - self.rcv_metadata_buffer = torch.zeros( - self.METADATA_LENGTH, dtype=self.METADATA_DTYPE, device=self.device - ) - - self.pending_recv = None - - def _make_metadata(self, tensor: torch.Tensor) -> torch.Tensor: - """ - Create the metadata on based on the input tensor, and move it to GPU. - The metadata's length is `TorchDistributedPipe.METADATA_LENGTH`. - - Currently, the metadata is a int64 tensor and it includes dtype, number - of dimensions, and the shape information of the input tensor. - - - The information follows the layout below: - - metadata[0] -- dtype - - metadata[1] -- number of dimensions - - metadata[2 : 2+ndims] -- the shape of the input tensor - - Parameters: - - tensor: the input tensor - - Returns: - - metadata: the metadata tensor, on self.device - """ - buffer = torch.empty(self.METADATA_LENGTH, dtype=self.METADATA_DTYPE) - buffer[0] = DTYPE2INT[tensor.dtype] - ndims = len(tensor.shape) - buffer[1] = len(tensor.shape) - buffer[2 : 2 + ndims] = torch.tensor( - tensor.shape, dtype=self.METADATA_DTYPE - ) - return buffer.to(self.device) - - def _prepare_recv_buffer( - self, d_metadata_buffer: torch.Tensor - ) -> torch.Tensor: - """ - Create a buffer to receive the tensor based on the metadata. - - Parameters: - - d_metadata_buffer: the metadata tensor on self.device - - Returns: - - buffer: the buffer tensor to receive the tensor, on self.device - """ - h_buffer = d_metadata_buffer.cpu().numpy() - dtype = INT2DTYPE[h_buffer[0]] - ndims = h_buffer[1] - shape = tuple(h_buffer[2 : 2 + ndims]) - return torch.empty(shape, dtype=dtype, device=self.device) - - def _send_metadata(self, d_metadata_buffer: torch.Tensor): - """ - Send the metadata buffer to the target rank. - """ - torch.distributed.send( - d_metadata_buffer, - dst=self.target_rank_for_send, - group=self.device_group, - ) - - def _recv_metadata(self) -> torch.Tensor: - """ - Receive the metadata buffer from the target rank. - - Returns: - - metadata_buffer: the metadata buffer tensor, on self.device - - Note: - The current implementation uses the assumption that there is no - race conditions during sending/receiving. Therefore, the metadata - buffer can be reused - """ - torch.distributed.recv( - self.rcv_metadata_buffer, - src=self.target_rank_for_recv, - group=self.device_group, - ) - return self.rcv_metadata_buffer - - def _send_impl(self, tensor): - """ - The actual implementation of sending the tensor to the target rank. - This function will first send the metadata, and then send the tensor. - - Parameters: - - tensor: the input tensor to be sent - """ - - metadata = self._make_metadata(tensor) - self._send_metadata(metadata) - - torch.distributed.send( - tensor, dst=self.target_rank_for_send, group=self.device_group - ) - - def _recv_impl(self) -> torch.Tensor: - """ - The actual implementation of receiving the tensor from the target rank. - This function will first receive the metadata, then receive the tensor. - - This function will block if there is no tensor to receive. - - Returns: - - buffer: the received tensor, on self.device - """ - d_metadata = self._recv_metadata() - buffer = self._prepare_recv_buffer(d_metadata) - - torch.distributed.recv( - buffer, src=self.target_rank_for_recv, group=self.device_group - ) - - return buffer - - def send_tensor_wrapper(self, tensor): - try: - """Wrapper for send_tensor_dict""" - tensor_size = tensor.element_size() * tensor.numel() - self._send_impl(tensor) - - with self.buffer_size_lock: - self.buffer_size = self.buffer_size - tensor_size - except Exception as e: - logger.error("Encountering exception in KV sending thread") - logger.error("%s", e) - - def block_if_full(self): - """ - Block the current thread if the buffer size is larger than 1e9. - """ - # TODO: replace this 1e9 with a configurable parameter or a constant - while self.buffer_size > 1e9: - logger.debug("KV cache transfer pipe is full. Waiting...") - time.sleep(0.05) - - def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: - """ - Sends a tensor to the destination rank in a non-blocking way. - Flow: send tensor dim -- send tensor shape -- send tensor data - """ - - if self.transport_thread is None: - self.transport_thread = ThreadPoolExecutor(max_workers=1) - - if tensor is None: - tensor = self.none_tensor - tensor_size = 0 - else: - tensor_size = tensor.element_size() * tensor.numel() - - assert ( - 0 < len(tensor.shape) < self.MAX_TENSOR_DIMENSIONS - ), f"Only support dimensions within 1-{self.MAX_TENSOR_DIMENSIONS}" - - self.block_if_full() - - with self.buffer_size_lock: - # print("Remaining size:", self.buffer_size) - self.buffer_size = self.buffer_size + tensor_size - - # prepare the metadata before sending the tensor. - self.transport_thread.submit( - self.send_tensor_wrapper, - tensor, - ) - - def recv_tensor(self, timeout: float = None) -> Optional[torch.Tensor]: - """Receives a tensor from the src rank. Blocking.""" - - if self.transport_thread is None: - self.transport_thread = ThreadPoolExecutor(max_workers=1) - - if self.pending_recv is None: - self.pending_recv = self.transport_thread.submit(self._recv_impl) - - try: - tensor = self.pending_recv.result(timeout=timeout) - self.pending_recv = None - - except TimeoutError as e: - raise e - - except Exception as e: - logger.error("Encountering exception in KV receiving thread") - logger.error("%s", e) - - if tensor.numel() == 1 and tensor.item() == NONE_INT: - return None - else: - return tensor - - def close(self): - """ - Close the pipe and release the resources. - """ - if ( - hasattr(self, "transport_thread") - and self.transport_thread is not None - ): - self.transport_thread.shutdown() From 417ccb35c6cd391942c29344ef0056333e716e0e Mon Sep 17 00:00:00 2001 From: YaoJiayi <120040070@link.cuhk.edu.cn> Date: Fri, 13 Sep 2024 04:25:08 -0500 Subject: [PATCH 175/303] update vllm adapter --- vllm/distributed/kv_transfer/vllm_adapter.py | 154 ++++++++++++++++++- 1 file changed, 146 insertions(+), 8 deletions(-) diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py index d13d132f5dfee..98d861fe9a12e 100644 --- a/vllm/distributed/kv_transfer/vllm_adapter.py +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -30,6 +30,9 @@ from vllm.distributed.kv_transfer.kv_pipe.torch_distributed_pipe import TorchDistributedPipe from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_kv_lookup_buffer import SimpleKVLookupBuffer +from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata +from copy import deepcopy + assert envs.VLLM_DISAGG_PREFILL_ROLE in [None, "prefill", "decode", "lmcache"], \ "VLLM_DISAGG_PREFILL_ROLE can only be prefill, decode or lmcache." @@ -77,7 +80,7 @@ def __init__( def send_kv_caches_and_hidden_states( self, model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", + model_input: ModelInputForGPUWithSamplingMetadata, kv_caches: List[torch.Tensor], hidden_or_intermediate_states: Union[torch.Tensor, IntermediateTensors], ) -> None: @@ -128,9 +131,9 @@ def send_kv_caches_and_hidden_states( def recv_kv_caches_and_hidden_states( self, model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", + model_input: ModelInputForGPUWithSamplingMetadata, kv_caches: List[torch.Tensor] - ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool]: + ) -> List[Union[torch.Tensor, IntermediateTensors], bool, ModelInputForGPUWithSamplingMetadata]: bypass_model_exec = True @@ -142,6 +145,10 @@ def recv_kv_caches_and_hidden_states( hidden_or_intermediate_states_for_one_req = [] + input_tokens_list = [] + num_computed_tokens_list = [] + start_pos_list = [] + # enumerate different requests # FIXME(Kuntai): This impl assumes that all requests are prefill. for idx, slen in enumerate(seq_lens): @@ -151,16 +158,27 @@ def recv_kv_caches_and_hidden_states( current_tokens = input_tokens_tensor[start_pos:end_pos] num_tokens = slen + input_tokens_list.append(current_tokens) + start_pos_list.append(start_pos) + ret = self.buffer.drop_select( current_tokens, torch.ones_like(current_tokens, dtype=bool)) if ret[0] is None: # didn't find any match. self.bypass_model_exec = False + num_computed_tokens_list.append(0) continue - _, _, keys, values, hidden = ret - + # TODO(Jiayi): change the logic here (need roi) + _, roi, keys, values, hidden = ret + + # Jiayi: currently assume roi is a prefix + num_computed_tokens = len(roi) + num_computed_tokens_list.append(num_computed_tokens) + is_complete = (num_computed_tokens == num_tokens) + end_pos = start_pos + num_computed_tokens + # receive KV cache from disaggregated prefill instance for i in range(model_executable.model.start_layer, model_executable.model.end_layer): @@ -184,14 +202,134 @@ def recv_kv_caches_and_hidden_states( hidden_or_intermediate_states_for_one_req.append(hidden) + # FIXME(Jiayi): we need to support only skip m out of n reqs in a batch + # same for prefix caching if not bypass_model_exec: # Some of the KV cache is not retrieved # so we need to recompute the hidden state - return [], bypass_model_exec - + logger.debug("[rank%d]: KV EMPTY recv DONE.", torch.distributed.get_rank()) + return None, bypass_model_exec, None + + if not is_complete: + rebuilt_model_input = self.adpat_model_input( + model_input, + input_tokens_list, + num_computed_tokens_list, + start_pos_list, + slot_mapping, + device=kv_cache[0].device, + ) + logger.debug("[rank%d]: KV PARTIAL recv DONE.", torch.distributed.get_rank()) + return None, bypass_model_exec, rebuilt_model_input + # concatenate hidden states from different requests hidden_or_intermediate_states = torch.cat( hidden_or_intermediate_states_for_one_req, dim=0) logger.debug("[rank%d]: KV recv DONE.", torch.distributed.get_rank()) - return hidden_or_intermediate_states, bypass_model_exec, model_input \ No newline at end of file + return hidden_or_intermediate_states, bypass_model_exec, model_input + + + def adpat_model_input( + self, + model_input: ModelInputForGPUWithSamplingMetadata, + input_tokens_list: List[torch.Tensor], + num_computed_tokens_list: List[int], + start_pos_list: List[int], + slot_mapping_flat: torch.Tensor, + device: torch.device, + ) -> ModelInputForGPUWithSamplingMetadata: + rebuilt_input_tokens = [] + rebuilt_input_positions= [] + rebuilt_query_lens = [] + + rebuilt_num_prefills = 0 + rebuilt_num_prefill_tokens = 0 + rebuilt_slot_mapping = [] + rebuilt_max_query_len = 0 + + rebuilt_block_tables = [] + + rebuilt_query_start_loc = [0] + rebuilt_context_lens_tensor = [] + rebuilt_selected_token_indices = [] + + for idx in range(len(input_tokens_list)): + token_tensor = input_tokens_list[idx] + num_token = len(token_tensor) + num_computed_token = num_computed_tokens_list[idx] + start_pos = start_pos_list[idx] + + rebuilt_input_tokens.append(token_tensor[num_computed_token:]) + # TODO(Jiayi): please check the correctness of next line + rebuilt_input_positions.append(model_input.input_positions[start_pos+num_computed_token:start_pos+num_token]) + q_len = num_token - num_computed_token + rebuilt_query_lens.append(q_len) + + # Attn metadata-related + rebuilt_num_prefills += 1 + rebuilt_num_prefill_tokens += q_len + rebuilt_slot_mapping.append(slot_mapping_flat[start_pos+num_computed_token:start_pos+num_token]) + rebuilt_max_query_len = max(q_len, rebuilt_max_query_len) + # TODO(Jiayi): remove hard-code (block_size=16) + blk_size = 16 + temp_block_table = [i//blk_size for i in range(start_pos, start_pos+num_token, blk_size)] + rebuilt_block_tables.append(temp_block_table) + rebuilt_query_start_loc.append(q_len) #start with 0 + rebuilt_context_lens_tensor.append(num_computed_token) + + # Sampling metadata related + #seq_groups (use rebuilt query lens) + rebuilt_selected_token_indices.append(start_pos+q_len-1) + + + # rebuilt attn_metadata + rebuilt_attn_metadata = deepcopy(model_input.attn_metadata) + rebuilt_attn_metadata.num_prefills = rebuilt_num_prefills + rebuilt_attn_metadata.num_prefill_tokens = rebuilt_num_prefill_tokens + rebuilt_attn_metadata.slot_mapping = torch.cat(rebuilt_slot_mapping).to(device) + rebuilt_attn_metadata.max_query_len = rebuilt_max_query_len + + rebuilt_attn_metadata.block_tables = torch.tensor( + rebuilt_block_tables, + dtype=model_input.attn_metadata.block_tables.dtype + ).to(device) + + rebuilt_attn_metadata.query_start_loc = torch.tensor( + rebuilt_query_start_loc, + dtype=model_input.attn_metadata.query_start_loc.dtype).to(device) + rebuilt_attn_metadata.context_lens_tensor = torch.tensor( + rebuilt_context_lens_tensor, + dtype=model_input.attn_metadata.context_lens_tensor.dtype, + ).to(device) + + rebuilt_attn_metadata._cached_prefill_metadata = None + + # rebuilt sampling_metadata + rebuilt_sampling_metadata = deepcopy(model_input.sampling_metadata) + for idx, q_len in enumerate(rebuilt_query_lens): + rebuilt_sampling_metadata.seq_groups[idx].query_len = q_len + rebuilt_sampling_metadata.selected_token_indices = torch.tensor( + rebuilt_selected_token_indices, + dtype=model_input.sampling_metadata.selected_token_indices.dtype, + ).to(device) + + rebuilt_model_input = ModelInputForGPUWithSamplingMetadata( + input_tokens = torch.cat(rebuilt_input_tokens).to(device), + input_positions = torch.cat(rebuilt_input_positions).to(device), + seq_lens = model_input.seq_lens, + query_lens = rebuilt_query_lens, + lora_mapping = model_input.lora_mapping, + lora_requests = model_input.lora_requests, + attn_metadata = rebuilt_attn_metadata, + prompt_adapter_mapping = model_input.prompt_adapter_mapping, + prompt_adapter_requests = model_input.prompt_adapter_requests, + multi_modal_kwargs = model_input.multi_modal_kwargs, + request_ids_to_seq_ids = model_input.request_ids_to_seq_ids, + finished_requests_ids = model_input.finished_requests_ids, + virtual_engine = model_input.virtual_engine, + sampling_metadata = rebuilt_sampling_metadata, + is_prompt = model_input.is_prompt, + ) + + return rebuilt_model_input \ No newline at end of file From 0176ebb18a1698e07b0e9e635060001a0f6b419a Mon Sep 17 00:00:00 2001 From: YaoJiayi <120040070@link.cuhk.edu.cn> Date: Fri, 13 Sep 2024 04:31:33 -0500 Subject: [PATCH 176/303] update worker_base --- vllm/distributed/kv_transfer/vllm_adapter.py | 3 +++ vllm/worker/worker_base.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py index 98d861fe9a12e..e3e9a2b3187ec 100644 --- a/vllm/distributed/kv_transfer/vllm_adapter.py +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -68,12 +68,15 @@ def __init__( torch_distributed_backend: Union[str, Backend], ): + # FIXME(Jiayi): we need two pipes + # one or send and one for recv # init pipe self.pipe = TorchDistributedPipe( group_ranks, local_rank, torch_distributed_backend, ) + # FIXME(Jiayi): buffer initializtion should be updated accordingly # init lookup buffer self.buffer = SimpleKVLookupBuffer(self.pipe, 1000**3 * 10) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 1b60b444709e6..2fc5d3200b138 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -295,7 +295,7 @@ def execute_model( model_input, self.kv_cache[worker_input.virtual_engine], ) - assert bypass_model_exec + #assert bypass_model_exec if not bypass_model_exec: hidden_or_intermediate_states = self.model_runner.execute_model( From 84fd0b826dc93f15861feaf5ea73fa142f571985 Mon Sep 17 00:00:00 2001 From: YaoJiayi <120040070@link.cuhk.edu.cn> Date: Fri, 13 Sep 2024 09:47:42 -0500 Subject: [PATCH 177/303] update comm initialization --- vllm/distributed/kv_transfer/vllm_adapter.py | 36 +++++++++++++++----- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py index e3e9a2b3187ec..981c23891a6f1 100644 --- a/vllm/distributed/kv_transfer/vllm_adapter.py +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -68,15 +68,33 @@ def __init__( torch_distributed_backend: Union[str, Backend], ): - # FIXME(Jiayi): we need two pipes - # one or send and one for recv - # init pipe - self.pipe = TorchDistributedPipe( - group_ranks, - local_rank, - torch_distributed_backend, - ) - # FIXME(Jiayi): buffer initializtion should be updated accordingly + # init two pipes: one or send and one for recv + if IS_KV_PREFILL_INSTANCE or IS_LMCACHE_INSTANCE: + self.send_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + torch_distributed_backend, + ) + self.recv_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + torch_distributed_backend, + ) + elif IS_KV_DECODE_INSTANCE: + self.recv_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + torch_distributed_backend, + ) + self.send_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + torch_distributed_backend, + ) + + + # FIXME(Jiayi): buffer initializtion should be adapted accordingly + # Signal pipe needs to be initialized on both vllm and lmc side # init lookup buffer self.buffer = SimpleKVLookupBuffer(self.pipe, 1000**3 * 10) From 826ca70daa1cd7ae991c516287823f22893544c3 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Fri, 13 Sep 2024 22:44:23 +0000 Subject: [PATCH 178/303] update --- .../kv_lookup_buffer/simple_kv_lookup_buffer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py index df52dd65692e4..c06093413d7e4 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py @@ -50,15 +50,14 @@ def _matches(self, tokens_roi_sender, tokens_roi_recver): return True - # I am assuming that roi is a mask on tokens + # Assuming that roi is a mask on tokens tokens_sender = tokens_sender[roi_sender] tokens_recver = tokens_recver[roi_recver] + # simple common prefix matching min_length = min(len(tokens_sender), len(tokens_recver)) if torch.allclose(tokens_sender[:min_length], tokens_recver[:min_length]): - # drastically simplified - # common prefix matching return min_length return 0 @@ -148,7 +147,7 @@ def drop_select_handler(self): if 'Connection closed by peer' not in str(e): raise e - logger.debug("closing drop_select_handler") + logger.debug("Closing drop_select_handler") def drop_select(self, input_tokens, roi): From 3425ab64e8afae1b4aca4d69628b69d584033a6f Mon Sep 17 00:00:00 2001 From: ApostaC Date: Fri, 13 Sep 2024 22:56:33 +0000 Subject: [PATCH 179/303] update documentation --- tests/kv_transfer/test_lookup_buffer.py | 2 +- .../kv_transfer/kv_pipe/torch_distributed_pipe.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/kv_transfer/test_lookup_buffer.py b/tests/kv_transfer/test_lookup_buffer.py index 5041bf0264839..5ccbc5b0f5865 100644 --- a/tests/kv_transfer/test_lookup_buffer.py +++ b/tests/kv_transfer/test_lookup_buffer.py @@ -70,7 +70,7 @@ def stress_test(my_rank, buf, device): n = 0 # the buffer size can only store 100 reqs - # so the sender will occasionally block.needs to wait for the receiver. + # so the sender will occasionally block to wait for the receiver. for req in tqdm(reqs): if my_rank == 0: buf.insert(*req) diff --git a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py index caa9e6aabd935..c77da1e6b75b8 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py @@ -10,9 +10,11 @@ logger = init_logger(__name__) -# if the tensor is only one-element and only contains this number +# if the tensor is only one-element and only contains NONE_INT # this means that the sended object is None. NONE_INT = -150886311 + +# Mapping tensor dtype to a int, used for tensor metadata transmission FLOAT16_INT = -543205003776624 INT64_INT = -375623078607432 BOOL_INT = -28035262008646 @@ -258,11 +260,9 @@ def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: self.block_if_full() with self.buffer_size_lock: - # print("Remaining size:", self.buffer_size) self.buffer_size = self.buffer_size + tensor_size - #self.send_tensor_wrapper(tensor) self.transport_thread.submit( self.send_tensor_wrapper, tensor, From 9f3a3a50bbe2bace6c2a17d139a68d92eecd8939 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Fri, 13 Sep 2024 23:04:17 +0000 Subject: [PATCH 180/303] adjust vllm adapter: now we separate CPU and device into different pipes --- vllm/distributed/kv_transfer/vllm_adapter.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py index d13d132f5dfee..ba2c06c24b2d9 100644 --- a/vllm/distributed/kv_transfer/vllm_adapter.py +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -66,13 +66,19 @@ def __init__( ): # init pipe - self.pipe = TorchDistributedPipe( + self.device_pipe = TorchDistributedPipe( group_ranks, local_rank, torch_distributed_backend, ) + self.cpu_pipe = TorchDistributedPipe( + group_ranks, + local_ranks, + "gloo" + ) # init lookup buffer - self.buffer = SimpleKVLookupBuffer(self.pipe, 1000**3 * 10) + # TODO: replace this 1e9 with a configurable parameter or a constant + self.buffer = SimpleKVLookupBuffer(self.cpu_pipe, self.device_pipe, 1e9) def send_kv_caches_and_hidden_states( self, From ce79d596fc68dbbb4e57921490a0a6a41cdc6715 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Fri, 13 Sep 2024 23:20:33 +0000 Subject: [PATCH 181/303] build 2 pipes in vLLM adapter --- .../kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py | 2 +- vllm/distributed/kv_transfer/vllm_adapter.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py index c06093413d7e4..de2e7cf5d5d57 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py @@ -181,7 +181,7 @@ def full_handler(self): def insert(self, input_tokens, roi, key, value, hidden) -> None: while self.buffer_size > self.buffer_size_threshold: - logger.debug("KV transfer buffer is full. Handling...") + # logger.debug("KV transfer buffer is full. Handling...") self.full_handler() self._add_to_buffer(input_tokens, roi, key, value, hidden) diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py index ba2c06c24b2d9..d85ffe79abdf5 100644 --- a/vllm/distributed/kv_transfer/vllm_adapter.py +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -73,12 +73,12 @@ def __init__( ) self.cpu_pipe = TorchDistributedPipe( group_ranks, - local_ranks, + local_rank, "gloo" ) # init lookup buffer # TODO: replace this 1e9 with a configurable parameter or a constant - self.buffer = SimpleKVLookupBuffer(self.cpu_pipe, self.device_pipe, 1e9) + self.buffer = SimpleKVLookupBuffer(self.cpu_pipe, self.device_pipe, 1e9 * 10) def send_kv_caches_and_hidden_states( self, From 34dfdde0face48494c69f1e5d1b20cd6ca96ab27 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Fri, 13 Sep 2024 23:39:27 +0000 Subject: [PATCH 182/303] documentation chagne --- vllm/distributed/kv_transfer/vllm_adapter.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py index d85ffe79abdf5..42f451f6bc8fa 100644 --- a/vllm/distributed/kv_transfer/vllm_adapter.py +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -1,14 +1,14 @@ """vLLM distributed KV cache transfer API. These APIs are used in `vllm/worker/model_runner.py`. -Currently supporting TP and PP. +Currently supporting TP and PP, but TP and PP must be the same. Workflow: -- In prefill instance, KV cache sender *buffers* the KV cache send requests +- In prefill instance, vLLM `insert` that buffers the KV cache into lookup buffer. - In decode instance - - KV cache receiver sends the hash of input tokens to sender - - KV cache sender executes send request - - KV cache receiver receives the KV cache + - vLLM first runs `drop_select` to send input tokens and a mask on input tokens to sender + - The prefill instance send back the matching KV caches + - vLLM then store the KV cache into paged memory. """ from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING from collections import defaultdict, deque From 9355be358a8fcfe86d945c8e8893515dab84d1c0 Mon Sep 17 00:00:00 2001 From: YaoJiayi <120040070@link.cuhk.edu.cn> Date: Sat, 14 Sep 2024 10:09:17 -0500 Subject: [PATCH 183/303] update vllm_adapter --- vllm/distributed/kv_transfer/vllm_adapter.py | 35 ++++++++++++++++---- 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py index 11bea4bcef12a..2ed4bea473454 100644 --- a/vllm/distributed/kv_transfer/vllm_adapter.py +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -68,7 +68,7 @@ def __init__( torch_distributed_backend: Union[str, Backend], ): - + ''' # init pipe self.device_pipe = TorchDistributedPipe( group_ranks, @@ -80,30 +80,50 @@ def __init__( local_rank, "gloo" ) - - # init two pipes: one or send and one for recv + ''' + # init 4 pipes: 2 * (one for send and one for recv) if IS_KV_PREFILL_INSTANCE or IS_LMCACHE_INSTANCE: self.send_pipe = TorchDistributedPipe( group_ranks, local_rank, torch_distributed_backend, ) + self.send_signal_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + "gloo", + ) self.recv_pipe = TorchDistributedPipe( group_ranks, local_rank, torch_distributed_backend, ) + self.recv_signal_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + "gloo", + ) elif IS_KV_DECODE_INSTANCE: self.recv_pipe = TorchDistributedPipe( group_ranks, local_rank, torch_distributed_backend, ) + self.recv_signal_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + "gloo", + ) self.send_pipe = TorchDistributedPipe( group_ranks, local_rank, torch_distributed_backend, ) + self.send_signal_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + "gloo", + ) # FIXME(Jiayi): buffer initializtion should be adapted accordingly @@ -111,7 +131,10 @@ def __init__( # init lookup buffer # TODO: replace this 1e9 with a configurable parameter or a constant - self.buffer = SimpleKVLookupBuffer(self.cpu_pipe, self.device_pipe, 1e9 * 10) + #self.buffer = SimpleKVLookupBuffer(self.cpu_pipe, self.device_pipe, 1e9 * 10) + + self.send_buffer = SimpleKVLookupBuffer(self.send_pipe, self.send_signal_pipe, 1e9 * 10) + self.recv_buffer = SimpleKVLookupBuffer(self.recv_pipe, self.recv_signal_pipe, 1e9 * 10) def send_kv_caches_and_hidden_states( self, @@ -152,7 +175,7 @@ def send_kv_caches_and_hidden_states( keys = torch.cat(keys, dim=0) values = torch.cat(values, dim=0) - self.buffer.insert( + self.send_buffer.insert( current_tokens, torch.ones_like(current_tokens, dtype=bool), keys, @@ -197,7 +220,7 @@ def recv_kv_caches_and_hidden_states( input_tokens_list.append(current_tokens) start_pos_list.append(start_pos) - ret = self.buffer.drop_select( + ret = self.recv_buffer.drop_select( current_tokens, torch.ones_like(current_tokens, dtype=bool)) if ret[0] is None: From 54b68c9a170215196862054bece059bd47e96f01 Mon Sep 17 00:00:00 2001 From: YaoJiayi <120040070@link.cuhk.edu.cn> Date: Sun, 15 Sep 2024 06:45:09 -0500 Subject: [PATCH 184/303] minor fix --- vllm/distributed/kv_transfer/vllm_adapter.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py index 2ed4bea473454..db1e4a39dd0c3 100644 --- a/vllm/distributed/kv_transfer/vllm_adapter.py +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -30,7 +30,6 @@ from vllm.distributed.kv_transfer.kv_pipe.torch_distributed_pipe import TorchDistributedPipe from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_kv_lookup_buffer import SimpleKVLookupBuffer -from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata from copy import deepcopy assert envs.VLLM_DISAGG_PREFILL_ROLE in [None, "prefill", "decode", "lmcache"], \ @@ -139,7 +138,7 @@ def __init__( def send_kv_caches_and_hidden_states( self, model_executable: torch.nn.Module, - model_input: ModelInputForGPUWithSamplingMetadata, + model_input: "ModelInputForGPUWithSamplingMetadata", kv_caches: List[torch.Tensor], hidden_or_intermediate_states: Union[torch.Tensor, IntermediateTensors], ) -> None: @@ -190,9 +189,9 @@ def send_kv_caches_and_hidden_states( def recv_kv_caches_and_hidden_states( self, model_executable: torch.nn.Module, - model_input: ModelInputForGPUWithSamplingMetadata, + model_input: "ModelInputForGPUWithSamplingMetadata", kv_caches: List[torch.Tensor] - ) -> List[Union[torch.Tensor, IntermediateTensors], bool, ModelInputForGPUWithSamplingMetadata]: + ) -> List[Union[torch.Tensor, IntermediateTensors], bool, "ModelInputForGPUWithSamplingMetadata"]: bypass_model_exec = True @@ -291,13 +290,13 @@ def recv_kv_caches_and_hidden_states( def adpat_model_input( self, - model_input: ModelInputForGPUWithSamplingMetadata, + model_input: "ModelInputForGPUWithSamplingMetadata", input_tokens_list: List[torch.Tensor], num_computed_tokens_list: List[int], start_pos_list: List[int], slot_mapping_flat: torch.Tensor, device: torch.device, - ) -> ModelInputForGPUWithSamplingMetadata: + ) -> "ModelInputForGPUWithSamplingMetadata": rebuilt_input_tokens = [] rebuilt_input_positions= [] rebuilt_query_lens = [] @@ -373,6 +372,7 @@ def adpat_model_input( dtype=model_input.sampling_metadata.selected_token_indices.dtype, ).to(device) + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata rebuilt_model_input = ModelInputForGPUWithSamplingMetadata( input_tokens = torch.cat(rebuilt_input_tokens).to(device), input_positions = torch.cat(rebuilt_input_positions).to(device), From 2dff6580684f8436e7d18878fed6681227a7142b Mon Sep 17 00:00:00 2001 From: YaoJiayi <120040070@link.cuhk.edu.cn> Date: Sun, 15 Sep 2024 07:02:00 -0500 Subject: [PATCH 185/303] fix type hint --- vllm/distributed/kv_transfer/vllm_adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py index db1e4a39dd0c3..b60236700ab26 100644 --- a/vllm/distributed/kv_transfer/vllm_adapter.py +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -191,7 +191,7 @@ def recv_kv_caches_and_hidden_states( model_executable: torch.nn.Module, model_input: "ModelInputForGPUWithSamplingMetadata", kv_caches: List[torch.Tensor] - ) -> List[Union[torch.Tensor, IntermediateTensors], bool, "ModelInputForGPUWithSamplingMetadata"]: + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, "ModelInputForGPUWithSamplingMetadata"]: bypass_model_exec = True From c6a6714310b24d70188efd2dccc0ad24f699bc43 Mon Sep 17 00:00:00 2001 From: YaoJiayi <120040070@link.cuhk.edu.cn> Date: Sun, 15 Sep 2024 09:40:01 -0500 Subject: [PATCH 186/303] fix comm init --- vllm/distributed/parallel_state.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index c48b113de9705..92f2fe03d2bc1 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -191,10 +191,10 @@ def init_distributed_environment( # this backend is used for WORLD maybe_disagg_world_size = world_size maybe_disagg_rank = rank - if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: + if dist_kv.IS_DISTRIBUTED_KV_INSTANCE or dist_kv.IS_LMCACHE_INSTANCE: maybe_disagg_world_size = world_size * 2 logger.debug("Disaggregated prefill enabled.") - if dist_kv.IS_KV_PREFILL_INSTANCE: + if dist_kv.IS_KV_PREFILL_INSTANCE or dist_kv.IS_LMCACHE_INSTANCE: # for prefill, the ranks are [0, world_size) maybe_disagg_rank = rank else: @@ -227,7 +227,7 @@ def init_distributed_environment( if _WORLD is None: ranks = [[i for i in range(world_size)]] # offset the distributed group - if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: + if dist_kv.IS_DISTRIBUTED_KV_INSTANCE or dist_kv.IS_LMCACHE_INSTANCE: ranks = include_decoding_groups_if_disagg_enabled( ranks, world_size) @@ -289,7 +289,7 @@ def initialize_model_parallel( world_size: int = torch.distributed.get_world_size() backend = backend or torch.distributed.get_backend( get_world_group().device_group) - if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: + if dist_kv.IS_DISTRIBUTED_KV_INSTANCE or dist_kv.IS_LMCACHE_INSTANCE: # Disaggregated prefill enabled # The world_size for this vLLM instance is tp * pp, but torch.distributed contains 2 vLLM instances, its world size is 2 * tp * pp # Adjust the world_size to match. @@ -341,7 +341,8 @@ def initialize_model_parallel( use_custom_allreduce=False) logger.debug("_PP initialized for rank %d", torch.distributed.get_rank()) - if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: + # TODO(Jiayi): perhaps we need to separate lmcache and disagg + if dist_kv.IS_DISTRIBUTED_KV_INSTANCE or dist_kv.IS_LMCACHE_INSTANCE: global _DISAGG logger.debug("Disaggregated prefill enabled, create _DISAGG group") group_ranks = [] From fef35b24cfb2884f3e26b86c85ec9d76ae629039 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Sun, 15 Sep 2024 17:42:34 +0000 Subject: [PATCH 187/303] bug fix: remove self from bypass_model_exec --- vllm/distributed/kv_transfer/vllm_adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py index b60236700ab26..b74bf3c758268 100644 --- a/vllm/distributed/kv_transfer/vllm_adapter.py +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -224,7 +224,7 @@ def recv_kv_caches_and_hidden_states( torch.ones_like(current_tokens, dtype=bool)) if ret[0] is None: # didn't find any match. - self.bypass_model_exec = False + bypass_model_exec = False num_computed_tokens_list.append(0) continue From 4d0b5cdf9ccbd99d6641da0af17a42e2fe7fad52 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Sun, 15 Sep 2024 18:59:38 +0000 Subject: [PATCH 188/303] bug fix: should init SimpleKVLookupBuffer with signal pipe first and then data pipe --- vllm/distributed/kv_transfer/vllm_adapter.py | 51 +++++++------------- 1 file changed, 18 insertions(+), 33 deletions(-) diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py index b74bf3c758268..93edeaa1d5d01 100644 --- a/vllm/distributed/kv_transfer/vllm_adapter.py +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -67,21 +67,9 @@ def __init__( torch_distributed_backend: Union[str, Backend], ): - ''' - # init pipe - self.device_pipe = TorchDistributedPipe( - group_ranks, - local_rank, - torch_distributed_backend, - ) - self.cpu_pipe = TorchDistributedPipe( - group_ranks, - local_rank, - "gloo" - ) - ''' - # init 4 pipes: 2 * (one for send and one for recv) - if IS_KV_PREFILL_INSTANCE or IS_LMCACHE_INSTANCE: + if IS_LMCACHE_INSTANCE: + # when vLLM is connected with LMCache + # it needs to both send and recv KV cache self.send_pipe = TorchDistributedPipe( group_ranks, local_rank, @@ -102,27 +90,26 @@ def __init__( local_rank, "gloo", ) - elif IS_KV_DECODE_INSTANCE: - self.recv_pipe = TorchDistributedPipe( - group_ranks, - local_rank, - torch_distributed_backend, - ) - self.recv_signal_pipe = TorchDistributedPipe( - group_ranks, - local_rank, - "gloo", - ) - self.send_pipe = TorchDistributedPipe( + self.send_buffer = SimpleKVLookupBuffer(self.send_signal_pipe, self.send_pipe, 1e9 * 10) + self.recv_buffer = SimpleKVLookupBuffer(self.recv_signal_pipe, self.recv_pipe, 1e9 * 10) + else: + # when performing disaggregated prefill, only 1 pipe is needed + # at prefill instance this pipe is used for send KV cache + # at decode instance this pipe is used for recv KV cache + self.pipe = TorchDistributedPipe( group_ranks, local_rank, torch_distributed_backend, ) - self.send_signal_pipe = TorchDistributedPipe( + self.signal_pipe = TorchDistributedPipe( group_ranks, local_rank, "gloo", ) + self.send_buffer = SimpleKVLookupBuffer(self.signal_pipe, self.pipe, 1e9 * 10) + self.recv_buffer = self.send_buffer + + # FIXME(Jiayi): buffer initializtion should be adapted accordingly @@ -132,9 +119,7 @@ def __init__( # TODO: replace this 1e9 with a configurable parameter or a constant #self.buffer = SimpleKVLookupBuffer(self.cpu_pipe, self.device_pipe, 1e9 * 10) - self.send_buffer = SimpleKVLookupBuffer(self.send_pipe, self.send_signal_pipe, 1e9 * 10) - self.recv_buffer = SimpleKVLookupBuffer(self.recv_pipe, self.recv_signal_pipe, 1e9 * 10) - + def send_kv_caches_and_hidden_states( self, model_executable: torch.nn.Module, @@ -269,7 +254,7 @@ def recv_kv_caches_and_hidden_states( return None, bypass_model_exec, None if not is_complete: - rebuilt_model_input = self.adpat_model_input( + rebuilt_model_input = self.build_partial_prefill_input( model_input, input_tokens_list, num_computed_tokens_list, @@ -288,7 +273,7 @@ def recv_kv_caches_and_hidden_states( return hidden_or_intermediate_states, bypass_model_exec, model_input - def adpat_model_input( + def build_partial_prefill_input( self, model_input: "ModelInputForGPUWithSamplingMetadata", input_tokens_list: List[torch.Tensor], From 31b891db4170a00572c3c81a962555411f9bee3c Mon Sep 17 00:00:00 2001 From: ApostaC Date: Sun, 15 Sep 2024 19:04:15 +0000 Subject: [PATCH 189/303] adjust torch distributed logging --- .../kv_transfer/kv_pipe/torch_distributed_pipe.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py index c77da1e6b75b8..3a6a94bb0e752 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py @@ -226,8 +226,14 @@ def send_tensor_wrapper(self, tensor): with self.buffer_size_lock: self.buffer_size = self.buffer_size - tensor_size except Exception as e: - logger.error("Encountering exception in KV sending thread") - logger.error("%s", e) + logger.error("[rank%d]: Exception when trying to send %s, msg: %s", + torch.distributed.get_rank(), + str(tensor), + str(e)) + import traceback + traceback.print_exc() + + def block_if_full(self): """ @@ -279,10 +285,11 @@ def recv_tensor(self) -> Optional[torch.Tensor]: try: tensor = future.result() except Exception as e: + # the underlying pipe is likely broken logger.error("Encountering exception in KV receiving thread") logger.error("%s", e) - - #tensor = self._recv_impl() + # fault tolerance: if the pipe is broken, return None + return None if tensor.numel() == 1 and tensor.item() == NONE_INT: return None From 7e68d089032077815cb9164035325154036c7b9e Mon Sep 17 00:00:00 2001 From: ApostaC Date: Sun, 15 Sep 2024 19:04:29 +0000 Subject: [PATCH 190/303] remove unnecessaqry comments --- vllm/distributed/kv_transfer/vllm_adapter.py | 30 ++++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py index 93edeaa1d5d01..15b512512055c 100644 --- a/vllm/distributed/kv_transfer/vllm_adapter.py +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -65,6 +65,8 @@ def __init__( group_ranks: List[List[int]], local_rank: int, torch_distributed_backend: Union[str, Backend], + # FIXME(Kuntai): remove this hardcoding + lookup_buffer_size: int = 1e10 ): if IS_LMCACHE_INSTANCE: @@ -90,8 +92,14 @@ def __init__( local_rank, "gloo", ) - self.send_buffer = SimpleKVLookupBuffer(self.send_signal_pipe, self.send_pipe, 1e9 * 10) - self.recv_buffer = SimpleKVLookupBuffer(self.recv_signal_pipe, self.recv_pipe, 1e9 * 10) + self.send_buffer = SimpleKVLookupBuffer( + self.send_signal_pipe, + self.send_pipe, + self.lookup_buffer_size) + self.recv_buffer = SimpleKVLookupBuffer( + self.recv_signal_pipe, + self.recv_pipe, + self.lookup_buffer_size) else: # when performing disaggregated prefill, only 1 pipe is needed # at prefill instance this pipe is used for send KV cache @@ -106,19 +114,11 @@ def __init__( local_rank, "gloo", ) - self.send_buffer = SimpleKVLookupBuffer(self.signal_pipe, self.pipe, 1e9 * 10) - self.recv_buffer = self.send_buffer - - - - - # FIXME(Jiayi): buffer initializtion should be adapted accordingly - # Signal pipe needs to be initialized on both vllm and lmc side - - # init lookup buffer - # TODO: replace this 1e9 with a configurable parameter or a constant - #self.buffer = SimpleKVLookupBuffer(self.cpu_pipe, self.device_pipe, 1e9 * 10) - + self.send_buffer = SimpleKVLookupBuffer( + self.signal_pipe, + self.pipe, + self.lookup_buffer_size) + self.recv_buffer = self.send_buffer def send_kv_caches_and_hidden_states( self, From 85c7a644f87fa5477fb6ea5812b70882981fc7e6 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Sun, 15 Sep 2024 19:18:00 +0000 Subject: [PATCH 191/303] remove unnecessary comments --- tests/kv_transfer/test_lookup_buffer.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/kv_transfer/test_lookup_buffer.py b/tests/kv_transfer/test_lookup_buffer.py index 5ccbc5b0f5865..ae19d068be9fa 100644 --- a/tests/kv_transfer/test_lookup_buffer.py +++ b/tests/kv_transfer/test_lookup_buffer.py @@ -27,10 +27,7 @@ def test_run(my_rank, buffer, device): placeholder = torch.tensor([1]).to(device) buffer.insert(tokens, roi, key, value, placeholder) - - #for i in range(2000): - # print("Here:", i) - # time.sleep(0.01) + torch.distributed.barrier() # drop_select From 01fe335ce3fb3c7023fd527db8a000e3a96e8996 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Sun, 15 Sep 2024 20:09:47 +0000 Subject: [PATCH 192/303] update documentation --- vllm/distributed/kv_transfer/vllm_adapter.py | 31 +++++++++++++------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py index 15b512512055c..e504d886466f6 100644 --- a/vllm/distributed/kv_transfer/vllm_adapter.py +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -1,14 +1,19 @@ """vLLM distributed KV cache transfer API. -These APIs are used in `vllm/worker/model_runner.py`. +These APIs are used in `vllm/worker/worker_base.py`. -Currently supporting TP and PP, but TP and PP must be the same. +Currently supporting TP. The TP between prefill and decode instance needs to be the same. -Workflow: -- In prefill instance, vLLM `insert` that buffers the KV cache into lookup buffer. +Workflow (disaggregated prefill) +- In prefill instance + - After prefill, vLLM `insert` its KV caches into a lookup buffer. + - The prefill instance will also open up a thread that listens to `drop_select` request. - In decode instance - - vLLM first runs `drop_select` to send input tokens and a mask on input tokens to sender - - The prefill instance send back the matching KV caches - - vLLM then store the KV cache into paged memory. + - vLLM first runs `drop_select` to send input tokens and a mask on input tokens (we call it roi, region of interest) to prefill instance + - The prefill instance then respond to `drop_select` request by + - Finding a match in current lookup buffer. + - Clone and send the matched item out + - Delete the matched item in the lookup buffer to free up GPU memory. + - The decode vLLM then store the KV cache into paged memory. """ from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING from collections import defaultdict, deque @@ -69,6 +74,8 @@ def __init__( lookup_buffer_size: int = 1e10 ): + self.lookup_buffer_size = lookup_buffer_size + if IS_LMCACHE_INSTANCE: # when vLLM is connected with LMCache # it needs to both send and recv KV cache @@ -114,11 +121,12 @@ def __init__( local_rank, "gloo", ) - self.send_buffer = SimpleKVLookupBuffer( + buffer = SimpleKVLookupBuffer( self.signal_pipe, self.pipe, self.lookup_buffer_size) - self.recv_buffer = self.send_buffer + self.send_buffer = buffer + self.recv_buffer = buffer def send_kv_caches_and_hidden_states( self, @@ -178,6 +186,7 @@ def recv_kv_caches_and_hidden_states( kv_caches: List[torch.Tensor] ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, "ModelInputForGPUWithSamplingMetadata"]: + # When this flag is set to False, it means that bypass_model_exec = True # This is disagg decode instance, during prefill state @@ -226,9 +235,7 @@ def recv_kv_caches_and_hidden_states( for i in range(model_executable.model.start_layer, model_executable.model.end_layer): - # get kv cache kv_cache = kv_caches[i - model_executable.model.start_layer] - # get corresponding layer layer = model_executable.model.layers[i] key_cache, value_cache = kv_cache[0], kv_cache[1] @@ -297,6 +304,7 @@ def build_partial_prefill_input( rebuilt_context_lens_tensor = [] rebuilt_selected_token_indices = [] + # recounting query and context lengths for idx in range(len(input_tokens_list)): token_tensor = input_tokens_list[idx] num_token = len(token_tensor) @@ -357,6 +365,7 @@ def build_partial_prefill_input( dtype=model_input.sampling_metadata.selected_token_indices.dtype, ).to(device) + # import here to avoid circular import. from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata rebuilt_model_input = ModelInputForGPUWithSamplingMetadata( input_tokens = torch.cat(rebuilt_input_tokens).to(device), From caaaeb8a1bf94a202be5bd6e6a3a2dacfefe8402 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Sun, 15 Sep 2024 20:15:51 +0000 Subject: [PATCH 193/303] update overhead benchmark --- .../disagg_overhead_benchmark.sh | 2 +- .../simple_kv_lookup_buffer.py | 35 ++++++++++++++----- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh index d264f18156438..f0ee54357af74 100644 --- a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh @@ -53,7 +53,7 @@ benchmark() { model="meta-llama/Meta-Llama-3.1-8B-Instruct" dataset_name="sonnet" dataset_path="../sonnet_4x.txt" - num_prompts=20 + num_prompts=10 qps=$1 prefix_len=50 input_len=2048 diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py index de2e7cf5d5d57..6172bf092fb03 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py @@ -2,7 +2,7 @@ from vllm.distributed.kv_transfer.kv_lookup_buffer.base import \ KVLookupBufferBase from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase -from typing import Dict, Tuple, List, Optional +from typing import Dict, Tuple, List, Optional, Union import threading import torch from collections import deque @@ -14,10 +14,19 @@ class SimpleKVLookupBuffer(KVLookupBufferBase): - def __init__(self, signal_pipe, data_pipe, buffer_size_thresh): + def __init__(self, + signal_pipe: KVPipeBase, + data_pipe: KVPipeBase, + buffer_size_thresh: int): """ - signal_pipe: on CPU -- avoid recv() stops the python intepreter - data_pipe: on GPU + signal_pipe: on CPU + + NOTE: on-device recv will block all threads in the process, making the + KV cache producer unable to listen to new request while transmitting + KV cache. Luckily CPU recv only blocks the current thread so we use + CPU recv to listen to new request. + + data_pipe: on device (e.g. GPU) """ self.buffer = deque() @@ -33,7 +42,9 @@ def __init__(self, signal_pipe, data_pipe, buffer_size_thresh): self.end_signal = None - def _matches(self, tokens_roi_sender, tokens_roi_recver): + def _matches(self, + tokens_roi_sender: List[torch.Tensor], + tokens_roi_recver: List[torch.Tensor]): # tokens_roi_sender: tokens and roi of the producer (in the buffer) # tokens_roi_recver: tokens and roi of the consumer (query) @@ -69,7 +80,7 @@ def _send_tensor_and_dec_size(self, tensor: Optional[torch.Tensor]) -> None: self.buffer_size -= tensor.element_size() * tensor.numel() self.data_pipe.send_tensor(tensor) - def _get_element_size(self, data): + def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]): if data == [] or data is None: return 0 @@ -78,7 +89,12 @@ def _get_element_size(self, data): assert False, "Unknown data type %s" % type(data) - def _add_to_buffer(self, input_tokens, roi, key, value, hidden): + def _add_to_buffer(self, + input_tokens: torch.Tensor, + roi: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + hidden: torch.Tensor): if isinstance(input_tokens, torch.Tensor): input_tokens = input_tokens.clone() @@ -150,7 +166,9 @@ def drop_select_handler(self): logger.debug("Closing drop_select_handler") - def drop_select(self, input_tokens, roi): + def drop_select(self, + input_tokens: torch.Tensor, + roi: torch.Tensor): assert self.request_handling_thread is None, \ "drop_select should be called by the receiver" @@ -183,6 +201,7 @@ def insert(self, input_tokens, roi, key, value, hidden) -> None: while self.buffer_size > self.buffer_size_threshold: # logger.debug("KV transfer buffer is full. Handling...") self.full_handler() + self._add_to_buffer(input_tokens, roi, key, value, hidden) From 515c47b4cea8f638e80fef32e06ad07765594016 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Sun, 15 Sep 2024 20:46:39 +0000 Subject: [PATCH 194/303] remove group coordinator import --- vllm/distributed/kv_transfer/vllm_adapter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py index e504d886466f6..9a6b55cbbe660 100644 --- a/vllm/distributed/kv_transfer/vllm_adapter.py +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -27,7 +27,6 @@ from torch.distributed import Backend, ProcessGroup import vllm.envs as envs -from vllm.distributed.group_coordinator import GroupCoordinator from vllm.logger import init_logger import vllm.distributed.parallel_state as ps from vllm import _custom_ops as ops From f166cf8db07fa93bd626564e0547cace2615232b Mon Sep 17 00:00:00 2001 From: ApostaC Date: Sun, 15 Sep 2024 20:48:19 +0000 Subject: [PATCH 195/303] remove syntax bug --- vllm/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/utils.py b/vllm/utils.py index 48ca1a67123b9..1adab61917265 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -529,7 +529,7 @@ def get_open_zmq_ipc_path() -> str: return f"ipc://{base_rpc_path}/{uuid4()}" -def get_open_port(, force: bool = False) -> int: +def get_open_port(force: bool = False) -> int: port = envs.VLLM_PORT if port is not None: if force and port is not None: From f320518485307b5eb7a4006253929b9bbc654bf7 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Sun, 15 Sep 2024 21:09:35 +0000 Subject: [PATCH 196/303] update round robin proxy. Prior bash-based impl is buggy --- .../disagg_benchmarks/round_robin_proxy.py | 94 +++++++++++++++++++ .../disagg_benchmarks/round_robin_proxy.sh | 19 ---- 2 files changed, 94 insertions(+), 19 deletions(-) create mode 100644 benchmarks/disagg_benchmarks/round_robin_proxy.py delete mode 100644 benchmarks/disagg_benchmarks/round_robin_proxy.sh diff --git a/benchmarks/disagg_benchmarks/round_robin_proxy.py b/benchmarks/disagg_benchmarks/round_robin_proxy.py new file mode 100644 index 0000000000000..04a30f774670a --- /dev/null +++ b/benchmarks/disagg_benchmarks/round_robin_proxy.py @@ -0,0 +1,94 @@ +import asyncio +import aiohttp +from aiohttp import web +import itertools + +class AsyncRoundRobinProxy: + def __init__(self, backend_ports): + self.backend_ports = itertools.cycle(backend_ports) + self.session = None + + async def start(self): + self.session = aiohttp.ClientSession() + + async def stop(self): + if self.session: + await self.session.close() + + async def handle_request(self, request): + backend_port = next(self.backend_ports) + print("forwarding to port", backend_port) + backend_url = f"http://localhost:{backend_port}{request.path_qs}" + + try: + async with self.session.request( + method=request.method, + url=backend_url, + headers=request.headers, + data=await request.read() + ) as backend_response: + response = web.StreamResponse( + status=backend_response.status, + headers=backend_response.headers + ) + await response.prepare(request) + + async for chunk in backend_response.content.iter_any(): + await response.write(chunk) + + await response.write_eof() + return response + + except aiohttp.ClientError as e: + return web.Response(text=f"Backend error: {str(e)}", status=502) + +async def run_backend(port): + async def handle(request): + if request.path == '/stream': + response = web.StreamResponse( + status=200, + headers={'Content-Type': 'text/plain'} + ) + await response.prepare(request) + for i in range(10): + await response.write(f"Chunk {i}\n".encode()) + await asyncio.sleep(0.5) # Simulate delay between chunks + return response + else: + return web.Response(text=f"Response from backend on port {port}") + + app = web.Application() + app.router.add_route('*', '/{tail:.*}', handle) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, 'localhost', port) + await site.start() + print(f"Backend running on http://localhost:{port}") + +async def main(): + proxy = AsyncRoundRobinProxy([8100, 8200]) + await proxy.start() + + app = web.Application() + app.router.add_route('*', '/{tail:.*}', proxy.handle_request) + + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, 'localhost', 8000) + + await asyncio.gather( + site.start(), + run_backend(8100), + run_backend(8200) + ) + + print("Proxy running on http://localhost:8000") + + try: + await asyncio.Future() # Run forever + finally: + await proxy.stop() + await runner.cleanup() + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/benchmarks/disagg_benchmarks/round_robin_proxy.sh b/benchmarks/disagg_benchmarks/round_robin_proxy.sh deleted file mode 100644 index 375bf9e422371..0000000000000 --- a/benchmarks/disagg_benchmarks/round_robin_proxy.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash - -# Define the ports to forward to -PORTS=(8100 8200) -NUM_PORTS=${#PORTS[@]} -CURRENT=0 - -# Function to handle the round-robin logic -get_next_port() { - NEXT_PORT=${PORTS[$CURRENT]} - CURRENT=$(( (CURRENT + 1) % NUM_PORTS )) - echo $NEXT_PORT -} - -# Start the proxy -while true; do - NEXT_PORT=$(get_next_port) - socat TCP4-LISTEN:8000,reuseaddr,fork TCP4:localhost:$NEXT_PORT 2>/dev/null -done \ No newline at end of file From 5b4a3e3140d99b812180a4eb9e4c0afef529726e Mon Sep 17 00:00:00 2001 From: ApostaC Date: Sun, 15 Sep 2024 21:12:35 +0000 Subject: [PATCH 197/303] update docs for disagg overhead benchmark --- .../disagg_overhead_benchmark.sh | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh index f0ee54357af74..36116172ab7c2 100644 --- a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh @@ -1,17 +1,10 @@ #!/bin/bash -# Requirement: 8x H100 GPUs. - - -# Model: neuralmagic/Meta-Llama-3-70B-Instruct-FP8-KV -# Query: 2048 input tokens, 11 output tokens, QPS 4, 500 requests -# Resource: 8x H100 -# Approaches: -# 1. Chunked prefill: 1 vllm instance with tp=8 -# 2. Chunked prefill: 2 vllm instance with tp=4, equivalent to 1 tp=4 instance with QPS 4 -# 3. Disaggregated prefill: 1 prefilling instance and 1 decoding instance -# Prefilling instance: max_output_token=1 -# Decoding instance: force the input tokens be the same across requests to bypass prefilling +# benchmark the overhead of disaggregated prefill. +# methodology: +# - send all request to prefill vLLM instance. It will buffer KV cache. +# - then send all request to decode instance. +# - The TTFT of decode instance is the overhead. set -ex From 01b2fd3068a709efd62b52e31a99b0cf9cbca030 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Sun, 15 Sep 2024 21:15:00 +0000 Subject: [PATCH 198/303] use new round robin proxy in performance benchmark --- benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh index dde9a80b59b37..715fe56d6c597 100644 --- a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh @@ -19,7 +19,6 @@ kill_gpu_processes() { # kill all processes on GPU. pkill -f pt_main_thread pkill -f python3 - pkill -f round_robin_proxy.sh ps -e | grep pt_main_thread | awk '{print $1}' | xargs kill -9 for port in 8000 8100 8200; do lsof -t -i:$port | xargs -r kill -9; done sleep 1 @@ -61,7 +60,7 @@ launch_chunked_prefill() { --gpu-memory-utilization 0.8 & wait_for_server 8100 wait_for_server 8200 - bash round_robin_proxy.sh & + python3 round_robin_proxy.py & sleep 1 } @@ -149,7 +148,7 @@ main() { mkdir results default_qps=10 - default_output_len=150 + default_output_len=10 export VLLM_LOGGING_LEVEL=DEBUG export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') From 54bd11f169daf6f9a0d639ce8695db71598778c8 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Sun, 15 Sep 2024 21:19:42 +0000 Subject: [PATCH 199/303] update --- .../disagg_benchmarks/round_robin_proxy.py | 117 ++++++------------ 1 file changed, 40 insertions(+), 77 deletions(-) diff --git a/benchmarks/disagg_benchmarks/round_robin_proxy.py b/benchmarks/disagg_benchmarks/round_robin_proxy.py index 04a30f774670a..8751e24a08d33 100644 --- a/benchmarks/disagg_benchmarks/round_robin_proxy.py +++ b/benchmarks/disagg_benchmarks/round_robin_proxy.py @@ -3,92 +3,55 @@ from aiohttp import web import itertools -class AsyncRoundRobinProxy: - def __init__(self, backend_ports): - self.backend_ports = itertools.cycle(backend_ports) - self.session = None - - async def start(self): - self.session = aiohttp.ClientSession() - - async def stop(self): - if self.session: - await self.session.close() +class RoundRobinProxy: + def __init__(self, target_ports): + self.target_ports = target_ports + self.port_cycle = itertools.cycle(self.target_ports) async def handle_request(self, request): - backend_port = next(self.backend_ports) - print("forwarding to port", backend_port) - backend_url = f"http://localhost:{backend_port}{request.path_qs}" - - try: - async with self.session.request( - method=request.method, - url=backend_url, - headers=request.headers, - data=await request.read() - ) as backend_response: - response = web.StreamResponse( - status=backend_response.status, - headers=backend_response.headers - ) - await response.prepare(request) - - async for chunk in backend_response.content.iter_any(): - await response.write(chunk) - - await response.write_eof() - return response - - except aiohttp.ClientError as e: - return web.Response(text=f"Backend error: {str(e)}", status=502) - -async def run_backend(port): - async def handle(request): - if request.path == '/stream': - response = web.StreamResponse( - status=200, - headers={'Content-Type': 'text/plain'} - ) - await response.prepare(request) - for i in range(10): - await response.write(f"Chunk {i}\n".encode()) - await asyncio.sleep(0.5) # Simulate delay between chunks - return response - else: - return web.Response(text=f"Response from backend on port {port}") - - app = web.Application() - app.router.add_route('*', '/{tail:.*}', handle) - runner = web.AppRunner(app) - await runner.setup() - site = web.TCPSite(runner, 'localhost', port) - await site.start() - print(f"Backend running on http://localhost:{port}") + target_port = next(self.port_cycle) + target_url = f"http://localhost:{target_port}{request.path_qs}" + + async with aiohttp.ClientSession() as session: + try: + # Forward the request + async with session.request( + method=request.method, + url=target_url, + headers=request.headers, + data=request.content, + ) as response: + # Start sending the response + resp = web.StreamResponse( + status=response.status, + headers=response.headers + ) + await resp.prepare(request) + + # Stream the response content + async for chunk in response.content.iter_any(): + await resp.write(chunk) + + await resp.write_eof() + return resp + + except Exception as e: + return web.Response(text=f"Error: {str(e)}", status=500) async def main(): - proxy = AsyncRoundRobinProxy([8100, 8200]) - await proxy.start() - + proxy = RoundRobinProxy([8100, 8200]) app = web.Application() - app.router.add_route('*', '/{tail:.*}', proxy.handle_request) + app.router.add_route('*', '/{path:.*}', proxy.handle_request) runner = web.AppRunner(app) await runner.setup() site = web.TCPSite(runner, 'localhost', 8000) + await site.start() - await asyncio.gather( - site.start(), - run_backend(8100), - run_backend(8200) - ) - - print("Proxy running on http://localhost:8000") - - try: - await asyncio.Future() # Run forever - finally: - await proxy.stop() - await runner.cleanup() + print("Proxy server started on http://localhost:8000") + + # Keep the server running + await asyncio.Event().wait() -if __name__ == "__main__": +if __name__ == '__main__': asyncio.run(main()) \ No newline at end of file From b19f346a0c72799ccfaf9e9f89b5b6a8938e0330 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Sun, 15 Sep 2024 22:13:26 +0000 Subject: [PATCH 200/303] update benchmarking script --- .../analyze_benchmark_results.py | 48 ------------------- .../disagg_performance_benchmark.sh | 17 +++---- 2 files changed, 9 insertions(+), 56 deletions(-) delete mode 100644 benchmarks/disagg_benchmarks/analyze_benchmark_results.py diff --git a/benchmarks/disagg_benchmarks/analyze_benchmark_results.py b/benchmarks/disagg_benchmarks/analyze_benchmark_results.py deleted file mode 100644 index 4b675c675d25f..0000000000000 --- a/benchmarks/disagg_benchmarks/analyze_benchmark_results.py +++ /dev/null @@ -1,48 +0,0 @@ - -import argparse -import json -import yaml -import os -from pathlib import Path - -def load(path): - - with open(str(path), 'r') as f: - return json.loads(f.read()) - -def main(args): - - results = Path(args.results_folder) - - chunk = load(results / "chunked_prefill_tp4.json") - prefill = load(results / "disagg_prefill_tp4.json") - decode = load(results / "disagg_decode_tp4.json") - - ttft_ratio = chunk["mean_ttft_ms"] / prefill["mean_ttft_ms"] - itl_ratio = chunk["mean_itl_ms"] / decode["mean_itl_ms"] - prefill_decode_ratio = prefill["mean_ttft_ms"] / (decode["mean_itl_ms"] * args.output_len) - - with open(results / args.output_file, 'a') as f: - f.write(yaml.dump([{ - 'qps': args.qps, - 'output_len': args.output_len, - 'prefill_decode_ratio': prefill_decode_ratio, - 'ttft_ratio': ttft_ratio, - 'itl_ratio': itl_ratio, - "chunk_ttft": chunk["mean_ttft_ms"], - "chunk_itl": chunk["mean_itl_ms"], - "disagg_ttft": prefill["mean_ttft_ms"], - "disagg_itl": decode["mean_itl_ms"] - }])) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Analyze benchmark results") - parser.add_argument("--results-folder", required=True, help="Path to the results folder") - parser.add_argument("--output-len", type=int, required=True, help="Target output length") - parser.add_argument("--qps", type=int, required=True, help="Target QPS") - parser.add_argument("--output-file", type=str, default="chunk_vs_disagg.yaml") - - args = parser.parse_args() - main(args) - \ No newline at end of file diff --git a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh index 715fe56d6c597..734679660c233 100644 --- a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh @@ -43,7 +43,7 @@ launch_chunked_prefill() { --model $model \ --port 8100 \ -tp 4 \ - --max-model-len 30000 \ + --max-model-len 10000 \ --disable-log-stats \ --disable-log-requests \ --enable-chunked-prefill \ @@ -53,7 +53,7 @@ launch_chunked_prefill() { --model $model \ --port 8200 \ -tp 4 \ - --max-model-len 30000 \ + --max-model-len 10000 \ --disable-log-stats \ --disable-log-requests \ --enable-chunked-prefill \ @@ -73,7 +73,7 @@ launch_disagg_prefill() { --model $model \ --port 8100 \ -tp 4 \ - --max-model-len 30000 \ + --max-model-len 10000 \ --disable-log-stats \ --disable-log-requests \ --gpu-memory-utilization 0.8 & @@ -82,7 +82,7 @@ launch_disagg_prefill() { --model $model \ --port 8200 \ -tp 4 \ - --max-model-len 30000 \ + --max-model-len 10000 \ --disable-log-stats \ --disable-log-requests \ --gpu-memory-utilization 0.8 & @@ -98,10 +98,10 @@ benchmark() { model="meta-llama/Meta-Llama-3.1-70B-Instruct" dataset_name="sonnet" dataset_path="../sonnet_4x.txt" - num_prompts=400 + num_prompts=200 qps=$1 prefix_len=50 - input_len=2048 + input_len=1024 output_len=$2 tag=$3 @@ -131,7 +131,7 @@ main() { (which jq) || (apt-get -y install jq) (which socat) || (apt-get -y install socat) - pip install quart httpx + pip install quart httpx matplotlib aiohttp cd "$(dirname "$0")" @@ -147,7 +147,6 @@ main() { rm -rf results mkdir results - default_qps=10 default_output_len=10 export VLLM_LOGGING_LEVEL=DEBUG @@ -165,6 +164,8 @@ main() { done kill_gpu_processes + python3 visualize_benchmark_results.py + } From cb7ff06e1a9c230ee48479833df2da21ec96a7b9 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Sun, 15 Sep 2024 22:18:40 +0000 Subject: [PATCH 201/303] revert changes in model_runner.py --- no change needed for disagg prefill --- vllm/worker/model_runner.py | 39 +++++++++++++------------------------ 1 file changed, 13 insertions(+), 26 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 8c4899a8b7f50..447d303a57fd8 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -14,7 +14,6 @@ import torch.distributed import torch.nn as nn - import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention.backends.abstract import AttentionState @@ -1545,30 +1544,21 @@ def execute_model( model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_start.record() - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(multi_modal_kwargs, - device=self.device), - **seqlen_agnostic_kwargs) - + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), + **seqlen_agnostic_kwargs) + if (self.observability_config is not None and self.observability_config.collect_model_forward_time): model_forward_end.record() - return hidden_or_intermediate_states - - @torch.inference_mode() - def postprocess_model( - self, - model_input, - hidden_or_intermediate_states, - - ): + # Compute the logits in the last pipeline stage. if not get_pp_group().is_last_rank: if (self.is_driver_worker and hidden_or_intermediate_states is not None @@ -1586,7 +1576,7 @@ def postprocess_model( hidden_or_intermediate_states.tensors["model_forward_time"] = ( torch.tensor(model_forward_time + orig_model_forward_time)) return hidden_or_intermediate_states - + logits = self.model.compute_logits(hidden_or_intermediate_states, model_input.sampling_metadata) @@ -1618,7 +1608,6 @@ def postprocess_model( output.model_forward_time = (orig_model_forward_time + model_forward_time) - decode_meta = model_input.attn_metadata.decode_metadata if self.return_hidden_states: # we only need to pass hidden states of most recent token assert model_input.sampling_metadata is not None @@ -1635,9 +1624,7 @@ def postprocess_model( output.hidden_states = hidden_states return [output] - - - + class CUDAGraphRunner: @@ -1808,4 +1795,4 @@ def _get_max_graph_batch_size(max_num_seqs: int) -> int: if padded_size in _BATCH_SIZES_TO_CAPTURE: return padded_size assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1] - return _BATCH_SIZES_TO_CAPTURE[-1] + return _BATCH_SIZES_TO_CAPTURE[-1] \ No newline at end of file From dd8c86d3e171ee045014ad65c6431b8493c25f86 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Sun, 15 Sep 2024 22:21:02 +0000 Subject: [PATCH 202/303] no I was wrong --- vllm/worker/model_runner.py | 56 ++++++++++++++++++------------------- 1 file changed, 27 insertions(+), 29 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 447d303a57fd8..ab38302b3321a 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -14,6 +14,7 @@ import torch.distributed import torch.nn as nn + import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention.backends.abstract import AttentionState @@ -54,6 +55,7 @@ _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_tensor_dict, _init_sampling_metadata_from_tensor_dict, dump_input_when_exception) +from vllm import _custom_ops as ops if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -1544,21 +1546,30 @@ def execute_model( model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_start.record() - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(multi_modal_kwargs, - device=self.device), - **seqlen_agnostic_kwargs) + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), + **seqlen_agnostic_kwargs) + if (self.observability_config is not None and self.observability_config.collect_model_forward_time): model_forward_end.record() - # Compute the logits in the last pipeline stage. + return hidden_or_intermediate_states + + @torch.inference_mode() + def postprocess_model( + self, + model_input, + hidden_or_intermediate_states, + + ): if not get_pp_group().is_last_rank: if (self.is_driver_worker and hidden_or_intermediate_states is not None @@ -1576,7 +1587,7 @@ def execute_model( hidden_or_intermediate_states.tensors["model_forward_time"] = ( torch.tensor(model_forward_time + orig_model_forward_time)) return hidden_or_intermediate_states - + logits = self.model.compute_logits(hidden_or_intermediate_states, model_input.sampling_metadata) @@ -1591,23 +1602,8 @@ def execute_model( logits=logits, sampling_metadata=model_input.sampling_metadata, ) - if (self.observability_config is not None - and self.observability_config.collect_model_forward_time - and output is not None): - model_forward_end.synchronize() - model_forward_time = model_forward_start.elapsed_time( - model_forward_end) - orig_model_forward_time = 0.0 - if intermediate_tensors is not None: - orig_model_forward_time = intermediate_tensors.tensors.get( - "model_forward_time", torch.tensor(0.0)).item() - # If there are multiple workers, we are still tracking the latency - # from the start time of the driver worker to the end time of the - # driver worker. The model forward time will then end up covering - # the communication time as well. - output.model_forward_time = (orig_model_forward_time + - model_forward_time) + decode_meta = model_input.attn_metadata.decode_metadata if self.return_hidden_states: # we only need to pass hidden states of most recent token assert model_input.sampling_metadata is not None @@ -1624,7 +1620,9 @@ def execute_model( output.hidden_states = hidden_states return [output] - + + + class CUDAGraphRunner: @@ -1795,4 +1793,4 @@ def _get_max_graph_batch_size(max_num_seqs: int) -> int: if padded_size in _BATCH_SIZES_TO_CAPTURE: return padded_size assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1] - return _BATCH_SIZES_TO_CAPTURE[-1] \ No newline at end of file + return _BATCH_SIZES_TO_CAPTURE[-1] From 4e8043c34c77e5155091c1cb95daefdd91238647 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Sun, 15 Sep 2024 22:23:20 +0000 Subject: [PATCH 203/303] update benchmark --- benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh index 734679660c233..1da5669dd1cd0 100644 --- a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh @@ -147,7 +147,7 @@ main() { rm -rf results mkdir results - default_output_len=10 + default_output_len=6 export VLLM_LOGGING_LEVEL=DEBUG export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') From b51f8913b3e595a272590c22fb393f89f33e8d4a Mon Sep 17 00:00:00 2001 From: ApostaC Date: Sun, 15 Sep 2024 22:23:48 +0000 Subject: [PATCH 204/303] remove sonnet 4x --- it can be automatically generated via benchmarking script --- benchmarks/sonnet_4x.txt | 2070 -------------------------------------- 1 file changed, 2070 deletions(-) delete mode 100644 benchmarks/sonnet_4x.txt diff --git a/benchmarks/sonnet_4x.txt b/benchmarks/sonnet_4x.txt deleted file mode 100644 index 02f39a9fb14fb..0000000000000 --- a/benchmarks/sonnet_4x.txt +++ /dev/null @@ -1,2070 +0,0 @@ - -FROM fairest creatures we desire increase, -That thereby beauty's rose might never die, -But as the riper should by time decease, -His tender heir might bear his memory: -But thou, contracted to thine own bright eyes, -Feed'st thy light'st flame with self-substantial fuel, -Making a famine where abundance lies, -Thyself thy foe, to thy sweet self too cruel. -Thou that art now the world's fresh ornament -And only herald to the gaudy spring, -Within thine own bud buriest thy content -And, tender churl, makest waste in niggarding. -Pity the world, or else this glutton be, -To eat the world's due, by the grave and thee. -When forty winters shall beseige thy brow, -And dig deep trenches in thy beauty's field, -Thy youth's proud livery, so gazed on now, -Will be a tatter'd weed, of small worth held: -Then being ask'd where all thy beauty lies, -Where all the treasure of thy lusty days, -To say, within thine own deep-sunken eyes, -Were an all-eating shame and thriftless praise. -How much more praise deserved thy beauty's use, -If thou couldst answer 'This fair child of mine -Shall sum my count and make my old excuse,' -Proving his beauty by succession thine! -This were to be new made when thou art old, -And see thy blood warm when thou feel'st it cold. -Look in thy glass, and tell the face thou viewest -Now is the time that face should form another; -Whose fresh repair if now thou not renewest, -Thou dost beguile the world, unbless some mother. -For where is she so fair whose unear'd womb -Disdains the tillage of thy husbandry? -Or who is he so fond will be the tomb -Of his self-love, to stop posterity? -Thou art thy mother's glass, and she in thee -Calls back the lovely April of her prime: -So thou through windows of thine age shall see -Despite of wrinkles this thy golden time. -But if thou live, remember'd not to be, -Die single, and thine image dies with thee. -Unthrifty loveliness, why dost thou spend -Upon thyself thy beauty's legacy? -Nature's bequest gives nothing but doth lend, -And being frank she lends to those are free. -Then, beauteous niggard, why dost thou abuse -The bounteous largess given thee to give? -Profitless usurer, why dost thou use -So great a sum of sums, yet canst not live? -For having traffic with thyself alone, -Thou of thyself thy sweet self dost deceive. -Then how, when nature calls thee to be gone, -What acceptable audit canst thou leave? -Thy unused beauty must be tomb'd with thee, -Which, used, lives th' executor to be. -Those hours, that with gentle work did frame -The lovely gaze where every eye doth dwell, -Will play the tyrants to the very same -And that unfair which fairly doth excel: -For never-resting time leads summer on -To hideous winter and confounds him there; -Sap cheque'd with frost and lusty leaves quite gone, -Beauty o'ersnow'd and bareness every where: -Then, were not summer's distillation left, -A liquid prisoner pent in walls of glass, -Beauty's effect with beauty were bereft, -Nor it nor no remembrance what it was: -But flowers distill'd though they with winter meet, -Leese but their show; their substance still lives sweet. -Then let not winter's ragged hand deface -In thee thy summer, ere thou be distill'd: -Make sweet some vial; treasure thou some place -With beauty's treasure, ere it be self-kill'd. -That use is not forbidden usury, -Which happies those that pay the willing loan; -That's for thyself to breed another thee, -Or ten times happier, be it ten for one; -Ten times thyself were happier than thou art, -If ten of thine ten times refigured thee: -Then what could death do, if thou shouldst depart, -Leaving thee living in posterity? -Be not self-will'd, for thou art much too fair -To be death's conquest and make worms thine heir. -Lo! in the orient when the gracious light -Lifts up his burning head, each under eye -Doth homage to his new-appearing sight, -Serving with looks his sacred majesty; -And having climb'd the steep-up heavenly hill, -Resembling strong youth in his middle age, -yet mortal looks adore his beauty still, -Attending on his golden pilgrimage; -But when from highmost pitch, with weary car, -Like feeble age, he reeleth from the day, -The eyes, 'fore duteous, now converted are -From his low tract and look another way: -So thou, thyself out-going in thy noon, -Unlook'd on diest, unless thou get a son. -Music to hear, why hear'st thou music sadly? -Sweets with sweets war not, joy delights in joy. -Why lovest thou that which thou receivest not gladly, -Or else receivest with pleasure thine annoy? -If the true concord of well-tuned sounds, -By unions married, do offend thine ear, -They do but sweetly chide thee, who confounds -In singleness the parts that thou shouldst bear. -Mark how one string, sweet husband to another, -Strikes each in each by mutual ordering, -Resembling sire and child and happy mother -Who all in one, one pleasing note do sing: -Whose speechless song, being many, seeming one, -Sings this to thee: 'thou single wilt prove none.' -Is it for fear to wet a widow's eye -That thou consumest thyself in single life? -Ah! if thou issueless shalt hap to die. -The world will wail thee, like a makeless wife; -The world will be thy widow and still weep -That thou no form of thee hast left behind, -When every private widow well may keep -By children's eyes her husband's shape in mind. -Look, what an unthrift in the world doth spend -Shifts but his place, for still the world enjoys it; -But beauty's waste hath in the world an end, -And kept unused, the user so destroys it. -No love toward others in that bosom sits -That on himself such murderous shame commits. -For shame! deny that thou bear'st love to any, -Who for thyself art so unprovident. -Grant, if thou wilt, thou art beloved of many, -But that thou none lovest is most evident; -For thou art so possess'd with murderous hate -That 'gainst thyself thou stick'st not to conspire. -Seeking that beauteous roof to ruinate -Which to repair should be thy chief desire. -O, change thy thought, that I may change my mind! -Shall hate be fairer lodged than gentle love? -Be, as thy presence is, gracious and kind, -Or to thyself at least kind-hearted prove: -Make thee another self, for love of me, -That beauty still may live in thine or thee. -As fast as thou shalt wane, so fast thou growest -In one of thine, from that which thou departest; -And that fresh blood which youngly thou bestowest -Thou mayst call thine when thou from youth convertest. -Herein lives wisdom, beauty and increase: -Without this, folly, age and cold decay: -If all were minded so, the times should cease -And threescore year would make the world away. -Let those whom Nature hath not made for store, -Harsh featureless and rude, barrenly perish: -Look, whom she best endow'd she gave the more; -Which bounteous gift thou shouldst in bounty cherish: -She carved thee for her seal, and meant thereby -Thou shouldst print more, not let that copy die. -When I do count the clock that tells the time, -And see the brave day sunk in hideous night; -When I behold the violet past prime, -And sable curls all silver'd o'er with white; -When lofty trees I see barren of leaves -Which erst from heat did canopy the herd, -And summer's green all girded up in sheaves -Borne on the bier with white and bristly beard, -Then of thy beauty do I question make, -That thou among the wastes of time must go, -Since sweets and beauties do themselves forsake -And die as fast as they see others grow; -And nothing 'gainst Time's scythe can make defence -Save breed, to brave him when he takes thee hence. -O, that you were yourself! but, love, you are -No longer yours than you yourself here live: -Against this coming end you should prepare, -And your sweet semblance to some other give. -So should that beauty which you hold in lease -Find no determination: then you were -Yourself again after yourself's decease, -When your sweet issue your sweet form should bear. -Who lets so fair a house fall to decay, -Which husbandry in honour might uphold -Against the stormy gusts of winter's day -And barren rage of death's eternal cold? -O, none but unthrifts! Dear my love, you know -You had a father: let your son say so. -Not from the stars do I my judgment pluck; -And yet methinks I have astronomy, -But not to tell of good or evil luck, -Of plagues, of dearths, or seasons' quality; -Nor can I fortune to brief minutes tell, -Pointing to each his thunder, rain and wind, -Or say with princes if it shall go well, -By oft predict that I in heaven find: -But from thine eyes my knowledge I derive, -And, constant stars, in them I read such art -As truth and beauty shall together thrive, -If from thyself to store thou wouldst convert; -Or else of thee this I prognosticate: -Thy end is truth's and beauty's doom and date. -When I consider every thing that grows -Holds in perfection but a little moment, -That this huge stage presenteth nought but shows -Whereon the stars in secret influence comment; -When I perceive that men as plants increase, -Cheered and cheque'd even by the self-same sky, -Vaunt in their youthful sap, at height decrease, -And wear their brave state out of memory; -Then the conceit of this inconstant stay -Sets you most rich in youth before my sight, -Where wasteful Time debateth with Decay, -To change your day of youth to sullied night; -And all in war with Time for love of you, -As he takes from you, I engraft you new. -But wherefore do not you a mightier way -Make war upon this bloody tyrant, Time? -And fortify yourself in your decay -With means more blessed than my barren rhyme? -Now stand you on the top of happy hours, -And many maiden gardens yet unset -With virtuous wish would bear your living flowers, -Much liker than your painted counterfeit: -So should the lines of life that life repair, -Which this, Time's pencil, or my pupil pen, -Neither in inward worth nor outward fair, -Can make you live yourself in eyes of men. -To give away yourself keeps yourself still, -And you must live, drawn by your own sweet skill. -Who will believe my verse in time to come, -If it were fill'd with your most high deserts? -Though yet, heaven knows, it is but as a tomb -Which hides your life and shows not half your parts. -If I could write the beauty of your eyes -And in fresh numbers number all your graces, -The age to come would say 'This poet lies: -Such heavenly touches ne'er touch'd earthly faces.' -So should my papers yellow'd with their age -Be scorn'd like old men of less truth than tongue, -And your true rights be term'd a poet's rage -And stretched metre of an antique song: -But were some child of yours alive that time, -You should live twice; in it and in my rhyme. -Shall I compare thee to a summer's day? -Thou art more lovely and more temperate: -Rough winds do shake the darling buds of May, -And summer's lease hath all too short a date: -Sometime too hot the eye of heaven shines, -And often is his gold complexion dimm'd; -And every fair from fair sometime declines, -By chance or nature's changing course untrimm'd; -But thy eternal summer shall not fade -Nor lose possession of that fair thou owest; -Nor shall Death brag thou wander'st in his shade, -When in eternal lines to time thou growest: -So long as men can breathe or eyes can see, -So long lives this and this gives life to thee. -Devouring Time, blunt thou the lion's paws, -And make the earth devour her own sweet brood; -Pluck the keen teeth from the fierce tiger's jaws, -And burn the long-lived phoenix in her blood; -Make glad and sorry seasons as thou fleets, -And do whate'er thou wilt, swift-footed Time, -To the wide world and all her fading sweets; -But I forbid thee one most heinous crime: -O, carve not with thy hours my love's fair brow, -Nor draw no lines there with thine antique pen; -Him in thy course untainted do allow -For beauty's pattern to succeeding men. -Yet, do thy worst, old Time: despite thy wrong, -My love shall in my verse ever live young. -A woman's face with Nature's own hand painted -Hast thou, the master-mistress of my passion; -A woman's gentle heart, but not acquainted -With shifting change, as is false women's fashion; -An eye more bright than theirs, less false in rolling, -Gilding the object whereupon it gazeth; -A man in hue, all 'hues' in his controlling, -Much steals men's eyes and women's souls amazeth. -And for a woman wert thou first created; -Till Nature, as she wrought thee, fell a-doting, -And by addition me of thee defeated, -By adding one thing to my purpose nothing. -But since she prick'd thee out for women's pleasure, -Mine be thy love and thy love's use their treasure. -So is it not with me as with that Muse -Stirr'd by a painted beauty to his verse, -Who heaven itself for ornament doth use -And every fair with his fair doth rehearse -Making a couplement of proud compare, -With sun and moon, with earth and sea's rich gems, -With April's first-born flowers, and all things rare -That heaven's air in this huge rondure hems. -O' let me, true in love, but truly write, -And then believe me, my love is as fair -As any mother's child, though not so bright -As those gold candles fix'd in heaven's air: -Let them say more than like of hearsay well; -I will not praise that purpose not to sell. -My glass shall not persuade me I am old, -So long as youth and thou are of one date; -But when in thee time's furrows I behold, -Then look I death my days should expiate. -For all that beauty that doth cover thee -Is but the seemly raiment of my heart, -Which in thy breast doth live, as thine in me: -How can I then be elder than thou art? -O, therefore, love, be of thyself so wary -As I, not for myself, but for thee will; -Bearing thy heart, which I will keep so chary -As tender nurse her babe from faring ill. -Presume not on thy heart when mine is slain; -Thou gavest me thine, not to give back again. -As an unperfect actor on the stage -Who with his fear is put besides his part, -Or some fierce thing replete with too much rage, -Whose strength's abundance weakens his own heart. -So I, for fear of trust, forget to say -The perfect ceremony of love's rite, -And in mine own love's strength seem to decay, -O'ercharged with burden of mine own love's might. -O, let my books be then the eloquence -And dumb presagers of my speaking breast, -Who plead for love and look for recompense -More than that tongue that more hath more express'd. -O, learn to read what silent love hath writ: -To hear with eyes belongs to love's fine wit. -Mine eye hath play'd the painter and hath stell'd -Thy beauty's form in table of my heart; -My body is the frame wherein 'tis held, -And perspective it is the painter's art. -For through the painter must you see his skill, -To find where your true image pictured lies; -Which in my bosom's shop is hanging still, -That hath his windows glazed with thine eyes. -Now see what good turns eyes for eyes have done: -Mine eyes have drawn thy shape, and thine for me -Are windows to my breast, where-through the sun -Delights to peep, to gaze therein on thee; -Yet eyes this cunning want to grace their art; -They draw but what they see, know not the heart. -Let those who are in favour with their stars -Of public honour and proud titles boast, -Whilst I, whom fortune of such triumph bars, -Unlook'd for joy in that I honour most. -Great princes' favourites their fair leaves spread -But as the marigold at the sun's eye, -And in themselves their pride lies buried, -For at a frown they in their glory die. -The painful warrior famoused for fight, -After a thousand victories once foil'd, -Is from the book of honour razed quite, -And all the rest forgot for which he toil'd: -Then happy I, that love and am beloved -Where I may not remove nor be removed. -Lord of my love, to whom in vassalage -Thy merit hath my duty strongly knit, -To thee I send this written embassage, -To witness duty, not to show my wit: -Duty so great, which wit so poor as mine -May make seem bare, in wanting words to show it, -But that I hope some good conceit of thine -In thy soul's thought, all naked, will bestow it; -Till whatsoever star that guides my moving -Points on me graciously with fair aspect -And puts apparel on my tatter'd loving, -To show me worthy of thy sweet respect: -Then may I dare to boast how I do love thee; -Till then not show my head where thou mayst prove me. -Weary with toil, I haste me to my bed, -The dear repose for limbs with travel tired; -But then begins a journey in my head, -To work my mind, when body's work's expired: -For then my thoughts, from far where I abide, -Intend a zealous pilgrimage to thee, -And keep my drooping eyelids open wide, -Looking on darkness which the blind do see -Save that my soul's imaginary sight -Presents thy shadow to my sightless view, -Which, like a jewel hung in ghastly night, -Makes black night beauteous and her old face new. -Lo! thus, by day my limbs, by night my mind, -For thee and for myself no quiet find. -How can I then return in happy plight, -That am debarr'd the benefit of rest? -When day's oppression is not eased by night, -But day by night, and night by day, oppress'd? -And each, though enemies to either's reign, -Do in consent shake hands to torture me; -The one by toil, the other to complain -How far I toil, still farther off from thee. -I tell the day, to please them thou art bright -And dost him grace when clouds do blot the heaven: -So flatter I the swart-complexion'd night, -When sparkling stars twire not thou gild'st the even. -But day doth daily draw my sorrows longer -And night doth nightly make grief's strength seem stronger. -When, in disgrace with fortune and men's eyes, -I all alone beweep my outcast state -And trouble deal heaven with my bootless cries -And look upon myself and curse my fate, -Wishing me like to one more rich in hope, -Featured like him, like him with friends possess'd, -Desiring this man's art and that man's scope, -With what I most enjoy contented least; -Yet in these thoughts myself almost despising, -Haply I think on thee, and then my state, -Like to the lark at break of day arising -From sullen earth, sings hymns at heaven's gate; -For thy sweet love remember'd such wealth brings -That then I scorn to change my state with kings. -When to the sessions of sweet silent thought -I summon up remembrance of things past, -I sigh the lack of many a thing I sought, -And with old woes new wail my dear time's waste: -Then can I drown an eye, unused to flow, -For precious friends hid in death's dateless night, -And weep afresh love's long since cancell'd woe, -And moan the expense of many a vanish'd sight: -Then can I grieve at grievances foregone, -And heavily from woe to woe tell o'er -The sad account of fore-bemoaned moan, -Which I new pay as if not paid before. -But if the while I think on thee, dear friend, -All losses are restored and sorrows end. -Thy bosom is endeared with all hearts, -Which I by lacking have supposed dead, -And there reigns love and all love's loving parts, -And all those friends which I thought buried. -How many a holy and obsequious tear -Hath dear religious love stol'n from mine eye -As interest of the dead, which now appear -But things removed that hidden in thee lie! -Thou art the grave where buried love doth live, -Hung with the trophies of my lovers gone, -Who all their parts of me to thee did give; -That due of many now is thine alone: -Their images I loved I view in thee, -And thou, all they, hast all the all of me. -If thou survive my well-contented day, -When that churl Death my bones with dust shall cover, -And shalt by fortune once more re-survey -These poor rude lines of thy deceased lover, -Compare them with the bettering of the time, -And though they be outstripp'd by every pen, -Reserve them for my love, not for their rhyme, -Exceeded by the height of happier men. -O, then vouchsafe me but this loving thought: -'Had my friend's Muse grown with this growing age, -A dearer birth than this his love had brought, -To march in ranks of better equipage: -But since he died and poets better prove, -Theirs for their style I'll read, his for his love.' -Full many a glorious morning have I seen -Flatter the mountain-tops with sovereign eye, -Kissing with golden face the meadows green, -Gilding pale streams with heavenly alchemy; -Anon permit the basest clouds to ride -With ugly rack on his celestial face, -And from the forlorn world his visage hide, -Stealing unseen to west with this disgrace: -Even so my sun one early morn did shine -With all triumphant splendor on my brow; -But out, alack! he was but one hour mine; -The region cloud hath mask'd him from me now. -Yet him for this my love no whit disdaineth; -Suns of the world may stain when heaven's sun staineth. -Why didst thou promise such a beauteous day, -And make me travel forth without my cloak, -To let base clouds o'ertake me in my way, -Hiding thy bravery in their rotten smoke? -'Tis not enough that through the cloud thou break, -To dry the rain on my storm-beaten face, -For no man well of such a salve can speak -That heals the wound and cures not the disgrace: -Nor can thy shame give physic to my grief; -Though thou repent, yet I have still the loss: -The offender's sorrow lends but weak relief -To him that bears the strong offence's cross. -Ah! but those tears are pearl which thy love sheds, -And they are rich and ransom all ill deeds. -No more be grieved at that which thou hast done: -Roses have thorns, and silver fountains mud; -Clouds and eclipses stain both moon and sun, -And loathsome canker lives in sweetest bud. -All men make faults, and even I in this, -Authorizing thy trespass with compare, -Myself corrupting, salving thy amiss, -Excusing thy sins more than thy sins are; -For to thy sensual fault I bring in sense-- -Thy adverse party is thy advocate-- -And 'gainst myself a lawful plea commence: -Such civil war is in my love and hate -That I an accessary needs must be -To that sweet thief which sourly robs from me. -Let me confess that we two must be twain, -Although our undivided loves are one: -So shall those blots that do with me remain -Without thy help by me be borne alone. -In our two loves there is but one respect, -Though in our lives a separable spite, -Which though it alter not love's sole effect, -Yet doth it steal sweet hours from love's delight. -I may not evermore acknowledge thee, -Lest my bewailed guilt should do thee shame, -Nor thou with public kindness honour me, -Unless thou take that honour from thy name: -But do not so; I love thee in such sort -As, thou being mine, mine is thy good report. -As a decrepit father takes delight -To see his active child do deeds of youth, -So I, made lame by fortune's dearest spite, -Take all my comfort of thy worth and truth. -For whether beauty, birth, or wealth, or wit, -Or any of these all, or all, or more, -Entitled in thy parts do crowned sit, -I make my love engrafted to this store: -So then I am not lame, poor, nor despised, -Whilst that this shadow doth such substance give -That I in thy abundance am sufficed -And by a part of all thy glory live. -Look, what is best, that best I wish in thee: -This wish I have; then ten times happy me!FROM fairest creatures we desire increase, -That thereby beauty's rose might never die, -But as the riper should by time decease, -His tender heir might bear his memory: -But thou, contracted to thine own bright eyes, -Feed'st thy light'st flame with self-substantial fuel, -Making a famine where abundance lies, -Thyself thy foe, to thy sweet self too cruel. -Thou that art now the world's fresh ornament -And only herald to the gaudy spring, -Within thine own bud buriest thy content -And, tender churl, makest waste in niggarding. -Pity the world, or else this glutton be, -To eat the world's due, by the grave and thee. -When forty winters shall beseige thy brow, -And dig deep trenches in thy beauty's field, -Thy youth's proud livery, so gazed on now, -Will be a tatter'd weed, of small worth held: -Then being ask'd where all thy beauty lies, -Where all the treasure of thy lusty days, -To say, within thine own deep-sunken eyes, -Were an all-eating shame and thriftless praise. -How much more praise deserved thy beauty's use, -If thou couldst answer 'This fair child of mine -Shall sum my count and make my old excuse,' -Proving his beauty by succession thine! -This were to be new made when thou art old, -And see thy blood warm when thou feel'st it cold. -Look in thy glass, and tell the face thou viewest -Now is the time that face should form another; -Whose fresh repair if now thou not renewest, -Thou dost beguile the world, unbless some mother. -For where is she so fair whose unear'd womb -Disdains the tillage of thy husbandry? -Or who is he so fond will be the tomb -Of his self-love, to stop posterity? -Thou art thy mother's glass, and she in thee -Calls back the lovely April of her prime: -So thou through windows of thine age shall see -Despite of wrinkles this thy golden time. -But if thou live, remember'd not to be, -Die single, and thine image dies with thee. -Unthrifty loveliness, why dost thou spend -Upon thyself thy beauty's legacy? -Nature's bequest gives nothing but doth lend, -And being frank she lends to those are free. -Then, beauteous niggard, why dost thou abuse -The bounteous largess given thee to give? -Profitless usurer, why dost thou use -So great a sum of sums, yet canst not live? -For having traffic with thyself alone, -Thou of thyself thy sweet self dost deceive. -Then how, when nature calls thee to be gone, -What acceptable audit canst thou leave? -Thy unused beauty must be tomb'd with thee, -Which, used, lives th' executor to be. -Those hours, that with gentle work did frame -The lovely gaze where every eye doth dwell, -Will play the tyrants to the very same -And that unfair which fairly doth excel: -For never-resting time leads summer on -To hideous winter and confounds him there; -Sap cheque'd with frost and lusty leaves quite gone, -Beauty o'ersnow'd and bareness every where: -Then, were not summer's distillation left, -A liquid prisoner pent in walls of glass, -Beauty's effect with beauty were bereft, -Nor it nor no remembrance what it was: -But flowers distill'd though they with winter meet, -Leese but their show; their substance still lives sweet. -Then let not winter's ragged hand deface -In thee thy summer, ere thou be distill'd: -Make sweet some vial; treasure thou some place -With beauty's treasure, ere it be self-kill'd. -That use is not forbidden usury, -Which happies those that pay the willing loan; -That's for thyself to breed another thee, -Or ten times happier, be it ten for one; -Ten times thyself were happier than thou art, -If ten of thine ten times refigured thee: -Then what could death do, if thou shouldst depart, -Leaving thee living in posterity? -Be not self-will'd, for thou art much too fair -To be death's conquest and make worms thine heir. -Lo! in the orient when the gracious light -Lifts up his burning head, each under eye -Doth homage to his new-appearing sight, -Serving with looks his sacred majesty; -And having climb'd the steep-up heavenly hill, -Resembling strong youth in his middle age, -yet mortal looks adore his beauty still, -Attending on his golden pilgrimage; -But when from highmost pitch, with weary car, -Like feeble age, he reeleth from the day, -The eyes, 'fore duteous, now converted are -From his low tract and look another way: -So thou, thyself out-going in thy noon, -Unlook'd on diest, unless thou get a son. -Music to hear, why hear'st thou music sadly? -Sweets with sweets war not, joy delights in joy. -Why lovest thou that which thou receivest not gladly, -Or else receivest with pleasure thine annoy? -If the true concord of well-tuned sounds, -By unions married, do offend thine ear, -They do but sweetly chide thee, who confounds -In singleness the parts that thou shouldst bear. -Mark how one string, sweet husband to another, -Strikes each in each by mutual ordering, -Resembling sire and child and happy mother -Who all in one, one pleasing note do sing: -Whose speechless song, being many, seeming one, -Sings this to thee: 'thou single wilt prove none.' -Is it for fear to wet a widow's eye -That thou consumest thyself in single life? -Ah! if thou issueless shalt hap to die. -The world will wail thee, like a makeless wife; -The world will be thy widow and still weep -That thou no form of thee hast left behind, -When every private widow well may keep -By children's eyes her husband's shape in mind. -Look, what an unthrift in the world doth spend -Shifts but his place, for still the world enjoys it; -But beauty's waste hath in the world an end, -And kept unused, the user so destroys it. -No love toward others in that bosom sits -That on himself such murderous shame commits. -For shame! deny that thou bear'st love to any, -Who for thyself art so unprovident. -Grant, if thou wilt, thou art beloved of many, -But that thou none lovest is most evident; -For thou art so possess'd with murderous hate -That 'gainst thyself thou stick'st not to conspire. -Seeking that beauteous roof to ruinate -Which to repair should be thy chief desire. -O, change thy thought, that I may change my mind! -Shall hate be fairer lodged than gentle love? -Be, as thy presence is, gracious and kind, -Or to thyself at least kind-hearted prove: -Make thee another self, for love of me, -That beauty still may live in thine or thee. -As fast as thou shalt wane, so fast thou growest -In one of thine, from that which thou departest; -And that fresh blood which youngly thou bestowest -Thou mayst call thine when thou from youth convertest. -Herein lives wisdom, beauty and increase: -Without this, folly, age and cold decay: -If all were minded so, the times should cease -And threescore year would make the world away. -Let those whom Nature hath not made for store, -Harsh featureless and rude, barrenly perish: -Look, whom she best endow'd she gave the more; -Which bounteous gift thou shouldst in bounty cherish: -She carved thee for her seal, and meant thereby -Thou shouldst print more, not let that copy die. -When I do count the clock that tells the time, -And see the brave day sunk in hideous night; -When I behold the violet past prime, -And sable curls all silver'd o'er with white; -When lofty trees I see barren of leaves -Which erst from heat did canopy the herd, -And summer's green all girded up in sheaves -Borne on the bier with white and bristly beard, -Then of thy beauty do I question make, -That thou among the wastes of time must go, -Since sweets and beauties do themselves forsake -And die as fast as they see others grow; -And nothing 'gainst Time's scythe can make defence -Save breed, to brave him when he takes thee hence. -O, that you were yourself! but, love, you are -No longer yours than you yourself here live: -Against this coming end you should prepare, -And your sweet semblance to some other give. -So should that beauty which you hold in lease -Find no determination: then you were -Yourself again after yourself's decease, -When your sweet issue your sweet form should bear. -Who lets so fair a house fall to decay, -Which husbandry in honour might uphold -Against the stormy gusts of winter's day -And barren rage of death's eternal cold? -O, none but unthrifts! Dear my love, you know -You had a father: let your son say so. -Not from the stars do I my judgment pluck; -And yet methinks I have astronomy, -But not to tell of good or evil luck, -Of plagues, of dearths, or seasons' quality; -Nor can I fortune to brief minutes tell, -Pointing to each his thunder, rain and wind, -Or say with princes if it shall go well, -By oft predict that I in heaven find: -But from thine eyes my knowledge I derive, -And, constant stars, in them I read such art -As truth and beauty shall together thrive, -If from thyself to store thou wouldst convert; -Or else of thee this I prognosticate: -Thy end is truth's and beauty's doom and date. -When I consider every thing that grows -Holds in perfection but a little moment, -That this huge stage presenteth nought but shows -Whereon the stars in secret influence comment; -When I perceive that men as plants increase, -Cheered and cheque'd even by the self-same sky, -Vaunt in their youthful sap, at height decrease, -And wear their brave state out of memory; -Then the conceit of this inconstant stay -Sets you most rich in youth before my sight, -Where wasteful Time debateth with Decay, -To change your day of youth to sullied night; -And all in war with Time for love of you, -As he takes from you, I engraft you new. -But wherefore do not you a mightier way -Make war upon this bloody tyrant, Time? -And fortify yourself in your decay -With means more blessed than my barren rhyme? -Now stand you on the top of happy hours, -And many maiden gardens yet unset -With virtuous wish would bear your living flowers, -Much liker than your painted counterfeit: -So should the lines of life that life repair, -Which this, Time's pencil, or my pupil pen, -Neither in inward worth nor outward fair, -Can make you live yourself in eyes of men. -To give away yourself keeps yourself still, -And you must live, drawn by your own sweet skill. -Who will believe my verse in time to come, -If it were fill'd with your most high deserts? -Though yet, heaven knows, it is but as a tomb -Which hides your life and shows not half your parts. -If I could write the beauty of your eyes -And in fresh numbers number all your graces, -The age to come would say 'This poet lies: -Such heavenly touches ne'er touch'd earthly faces.' -So should my papers yellow'd with their age -Be scorn'd like old men of less truth than tongue, -And your true rights be term'd a poet's rage -And stretched metre of an antique song: -But were some child of yours alive that time, -You should live twice; in it and in my rhyme. -Shall I compare thee to a summer's day? -Thou art more lovely and more temperate: -Rough winds do shake the darling buds of May, -And summer's lease hath all too short a date: -Sometime too hot the eye of heaven shines, -And often is his gold complexion dimm'd; -And every fair from fair sometime declines, -By chance or nature's changing course untrimm'd; -But thy eternal summer shall not fade -Nor lose possession of that fair thou owest; -Nor shall Death brag thou wander'st in his shade, -When in eternal lines to time thou growest: -So long as men can breathe or eyes can see, -So long lives this and this gives life to thee. -Devouring Time, blunt thou the lion's paws, -And make the earth devour her own sweet brood; -Pluck the keen teeth from the fierce tiger's jaws, -And burn the long-lived phoenix in her blood; -Make glad and sorry seasons as thou fleets, -And do whate'er thou wilt, swift-footed Time, -To the wide world and all her fading sweets; -But I forbid thee one most heinous crime: -O, carve not with thy hours my love's fair brow, -Nor draw no lines there with thine antique pen; -Him in thy course untainted do allow -For beauty's pattern to succeeding men. -Yet, do thy worst, old Time: despite thy wrong, -My love shall in my verse ever live young. -A woman's face with Nature's own hand painted -Hast thou, the master-mistress of my passion; -A woman's gentle heart, but not acquainted -With shifting change, as is false women's fashion; -An eye more bright than theirs, less false in rolling, -Gilding the object whereupon it gazeth; -A man in hue, all 'hues' in his controlling, -Much steals men's eyes and women's souls amazeth. -And for a woman wert thou first created; -Till Nature, as she wrought thee, fell a-doting, -And by addition me of thee defeated, -By adding one thing to my purpose nothing. -But since she prick'd thee out for women's pleasure, -Mine be thy love and thy love's use their treasure. -So is it not with me as with that Muse -Stirr'd by a painted beauty to his verse, -Who heaven itself for ornament doth use -And every fair with his fair doth rehearse -Making a couplement of proud compare, -With sun and moon, with earth and sea's rich gems, -With April's first-born flowers, and all things rare -That heaven's air in this huge rondure hems. -O' let me, true in love, but truly write, -And then believe me, my love is as fair -As any mother's child, though not so bright -As those gold candles fix'd in heaven's air: -Let them say more than like of hearsay well; -I will not praise that purpose not to sell. -My glass shall not persuade me I am old, -So long as youth and thou are of one date; -But when in thee time's furrows I behold, -Then look I death my days should expiate. -For all that beauty that doth cover thee -Is but the seemly raiment of my heart, -Which in thy breast doth live, as thine in me: -How can I then be elder than thou art? -O, therefore, love, be of thyself so wary -As I, not for myself, but for thee will; -Bearing thy heart, which I will keep so chary -As tender nurse her babe from faring ill. -Presume not on thy heart when mine is slain; -Thou gavest me thine, not to give back again. -As an unperfect actor on the stage -Who with his fear is put besides his part, -Or some fierce thing replete with too much rage, -Whose strength's abundance weakens his own heart. -So I, for fear of trust, forget to say -The perfect ceremony of love's rite, -And in mine own love's strength seem to decay, -O'ercharged with burden of mine own love's might. -O, let my books be then the eloquence -And dumb presagers of my speaking breast, -Who plead for love and look for recompense -More than that tongue that more hath more express'd. -O, learn to read what silent love hath writ: -To hear with eyes belongs to love's fine wit. -Mine eye hath play'd the painter and hath stell'd -Thy beauty's form in table of my heart; -My body is the frame wherein 'tis held, -And perspective it is the painter's art. -For through the painter must you see his skill, -To find where your true image pictured lies; -Which in my bosom's shop is hanging still, -That hath his windows glazed with thine eyes. -Now see what good turns eyes for eyes have done: -Mine eyes have drawn thy shape, and thine for me -Are windows to my breast, where-through the sun -Delights to peep, to gaze therein on thee; -Yet eyes this cunning want to grace their art; -They draw but what they see, know not the heart. -Let those who are in favour with their stars -Of public honour and proud titles boast, -Whilst I, whom fortune of such triumph bars, -Unlook'd for joy in that I honour most. -Great princes' favourites their fair leaves spread -But as the marigold at the sun's eye, -And in themselves their pride lies buried, -For at a frown they in their glory die. -The painful warrior famoused for fight, -After a thousand victories once foil'd, -Is from the book of honour razed quite, -And all the rest forgot for which he toil'd: -Then happy I, that love and am beloved -Where I may not remove nor be removed. -Lord of my love, to whom in vassalage -Thy merit hath my duty strongly knit, -To thee I send this written embassage, -To witness duty, not to show my wit: -Duty so great, which wit so poor as mine -May make seem bare, in wanting words to show it, -But that I hope some good conceit of thine -In thy soul's thought, all naked, will bestow it; -Till whatsoever star that guides my moving -Points on me graciously with fair aspect -And puts apparel on my tatter'd loving, -To show me worthy of thy sweet respect: -Then may I dare to boast how I do love thee; -Till then not show my head where thou mayst prove me. -Weary with toil, I haste me to my bed, -The dear repose for limbs with travel tired; -But then begins a journey in my head, -To work my mind, when body's work's expired: -For then my thoughts, from far where I abide, -Intend a zealous pilgrimage to thee, -And keep my drooping eyelids open wide, -Looking on darkness which the blind do see -Save that my soul's imaginary sight -Presents thy shadow to my sightless view, -Which, like a jewel hung in ghastly night, -Makes black night beauteous and her old face new. -Lo! thus, by day my limbs, by night my mind, -For thee and for myself no quiet find. -How can I then return in happy plight, -That am debarr'd the benefit of rest? -When day's oppression is not eased by night, -But day by night, and night by day, oppress'd? -And each, though enemies to either's reign, -Do in consent shake hands to torture me; -The one by toil, the other to complain -How far I toil, still farther off from thee. -I tell the day, to please them thou art bright -And dost him grace when clouds do blot the heaven: -So flatter I the swart-complexion'd night, -When sparkling stars twire not thou gild'st the even. -But day doth daily draw my sorrows longer -And night doth nightly make grief's strength seem stronger. -When, in disgrace with fortune and men's eyes, -I all alone beweep my outcast state -And trouble deal heaven with my bootless cries -And look upon myself and curse my fate, -Wishing me like to one more rich in hope, -Featured like him, like him with friends possess'd, -Desiring this man's art and that man's scope, -With what I most enjoy contented least; -Yet in these thoughts myself almost despising, -Haply I think on thee, and then my state, -Like to the lark at break of day arising -From sullen earth, sings hymns at heaven's gate; -For thy sweet love remember'd such wealth brings -That then I scorn to change my state with kings. -When to the sessions of sweet silent thought -I summon up remembrance of things past, -I sigh the lack of many a thing I sought, -And with old woes new wail my dear time's waste: -Then can I drown an eye, unused to flow, -For precious friends hid in death's dateless night, -And weep afresh love's long since cancell'd woe, -And moan the expense of many a vanish'd sight: -Then can I grieve at grievances foregone, -And heavily from woe to woe tell o'er -The sad account of fore-bemoaned moan, -Which I new pay as if not paid before. -But if the while I think on thee, dear friend, -All losses are restored and sorrows end. -Thy bosom is endeared with all hearts, -Which I by lacking have supposed dead, -And there reigns love and all love's loving parts, -And all those friends which I thought buried. -How many a holy and obsequious tear -Hath dear religious love stol'n from mine eye -As interest of the dead, which now appear -But things removed that hidden in thee lie! -Thou art the grave where buried love doth live, -Hung with the trophies of my lovers gone, -Who all their parts of me to thee did give; -That due of many now is thine alone: -Their images I loved I view in thee, -And thou, all they, hast all the all of me. -If thou survive my well-contented day, -When that churl Death my bones with dust shall cover, -And shalt by fortune once more re-survey -These poor rude lines of thy deceased lover, -Compare them with the bettering of the time, -And though they be outstripp'd by every pen, -Reserve them for my love, not for their rhyme, -Exceeded by the height of happier men. -O, then vouchsafe me but this loving thought: -'Had my friend's Muse grown with this growing age, -A dearer birth than this his love had brought, -To march in ranks of better equipage: -But since he died and poets better prove, -Theirs for their style I'll read, his for his love.' -Full many a glorious morning have I seen -Flatter the mountain-tops with sovereign eye, -Kissing with golden face the meadows green, -Gilding pale streams with heavenly alchemy; -Anon permit the basest clouds to ride -With ugly rack on his celestial face, -And from the forlorn world his visage hide, -Stealing unseen to west with this disgrace: -Even so my sun one early morn did shine -With all triumphant splendor on my brow; -But out, alack! he was but one hour mine; -The region cloud hath mask'd him from me now. -Yet him for this my love no whit disdaineth; -Suns of the world may stain when heaven's sun staineth. -Why didst thou promise such a beauteous day, -And make me travel forth without my cloak, -To let base clouds o'ertake me in my way, -Hiding thy bravery in their rotten smoke? -'Tis not enough that through the cloud thou break, -To dry the rain on my storm-beaten face, -For no man well of such a salve can speak -That heals the wound and cures not the disgrace: -Nor can thy shame give physic to my grief; -Though thou repent, yet I have still the loss: -The offender's sorrow lends but weak relief -To him that bears the strong offence's cross. -Ah! but those tears are pearl which thy love sheds, -And they are rich and ransom all ill deeds. -No more be grieved at that which thou hast done: -Roses have thorns, and silver fountains mud; -Clouds and eclipses stain both moon and sun, -And loathsome canker lives in sweetest bud. -All men make faults, and even I in this, -Authorizing thy trespass with compare, -Myself corrupting, salving thy amiss, -Excusing thy sins more than thy sins are; -For to thy sensual fault I bring in sense-- -Thy adverse party is thy advocate-- -And 'gainst myself a lawful plea commence: -Such civil war is in my love and hate -That I an accessary needs must be -To that sweet thief which sourly robs from me. -Let me confess that we two must be twain, -Although our undivided loves are one: -So shall those blots that do with me remain -Without thy help by me be borne alone. -In our two loves there is but one respect, -Though in our lives a separable spite, -Which though it alter not love's sole effect, -Yet doth it steal sweet hours from love's delight. -I may not evermore acknowledge thee, -Lest my bewailed guilt should do thee shame, -Nor thou with public kindness honour me, -Unless thou take that honour from thy name: -But do not so; I love thee in such sort -As, thou being mine, mine is thy good report. -As a decrepit father takes delight -To see his active child do deeds of youth, -So I, made lame by fortune's dearest spite, -Take all my comfort of thy worth and truth. -For whether beauty, birth, or wealth, or wit, -Or any of these all, or all, or more, -Entitled in thy parts do crowned sit, -I make my love engrafted to this store: -So then I am not lame, poor, nor despised, -Whilst that this shadow doth such substance give -That I in thy abundance am sufficed -And by a part of all thy glory live. -Look, what is best, that best I wish in thee: -This wish I have; then ten times happy me!FROM fairest creatures we desire increase, -That thereby beauty's rose might never die, -But as the riper should by time decease, -His tender heir might bear his memory: -But thou, contracted to thine own bright eyes, -Feed'st thy light'st flame with self-substantial fuel, -Making a famine where abundance lies, -Thyself thy foe, to thy sweet self too cruel. -Thou that art now the world's fresh ornament -And only herald to the gaudy spring, -Within thine own bud buriest thy content -And, tender churl, makest waste in niggarding. -Pity the world, or else this glutton be, -To eat the world's due, by the grave and thee. -When forty winters shall beseige thy brow, -And dig deep trenches in thy beauty's field, -Thy youth's proud livery, so gazed on now, -Will be a tatter'd weed, of small worth held: -Then being ask'd where all thy beauty lies, -Where all the treasure of thy lusty days, -To say, within thine own deep-sunken eyes, -Were an all-eating shame and thriftless praise. -How much more praise deserved thy beauty's use, -If thou couldst answer 'This fair child of mine -Shall sum my count and make my old excuse,' -Proving his beauty by succession thine! -This were to be new made when thou art old, -And see thy blood warm when thou feel'st it cold. -Look in thy glass, and tell the face thou viewest -Now is the time that face should form another; -Whose fresh repair if now thou not renewest, -Thou dost beguile the world, unbless some mother. -For where is she so fair whose unear'd womb -Disdains the tillage of thy husbandry? -Or who is he so fond will be the tomb -Of his self-love, to stop posterity? -Thou art thy mother's glass, and she in thee -Calls back the lovely April of her prime: -So thou through windows of thine age shall see -Despite of wrinkles this thy golden time. -But if thou live, remember'd not to be, -Die single, and thine image dies with thee. -Unthrifty loveliness, why dost thou spend -Upon thyself thy beauty's legacy? -Nature's bequest gives nothing but doth lend, -And being frank she lends to those are free. -Then, beauteous niggard, why dost thou abuse -The bounteous largess given thee to give? -Profitless usurer, why dost thou use -So great a sum of sums, yet canst not live? -For having traffic with thyself alone, -Thou of thyself thy sweet self dost deceive. -Then how, when nature calls thee to be gone, -What acceptable audit canst thou leave? -Thy unused beauty must be tomb'd with thee, -Which, used, lives th' executor to be. -Those hours, that with gentle work did frame -The lovely gaze where every eye doth dwell, -Will play the tyrants to the very same -And that unfair which fairly doth excel: -For never-resting time leads summer on -To hideous winter and confounds him there; -Sap cheque'd with frost and lusty leaves quite gone, -Beauty o'ersnow'd and bareness every where: -Then, were not summer's distillation left, -A liquid prisoner pent in walls of glass, -Beauty's effect with beauty were bereft, -Nor it nor no remembrance what it was: -But flowers distill'd though they with winter meet, -Leese but their show; their substance still lives sweet. -Then let not winter's ragged hand deface -In thee thy summer, ere thou be distill'd: -Make sweet some vial; treasure thou some place -With beauty's treasure, ere it be self-kill'd. -That use is not forbidden usury, -Which happies those that pay the willing loan; -That's for thyself to breed another thee, -Or ten times happier, be it ten for one; -Ten times thyself were happier than thou art, -If ten of thine ten times refigured thee: -Then what could death do, if thou shouldst depart, -Leaving thee living in posterity? -Be not self-will'd, for thou art much too fair -To be death's conquest and make worms thine heir. -Lo! in the orient when the gracious light -Lifts up his burning head, each under eye -Doth homage to his new-appearing sight, -Serving with looks his sacred majesty; -And having climb'd the steep-up heavenly hill, -Resembling strong youth in his middle age, -yet mortal looks adore his beauty still, -Attending on his golden pilgrimage; -But when from highmost pitch, with weary car, -Like feeble age, he reeleth from the day, -The eyes, 'fore duteous, now converted are -From his low tract and look another way: -So thou, thyself out-going in thy noon, -Unlook'd on diest, unless thou get a son. -Music to hear, why hear'st thou music sadly? -Sweets with sweets war not, joy delights in joy. -Why lovest thou that which thou receivest not gladly, -Or else receivest with pleasure thine annoy? -If the true concord of well-tuned sounds, -By unions married, do offend thine ear, -They do but sweetly chide thee, who confounds -In singleness the parts that thou shouldst bear. -Mark how one string, sweet husband to another, -Strikes each in each by mutual ordering, -Resembling sire and child and happy mother -Who all in one, one pleasing note do sing: -Whose speechless song, being many, seeming one, -Sings this to thee: 'thou single wilt prove none.' -Is it for fear to wet a widow's eye -That thou consumest thyself in single life? -Ah! if thou issueless shalt hap to die. -The world will wail thee, like a makeless wife; -The world will be thy widow and still weep -That thou no form of thee hast left behind, -When every private widow well may keep -By children's eyes her husband's shape in mind. -Look, what an unthrift in the world doth spend -Shifts but his place, for still the world enjoys it; -But beauty's waste hath in the world an end, -And kept unused, the user so destroys it. -No love toward others in that bosom sits -That on himself such murderous shame commits. -For shame! deny that thou bear'st love to any, -Who for thyself art so unprovident. -Grant, if thou wilt, thou art beloved of many, -But that thou none lovest is most evident; -For thou art so possess'd with murderous hate -That 'gainst thyself thou stick'st not to conspire. -Seeking that beauteous roof to ruinate -Which to repair should be thy chief desire. -O, change thy thought, that I may change my mind! -Shall hate be fairer lodged than gentle love? -Be, as thy presence is, gracious and kind, -Or to thyself at least kind-hearted prove: -Make thee another self, for love of me, -That beauty still may live in thine or thee. -As fast as thou shalt wane, so fast thou growest -In one of thine, from that which thou departest; -And that fresh blood which youngly thou bestowest -Thou mayst call thine when thou from youth convertest. -Herein lives wisdom, beauty and increase: -Without this, folly, age and cold decay: -If all were minded so, the times should cease -And threescore year would make the world away. -Let those whom Nature hath not made for store, -Harsh featureless and rude, barrenly perish: -Look, whom she best endow'd she gave the more; -Which bounteous gift thou shouldst in bounty cherish: -She carved thee for her seal, and meant thereby -Thou shouldst print more, not let that copy die. -When I do count the clock that tells the time, -And see the brave day sunk in hideous night; -When I behold the violet past prime, -And sable curls all silver'd o'er with white; -When lofty trees I see barren of leaves -Which erst from heat did canopy the herd, -And summer's green all girded up in sheaves -Borne on the bier with white and bristly beard, -Then of thy beauty do I question make, -That thou among the wastes of time must go, -Since sweets and beauties do themselves forsake -And die as fast as they see others grow; -And nothing 'gainst Time's scythe can make defence -Save breed, to brave him when he takes thee hence. -O, that you were yourself! but, love, you are -No longer yours than you yourself here live: -Against this coming end you should prepare, -And your sweet semblance to some other give. -So should that beauty which you hold in lease -Find no determination: then you were -Yourself again after yourself's decease, -When your sweet issue your sweet form should bear. -Who lets so fair a house fall to decay, -Which husbandry in honour might uphold -Against the stormy gusts of winter's day -And barren rage of death's eternal cold? -O, none but unthrifts! Dear my love, you know -You had a father: let your son say so. -Not from the stars do I my judgment pluck; -And yet methinks I have astronomy, -But not to tell of good or evil luck, -Of plagues, of dearths, or seasons' quality; -Nor can I fortune to brief minutes tell, -Pointing to each his thunder, rain and wind, -Or say with princes if it shall go well, -By oft predict that I in heaven find: -But from thine eyes my knowledge I derive, -And, constant stars, in them I read such art -As truth and beauty shall together thrive, -If from thyself to store thou wouldst convert; -Or else of thee this I prognosticate: -Thy end is truth's and beauty's doom and date. -When I consider every thing that grows -Holds in perfection but a little moment, -That this huge stage presenteth nought but shows -Whereon the stars in secret influence comment; -When I perceive that men as plants increase, -Cheered and cheque'd even by the self-same sky, -Vaunt in their youthful sap, at height decrease, -And wear their brave state out of memory; -Then the conceit of this inconstant stay -Sets you most rich in youth before my sight, -Where wasteful Time debateth with Decay, -To change your day of youth to sullied night; -And all in war with Time for love of you, -As he takes from you, I engraft you new. -But wherefore do not you a mightier way -Make war upon this bloody tyrant, Time? -And fortify yourself in your decay -With means more blessed than my barren rhyme? -Now stand you on the top of happy hours, -And many maiden gardens yet unset -With virtuous wish would bear your living flowers, -Much liker than your painted counterfeit: -So should the lines of life that life repair, -Which this, Time's pencil, or my pupil pen, -Neither in inward worth nor outward fair, -Can make you live yourself in eyes of men. -To give away yourself keeps yourself still, -And you must live, drawn by your own sweet skill. -Who will believe my verse in time to come, -If it were fill'd with your most high deserts? -Though yet, heaven knows, it is but as a tomb -Which hides your life and shows not half your parts. -If I could write the beauty of your eyes -And in fresh numbers number all your graces, -The age to come would say 'This poet lies: -Such heavenly touches ne'er touch'd earthly faces.' -So should my papers yellow'd with their age -Be scorn'd like old men of less truth than tongue, -And your true rights be term'd a poet's rage -And stretched metre of an antique song: -But were some child of yours alive that time, -You should live twice; in it and in my rhyme. -Shall I compare thee to a summer's day? -Thou art more lovely and more temperate: -Rough winds do shake the darling buds of May, -And summer's lease hath all too short a date: -Sometime too hot the eye of heaven shines, -And often is his gold complexion dimm'd; -And every fair from fair sometime declines, -By chance or nature's changing course untrimm'd; -But thy eternal summer shall not fade -Nor lose possession of that fair thou owest; -Nor shall Death brag thou wander'st in his shade, -When in eternal lines to time thou growest: -So long as men can breathe or eyes can see, -So long lives this and this gives life to thee. -Devouring Time, blunt thou the lion's paws, -And make the earth devour her own sweet brood; -Pluck the keen teeth from the fierce tiger's jaws, -And burn the long-lived phoenix in her blood; -Make glad and sorry seasons as thou fleets, -And do whate'er thou wilt, swift-footed Time, -To the wide world and all her fading sweets; -But I forbid thee one most heinous crime: -O, carve not with thy hours my love's fair brow, -Nor draw no lines there with thine antique pen; -Him in thy course untainted do allow -For beauty's pattern to succeeding men. -Yet, do thy worst, old Time: despite thy wrong, -My love shall in my verse ever live young. -A woman's face with Nature's own hand painted -Hast thou, the master-mistress of my passion; -A woman's gentle heart, but not acquainted -With shifting change, as is false women's fashion; -An eye more bright than theirs, less false in rolling, -Gilding the object whereupon it gazeth; -A man in hue, all 'hues' in his controlling, -Much steals men's eyes and women's souls amazeth. -And for a woman wert thou first created; -Till Nature, as she wrought thee, fell a-doting, -And by addition me of thee defeated, -By adding one thing to my purpose nothing. -But since she prick'd thee out for women's pleasure, -Mine be thy love and thy love's use their treasure. -So is it not with me as with that Muse -Stirr'd by a painted beauty to his verse, -Who heaven itself for ornament doth use -And every fair with his fair doth rehearse -Making a couplement of proud compare, -With sun and moon, with earth and sea's rich gems, -With April's first-born flowers, and all things rare -That heaven's air in this huge rondure hems. -O' let me, true in love, but truly write, -And then believe me, my love is as fair -As any mother's child, though not so bright -As those gold candles fix'd in heaven's air: -Let them say more than like of hearsay well; -I will not praise that purpose not to sell. -My glass shall not persuade me I am old, -So long as youth and thou are of one date; -But when in thee time's furrows I behold, -Then look I death my days should expiate. -For all that beauty that doth cover thee -Is but the seemly raiment of my heart, -Which in thy breast doth live, as thine in me: -How can I then be elder than thou art? -O, therefore, love, be of thyself so wary -As I, not for myself, but for thee will; -Bearing thy heart, which I will keep so chary -As tender nurse her babe from faring ill. -Presume not on thy heart when mine is slain; -Thou gavest me thine, not to give back again. -As an unperfect actor on the stage -Who with his fear is put besides his part, -Or some fierce thing replete with too much rage, -Whose strength's abundance weakens his own heart. -So I, for fear of trust, forget to say -The perfect ceremony of love's rite, -And in mine own love's strength seem to decay, -O'ercharged with burden of mine own love's might. -O, let my books be then the eloquence -And dumb presagers of my speaking breast, -Who plead for love and look for recompense -More than that tongue that more hath more express'd. -O, learn to read what silent love hath writ: -To hear with eyes belongs to love's fine wit. -Mine eye hath play'd the painter and hath stell'd -Thy beauty's form in table of my heart; -My body is the frame wherein 'tis held, -And perspective it is the painter's art. -For through the painter must you see his skill, -To find where your true image pictured lies; -Which in my bosom's shop is hanging still, -That hath his windows glazed with thine eyes. -Now see what good turns eyes for eyes have done: -Mine eyes have drawn thy shape, and thine for me -Are windows to my breast, where-through the sun -Delights to peep, to gaze therein on thee; -Yet eyes this cunning want to grace their art; -They draw but what they see, know not the heart. -Let those who are in favour with their stars -Of public honour and proud titles boast, -Whilst I, whom fortune of such triumph bars, -Unlook'd for joy in that I honour most. -Great princes' favourites their fair leaves spread -But as the marigold at the sun's eye, -And in themselves their pride lies buried, -For at a frown they in their glory die. -The painful warrior famoused for fight, -After a thousand victories once foil'd, -Is from the book of honour razed quite, -And all the rest forgot for which he toil'd: -Then happy I, that love and am beloved -Where I may not remove nor be removed. -Lord of my love, to whom in vassalage -Thy merit hath my duty strongly knit, -To thee I send this written embassage, -To witness duty, not to show my wit: -Duty so great, which wit so poor as mine -May make seem bare, in wanting words to show it, -But that I hope some good conceit of thine -In thy soul's thought, all naked, will bestow it; -Till whatsoever star that guides my moving -Points on me graciously with fair aspect -And puts apparel on my tatter'd loving, -To show me worthy of thy sweet respect: -Then may I dare to boast how I do love thee; -Till then not show my head where thou mayst prove me. -Weary with toil, I haste me to my bed, -The dear repose for limbs with travel tired; -But then begins a journey in my head, -To work my mind, when body's work's expired: -For then my thoughts, from far where I abide, -Intend a zealous pilgrimage to thee, -And keep my drooping eyelids open wide, -Looking on darkness which the blind do see -Save that my soul's imaginary sight -Presents thy shadow to my sightless view, -Which, like a jewel hung in ghastly night, -Makes black night beauteous and her old face new. -Lo! thus, by day my limbs, by night my mind, -For thee and for myself no quiet find. -How can I then return in happy plight, -That am debarr'd the benefit of rest? -When day's oppression is not eased by night, -But day by night, and night by day, oppress'd? -And each, though enemies to either's reign, -Do in consent shake hands to torture me; -The one by toil, the other to complain -How far I toil, still farther off from thee. -I tell the day, to please them thou art bright -And dost him grace when clouds do blot the heaven: -So flatter I the swart-complexion'd night, -When sparkling stars twire not thou gild'st the even. -But day doth daily draw my sorrows longer -And night doth nightly make grief's strength seem stronger. -When, in disgrace with fortune and men's eyes, -I all alone beweep my outcast state -And trouble deal heaven with my bootless cries -And look upon myself and curse my fate, -Wishing me like to one more rich in hope, -Featured like him, like him with friends possess'd, -Desiring this man's art and that man's scope, -With what I most enjoy contented least; -Yet in these thoughts myself almost despising, -Haply I think on thee, and then my state, -Like to the lark at break of day arising -From sullen earth, sings hymns at heaven's gate; -For thy sweet love remember'd such wealth brings -That then I scorn to change my state with kings. -When to the sessions of sweet silent thought -I summon up remembrance of things past, -I sigh the lack of many a thing I sought, -And with old woes new wail my dear time's waste: -Then can I drown an eye, unused to flow, -For precious friends hid in death's dateless night, -And weep afresh love's long since cancell'd woe, -And moan the expense of many a vanish'd sight: -Then can I grieve at grievances foregone, -And heavily from woe to woe tell o'er -The sad account of fore-bemoaned moan, -Which I new pay as if not paid before. -But if the while I think on thee, dear friend, -All losses are restored and sorrows end. -Thy bosom is endeared with all hearts, -Which I by lacking have supposed dead, -And there reigns love and all love's loving parts, -And all those friends which I thought buried. -How many a holy and obsequious tear -Hath dear religious love stol'n from mine eye -As interest of the dead, which now appear -But things removed that hidden in thee lie! -Thou art the grave where buried love doth live, -Hung with the trophies of my lovers gone, -Who all their parts of me to thee did give; -That due of many now is thine alone: -Their images I loved I view in thee, -And thou, all they, hast all the all of me. -If thou survive my well-contented day, -When that churl Death my bones with dust shall cover, -And shalt by fortune once more re-survey -These poor rude lines of thy deceased lover, -Compare them with the bettering of the time, -And though they be outstripp'd by every pen, -Reserve them for my love, not for their rhyme, -Exceeded by the height of happier men. -O, then vouchsafe me but this loving thought: -'Had my friend's Muse grown with this growing age, -A dearer birth than this his love had brought, -To march in ranks of better equipage: -But since he died and poets better prove, -Theirs for their style I'll read, his for his love.' -Full many a glorious morning have I seen -Flatter the mountain-tops with sovereign eye, -Kissing with golden face the meadows green, -Gilding pale streams with heavenly alchemy; -Anon permit the basest clouds to ride -With ugly rack on his celestial face, -And from the forlorn world his visage hide, -Stealing unseen to west with this disgrace: -Even so my sun one early morn did shine -With all triumphant splendor on my brow; -But out, alack! he was but one hour mine; -The region cloud hath mask'd him from me now. -Yet him for this my love no whit disdaineth; -Suns of the world may stain when heaven's sun staineth. -Why didst thou promise such a beauteous day, -And make me travel forth without my cloak, -To let base clouds o'ertake me in my way, -Hiding thy bravery in their rotten smoke? -'Tis not enough that through the cloud thou break, -To dry the rain on my storm-beaten face, -For no man well of such a salve can speak -That heals the wound and cures not the disgrace: -Nor can thy shame give physic to my grief; -Though thou repent, yet I have still the loss: -The offender's sorrow lends but weak relief -To him that bears the strong offence's cross. -Ah! but those tears are pearl which thy love sheds, -And they are rich and ransom all ill deeds. -No more be grieved at that which thou hast done: -Roses have thorns, and silver fountains mud; -Clouds and eclipses stain both moon and sun, -And loathsome canker lives in sweetest bud. -All men make faults, and even I in this, -Authorizing thy trespass with compare, -Myself corrupting, salving thy amiss, -Excusing thy sins more than thy sins are; -For to thy sensual fault I bring in sense-- -Thy adverse party is thy advocate-- -And 'gainst myself a lawful plea commence: -Such civil war is in my love and hate -That I an accessary needs must be -To that sweet thief which sourly robs from me. -Let me confess that we two must be twain, -Although our undivided loves are one: -So shall those blots that do with me remain -Without thy help by me be borne alone. -In our two loves there is but one respect, -Though in our lives a separable spite, -Which though it alter not love's sole effect, -Yet doth it steal sweet hours from love's delight. -I may not evermore acknowledge thee, -Lest my bewailed guilt should do thee shame, -Nor thou with public kindness honour me, -Unless thou take that honour from thy name: -But do not so; I love thee in such sort -As, thou being mine, mine is thy good report. -As a decrepit father takes delight -To see his active child do deeds of youth, -So I, made lame by fortune's dearest spite, -Take all my comfort of thy worth and truth. -For whether beauty, birth, or wealth, or wit, -Or any of these all, or all, or more, -Entitled in thy parts do crowned sit, -I make my love engrafted to this store: -So then I am not lame, poor, nor despised, -Whilst that this shadow doth such substance give -That I in thy abundance am sufficed -And by a part of all thy glory live. -Look, what is best, that best I wish in thee: -This wish I have; then ten times happy me!FROM fairest creatures we desire increase, -That thereby beauty's rose might never die, -But as the riper should by time decease, -His tender heir might bear his memory: -But thou, contracted to thine own bright eyes, -Feed'st thy light'st flame with self-substantial fuel, -Making a famine where abundance lies, -Thyself thy foe, to thy sweet self too cruel. -Thou that art now the world's fresh ornament -And only herald to the gaudy spring, -Within thine own bud buriest thy content -And, tender churl, makest waste in niggarding. -Pity the world, or else this glutton be, -To eat the world's due, by the grave and thee. -When forty winters shall beseige thy brow, -And dig deep trenches in thy beauty's field, -Thy youth's proud livery, so gazed on now, -Will be a tatter'd weed, of small worth held: -Then being ask'd where all thy beauty lies, -Where all the treasure of thy lusty days, -To say, within thine own deep-sunken eyes, -Were an all-eating shame and thriftless praise. -How much more praise deserved thy beauty's use, -If thou couldst answer 'This fair child of mine -Shall sum my count and make my old excuse,' -Proving his beauty by succession thine! -This were to be new made when thou art old, -And see thy blood warm when thou feel'st it cold. -Look in thy glass, and tell the face thou viewest -Now is the time that face should form another; -Whose fresh repair if now thou not renewest, -Thou dost beguile the world, unbless some mother. -For where is she so fair whose unear'd womb -Disdains the tillage of thy husbandry? -Or who is he so fond will be the tomb -Of his self-love, to stop posterity? -Thou art thy mother's glass, and she in thee -Calls back the lovely April of her prime: -So thou through windows of thine age shall see -Despite of wrinkles this thy golden time. -But if thou live, remember'd not to be, -Die single, and thine image dies with thee. -Unthrifty loveliness, why dost thou spend -Upon thyself thy beauty's legacy? -Nature's bequest gives nothing but doth lend, -And being frank she lends to those are free. -Then, beauteous niggard, why dost thou abuse -The bounteous largess given thee to give? -Profitless usurer, why dost thou use -So great a sum of sums, yet canst not live? -For having traffic with thyself alone, -Thou of thyself thy sweet self dost deceive. -Then how, when nature calls thee to be gone, -What acceptable audit canst thou leave? -Thy unused beauty must be tomb'd with thee, -Which, used, lives th' executor to be. -Those hours, that with gentle work did frame -The lovely gaze where every eye doth dwell, -Will play the tyrants to the very same -And that unfair which fairly doth excel: -For never-resting time leads summer on -To hideous winter and confounds him there; -Sap cheque'd with frost and lusty leaves quite gone, -Beauty o'ersnow'd and bareness every where: -Then, were not summer's distillation left, -A liquid prisoner pent in walls of glass, -Beauty's effect with beauty were bereft, -Nor it nor no remembrance what it was: -But flowers distill'd though they with winter meet, -Leese but their show; their substance still lives sweet. -Then let not winter's ragged hand deface -In thee thy summer, ere thou be distill'd: -Make sweet some vial; treasure thou some place -With beauty's treasure, ere it be self-kill'd. -That use is not forbidden usury, -Which happies those that pay the willing loan; -That's for thyself to breed another thee, -Or ten times happier, be it ten for one; -Ten times thyself were happier than thou art, -If ten of thine ten times refigured thee: -Then what could death do, if thou shouldst depart, -Leaving thee living in posterity? -Be not self-will'd, for thou art much too fair -To be death's conquest and make worms thine heir. -Lo! in the orient when the gracious light -Lifts up his burning head, each under eye -Doth homage to his new-appearing sight, -Serving with looks his sacred majesty; -And having climb'd the steep-up heavenly hill, -Resembling strong youth in his middle age, -yet mortal looks adore his beauty still, -Attending on his golden pilgrimage; -But when from highmost pitch, with weary car, -Like feeble age, he reeleth from the day, -The eyes, 'fore duteous, now converted are -From his low tract and look another way: -So thou, thyself out-going in thy noon, -Unlook'd on diest, unless thou get a son. -Music to hear, why hear'st thou music sadly? -Sweets with sweets war not, joy delights in joy. -Why lovest thou that which thou receivest not gladly, -Or else receivest with pleasure thine annoy? -If the true concord of well-tuned sounds, -By unions married, do offend thine ear, -They do but sweetly chide thee, who confounds -In singleness the parts that thou shouldst bear. -Mark how one string, sweet husband to another, -Strikes each in each by mutual ordering, -Resembling sire and child and happy mother -Who all in one, one pleasing note do sing: -Whose speechless song, being many, seeming one, -Sings this to thee: 'thou single wilt prove none.' -Is it for fear to wet a widow's eye -That thou consumest thyself in single life? -Ah! if thou issueless shalt hap to die. -The world will wail thee, like a makeless wife; -The world will be thy widow and still weep -That thou no form of thee hast left behind, -When every private widow well may keep -By children's eyes her husband's shape in mind. -Look, what an unthrift in the world doth spend -Shifts but his place, for still the world enjoys it; -But beauty's waste hath in the world an end, -And kept unused, the user so destroys it. -No love toward others in that bosom sits -That on himself such murderous shame commits. -For shame! deny that thou bear'st love to any, -Who for thyself art so unprovident. -Grant, if thou wilt, thou art beloved of many, -But that thou none lovest is most evident; -For thou art so possess'd with murderous hate -That 'gainst thyself thou stick'st not to conspire. -Seeking that beauteous roof to ruinate -Which to repair should be thy chief desire. -O, change thy thought, that I may change my mind! -Shall hate be fairer lodged than gentle love? -Be, as thy presence is, gracious and kind, -Or to thyself at least kind-hearted prove: -Make thee another self, for love of me, -That beauty still may live in thine or thee. -As fast as thou shalt wane, so fast thou growest -In one of thine, from that which thou departest; -And that fresh blood which youngly thou bestowest -Thou mayst call thine when thou from youth convertest. -Herein lives wisdom, beauty and increase: -Without this, folly, age and cold decay: -If all were minded so, the times should cease -And threescore year would make the world away. -Let those whom Nature hath not made for store, -Harsh featureless and rude, barrenly perish: -Look, whom she best endow'd she gave the more; -Which bounteous gift thou shouldst in bounty cherish: -She carved thee for her seal, and meant thereby -Thou shouldst print more, not let that copy die. -When I do count the clock that tells the time, -And see the brave day sunk in hideous night; -When I behold the violet past prime, -And sable curls all silver'd o'er with white; -When lofty trees I see barren of leaves -Which erst from heat did canopy the herd, -And summer's green all girded up in sheaves -Borne on the bier with white and bristly beard, -Then of thy beauty do I question make, -That thou among the wastes of time must go, -Since sweets and beauties do themselves forsake -And die as fast as they see others grow; -And nothing 'gainst Time's scythe can make defence -Save breed, to brave him when he takes thee hence. -O, that you were yourself! but, love, you are -No longer yours than you yourself here live: -Against this coming end you should prepare, -And your sweet semblance to some other give. -So should that beauty which you hold in lease -Find no determination: then you were -Yourself again after yourself's decease, -When your sweet issue your sweet form should bear. -Who lets so fair a house fall to decay, -Which husbandry in honour might uphold -Against the stormy gusts of winter's day -And barren rage of death's eternal cold? -O, none but unthrifts! Dear my love, you know -You had a father: let your son say so. -Not from the stars do I my judgment pluck; -And yet methinks I have astronomy, -But not to tell of good or evil luck, -Of plagues, of dearths, or seasons' quality; -Nor can I fortune to brief minutes tell, -Pointing to each his thunder, rain and wind, -Or say with princes if it shall go well, -By oft predict that I in heaven find: -But from thine eyes my knowledge I derive, -And, constant stars, in them I read such art -As truth and beauty shall together thrive, -If from thyself to store thou wouldst convert; -Or else of thee this I prognosticate: -Thy end is truth's and beauty's doom and date. -When I consider every thing that grows -Holds in perfection but a little moment, -That this huge stage presenteth nought but shows -Whereon the stars in secret influence comment; -When I perceive that men as plants increase, -Cheered and cheque'd even by the self-same sky, -Vaunt in their youthful sap, at height decrease, -And wear their brave state out of memory; -Then the conceit of this inconstant stay -Sets you most rich in youth before my sight, -Where wasteful Time debateth with Decay, -To change your day of youth to sullied night; -And all in war with Time for love of you, -As he takes from you, I engraft you new. -But wherefore do not you a mightier way -Make war upon this bloody tyrant, Time? -And fortify yourself in your decay -With means more blessed than my barren rhyme? -Now stand you on the top of happy hours, -And many maiden gardens yet unset -With virtuous wish would bear your living flowers, -Much liker than your painted counterfeit: -So should the lines of life that life repair, -Which this, Time's pencil, or my pupil pen, -Neither in inward worth nor outward fair, -Can make you live yourself in eyes of men. -To give away yourself keeps yourself still, -And you must live, drawn by your own sweet skill. -Who will believe my verse in time to come, -If it were fill'd with your most high deserts? -Though yet, heaven knows, it is but as a tomb -Which hides your life and shows not half your parts. -If I could write the beauty of your eyes -And in fresh numbers number all your graces, -The age to come would say 'This poet lies: -Such heavenly touches ne'er touch'd earthly faces.' -So should my papers yellow'd with their age -Be scorn'd like old men of less truth than tongue, -And your true rights be term'd a poet's rage -And stretched metre of an antique song: -But were some child of yours alive that time, -You should live twice; in it and in my rhyme. -Shall I compare thee to a summer's day? -Thou art more lovely and more temperate: -Rough winds do shake the darling buds of May, -And summer's lease hath all too short a date: -Sometime too hot the eye of heaven shines, -And often is his gold complexion dimm'd; -And every fair from fair sometime declines, -By chance or nature's changing course untrimm'd; -But thy eternal summer shall not fade -Nor lose possession of that fair thou owest; -Nor shall Death brag thou wander'st in his shade, -When in eternal lines to time thou growest: -So long as men can breathe or eyes can see, -So long lives this and this gives life to thee. -Devouring Time, blunt thou the lion's paws, -And make the earth devour her own sweet brood; -Pluck the keen teeth from the fierce tiger's jaws, -And burn the long-lived phoenix in her blood; -Make glad and sorry seasons as thou fleets, -And do whate'er thou wilt, swift-footed Time, -To the wide world and all her fading sweets; -But I forbid thee one most heinous crime: -O, carve not with thy hours my love's fair brow, -Nor draw no lines there with thine antique pen; -Him in thy course untainted do allow -For beauty's pattern to succeeding men. -Yet, do thy worst, old Time: despite thy wrong, -My love shall in my verse ever live young. -A woman's face with Nature's own hand painted -Hast thou, the master-mistress of my passion; -A woman's gentle heart, but not acquainted -With shifting change, as is false women's fashion; -An eye more bright than theirs, less false in rolling, -Gilding the object whereupon it gazeth; -A man in hue, all 'hues' in his controlling, -Much steals men's eyes and women's souls amazeth. -And for a woman wert thou first created; -Till Nature, as she wrought thee, fell a-doting, -And by addition me of thee defeated, -By adding one thing to my purpose nothing. -But since she prick'd thee out for women's pleasure, -Mine be thy love and thy love's use their treasure. -So is it not with me as with that Muse -Stirr'd by a painted beauty to his verse, -Who heaven itself for ornament doth use -And every fair with his fair doth rehearse -Making a couplement of proud compare, -With sun and moon, with earth and sea's rich gems, -With April's first-born flowers, and all things rare -That heaven's air in this huge rondure hems. -O' let me, true in love, but truly write, -And then believe me, my love is as fair -As any mother's child, though not so bright -As those gold candles fix'd in heaven's air: -Let them say more than like of hearsay well; -I will not praise that purpose not to sell. -My glass shall not persuade me I am old, -So long as youth and thou are of one date; -But when in thee time's furrows I behold, -Then look I death my days should expiate. -For all that beauty that doth cover thee -Is but the seemly raiment of my heart, -Which in thy breast doth live, as thine in me: -How can I then be elder than thou art? -O, therefore, love, be of thyself so wary -As I, not for myself, but for thee will; -Bearing thy heart, which I will keep so chary -As tender nurse her babe from faring ill. -Presume not on thy heart when mine is slain; -Thou gavest me thine, not to give back again. -As an unperfect actor on the stage -Who with his fear is put besides his part, -Or some fierce thing replete with too much rage, -Whose strength's abundance weakens his own heart. -So I, for fear of trust, forget to say -The perfect ceremony of love's rite, -And in mine own love's strength seem to decay, -O'ercharged with burden of mine own love's might. -O, let my books be then the eloquence -And dumb presagers of my speaking breast, -Who plead for love and look for recompense -More than that tongue that more hath more express'd. -O, learn to read what silent love hath writ: -To hear with eyes belongs to love's fine wit. -Mine eye hath play'd the painter and hath stell'd -Thy beauty's form in table of my heart; -My body is the frame wherein 'tis held, -And perspective it is the painter's art. -For through the painter must you see his skill, -To find where your true image pictured lies; -Which in my bosom's shop is hanging still, -That hath his windows glazed with thine eyes. -Now see what good turns eyes for eyes have done: -Mine eyes have drawn thy shape, and thine for me -Are windows to my breast, where-through the sun -Delights to peep, to gaze therein on thee; -Yet eyes this cunning want to grace their art; -They draw but what they see, know not the heart. -Let those who are in favour with their stars -Of public honour and proud titles boast, -Whilst I, whom fortune of such triumph bars, -Unlook'd for joy in that I honour most. -Great princes' favourites their fair leaves spread -But as the marigold at the sun's eye, -And in themselves their pride lies buried, -For at a frown they in their glory die. -The painful warrior famoused for fight, -After a thousand victories once foil'd, -Is from the book of honour razed quite, -And all the rest forgot for which he toil'd: -Then happy I, that love and am beloved -Where I may not remove nor be removed. -Lord of my love, to whom in vassalage -Thy merit hath my duty strongly knit, -To thee I send this written embassage, -To witness duty, not to show my wit: -Duty so great, which wit so poor as mine -May make seem bare, in wanting words to show it, -But that I hope some good conceit of thine -In thy soul's thought, all naked, will bestow it; -Till whatsoever star that guides my moving -Points on me graciously with fair aspect -And puts apparel on my tatter'd loving, -To show me worthy of thy sweet respect: -Then may I dare to boast how I do love thee; -Till then not show my head where thou mayst prove me. -Weary with toil, I haste me to my bed, -The dear repose for limbs with travel tired; -But then begins a journey in my head, -To work my mind, when body's work's expired: -For then my thoughts, from far where I abide, -Intend a zealous pilgrimage to thee, -And keep my drooping eyelids open wide, -Looking on darkness which the blind do see -Save that my soul's imaginary sight -Presents thy shadow to my sightless view, -Which, like a jewel hung in ghastly night, -Makes black night beauteous and her old face new. -Lo! thus, by day my limbs, by night my mind, -For thee and for myself no quiet find. -How can I then return in happy plight, -That am debarr'd the benefit of rest? -When day's oppression is not eased by night, -But day by night, and night by day, oppress'd? -And each, though enemies to either's reign, -Do in consent shake hands to torture me; -The one by toil, the other to complain -How far I toil, still farther off from thee. -I tell the day, to please them thou art bright -And dost him grace when clouds do blot the heaven: -So flatter I the swart-complexion'd night, -When sparkling stars twire not thou gild'st the even. -But day doth daily draw my sorrows longer -And night doth nightly make grief's strength seem stronger. -When, in disgrace with fortune and men's eyes, -I all alone beweep my outcast state -And trouble deal heaven with my bootless cries -And look upon myself and curse my fate, -Wishing me like to one more rich in hope, -Featured like him, like him with friends possess'd, -Desiring this man's art and that man's scope, -With what I most enjoy contented least; -Yet in these thoughts myself almost despising, -Haply I think on thee, and then my state, -Like to the lark at break of day arising -From sullen earth, sings hymns at heaven's gate; -For thy sweet love remember'd such wealth brings -That then I scorn to change my state with kings. -When to the sessions of sweet silent thought -I summon up remembrance of things past, -I sigh the lack of many a thing I sought, -And with old woes new wail my dear time's waste: -Then can I drown an eye, unused to flow, -For precious friends hid in death's dateless night, -And weep afresh love's long since cancell'd woe, -And moan the expense of many a vanish'd sight: -Then can I grieve at grievances foregone, -And heavily from woe to woe tell o'er -The sad account of fore-bemoaned moan, -Which I new pay as if not paid before. -But if the while I think on thee, dear friend, -All losses are restored and sorrows end. -Thy bosom is endeared with all hearts, -Which I by lacking have supposed dead, -And there reigns love and all love's loving parts, -And all those friends which I thought buried. -How many a holy and obsequious tear -Hath dear religious love stol'n from mine eye -As interest of the dead, which now appear -But things removed that hidden in thee lie! -Thou art the grave where buried love doth live, -Hung with the trophies of my lovers gone, -Who all their parts of me to thee did give; -That due of many now is thine alone: -Their images I loved I view in thee, -And thou, all they, hast all the all of me. -If thou survive my well-contented day, -When that churl Death my bones with dust shall cover, -And shalt by fortune once more re-survey -These poor rude lines of thy deceased lover, -Compare them with the bettering of the time, -And though they be outstripp'd by every pen, -Reserve them for my love, not for their rhyme, -Exceeded by the height of happier men. -O, then vouchsafe me but this loving thought: -'Had my friend's Muse grown with this growing age, -A dearer birth than this his love had brought, -To march in ranks of better equipage: -But since he died and poets better prove, -Theirs for their style I'll read, his for his love.' -Full many a glorious morning have I seen -Flatter the mountain-tops with sovereign eye, -Kissing with golden face the meadows green, -Gilding pale streams with heavenly alchemy; -Anon permit the basest clouds to ride -With ugly rack on his celestial face, -And from the forlorn world his visage hide, -Stealing unseen to west with this disgrace: -Even so my sun one early morn did shine -With all triumphant splendor on my brow; -But out, alack! he was but one hour mine; -The region cloud hath mask'd him from me now. -Yet him for this my love no whit disdaineth; -Suns of the world may stain when heaven's sun staineth. -Why didst thou promise such a beauteous day, -And make me travel forth without my cloak, -To let base clouds o'ertake me in my way, -Hiding thy bravery in their rotten smoke? -'Tis not enough that through the cloud thou break, -To dry the rain on my storm-beaten face, -For no man well of such a salve can speak -That heals the wound and cures not the disgrace: -Nor can thy shame give physic to my grief; -Though thou repent, yet I have still the loss: -The offender's sorrow lends but weak relief -To him that bears the strong offence's cross. -Ah! but those tears are pearl which thy love sheds, -And they are rich and ransom all ill deeds. -No more be grieved at that which thou hast done: -Roses have thorns, and silver fountains mud; -Clouds and eclipses stain both moon and sun, -And loathsome canker lives in sweetest bud. -All men make faults, and even I in this, -Authorizing thy trespass with compare, -Myself corrupting, salving thy amiss, -Excusing thy sins more than thy sins are; -For to thy sensual fault I bring in sense-- -Thy adverse party is thy advocate-- -And 'gainst myself a lawful plea commence: -Such civil war is in my love and hate -That I an accessary needs must be -To that sweet thief which sourly robs from me. -Let me confess that we two must be twain, -Although our undivided loves are one: -So shall those blots that do with me remain -Without thy help by me be borne alone. -In our two loves there is but one respect, -Though in our lives a separable spite, -Which though it alter not love's sole effect, -Yet doth it steal sweet hours from love's delight. -I may not evermore acknowledge thee, -Lest my bewailed guilt should do thee shame, -Nor thou with public kindness honour me, -Unless thou take that honour from thy name: -But do not so; I love thee in such sort -As, thou being mine, mine is thy good report. -As a decrepit father takes delight -To see his active child do deeds of youth, -So I, made lame by fortune's dearest spite, -Take all my comfort of thy worth and truth. -For whether beauty, birth, or wealth, or wit, -Or any of these all, or all, or more, -Entitled in thy parts do crowned sit, -I make my love engrafted to this store: -So then I am not lame, poor, nor despised, -Whilst that this shadow doth such substance give -That I in thy abundance am sufficed -And by a part of all thy glory live. -Look, what is best, that best I wish in thee: -This wish I have; then ten times happy me! \ No newline at end of file From 168452fb026bc3cb35e476d74fca9f46afcb4f4e Mon Sep 17 00:00:00 2001 From: ApostaC Date: Sun, 15 Sep 2024 22:26:11 +0000 Subject: [PATCH 205/303] revert change in flash attn and flash infer to clean up the diff --- vllm/attention/backends/flash_attn.py | 3 --- vllm/attention/backends/flashinfer.py | 1 - 2 files changed, 4 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 09456ca8d7b61..bf883987bd80b 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -15,9 +15,6 @@ is_block_tables_empty) from vllm.utils import async_tensor_h2d, make_tensor_with_pad -from vllm.distributed import get_disagg_group -import vllm.envs as envs - if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 852c5cd8dc180..4054d337316fe 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -823,4 +823,3 @@ def forward( k_scale=k_scale, v_scale=v_scale) return output.view(num_tokens, hidden_size) - \ No newline at end of file From 784d9058d5a8f483024d23887942583bac2ff238 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Sun, 15 Sep 2024 22:35:12 +0000 Subject: [PATCH 206/303] update the example --- .../disagg_prefill/disagg_prefill_example.sh | 35 +++++++++++++++---- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/examples/disagg_prefill/disagg_prefill_example.sh b/examples/disagg_prefill/disagg_prefill_example.sh index f57f5fd86d89c..56b6f44c7418a 100644 --- a/examples/disagg_prefill/disagg_prefill_example.sh +++ b/examples/disagg_prefill/disagg_prefill_example.sh @@ -16,7 +16,7 @@ wait_for_server() { } # prefilling instance -VLLM_RPC_PORT=5570 VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0 python3 \ +VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0 python3 \ -m vllm.entrypoints.openai.api_server \ --model meta-llama/Meta-Llama-3.1-8B-Instruct \ --port 8100 \ @@ -24,7 +24,7 @@ VLLM_RPC_PORT=5570 VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0 pytho --gpu-memory-utilization 0.8 & # decoding instance -VLLM_RPC_PORT=5580 VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=1 python3 \ +VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=1 python3 \ -m vllm.entrypoints.openai.api_server \ --model meta-llama/Meta-Llama-3.1-8B-Instruct \ --port 8200 \ @@ -36,18 +36,39 @@ wait_for_server 8100 wait_for_server 8200 # launch a proxy server that opens the service at port 8000 +# the workflow of this proxy: +# - send the request to prefill vLLM instance (port 8100), change max_tokens to 1 +# - after the prefill vLLM finishes prefill, send the request to decode vLLM instance python3 ../../benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py & sleep 1 -# serve an example request -curl http://localhost:8000/v1/completions \ +# serve two example requests +output1=$(curl -s http://localhost:8000/v1/completions \ -H "Content-Type: application/json" \ -d '{ "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", "prompt": "San Francisco is a", "max_tokens": 10, "temperature": 0 -}' +}') -# clean up -ps -e | grep pt_main_thread | awk '{print $1}' | xargs kill -9 \ No newline at end of file +output2=$(curl -s http://localhost:8000/v1/completions \ +-H "Content-Type: application/json" \ +-d '{ +"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", +"prompt": "Santa Clara is a", +"max_tokens": 10, +"temperature": 0 +}') + +# Print the outputs of the curl requests +echo "" +echo "Output of first request: $output1" +echo "Output of second request: $output2" + +echo "Successfully finished 2 test requests!" +echo "" + +# Cleanup commands, suppressing their output +ps -e | grep pt_main_thread | awk '{print $1}' | xargs kill -9 > /dev/null 2>&1 +pkill -f python3 > /dev/null 2>&1 From 17d2505f1b810c68ec8e8f9ff8cde425647013f9 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Sun, 15 Sep 2024 23:42:30 +0000 Subject: [PATCH 207/303] make format checker happy --- .../disagg_prefill_proxy_server.py | 40 +-- .../disagg_benchmarks/round_robin_proxy.py | 25 +- .../visualize_benchmark_results.py | 39 ++- .../kv_transfer/kv_lookup_buffer/base.py | 19 +- .../simple_kv_lookup_buffer.py | 122 ++++----- vllm/distributed/kv_transfer/kv_pipe/base.py | 16 +- .../kv_pipe/torch_distributed_pipe.py | 76 +++--- vllm/distributed/kv_transfer/vllm_adapter.py | 252 ++++++++++-------- vllm/distributed/parallel_state.py | 33 ++- vllm/executor/gpu_executor.py | 3 +- vllm/executor/multiproc_gpu_executor.py | 3 +- vllm/executor/ray_gpu_executor.py | 2 +- vllm/worker/model_runner.py | 46 ++-- vllm/worker/worker_base.py | 87 +++--- 14 files changed, 386 insertions(+), 377 deletions(-) diff --git a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py index 5750df7735ad1..4058b1c0a3b79 100644 --- a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py +++ b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py @@ -1,28 +1,31 @@ -from quart import Quart, request, Response, jsonify, make_response -import aiohttp -import sys -import traceback import os +import aiohttp +from quart import Quart, make_response, request + AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) app = Quart(__name__) + async def forward_request(url, data): - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" } - async with session.post(url=url, json=data, headers=headers) as response: + async with session.post(url=url, json=data, + headers=headers) as response: if response.status == 200: # if response.headers.get('Transfer-Encoding') == 'chunked': if True: - async for chunk_bytes in response.content.iter_chunked(1024): + async for chunk_bytes in response.content.iter_chunked( + 1024): yield chunk_bytes else: content = await response.read() yield content + @app.route('/v1/completions', methods=['POST']) async def handle_request(): try: @@ -31,25 +34,28 @@ async def handle_request(): prefill_request = original_request_data.copy() # change max_tokens = 1 to let it only do prefill prefill_request['max_tokens'] = 1 - + # finish prefill - async for _ in forward_request('http://localhost:8100/v1/completions', prefill_request): + async for _ in forward_request('http://localhost:8100/v1/completions', + prefill_request): continue - print(f"Prefill done. proceeding to decode.") - # return decode - generator = forward_request('http://localhost:8200/v1/completions', original_request_data) + generator = forward_request('http://localhost:8200/v1/completions', + original_request_data) response = await make_response(generator) response.timeout = None return response - + except Exception as e: - pass - # exc_info = sys.exc_info() - # print(e) - # print("".join(traceback.format_exception(*exc_info))) + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server") + print(e) + print("".join(traceback.format_exception(*exc_info))) + if __name__ == '__main__': app.run(port=8000) diff --git a/benchmarks/disagg_benchmarks/round_robin_proxy.py b/benchmarks/disagg_benchmarks/round_robin_proxy.py index 8751e24a08d33..6eb5f63980070 100644 --- a/benchmarks/disagg_benchmarks/round_robin_proxy.py +++ b/benchmarks/disagg_benchmarks/round_robin_proxy.py @@ -1,9 +1,12 @@ import asyncio +import itertools + import aiohttp from aiohttp import web -import itertools + class RoundRobinProxy: + def __init__(self, target_ports): self.target_ports = target_ports self.port_cycle = itertools.cycle(self.target_ports) @@ -16,16 +19,14 @@ async def handle_request(self, request): try: # Forward the request async with session.request( - method=request.method, - url=target_url, - headers=request.headers, - data=request.content, + method=request.method, + url=target_url, + headers=request.headers, + data=request.content, ) as response: # Start sending the response - resp = web.StreamResponse( - status=response.status, - headers=response.headers - ) + resp = web.StreamResponse(status=response.status, + headers=response.headers) await resp.prepare(request) # Stream the response content @@ -38,6 +39,7 @@ async def handle_request(self, request): except Exception as e: return web.Response(text=f"Error: {str(e)}", status=500) + async def main(): proxy = RoundRobinProxy([8100, 8200]) app = web.Application() @@ -49,9 +51,10 @@ async def main(): await site.start() print("Proxy server started on http://localhost:8000") - + # Keep the server running await asyncio.Event().wait() + if __name__ == '__main__': - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/benchmarks/disagg_benchmarks/visualize_benchmark_results.py b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py index 192f26a1e3cd2..6c5bf5c791dc9 100644 --- a/benchmarks/disagg_benchmarks/visualize_benchmark_results.py +++ b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py @@ -1,40 +1,42 @@ +import json import matplotlib.pyplot as plt -import yaml import pandas as pd -import json - - if __name__ == "__main__": data = [] for name in ['disagg_prefill', 'chunked_prefill']: - for qps in [2,4,6,8]: + for qps in [2, 4, 6, 8]: with open(f"results/{name}-qps-{qps}.json", "r") as f: x = json.load(f) x['name'] = name x['qps'] = qps data.append(x) - + df = pd.DataFrame.from_dict(data) dis_df = df[df['name'] == 'disagg_prefill'] chu_df = df[df['name'] == 'chunked_prefill'] - + plt.style.use('bmh') plt.rcParams['font.size'] = 20 - - for key in ['mean_ttft_ms', - 'median_ttft_ms', - 'p99_ttft_ms', - 'mean_itl_ms', - 'median_itl_ms', - 'p99_itl_ms']: - + for key in [ + 'mean_ttft_ms', 'median_ttft_ms', 'p99_ttft_ms', 'mean_itl_ms', + 'median_itl_ms', 'p99_itl_ms' + ]: + fig, ax = plt.subplots(figsize=(11, 7)) - plt.plot(dis_df['qps'], dis_df[key], label='disagg_prefill', marker='o', linewidth=4) - plt.plot(chu_df['qps'], chu_df[key], label='chunked_prefill', marker='o', linewidth=4) + plt.plot(dis_df['qps'], + dis_df[key], + label='disagg_prefill', + marker='o', + linewidth=4) + plt.plot(chu_df['qps'], + chu_df[key], + label='chunked_prefill', + marker='o', + linewidth=4) ax.legend() ax.set_xlabel('QPS') @@ -42,6 +44,3 @@ ax.set_ylim(bottom=0) fig.savefig(f'results/{key}.png') plt.close(fig) - - - \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py index 733bc82bf53f9..80802f87987ac 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py @@ -1,21 +1,22 @@ - from abc import ABC, abstractmethod -from typing import Optional +from typing import List, Optional + import torch class KVLookupBufferBase(ABC): - + @abstractmethod - def insert(self, - input_tokens: torch.Tensor, - kv: torch.Tensor, roi) -> None: + def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor) -> None: raise NotImplementedError - + @abstractmethod - def drop_select(self, input_tokens, roi) -> Optional[torch.Tensor]: + def drop_select(self, input_tokens: torch.Tensor, + roi: torch.Tensor) -> List[Optional[torch.Tensor]]: raise NotImplementedError - + @abstractmethod def close(self): """ diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py index 6172bf092fb03..9696032002fda 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py @@ -1,22 +1,21 @@ - -from vllm.distributed.kv_transfer.kv_lookup_buffer.base import \ - KVLookupBufferBase -from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase -from typing import Dict, Tuple, List, Optional, Union import threading -import torch -from collections import deque import time +from collections import deque +from typing import Deque, List, Optional, Union + +import torch +from vllm.distributed.kv_transfer.kv_lookup_buffer.base import ( + KVLookupBufferBase) +from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase from vllm.logger import init_logger logger = init_logger(__name__) + class SimpleKVLookupBuffer(KVLookupBufferBase): - - def __init__(self, - signal_pipe: KVPipeBase, - data_pipe: KVPipeBase, + + def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, buffer_size_thresh: int): """ signal_pipe: on CPU @@ -28,72 +27,66 @@ def __init__(self, data_pipe: on device (e.g. GPU) """ - - self.buffer = deque() - + + self.buffer: Deque[List[torch.Tensor]] = deque() + self.buffer_size = 0 self.buffer_size_threshold = buffer_size_thresh self.buffer_lock = threading.Lock() self.signal_pipe = signal_pipe self.data_pipe = data_pipe - self.request_handling_thread = None + self.request_handling_thread: Optional[threading.Thread] = None self.normal_signal = torch.tensor([0]) self.end_signal = None - - def _matches(self, - tokens_roi_sender: List[torch.Tensor], + def _matches(self, tokens_roi_sender: List[torch.Tensor], tokens_roi_recver: List[torch.Tensor]): # tokens_roi_sender: tokens and roi of the producer (in the buffer) # tokens_roi_recver: tokens and roi of the consumer (query) - + tokens_sender = tokens_roi_sender[0] tokens_recver = tokens_roi_recver[0] roi_sender = tokens_roi_sender[1] roi_recver = tokens_roi_recver[1] - + if tokens_recver is None: # consumer sends an empty request # semantics: DROP SELECT * LIMIT 1 # so any of the data in the buffer can be drop-selected return True - # Assuming that roi is a mask on tokens tokens_sender = tokens_sender[roi_sender] tokens_recver = tokens_recver[roi_recver] - - + # simple common prefix matching min_length = min(len(tokens_sender), len(tokens_recver)) - if torch.allclose(tokens_sender[:min_length], tokens_recver[:min_length]): + if torch.allclose(tokens_sender[:min_length], + tokens_recver[:min_length]): return min_length - + return 0 - - def _send_tensor_and_dec_size(self, tensor: Optional[torch.Tensor]) -> None: + def _send_tensor_and_dec_size(self, + tensor: Optional[torch.Tensor]) -> None: assert tensor is not None, "Use self.data_pipe.send(None) instead" self.buffer_size -= tensor.element_size() * tensor.numel() self.data_pipe.send_tensor(tensor) def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]): - + if data == [] or data is None: return 0 if isinstance(data, torch.Tensor): return data.element_size() * data.numel() + else: + raise AssertionError("Unknown data type %s" % type(data)) - assert False, "Unknown data type %s" % type(data) - - def _add_to_buffer(self, - input_tokens: torch.Tensor, - roi: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, + def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, hidden: torch.Tensor): if isinstance(input_tokens, torch.Tensor): @@ -107,21 +100,20 @@ def _add_to_buffer(self, if isinstance(hidden, torch.Tensor): hidden = hidden.clone() - buffer_item = [input_tokens, roi, key, value, hidden] - + with self.buffer_lock: for data in buffer_item: self.buffer_size += self._get_element_size(data) self.buffer.append(buffer_item) - + def _is_end_signal(self, signal): return signal is None - + def drop_select_handler(self): try: - + while True: signal = self.signal_pipe.recv_tensor() if self._is_end_signal(signal): @@ -132,28 +124,29 @@ def drop_select_handler(self): roi = self.data_pipe.recv_tensor() tokens_roi_recver = [input_tokens, roi] - + matched_length = 0 - + # perform input tokens and roi matching with self.buffer_lock: for _ in range(len(self.buffer)): - - temp_length = self._matches(self.buffer[0], tokens_roi_recver) + + temp_length = self._matches(self.buffer[0], + tokens_roi_recver) if temp_length > 0: matched_length = temp_length break # rotate the element we just accessed to the end self.buffer.rotate(-1) - + if matched_length > 0: # need to clone the tensor # in case the tensor is freed before sending finishes matched_item = self.buffer.popleft() for tensor in matched_item: self._send_tensor_and_dec_size(tensor) - + else: # no match, just send None for _ in range(5): @@ -164,60 +157,57 @@ def drop_select_handler(self): raise e logger.debug("Closing drop_select_handler") - - - def drop_select(self, - input_tokens: torch.Tensor, - roi: torch.Tensor): - + + def drop_select(self, input_tokens: torch.Tensor, + roi: torch.Tensor) -> List[Optional[torch.Tensor]]: + assert self.request_handling_thread is None, \ "drop_select should be called by the receiver" - if isinstance(input_tokens, torch.Tensor): input_tokens = input_tokens.clone() if isinstance(roi, torch.Tensor): roi = roi.clone() - + self.signal_pipe.send_tensor(self.normal_signal) self.data_pipe.send_tensor(input_tokens) self.data_pipe.send_tensor(roi) - + input_tokens = self.data_pipe.recv_tensor() roi = self.data_pipe.recv_tensor() key = self.data_pipe.recv_tensor() value = self.data_pipe.recv_tensor() hidden = self.data_pipe.recv_tensor() - + return [input_tokens, roi, key, value, hidden] - def full_handler(self): time.sleep(0.001) - - - def insert(self, input_tokens, roi, key, value, hidden) -> None: + + def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor) -> None: while self.buffer_size > self.buffer_size_threshold: # logger.debug("KV transfer buffer is full. Handling...") self.full_handler() - self._add_to_buffer(input_tokens, roi, key, value, hidden) - + # when calling the insert, the current process is a sender # need to launch the request handler and start listening to request. if self.request_handling_thread is None: self.request_handling_thread = threading.Thread( target=self.drop_select_handler) self.request_handling_thread.start() - - + def close(self): - if hasattr(self, "request_handling_thread") and self.request_handling_thread is not None: + if hasattr(self, "request_handling_thread" + ) and self.request_handling_thread is not None: self.request_handling_thread.join() else: - # TODO: have a explicit close signal and have a explicit way to check if it's requester + # TODO: have a explicit close signal and have a explicit way to + # check if it's requester self.signal_pipe.send_tensor(self.end_signal) diff --git a/vllm/distributed/kv_transfer/kv_pipe/base.py b/vllm/distributed/kv_transfer/kv_pipe/base.py index 7662a5893ceb2..0955b4e838896 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/base.py +++ b/vllm/distributed/kv_transfer/kv_pipe/base.py @@ -1,15 +1,17 @@ - from abc import ABC, abstractmethod +from typing import Optional + +import torch class KVPipeBase(ABC): - - @abstractmethod - def send_tensor(self, tensor): + + @abstractmethod + def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: raise NotImplementedError - - @abstractmethod - def recv_tensor(self): + + @abstractmethod + def recv_tensor(self) -> Optional[torch.Tensor]: raise NotImplementedError @abstractmethod diff --git a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py index 3a6a94bb0e752..911bce88a38f1 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py @@ -1,15 +1,16 @@ -from torch.distributed import Backend -import torch -from typing import List, Optional, Union import threading -from concurrent.futures import ThreadPoolExecutor import time +from concurrent.futures import ThreadPoolExecutor +from typing import List, Optional, Union +import torch +from torch.distributed import Backend + +from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase from vllm.logger import init_logger logger = init_logger(__name__) - # if the tensor is only one-element and only contains NONE_INT # this means that the sended object is None. NONE_INT = -150886311 @@ -42,17 +43,17 @@ class BrokenPipeException(Exception): + def __init__(self, message): self.message = message super().__init__(self.message) -class TorchDistributedPipe: +class TorchDistributedPipe(KVPipeBase): METADATA_LENGTH = 16 MAX_TENSOR_DIMENSIONS = 14 METADATA_DTYPE = torch.int64 - def __init__( self, group_ranks: List[List[int]], @@ -65,8 +66,7 @@ def __init__( for ranks in group_ranks: device_group = torch.distributed.new_group( - ranks, backend=torch_distributed_backend - ) + ranks, backend=torch_distributed_backend) if self.rank in ranks: self.ranks = ranks self.world_size = len(ranks) @@ -78,26 +78,24 @@ def __init__( self.device = self._select_device(torch_distributed_backend) - self.target_rank_for_send = self.ranks[ - (self.rank_in_group + 1) % self.world_size - ] - self.target_rank_for_recv = self.ranks[ - (self.rank_in_group - 1) % self.world_size - ] + self.target_rank_for_send = self.ranks[(self.rank_in_group + 1) % + self.world_size] + self.target_rank_for_recv = self.ranks[(self.rank_in_group - 1) % + self.world_size] # FIXME: why we need this? torch.set_default_device(self.device) - self.transport_thread = None + self.transport_thread: Optional[ThreadPoolExecutor] = None self.buffer_size = 0 self.buffer_size_lock = threading.Lock() self.none_tensor = torch.tensor([NONE_INT], device=self.device) # On-device tensors to be reused for recv - self.rcv_metadata_buffer = torch.zeros( - self.METADATA_LENGTH, dtype=self.METADATA_DTYPE, device=self.device - ) + self.rcv_metadata_buffer = torch.zeros(self.METADATA_LENGTH, + dtype=self.METADATA_DTYPE, + device=self.device) def _select_device(self, backend: Union[str, Backend]): if torch.cuda.is_available() and backend == Backend.NCCL: @@ -129,14 +127,12 @@ def _make_metadata(self, tensor: torch.Tensor) -> torch.Tensor: buffer[0] = DTYPE2INT[tensor.dtype] ndims = len(tensor.shape) buffer[1] = len(tensor.shape) - buffer[2 : 2 + ndims] = torch.tensor( - tensor.shape, dtype=self.METADATA_DTYPE - ) + buffer[2:2 + ndims] = torch.tensor(tensor.shape, + dtype=self.METADATA_DTYPE) return buffer.to(self.device) - def _prepare_recv_buffer( - self, d_metadata_buffer: torch.Tensor - ) -> torch.Tensor: + def _prepare_recv_buffer(self, + d_metadata_buffer: torch.Tensor) -> torch.Tensor: """ Create a buffer to receive the tensor based on the metadata. @@ -149,7 +145,7 @@ def _prepare_recv_buffer( h_buffer = d_metadata_buffer.cpu().numpy() dtype = INT2DTYPE[h_buffer[0]] ndims = h_buffer[1] - shape = tuple(h_buffer[2 : 2 + ndims]) + shape = tuple(h_buffer[2:2 + ndims]) return torch.empty(shape, dtype=dtype, device=self.device) def _send_metadata(self, d_metadata_buffer: torch.Tensor): @@ -174,7 +170,7 @@ def _recv_metadata(self) -> torch.Tensor: race conditions during sending/receiving. Therefore, the metadata buffer can be reused """ - task = torch.distributed.recv( + torch.distributed.recv( self.rcv_metadata_buffer, src=self.target_rank_for_recv, group=self.device_group, @@ -194,9 +190,9 @@ def _send_impl(self, tensor): metadata = self._make_metadata(tensor) self._send_metadata(metadata) - torch.distributed.send( - tensor, dst=self.target_rank_for_send, group=self.device_group - ) + torch.distributed.send(tensor, + dst=self.target_rank_for_send, + group=self.device_group) def _recv_impl(self) -> torch.Tensor: """ @@ -211,9 +207,9 @@ def _recv_impl(self) -> torch.Tensor: d_metadata = self._recv_metadata() buffer = self._prepare_recv_buffer(d_metadata) - torch.distributed.recv( - buffer, src=self.target_rank_for_recv, group=self.device_group - ) + torch.distributed.recv(buffer, + src=self.target_rank_for_recv, + group=self.device_group) return buffer @@ -227,13 +223,9 @@ def send_tensor_wrapper(self, tensor): self.buffer_size = self.buffer_size - tensor_size except Exception as e: logger.error("[rank%d]: Exception when trying to send %s, msg: %s", - torch.distributed.get_rank(), - str(tensor), - str(e)) + torch.distributed.get_rank(), str(tensor), str(e)) import traceback traceback.print_exc() - - def block_if_full(self): """ @@ -268,13 +260,11 @@ def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: with self.buffer_size_lock: self.buffer_size = self.buffer_size + tensor_size - self.transport_thread.submit( self.send_tensor_wrapper, tensor, ) - def recv_tensor(self) -> Optional[torch.Tensor]: """Receives a tensor from the src rank. Blocking.""" if self.transport_thread is None: @@ -300,8 +290,6 @@ def close(self): """ Close the pipe and release the resources. """ - if ( - hasattr(self, "transport_thread") - and self.transport_thread is not None - ): + if (hasattr(self, "transport_thread") + and self.transport_thread is not None): self.transport_thread.shutdown() diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py index 9a6b55cbbe660..03392ec13f10b 100644 --- a/vllm/distributed/kv_transfer/vllm_adapter.py +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -1,59 +1,59 @@ """vLLM distributed KV cache transfer API. These APIs are used in `vllm/worker/worker_base.py`. -Currently supporting TP. The TP between prefill and decode instance needs to be the same. +Currently supporting TP. The TP between prefill and decode instance needs to be +the same. Workflow (disaggregated prefill) - In prefill instance - After prefill, vLLM `insert` its KV caches into a lookup buffer. - - The prefill instance will also open up a thread that listens to `drop_select` request. + - The prefill instance will also open up a thread that listens to + `drop_select` request. - In decode instance - - vLLM first runs `drop_select` to send input tokens and a mask on input tokens (we call it roi, region of interest) to prefill instance + - vLLM first runs `drop_select` to send input tokens and a mask on input + tokens (we call it roi, region of interest) to prefill instance - The prefill instance then respond to `drop_select` request by - Finding a match in current lookup buffer. - Clone and send the matched item out - Delete the matched item in the lookup buffer to free up GPU memory. - The decode vLLM then store the KV cache into paged memory. """ -from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING -from collections import defaultdict, deque -from concurrent.futures import ThreadPoolExecutor -from threading import Lock +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + from copy import deepcopy -import time -import threading import torch -from torch.distributed import Backend, ProcessGroup +from torch.distributed import Backend import vllm.envs as envs -from vllm.logger import init_logger -import vllm.distributed.parallel_state as ps from vllm import _custom_ops as ops +from vllm.distributed.kv_transfer.kv_lookup_buffer.base import ( + KVLookupBufferBase) +from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_kv_lookup_buffer import ( + SimpleKVLookupBuffer) +from vllm.distributed.kv_transfer.kv_pipe.torch_distributed_pipe import ( + TorchDistributedPipe) +from vllm.logger import init_logger from vllm.sequence import IntermediateTensors -from vllm.distributed.kv_transfer.kv_pipe.torch_distributed_pipe import TorchDistributedPipe -from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_kv_lookup_buffer import SimpleKVLookupBuffer - -from copy import deepcopy -assert envs.VLLM_DISAGG_PREFILL_ROLE in [None, "prefill", "decode", "lmcache"], \ +assert envs.VLLM_DISAGG_PREFILL_ROLE in [None, "prefill", "decode", "lmcache"],\ "VLLM_DISAGG_PREFILL_ROLE can only be prefill, decode or lmcache." - # currently the connections are hard-coded. # we only handle 2 cases: # - prefill vLLM --> decode vLLM # - vLLM --> LMCache -IS_DISTRIBUTED_KV_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE in ["prefill", "decode"]) +IS_DISTRIBUTED_KV_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE + in ["prefill", "decode"]) IS_KV_PREFILL_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE == "prefill") IS_KV_DECODE_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE == "decode") IS_LMCACHE_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE == "lmcache") - logger = init_logger(__name__) -import logging - class KV_transfer_agent: """ @@ -70,11 +70,13 @@ def __init__( local_rank: int, torch_distributed_backend: Union[str, Backend], # FIXME(Kuntai): remove this hardcoding - lookup_buffer_size: int = 1e10 - ): - + lookup_buffer_size: int = int(1e10)): + self.lookup_buffer_size = lookup_buffer_size - + + self.send_buffer: Optional[KVLookupBufferBase] = None + self.recv_buffer: Optional[KVLookupBufferBase] = None + if IS_LMCACHE_INSTANCE: # when vLLM is connected with LMCache # it needs to both send and recv KV cache @@ -98,14 +100,12 @@ def __init__( local_rank, "gloo", ) - self.send_buffer = SimpleKVLookupBuffer( - self.send_signal_pipe, - self.send_pipe, - self.lookup_buffer_size) - self.recv_buffer = SimpleKVLookupBuffer( - self.recv_signal_pipe, - self.recv_pipe, - self.lookup_buffer_size) + self.send_buffer = SimpleKVLookupBuffer(self.send_signal_pipe, + self.send_pipe, + self.lookup_buffer_size) + self.recv_buffer = SimpleKVLookupBuffer(self.recv_signal_pipe, + self.recv_pipe, + self.lookup_buffer_size) else: # when performing disaggregated prefill, only 1 pipe is needed # at prefill instance this pipe is used for send KV cache @@ -120,24 +120,25 @@ def __init__( local_rank, "gloo", ) - buffer = SimpleKVLookupBuffer( - self.signal_pipe, - self.pipe, - self.lookup_buffer_size) + buffer = SimpleKVLookupBuffer(self.signal_pipe, self.pipe, + self.lookup_buffer_size) self.send_buffer = buffer self.recv_buffer = buffer - + def send_kv_caches_and_hidden_states( self, model_executable: torch.nn.Module, model_input: "ModelInputForGPUWithSamplingMetadata", kv_caches: List[torch.Tensor], - hidden_or_intermediate_states: Union[torch.Tensor, IntermediateTensors], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], ) -> None: input_tokens_tensor = model_input.input_tokens seq_lens = model_input.attn_metadata.seq_lens slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() + start_layer = model_executable.model.start_layer + end_layer = model_executable.model.end_layer # query_lens contains new KV caches that are added to vLLM. # so we will send them to decode instance @@ -146,13 +147,11 @@ def send_kv_caches_and_hidden_states( start_pos = sum(seq_lens[:idx]) end_pos = start_pos + slen current_tokens = input_tokens_tensor[start_pos:end_pos] - + keys, values = [], [] - - - for l in range(model_executable.model.start_layer, - model_executable.model.end_layer): - kv_cache = kv_caches[l - model_executable.model.start_layer] + + for layer_id in range(start_layer, end_layer): + kv_cache = kv_caches[layer_id - start_layer] _, _, num_heads, head_size = kv_cache[0].shape @@ -163,29 +162,31 @@ def send_kv_caches_and_hidden_states( keys.append(key_cache[current_slot_mapping].unsqueeze(0)) values.append(value_cache[current_slot_mapping].unsqueeze(0)) - + keys = torch.cat(keys, dim=0) values = torch.cat(values, dim=0) - self.send_buffer.insert( - current_tokens, - torch.ones_like(current_tokens, dtype=bool), - keys, - values, - hidden_or_intermediate_states[start_pos:end_pos] - ) - + if self.send_buffer is not None: + self.send_buffer.insert( + current_tokens, torch.ones_like(current_tokens, + dtype=bool), keys, values, + hidden_or_intermediate_states[start_pos:end_pos]) logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) + def destroy(self) -> None: + if self.send_buffer is not None: + self.send_buffer.close() + if self.recv_buffer is not None: + self.recv_buffer.close() def recv_kv_caches_and_hidden_states( - self, - model_executable: torch.nn.Module, + self, model_executable: torch.nn.Module, model_input: "ModelInputForGPUWithSamplingMetadata", kv_caches: List[torch.Tensor] - ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, "ModelInputForGPUWithSamplingMetadata"]: + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: - # When this flag is set to False, it means that + # When this flag is set to False, it means that bypass_model_exec = True # This is disagg decode instance, during prefill state @@ -199,7 +200,7 @@ def recv_kv_caches_and_hidden_states( input_tokens_list = [] num_computed_tokens_list = [] start_pos_list = [] - + # enumerate different requests # FIXME(Kuntai): This impl assumes that all requests are prefill. for idx, slen in enumerate(seq_lens): @@ -211,28 +212,34 @@ def recv_kv_caches_and_hidden_states( input_tokens_list.append(current_tokens) start_pos_list.append(start_pos) - + + if self.recv_buffer is None: + bypass_model_exec = False + break + ret = self.recv_buffer.drop_select( - current_tokens, - torch.ones_like(current_tokens, dtype=bool)) + current_tokens, torch.ones_like(current_tokens, dtype=bool)) if ret[0] is None: # didn't find any match. bypass_model_exec = False num_computed_tokens_list.append(0) continue - + # TODO(Jiayi): change the logic here (need roi) - _, roi, keys, values, hidden = ret - + roi: torch.Tensor = ret[1] + keys: torch.Tensor = ret[2] + values: torch.Tensor = ret[3] + hidden: torch.Tensor = ret[4] + # Jiayi: currently assume roi is a prefix - num_computed_tokens = len(roi) + num_computed_tokens = roi.shape[0] num_computed_tokens_list.append(num_computed_tokens) is_complete = (num_computed_tokens == num_tokens) end_pos = start_pos + num_computed_tokens - + # receive KV cache from disaggregated prefill instance for i in range(model_executable.model.start_layer, - model_executable.model.end_layer): + model_executable.model.end_layer): kv_cache = kv_caches[i - model_executable.model.start_layer] layer = model_executable.model.layers[i] @@ -251,12 +258,13 @@ def recv_kv_caches_and_hidden_states( hidden_or_intermediate_states_for_one_req.append(hidden) - # FIXME(Jiayi): we need to support only skip m out of n reqs in a batch + # FIXME(Jiayi): we need to support only skip m out of n reqs in a batch # same for prefix caching if not bypass_model_exec: # Some of the KV cache is not retrieved # so we need to recompute the hidden state - logger.debug("[rank%d]: KV EMPTY recv DONE.", torch.distributed.get_rank()) + logger.debug("[rank%d]: KV EMPTY recv DONE.", + torch.distributed.get_rank()) return None, bypass_model_exec, None if not is_complete: @@ -268,17 +276,17 @@ def recv_kv_caches_and_hidden_states( slot_mapping, device=kv_cache[0].device, ) - logger.debug("[rank%d]: KV PARTIAL recv DONE.", torch.distributed.get_rank()) + logger.debug("[rank%d]: KV PARTIAL recv DONE.", + torch.distributed.get_rank()) return None, bypass_model_exec, rebuilt_model_input - + # concatenate hidden states from different requests hidden_or_intermediate_states = torch.cat( hidden_or_intermediate_states_for_one_req, dim=0) logger.debug("[rank%d]: KV recv DONE.", torch.distributed.get_rank()) return hidden_or_intermediate_states, bypass_model_exec, model_input - - + def build_partial_prefill_input( self, model_input: "ModelInputForGPUWithSamplingMetadata", @@ -289,70 +297,77 @@ def build_partial_prefill_input( device: torch.device, ) -> "ModelInputForGPUWithSamplingMetadata": rebuilt_input_tokens = [] - rebuilt_input_positions= [] + rebuilt_input_positions = [] rebuilt_query_lens = [] - + rebuilt_num_prefills = 0 rebuilt_num_prefill_tokens = 0 rebuilt_slot_mapping = [] rebuilt_max_query_len = 0 - + rebuilt_block_tables = [] - + rebuilt_query_start_loc = [0] rebuilt_context_lens_tensor = [] rebuilt_selected_token_indices = [] - + # recounting query and context lengths for idx in range(len(input_tokens_list)): token_tensor = input_tokens_list[idx] num_token = len(token_tensor) num_computed_token = num_computed_tokens_list[idx] start_pos = start_pos_list[idx] - + rebuilt_input_tokens.append(token_tensor[num_computed_token:]) # TODO(Jiayi): please check the correctness of next line - rebuilt_input_positions.append(model_input.input_positions[start_pos+num_computed_token:start_pos+num_token]) + rebuilt_input_positions.append( + model_input.input_positions[start_pos + + num_computed_token:start_pos + + num_token]) q_len = num_token - num_computed_token rebuilt_query_lens.append(q_len) - + # Attn metadata-related rebuilt_num_prefills += 1 rebuilt_num_prefill_tokens += q_len - rebuilt_slot_mapping.append(slot_mapping_flat[start_pos+num_computed_token:start_pos+num_token]) + rebuilt_slot_mapping.append( + slot_mapping_flat[start_pos + num_computed_token:start_pos + + num_token]) rebuilt_max_query_len = max(q_len, rebuilt_max_query_len) # TODO(Jiayi): remove hard-code (block_size=16) blk_size = 16 - temp_block_table = [i//blk_size for i in range(start_pos, start_pos+num_token, blk_size)] + temp_block_table = [ + i // blk_size + for i in range(start_pos, start_pos + num_token, blk_size) + ] rebuilt_block_tables.append(temp_block_table) - rebuilt_query_start_loc.append(q_len) #start with 0 + rebuilt_query_start_loc.append(q_len) #start with 0 rebuilt_context_lens_tensor.append(num_computed_token) - + # Sampling metadata related #seq_groups (use rebuilt query lens) - rebuilt_selected_token_indices.append(start_pos+q_len-1) - - + rebuilt_selected_token_indices.append(start_pos + q_len - 1) + # rebuilt attn_metadata rebuilt_attn_metadata = deepcopy(model_input.attn_metadata) rebuilt_attn_metadata.num_prefills = rebuilt_num_prefills rebuilt_attn_metadata.num_prefill_tokens = rebuilt_num_prefill_tokens - rebuilt_attn_metadata.slot_mapping = torch.cat(rebuilt_slot_mapping).to(device) + rebuilt_attn_metadata.slot_mapping = torch.cat( + rebuilt_slot_mapping).to(device) rebuilt_attn_metadata.max_query_len = rebuilt_max_query_len - + rebuilt_attn_metadata.block_tables = torch.tensor( rebuilt_block_tables, - dtype=model_input.attn_metadata.block_tables.dtype - ).to(device) - + dtype=model_input.attn_metadata.block_tables.dtype).to(device) + rebuilt_attn_metadata.query_start_loc = torch.tensor( rebuilt_query_start_loc, dtype=model_input.attn_metadata.query_start_loc.dtype).to(device) rebuilt_attn_metadata.context_lens_tensor = torch.tensor( - rebuilt_context_lens_tensor, + rebuilt_context_lens_tensor, dtype=model_input.attn_metadata.context_lens_tensor.dtype, - ).to(device) - + ).to(device) + rebuilt_attn_metadata._cached_prefill_metadata = None # rebuilt sampling_metadata @@ -362,26 +377,27 @@ def build_partial_prefill_input( rebuilt_sampling_metadata.selected_token_indices = torch.tensor( rebuilt_selected_token_indices, dtype=model_input.sampling_metadata.selected_token_indices.dtype, - ).to(device) - + ).to(device) + # import here to avoid circular import. - from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + from vllm.worker.model_runner import ( + ModelInputForGPUWithSamplingMetadata) rebuilt_model_input = ModelInputForGPUWithSamplingMetadata( - input_tokens = torch.cat(rebuilt_input_tokens).to(device), - input_positions = torch.cat(rebuilt_input_positions).to(device), - seq_lens = model_input.seq_lens, - query_lens = rebuilt_query_lens, - lora_mapping = model_input.lora_mapping, - lora_requests = model_input.lora_requests, - attn_metadata = rebuilt_attn_metadata, - prompt_adapter_mapping = model_input.prompt_adapter_mapping, - prompt_adapter_requests = model_input.prompt_adapter_requests, - multi_modal_kwargs = model_input.multi_modal_kwargs, - request_ids_to_seq_ids = model_input.request_ids_to_seq_ids, - finished_requests_ids = model_input.finished_requests_ids, - virtual_engine = model_input.virtual_engine, - sampling_metadata = rebuilt_sampling_metadata, - is_prompt = model_input.is_prompt, + input_tokens=torch.cat(rebuilt_input_tokens).to(device), + input_positions=torch.cat(rebuilt_input_positions).to(device), + seq_lens=model_input.seq_lens, + query_lens=rebuilt_query_lens, + lora_mapping=model_input.lora_mapping, + lora_requests=model_input.lora_requests, + attn_metadata=rebuilt_attn_metadata, + prompt_adapter_mapping=model_input.prompt_adapter_mapping, + prompt_adapter_requests=model_input.prompt_adapter_requests, + multi_modal_kwargs=model_input.multi_modal_kwargs, + request_ids_to_seq_ids=model_input.request_ids_to_seq_ids, + finished_requests_ids=model_input.finished_requests_ids, + virtual_engine=model_input.virtual_engine, + sampling_metadata=rebuilt_sampling_metadata, + is_prompt=model_input.is_prompt, ) - - return rebuilt_model_input \ No newline at end of file + + return rebuilt_model_input diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 13527110a2232..3615fa6af399c 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -20,29 +20,25 @@ parallelism, you can skip the model parallel initialization and destruction steps. """ -import time import contextlib import pickle -import logging +import time from collections import namedtuple from contextlib import contextmanager, nullcontext from dataclasses import dataclass from multiprocessing import shared_memory from typing import Any, Dict, List, Optional, Tuple, Union from unittest.mock import patch -import queue import torch import torch.distributed from torch.distributed import Backend, ProcessGroup -import vllm.envs as envs -from vllm.logger import init_logger - - # Use this import to check if disagg prefill is enabled. # if enabled, need to adjust distributed group correspondingly. import vllm.distributed.kv_transfer.vllm_adapter as dist_kv +import vllm.envs as envs +from vllm.logger import init_logger @dataclass @@ -865,7 +861,8 @@ def include_decoding_groups_if_disagg_enabled( Extended: [ [0,1], [2,3], [4,5], [6,7] ] Arguments: groups: original distributed group - world_size: the vLLM world size, which is half of torch.distributed.get_world_size() + world_size: the vLLM world size, which is half of + torch.distributed.get_world_size() """ if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: @@ -908,9 +905,8 @@ def init_distributed_environment( # offset global rank by tp * pp (which is world_size) maybe_disagg_rank = rank + world_size - logger.debug( - f"Before: world size {maybe_disagg_world_size}, rank {maybe_disagg_rank}" - ) + logger.debug("Before: world size %d, rank %d", maybe_disagg_world_size, + maybe_disagg_rank) torch.distributed.init_process_group( backend=backend, @@ -974,17 +970,18 @@ def initialize_model_parallel( ranks 8 to 15 belong to the second box. - Disaggregated prefill will also initialize its process group using this function. + Disaggregated prefill will also init its process group using this function. Changes: - vLLM world size: unchanged (tp * pp) - torch.distributed.get_world_size(): - 2 * tp * pp - - Why: torch.distributed package sees 2 vLLM instances (prefill and decode) + - Why: both prefill vLLM and decode vLLM is in the world - Global rank: - [0, tp * pp) for prefill - [tp * pp, 2 * tp * pp) for decode - Parallel groups - - Extend _WORLD, _TP and _PP using `include_decoding_groups_if_disagg_enabled` + - Extend _WORLD, _TP and _PP using + `include_decoding_groups_if_disagg_enabled` - Add a new parallel group `_DISAGG` for disaggregated prefill - [ [0, tp * pp], [1, tp * pp + 1], .. ] - Local rank: unchanged @@ -997,12 +994,14 @@ def initialize_model_parallel( get_world_group().device_group) if dist_kv.IS_DISTRIBUTED_KV_INSTANCE or dist_kv.IS_LMCACHE_INSTANCE: # Disaggregated prefill enabled - # The world_size for this vLLM instance is tp * pp, but torch.distributed contains 2 vLLM instances, its world size is 2 * tp * pp + # The world_size for this vLLM instance is tp * pp, but + # torch.distributed contains 2 vLLM instances, + # its world size is 2 * tp * pp # Adjust the world_size to match. world_size = world_size // 2 - if (world_size - != tensor_model_parallel_size * pipeline_model_parallel_size): + if (world_size != + tensor_model_parallel_size * pipeline_model_parallel_size): raise RuntimeError( f"world_size ({world_size}) is not equal to " f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 679f8394688e8..b774a649d39f5 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -48,7 +48,8 @@ def _get_worker_kwargs( """Return worker init args for a given rank.""" if distributed_init_method is None: distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port(force=dist_kv.IS_DISTRIBUTED_KV_INSTANCE)) + get_ip(), + get_open_port(force=dist_kv.IS_DISTRIBUTED_KV_INSTANCE)) return dict( model_config=self.model_config, parallel_config=self.parallel_config, diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 9448228879453..499e891d98fc0 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -70,7 +70,8 @@ def _init_executor(self) -> None: # Since it only works for single node, we can use the loopback address # 127.0.0.1 for communication. distributed_init_method = get_distributed_init_method( - "127.0.0.1", get_open_port(force=dist_kv.IS_DISTRIBUTED_KV_INSTANCE)) + "127.0.0.1", + get_open_port(force=dist_kv.IS_DISTRIBUTED_KV_INSTANCE)) self.workers: List[ProcessWorkerWrapper] = [] # This is the list of workers that are rank 0 of each TP group EXCEPT diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index c646e8536ba15..0cca5db1677ed 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -6,8 +6,8 @@ import msgspec -import vllm.envs as envs import vllm.distributed.kv_transfer.vllm_adapter as dist_kv +import vllm.envs as envs from vllm.executor.distributed_gpu_executor import ( # yapf: disable DistributedGPUExecutor, DistributedGPUExecutorAsync) from vllm.executor.msgspec_utils import encode_hook diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index ab38302b3321a..b846d1d707db0 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -14,7 +14,6 @@ import torch.distributed import torch.nn as nn - import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention.backends.abstract import AttentionState @@ -55,7 +54,6 @@ _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_tensor_dict, _init_sampling_metadata_from_tensor_dict, dump_input_when_exception) -from vllm import _custom_ops as ops if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -1546,30 +1544,20 @@ def execute_model( model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_start.record() - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(multi_modal_kwargs, - device=self.device), - **seqlen_agnostic_kwargs) - + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), + **seqlen_agnostic_kwargs) + if (self.observability_config is not None and self.observability_config.collect_model_forward_time): model_forward_end.record() - return hidden_or_intermediate_states - - @torch.inference_mode() - def postprocess_model( - self, - model_input, - hidden_or_intermediate_states, - - ): if not get_pp_group().is_last_rank: if (self.is_driver_worker and hidden_or_intermediate_states is not None @@ -1587,7 +1575,16 @@ def postprocess_model( hidden_or_intermediate_states.tensors["model_forward_time"] = ( torch.tensor(model_forward_time + orig_model_forward_time)) return hidden_or_intermediate_states - + + return hidden_or_intermediate_states + + @torch.inference_mode() + def postprocess_model( + self, + model_input: ModelInputForGPUWithSamplingMetadata, + hidden_or_intermediate_states, + ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: + logits = self.model.compute_logits(hidden_or_intermediate_states, model_input.sampling_metadata) @@ -1603,6 +1600,7 @@ def postprocess_model( sampling_metadata=model_input.sampling_metadata, ) + assert model_input.attn_metadata is not None decode_meta = model_input.attn_metadata.decode_metadata if self.return_hidden_states: # we only need to pass hidden states of most recent token @@ -1620,9 +1618,7 @@ def postprocess_model( output.hidden_states = hidden_states return [output] - - - + class CUDAGraphRunner: diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 22577ebf69492..7908fc466eb38 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -7,6 +7,8 @@ import torch +import vllm.distributed.kv_transfer.vllm_adapter as dist_kv +import vllm.distributed.parallel_state as ps from vllm.config import ObservabilityConfig from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group from vllm.logger import init_logger @@ -16,13 +18,11 @@ from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.utils import (enable_trace_function_call_for_thread, update_environment_variables) +from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner from vllm.worker.model_runner_base import (BroadcastableModelInput, ModelRunnerBase, ModelRunnerInputBase) -import vllm.distributed.kv_transfer.vllm_adapter as dist_kv -import vllm.distributed.parallel_state as ps - logger = init_logger(__name__) @@ -223,7 +223,6 @@ def execute_worker(self, worker_input: WorkerInput) -> None: Process an execution request. """ raise NotImplementedError - def _get_worker_input_from_broadcast( self @@ -327,19 +326,14 @@ def execute_model( and self.observability_config.collect_model_execute_time): orig_model_execute_time = intermediate_tensors.tensors.get( "model_execute_time", torch.tensor(0)).item() - - + # for disaggregated prefilling: allow bypassing model execution bypass_model_exec = False - - - # receive KV cache. - # NOTE(kuntai): - # If only a part of KV cache is received, we will adjust model_input - # to avoid prefill on the part of KV caches that are already received. - # This will not happen for disaggregated prefill, but will happen - # when connecting to a KV cache database (like LMCache). + + # receive KV cache from prefill instance, or from LMCache if self.need_recv_kv(model_input, worker_input): + assert isinstance(self.model_runner, GPUModelRunnerBase), \ + "Distributed KV transfer only support GPU modelrunner" hidden_or_intermediate_states, bypass_model_exec, model_input = \ ps.get_disagg_group().recv_kv_caches_and_hidden_states( # model is used to know which layer the current worker @@ -347,11 +341,12 @@ def execute_model( # layers. self.model_runner.model, model_input, - self.kv_cache[worker_input.virtual_engine], + kv_caches=self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None, ) #assert bypass_model_exec - - if not bypass_model_exec: + + if not bypass_model_exec: hidden_or_intermediate_states = self.model_runner.execute_model( model_input=model_input, kv_caches=self.kv_cache[worker_input.virtual_engine] @@ -360,24 +355,31 @@ def execute_model( num_steps=num_steps, **kwargs, ) - + # sending out KV cache if self.need_send_kv(model_input, worker_input): + assert isinstance(self.model_runner, GPUModelRunnerBase), \ + "Distributed KV transfer only support GPU modelrunner" ps.get_disagg_group().send_kv_caches_and_hidden_states( # model is used to know which layer the current worker # is working on, so that we can send KV for only those # layers. self.model_runner.model, model_input, - self.kv_cache[worker_input.virtual_engine], + self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None, hidden_or_intermediate_states, ) - - # Get model output based on hidden state. - output = self.model_runner.postprocess_model( - model_input, - hidden_or_intermediate_states, - ) + + # separating postprocessing steps out from execute_model + # so that disaggregated prefill can completely bypass model forwarding + if isinstance(self.model_runner, ModelRunner): + output = self.model_runner.postprocess_model( + model_input, + hidden_or_intermediate_states, + ) + else: + output = hidden_or_intermediate_states model_execute_time = time.perf_counter() - start_time if not get_pp_group().is_last_rank: @@ -400,38 +402,43 @@ def execute_model( return output def need_recv_kv(self, model_input, worker_input) -> bool: - + + if self.kv_cache is None: + return False + kv_caches = self.kv_cache[worker_input.virtual_engine] prefill_meta = model_input.attn_metadata.prefill_metadata - + # check if the current run is profiling is_profile_run = (kv_caches is None) or (kv_caches[0] is None) # check if the current run is prefill is_prefill_run = prefill_meta is not None # for disaggregated prefilling: allow bypassing model execution - + return all([ - is_prefill_run, - dist_kv.IS_KV_DECODE_INSTANCE or dist_kv.IS_LMCACHE_INSTANCE, - not is_profile_run]) + is_prefill_run, dist_kv.IS_KV_DECODE_INSTANCE + or dist_kv.IS_LMCACHE_INSTANCE, not is_profile_run + ]) - def need_send_kv(self, model_input, worker_input) -> bool: - + + if self.kv_cache is None: + return False + kv_caches = self.kv_cache[worker_input.virtual_engine] prefill_meta = model_input.attn_metadata.prefill_metadata - model_executable = self.model_runner.model - + if not isinstance(self.model_runner, GPUModelRunnerBase): + return False + # check if the current run is profiling is_profile_run = (kv_caches is None) or (kv_caches[0] is None) # check if the current run is prefill is_prefill_run = prefill_meta is not None - + return all([ - is_prefill_run, - dist_kv.IS_KV_PREFILL_INSTANCE or dist_kv.IS_LMCACHE_INSTANCE, - not is_profile_run]) - + is_prefill_run, dist_kv.IS_KV_PREFILL_INSTANCE + or dist_kv.IS_LMCACHE_INSTANCE, not is_profile_run + ]) def _execute_model_spmd( self, From 36a382c961c75aa00733a5d04022a2ad1a17b229 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Sun, 15 Sep 2024 23:48:40 +0000 Subject: [PATCH 208/303] resolve circular import --- vllm/utils.py | 2 +- vllm/worker/worker_base.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 1adab61917265..8e27e1f73f4b4 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -535,7 +535,7 @@ def get_open_port(force: bool = False) -> int: if force and port is not None: # force vLLM to use envs.VLLM_PORT for torch.distributed init # This is because this port will binded by prefill instance - # But both prefill and decode instance need to use this port to + # But both prefill and decode instance need to use this port to # initialize torch.distributed return port while True: diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 7908fc466eb38..d55400a402400 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -18,7 +18,6 @@ from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.utils import (enable_trace_function_call_for_thread, update_environment_variables) -from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner from vllm.worker.model_runner_base import (BroadcastableModelInput, ModelRunnerBase, ModelRunnerInputBase) @@ -332,6 +331,7 @@ def execute_model( # receive KV cache from prefill instance, or from LMCache if self.need_recv_kv(model_input, worker_input): + from vllm.worker.model_runner import GPUModelRunnerBase assert isinstance(self.model_runner, GPUModelRunnerBase), \ "Distributed KV transfer only support GPU modelrunner" hidden_or_intermediate_states, bypass_model_exec, model_input = \ @@ -358,6 +358,7 @@ def execute_model( # sending out KV cache if self.need_send_kv(model_input, worker_input): + from vllm.worker.model_runner import GPUModelRunnerBase assert isinstance(self.model_runner, GPUModelRunnerBase), \ "Distributed KV transfer only support GPU modelrunner" ps.get_disagg_group().send_kv_caches_and_hidden_states( @@ -373,6 +374,7 @@ def execute_model( # separating postprocessing steps out from execute_model # so that disaggregated prefill can completely bypass model forwarding + from vllm.worker.model_runner import ModelRunner if isinstance(self.model_runner, ModelRunner): output = self.model_runner.postprocess_model( model_input, @@ -427,6 +429,7 @@ def need_send_kv(self, model_input, worker_input) -> bool: kv_caches = self.kv_cache[worker_input.virtual_engine] prefill_meta = model_input.attn_metadata.prefill_metadata + from vllm.worker.model_runner import GPUModelRunnerBase if not isinstance(self.model_runner, GPUModelRunnerBase): return False From a0867dd1cf73ae998f6051875a7949026f98cf26 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Sun, 15 Sep 2024 23:49:01 +0000 Subject: [PATCH 209/303] fix redundant import --- tests/kv_transfer/test_send_recv.py | 81 ++++++++++++++--------------- 1 file changed, 39 insertions(+), 42 deletions(-) diff --git a/tests/kv_transfer/test_send_recv.py b/tests/kv_transfer/test_send_recv.py index 4bf757d7c8492..994b907e0c899 100644 --- a/tests/kv_transfer/test_send_recv.py +++ b/tests/kv_transfer/test_send_recv.py @@ -1,10 +1,11 @@ +import os +import time +from typing import List -import vllm.distributed.kv_transfer.kv_pipe.torch_distributed_pipe as tdp import torch -import os -import random from tqdm import tqdm -import time + +import vllm.distributed.kv_transfer.kv_pipe.torch_distributed_pipe as tdp def test_run(my_rank, pipe): @@ -35,20 +36,19 @@ def test_run(my_rank, pipe): assert torch.allclose(y, y2) - def stress_test(my_rank, pipe): - + torch.distributed.barrier() - - tensors = [] - - + + tensors: List[torch.Tensor] = [] + for i in tqdm(range(2000)): mean = torch.rand(1).item() std = torch.rand(1).item() - size = torch.randint(900, 1000, (2,)) - x = torch.normal(mean * 1.0, std * 1.0, size=size.tolist()).to(pipe.device) - + size = torch.randint(900, 1000, (2, )) + x = torch.normal(mean * 1.0, std * 1.0, + size=size.tolist()).to(pipe.device) + # 5% probability of sending a None if torch.rand(1).item() < 0.05: tensors.append(None) @@ -59,15 +59,13 @@ def stress_test(my_rank, pipe): tensors.append(x.mean().unsqueeze(0)) tensors.append(x.std().unsqueeze(0)) - - torch.distributed.barrier() - + for i in tqdm(range(2000)): if my_rank == int((i % 10) > 3): - pipe.send_tensor(tensors[3*i]) - pipe.send_tensor(tensors[3*i+1]) - pipe.send_tensor(tensors[3*i+2]) + pipe.send_tensor(tensors[3 * i]) + pipe.send_tensor(tensors[3 * i + 1]) + pipe.send_tensor(tensors[3 * i + 2]) else: x = pipe.recv_tensor() mean = pipe.recv_tensor() @@ -76,34 +74,36 @@ def stress_test(my_rank, pipe): assert mean is None assert std is None else: - assert torch.allclose(x, tensors[3*i]) + assert torch.allclose(x, tensors[3 * i]) assert x.mean() == mean[0] assert x.std() == std[0] torch.distributed.barrier() print("Stress test passed.") - - - + + def latency_test(my_rank, pipe, nelement, ntensor): - + latencies = [] - + torch.distributed.barrier() - + for i in tqdm(range(1000)): - + tensors = [] - + if my_rank == 0: # create tensor - tensors = [torch.rand(nelement).to(pipe.device) for _ in range(ntensor)] - + tensors = [ + torch.rand(nelement).to(pipe.device) for _ in range(ntensor) + ] + torch.distributed.barrier() - + if my_rank == 0: - t = torch.tensor([time.time()], dtype=torch.float64).to(pipe.device) + t = torch.tensor([time.time()], + dtype=torch.float64).to(pipe.device) for tensor in tensors: pipe.send_tensor(tensor) pipe.send_tensor(t) @@ -114,7 +114,7 @@ def latency_test(my_rank, pipe, nelement, ntensor): latencies.append(time.time() - t.item()) torch.distributed.barrier() - + print('Latency test passed.') print('Latency:', torch.tensor(latencies).mean().item() * 1000, 'ms') @@ -123,18 +123,15 @@ def latency_test(my_rank, pipe, nelement, ntensor): my_rank = int(os.environ['RANK']) - - torch.distributed.init_process_group( - init_method="tcp://127.0.0.1:23456", - world_size=2, - rank=my_rank) + torch.distributed.init_process_group(init_method="tcp://127.0.0.1:23456", + world_size=2, + rank=my_rank) print("initialized! My rank is %d" % my_rank) + pipe = tdp.TorchDistributedPipe([[0, 1]], my_rank, "nccl") - pipe = tdp.TorchDistributedPipe([[0,1]], my_rank, "nccl") - - torch.manual_seed(0) + torch.manual_seed(0) test_run(my_rank, pipe) stress_test(my_rank, pipe) - latency_test(my_rank, pipe, 1024*8*128, 80) + latency_test(my_rank, pipe, 1024 * 8 * 128, 80) From 7f90903a448755fe9c00b252ec9d17a4c6566f61 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Sun, 15 Sep 2024 23:55:50 +0000 Subject: [PATCH 210/303] rename to a shorter name --- tests/kv_transfer/test_lookup_buffer.py | 80 +++++++++---------- ...e_kv_lookup_buffer.py => simple_buffer.py} | 0 vllm/distributed/kv_transfer/vllm_adapter.py | 5 +- 3 files changed, 42 insertions(+), 43 deletions(-) rename vllm/distributed/kv_transfer/kv_lookup_buffer/{simple_kv_lookup_buffer.py => simple_buffer.py} (100%) diff --git a/tests/kv_transfer/test_lookup_buffer.py b/tests/kv_transfer/test_lookup_buffer.py index ae19d068be9fa..0730f091a34b8 100644 --- a/tests/kv_transfer/test_lookup_buffer.py +++ b/tests/kv_transfer/test_lookup_buffer.py @@ -1,24 +1,25 @@ - -import vllm.distributed.kv_transfer.kv_pipe.torch_distributed_pipe as tdp -import vllm.distributed.kv_transfer.kv_lookup_buffer.simple_kv_lookup_buffer as sklb -import torch import os import random + +import torch from tqdm import tqdm -import time -# TODO: the test depends on a lot of fields in the current implementation. We should have standard interface instead direct field access +import vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer as sklb +import vllm.distributed.kv_transfer.kv_pipe.torch_distributed_pipe as tdp + +# TODO: the test depends on a lot of fields in the current implementation. +# We should have standard interface instead direct field access + def test_run(my_rank, buffer, device): - - # buffer should be empty in the beginning + + # buffer should be empty in the beginning if my_rank == 0: assert buffer.buffer_size == 0 assert len(buffer.buffer) == 0 - # insert - tokens = torch.tensor([1,2,3]).to(device) + tokens = torch.tensor([1, 2, 3]).to(device) roi = (tokens > 0) if my_rank == 0: key = 2.0 * torch.ones([5, 6]).to(device) @@ -27,45 +28,47 @@ def test_run(my_rank, buffer, device): placeholder = torch.tensor([1]).to(device) buffer.insert(tokens, roi, key, value, placeholder) - + torch.distributed.barrier() - + # drop_select if my_rank == 1: tok, roi_, key, value, hidden = buffer.drop_select(tokens, roi) assert torch.allclose(tokens, tok) assert torch.allclose(roi, roi_) - assert torch.allclose(key, 2.0 * torch.ones([5, 6], device = device)) - assert torch.allclose(value, 3.0 * torch.ones([5, 6], device = device)) + assert torch.allclose(key, 2.0 * torch.ones([5, 6], device=device)) + assert torch.allclose(value, 3.0 * torch.ones([5, 6], device=device)) torch.distributed.barrier() - + if my_rank == 0: assert buffer.buffer_size == 0 assert len(buffer.buffer) == 0 - + print("Test run passed!") + def stress_test(my_rank, buf, device): - + torch.distributed.barrier() torch.manual_seed(100) reqs = [ ( - torch.rand(100).to(device), # tokens - torch.ones(100).bool().to(device), # roi - torch.rand(100).to(device), # key - torch.rand(100).to(device), # value - torch.rand(100).to(device), # hidden - ) for i in tqdm(range(200))] + torch.rand(100).to(device), # tokens + torch.ones(100).bool().to(device), # roi + torch.rand(100).to(device), # key + torch.rand(100).to(device), # value + torch.rand(100).to(device), # hidden + ) for i in tqdm(range(200)) + ] random.seed(my_rank) random.shuffle(reqs) - + torch.distributed.barrier() - + n = 0 - + # the buffer size can only store 100 reqs # so the sender will occasionally block to wait for the receiver. for req in tqdm(reqs): @@ -74,7 +77,7 @@ def stress_test(my_rank, buf, device): else: tok, roi, k, v, h = req tok_, roi_, k_, v_, h_ = buf.drop_select(tok, roi) - + if tok_ is None: assert roi_ is None assert k_ is None @@ -89,8 +92,7 @@ def stress_test(my_rank, buf, device): assert torch.allclose(h, h_) print('Rank %d done' % my_rank) torch.distributed.barrier() - - + if my_rank == 0: x = torch.tensor([0]) torch.distributed.recv(x, 1) @@ -103,30 +105,26 @@ def stress_test(my_rank, buf, device): torch.distributed.send(torch.tensor([n]), 0) print("Passed stress test!") - - + if __name__ == "__main__": my_rank = int(os.environ['RANK']) - - torch.distributed.init_process_group( - init_method="tcp://127.0.0.1:23456", - world_size=2, - rank=my_rank) + torch.distributed.init_process_group(init_method="tcp://127.0.0.1:23456", + world_size=2, + rank=my_rank) print("initialized! My rank is %d" % my_rank) - - pipe = tdp.TorchDistributedPipe([[0,1]], my_rank, "nccl") - cpu_pipe = tdp.TorchDistributedPipe([[0,1]], my_rank, "gloo") + pipe = tdp.TorchDistributedPipe([[0, 1]], my_rank, "nccl") + cpu_pipe = tdp.TorchDistributedPipe([[0, 1]], my_rank, "gloo") buffer = sklb.SimpleKVLookupBuffer(cpu_pipe, pipe, 170000) test_run(my_rank, buffer, pipe.device) - + stress_test(my_rank, buffer, pipe.device) - + buffer.close() pipe.close() cpu_pipe.close() diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py similarity index 100% rename from vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py rename to vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py index 03392ec13f10b..2edb426c5c8da 100644 --- a/vllm/distributed/kv_transfer/vllm_adapter.py +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -28,12 +28,11 @@ import torch from torch.distributed import Backend +import vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer as sklb import vllm.envs as envs from vllm import _custom_ops as ops from vllm.distributed.kv_transfer.kv_lookup_buffer.base import ( KVLookupBufferBase) -from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_kv_lookup_buffer import ( - SimpleKVLookupBuffer) from vllm.distributed.kv_transfer.kv_pipe.torch_distributed_pipe import ( TorchDistributedPipe) from vllm.logger import init_logger @@ -77,6 +76,8 @@ def __init__( self.send_buffer: Optional[KVLookupBufferBase] = None self.recv_buffer: Optional[KVLookupBufferBase] = None + SimpleKVLookupBuffer = sklb.SimpleKVLookupBuffer + if IS_LMCACHE_INSTANCE: # when vLLM is connected with LMCache # it needs to both send and recv KV cache From 5ca22fb44866e9b5154fa09844d71a9e34c18729 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Mon, 16 Sep 2024 00:06:43 +0000 Subject: [PATCH 211/303] remove unnecessary file --- tests/test_send_recv.sh | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 tests/test_send_recv.sh diff --git a/tests/test_send_recv.sh b/tests/test_send_recv.sh deleted file mode 100644 index e69de29bb2d1d..0000000000000 From 073642bde894301ced000178eaec7f76562e78be Mon Sep 17 00:00:00 2001 From: ApostaC Date: Mon, 16 Sep 2024 00:40:26 +0000 Subject: [PATCH 212/303] update kv transfer test --- tests/kv_transfer/test_launcher.py | 52 +++++++++++++++++++++++++ tests/kv_transfer/test_lookup_buffer.sh | 3 -- tests/kv_transfer/test_send_recv.py | 10 +++-- tests/kv_transfer/test_send_recv.sh | 3 -- 4 files changed, 58 insertions(+), 10 deletions(-) create mode 100644 tests/kv_transfer/test_launcher.py delete mode 100644 tests/kv_transfer/test_lookup_buffer.sh delete mode 100644 tests/kv_transfer/test_send_recv.sh diff --git a/tests/kv_transfer/test_launcher.py b/tests/kv_transfer/test_launcher.py new file mode 100644 index 0000000000000..5c0aeb04b43fa --- /dev/null +++ b/tests/kv_transfer/test_launcher.py @@ -0,0 +1,52 @@ +import subprocess +import pytest +import sys +import torch + +def run_python_script(script_name, timeout): + try: + # Start both processes asynchronously using Popen + process0 = subprocess.Popen( + [sys.executable, script_name], + env={"RANK": "0"}, # Set the RANK environment variable for process 0 + stdout=sys.stdout, # Pipe stdout to current stdout + stderr=sys.stderr, # Pipe stderr to current stderr + ) + + process1 = subprocess.Popen( + [sys.executable, script_name], + env={"RANK": "1"}, # Set the RANK environment variable for process 1 + stdout=sys.stdout, # Pipe stdout to current stdout + stderr=sys.stderr, # Pipe stderr to current stderr + ) + + # Wait for both processes to complete, with a timeout + process0.wait(timeout=timeout) + process1.wait(timeout=timeout) + + # Check the return status of both processes + if process0.returncode != 0: + pytest.fail(f"Test {script_name} failed for RANK=0 with return code {process0.returncode}") + if process1.returncode != 0: + pytest.fail(f"Test {script_name} failed for RANK=1 with return code {process1.returncode}") + + except subprocess.TimeoutExpired: + # If either process times out, terminate both and fail the test + process0.terminate() + process1.terminate() + pytest.fail(f"Test {script_name} timed out") + except Exception as e: + pytest.fail(f"Test {script_name} failed with error: {str(e)}") + +# Define the test cases using pytest's parametrize +@pytest.mark.parametrize("script_name,timeout", [ + ("test_lookup_buffer.py", 60), # Second test case with a 60-second timeout + ("test_send_recv.py", 120) # First test case with a 120-second timeout +]) +def test_run_python_script(script_name, timeout): + # Check the number of GPUs + if torch.cuda.device_count() < 2: + pytest.skip(f"Skipping test {script_name} because fewer than 2 GPUs are available") + + # Run the test if there are at least 2 GPUs + run_python_script(script_name, timeout) diff --git a/tests/kv_transfer/test_lookup_buffer.sh b/tests/kv_transfer/test_lookup_buffer.sh deleted file mode 100644 index 336b540e70542..0000000000000 --- a/tests/kv_transfer/test_lookup_buffer.sh +++ /dev/null @@ -1,3 +0,0 @@ - -RANK=0 python3 test_lookup_buffer.py & -RANK=1 python3 test_lookup_buffer.py & diff --git a/tests/kv_transfer/test_send_recv.py b/tests/kv_transfer/test_send_recv.py index 994b907e0c899..f6da7f88d5f5c 100644 --- a/tests/kv_transfer/test_send_recv.py +++ b/tests/kv_transfer/test_send_recv.py @@ -42,7 +42,7 @@ def stress_test(my_rank, pipe): tensors: List[torch.Tensor] = [] - for i in tqdm(range(2000)): + for i in tqdm(range(500)): mean = torch.rand(1).item() std = torch.rand(1).item() size = torch.randint(900, 1000, (2, )) @@ -61,7 +61,7 @@ def stress_test(my_rank, pipe): torch.distributed.barrier() - for i in tqdm(range(2000)): + for i in tqdm(range(500)): if my_rank == int((i % 10) > 3): pipe.send_tensor(tensors[3 * i]) pipe.send_tensor(tensors[3 * i + 1]) @@ -89,7 +89,7 @@ def latency_test(my_rank, pipe, nelement, ntensor): torch.distributed.barrier() - for i in tqdm(range(1000)): + for i in tqdm(range(500)): tensors = [] @@ -134,4 +134,6 @@ def latency_test(my_rank, pipe, nelement, ntensor): torch.manual_seed(0) test_run(my_rank, pipe) stress_test(my_rank, pipe) - latency_test(my_rank, pipe, 1024 * 8 * 128, 80) + + # Use this function if you want to test the latency of pipe impl. + # latency_test(my_rank, pipe, 1024 * 8 * 128, 80) diff --git a/tests/kv_transfer/test_send_recv.sh b/tests/kv_transfer/test_send_recv.sh deleted file mode 100644 index 2a478871bd0e7..0000000000000 --- a/tests/kv_transfer/test_send_recv.sh +++ /dev/null @@ -1,3 +0,0 @@ - -RANK=0 python3 test_send_recv.py & -RANK=1 python3 test_send_recv.py & From 70d6571c936d4b8baf351a7161512ca67cb1e079 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Mon, 16 Sep 2024 01:12:57 +0000 Subject: [PATCH 213/303] update tests --- tests/kv_transfer/disagg_test.py | 107 ++++++++++++++++++ .../{test_launcher.py => module_test.py} | 0 2 files changed, 107 insertions(+) create mode 100644 tests/kv_transfer/disagg_test.py rename tests/kv_transfer/{test_launcher.py => module_test.py} (100%) diff --git a/tests/kv_transfer/disagg_test.py b/tests/kv_transfer/disagg_test.py new file mode 100644 index 0000000000000..2e8414a9f4642 --- /dev/null +++ b/tests/kv_transfer/disagg_test.py @@ -0,0 +1,107 @@ +import os +import sys +import subprocess +import time +import pytest +import requests +import signal +from subprocess import Popen +import torch + + +# Fixture to set up environment variables and teardown servers after tests +@pytest.fixture(scope="module", autouse=True) +def setup_servers(): + if torch.cuda.device_count() < 4: + pytest.skip("Skipping test: fewer than 4 GPUs available") + + # Set up environment variables + VLLM_HOST_IP = subprocess.check_output("hostname -I | awk '{print $1}'", shell=True).decode().strip() + os.environ["VLLM_HOST_IP"] = VLLM_HOST_IP + os.environ["VLLM_PORT"] = "12345" + + # Start prefill instance + prefill_cmd = [ + sys.executable, "-m", "vllm.entrypoints.openai.api_server", + "-tp", "2", + "--model", "meta-llama/Meta-Llama-3.1-8B-Instruct", + "--port", "8100", + "--gpu-memory-utilization", "0.8" + ] + prefill_env = os.environ.copy() + prefill_env["VLLM_DISAGG_PREFILL_ROLE"] = "prefill" + prefill_env["CUDA_VISIBLE_DEVICES"] = "0,1" + prefill_proc = Popen(prefill_cmd, env=prefill_env) + + # Start decode instance + decode_cmd = [ + sys.executable, "-m", "vllm.entrypoints.openai.api_server", + "-tp", "2", + "--model", "meta-llama/Meta-Llama-3.1-8B-Instruct", + "--port", "8200", + "--gpu-memory-utilization", "0.8" + ] + decode_env = os.environ.copy() + decode_env["VLLM_DISAGG_PREFILL_ROLE"] = "decode" + decode_env["CUDA_VISIBLE_DEVICES"] = "2,3" + decode_proc = Popen(decode_cmd, env=decode_env) + + # Wait for servers to be ready + assert wait_for_server(8100), "Prefill server did not start in time" + assert wait_for_server(8200), "Decode server did not start in time" + + # Yield to the test function and handle teardown after tests + yield + + # Cleanup: kill the processes + prefill_proc.terminate() + decode_proc.terminate() + + # Additional cleanup if needed + prefill_proc.wait() + decode_proc.wait() + +# Helper function to wait for server +def wait_for_server(port, timeout=120): + start_time = time.time() + while time.time() - start_time < timeout: + try: + response = requests.get(f"http://localhost:{port}/v1/completions") + if response.status_code in [200, 405]: + return True + except requests.ConnectionError: + time.sleep(1) + return False + +# Test function to send curl requests and validate responses +@pytest.mark.parametrize("prompt", [ + "San Francisco is a", + "Santa Clara is a" +]) +def test_disaggregated_prefilling(prompt): + # Send to prefill + response = requests.post( + "http://localhost:8100/v1/completions", + headers={"Content-Type": "application/json"}, + json={ + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "prompt": prompt, + "max_tokens": 1, + "temperature": 0 + } + ) + assert response.status_code == 200 + + # Send to decode + response = requests.post( + "http://localhost:8200/v1/completions", + headers={"Content-Type": "application/json"}, + json={ + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "prompt": prompt, + "max_tokens": 10, + "temperature": 0 + } + ) + assert response.status_code == 200 + \ No newline at end of file diff --git a/tests/kv_transfer/test_launcher.py b/tests/kv_transfer/module_test.py similarity index 100% rename from tests/kv_transfer/test_launcher.py rename to tests/kv_transfer/module_test.py From 4d6b00a6c370e5005a8e92745de603938a3b341a Mon Sep 17 00:00:00 2001 From: ApostaC Date: Mon, 16 Sep 2024 01:14:18 +0000 Subject: [PATCH 214/303] make fmt checker happy --- tests/kv_transfer/disagg_test.py | 67 +++++++++++++---------------- tests/kv_transfer/module_test.py | 35 +++++++++------ tests/kv_transfer/test_send_recv.py | 2 +- 3 files changed, 54 insertions(+), 50 deletions(-) diff --git a/tests/kv_transfer/disagg_test.py b/tests/kv_transfer/disagg_test.py index 2e8414a9f4642..fa6a527574cf4 100644 --- a/tests/kv_transfer/disagg_test.py +++ b/tests/kv_transfer/disagg_test.py @@ -1,11 +1,11 @@ import os -import sys import subprocess +import sys import time +from subprocess import Popen + import pytest import requests -import signal -from subprocess import Popen import torch @@ -16,16 +16,15 @@ def setup_servers(): pytest.skip("Skipping test: fewer than 4 GPUs available") # Set up environment variables - VLLM_HOST_IP = subprocess.check_output("hostname -I | awk '{print $1}'", shell=True).decode().strip() + VLLM_HOST_IP = subprocess.check_output("hostname -I | awk '{print $1}'", + shell=True).decode().strip() os.environ["VLLM_HOST_IP"] = VLLM_HOST_IP os.environ["VLLM_PORT"] = "12345" # Start prefill instance prefill_cmd = [ - sys.executable, "-m", "vllm.entrypoints.openai.api_server", - "-tp", "2", - "--model", "meta-llama/Meta-Llama-3.1-8B-Instruct", - "--port", "8100", + sys.executable, "-m", "vllm.entrypoints.openai.api_server", "-tp", "2", + "--model", "meta-llama/Meta-Llama-3.1-8B-Instruct", "--port", "8100", "--gpu-memory-utilization", "0.8" ] prefill_env = os.environ.copy() @@ -35,10 +34,8 @@ def setup_servers(): # Start decode instance decode_cmd = [ - sys.executable, "-m", "vllm.entrypoints.openai.api_server", - "-tp", "2", - "--model", "meta-llama/Meta-Llama-3.1-8B-Instruct", - "--port", "8200", + sys.executable, "-m", "vllm.entrypoints.openai.api_server", "-tp", "2", + "--model", "meta-llama/Meta-Llama-3.1-8B-Instruct", "--port", "8200", "--gpu-memory-utilization", "0.8" ] decode_env = os.environ.copy() @@ -61,6 +58,7 @@ def setup_servers(): prefill_proc.wait() decode_proc.wait() + # Helper function to wait for server def wait_for_server(port, timeout=120): start_time = time.time() @@ -73,35 +71,30 @@ def wait_for_server(port, timeout=120): time.sleep(1) return False + # Test function to send curl requests and validate responses -@pytest.mark.parametrize("prompt", [ - "San Francisco is a", - "Santa Clara is a" -]) +@pytest.mark.parametrize("prompt", ["San Francisco is a", "Santa Clara is a"]) def test_disaggregated_prefilling(prompt): # Send to prefill - response = requests.post( - "http://localhost:8100/v1/completions", - headers={"Content-Type": "application/json"}, - json={ - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", - "prompt": prompt, - "max_tokens": 1, - "temperature": 0 - } - ) + response = requests.post("http://localhost:8100/v1/completions", + headers={"Content-Type": "application/json"}, + json={ + "model": + "meta-llama/Meta-Llama-3.1-8B-Instruct", + "prompt": prompt, + "max_tokens": 1, + "temperature": 0 + }) assert response.status_code == 200 # Send to decode - response = requests.post( - "http://localhost:8200/v1/completions", - headers={"Content-Type": "application/json"}, - json={ - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", - "prompt": prompt, - "max_tokens": 10, - "temperature": 0 - } - ) + response = requests.post("http://localhost:8200/v1/completions", + headers={"Content-Type": "application/json"}, + json={ + "model": + "meta-llama/Meta-Llama-3.1-8B-Instruct", + "prompt": prompt, + "max_tokens": 10, + "temperature": 0 + }) assert response.status_code == 200 - \ No newline at end of file diff --git a/tests/kv_transfer/module_test.py b/tests/kv_transfer/module_test.py index 5c0aeb04b43fa..10fb19a3128e2 100644 --- a/tests/kv_transfer/module_test.py +++ b/tests/kv_transfer/module_test.py @@ -1,21 +1,25 @@ import subprocess -import pytest import sys + +import pytest import torch + def run_python_script(script_name, timeout): try: # Start both processes asynchronously using Popen process0 = subprocess.Popen( [sys.executable, script_name], - env={"RANK": "0"}, # Set the RANK environment variable for process 0 + env={"RANK": + "0"}, # Set the RANK environment variable for process 0 stdout=sys.stdout, # Pipe stdout to current stdout stderr=sys.stderr, # Pipe stderr to current stderr ) - + process1 = subprocess.Popen( [sys.executable, script_name], - env={"RANK": "1"}, # Set the RANK environment variable for process 1 + env={"RANK": + "1"}, # Set the RANK environment variable for process 1 stdout=sys.stdout, # Pipe stdout to current stdout stderr=sys.stderr, # Pipe stderr to current stderr ) @@ -26,9 +30,11 @@ def run_python_script(script_name, timeout): # Check the return status of both processes if process0.returncode != 0: - pytest.fail(f"Test {script_name} failed for RANK=0 with return code {process0.returncode}") + pytest.fail( + f"Test {script_name} failed for RANK=0, {process0.returncode}") if process1.returncode != 0: - pytest.fail(f"Test {script_name} failed for RANK=1 with return code {process1.returncode}") + pytest.fail( + f"Test {script_name} failed for RANK=1, {process1.returncode}") except subprocess.TimeoutExpired: # If either process times out, terminate both and fail the test @@ -38,15 +44,20 @@ def run_python_script(script_name, timeout): except Exception as e: pytest.fail(f"Test {script_name} failed with error: {str(e)}") + # Define the test cases using pytest's parametrize -@pytest.mark.parametrize("script_name,timeout", [ - ("test_lookup_buffer.py", 60), # Second test case with a 60-second timeout - ("test_send_recv.py", 120) # First test case with a 120-second timeout -]) +@pytest.mark.parametrize( + "script_name,timeout", + [ + ("test_lookup_buffer.py", + 60), # Second test case with a 60-second timeout + ("test_send_recv.py", 120) # First test case with a 120-second timeout + ]) def test_run_python_script(script_name, timeout): # Check the number of GPUs if torch.cuda.device_count() < 2: - pytest.skip(f"Skipping test {script_name} because fewer than 2 GPUs are available") - + pytest.skip( + f"Skipping test {script_name} because <2 GPUs are available") + # Run the test if there are at least 2 GPUs run_python_script(script_name, timeout) diff --git a/tests/kv_transfer/test_send_recv.py b/tests/kv_transfer/test_send_recv.py index f6da7f88d5f5c..ff771f34c0325 100644 --- a/tests/kv_transfer/test_send_recv.py +++ b/tests/kv_transfer/test_send_recv.py @@ -134,6 +134,6 @@ def latency_test(my_rank, pipe, nelement, ntensor): torch.manual_seed(0) test_run(my_rank, pipe) stress_test(my_rank, pipe) - + # Use this function if you want to test the latency of pipe impl. # latency_test(my_rank, pipe, 1024 * 8 * 128, 80) From 7c13e03847a37417277483def8827cc2190749f8 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Mon, 16 Sep 2024 01:18:19 +0000 Subject: [PATCH 215/303] constraint the model length --- tests/kv_transfer/disagg_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/kv_transfer/disagg_test.py b/tests/kv_transfer/disagg_test.py index fa6a527574cf4..fffd9ab6f42a7 100644 --- a/tests/kv_transfer/disagg_test.py +++ b/tests/kv_transfer/disagg_test.py @@ -25,7 +25,7 @@ def setup_servers(): prefill_cmd = [ sys.executable, "-m", "vllm.entrypoints.openai.api_server", "-tp", "2", "--model", "meta-llama/Meta-Llama-3.1-8B-Instruct", "--port", "8100", - "--gpu-memory-utilization", "0.8" + "--gpu-memory-utilization", "0.8", "--max-model-len", "1000", ] prefill_env = os.environ.copy() prefill_env["VLLM_DISAGG_PREFILL_ROLE"] = "prefill" @@ -36,7 +36,7 @@ def setup_servers(): decode_cmd = [ sys.executable, "-m", "vllm.entrypoints.openai.api_server", "-tp", "2", "--model", "meta-llama/Meta-Llama-3.1-8B-Instruct", "--port", "8200", - "--gpu-memory-utilization", "0.8" + "--gpu-memory-utilization", "0.8", "--max-model-len", "1000", ] decode_env = os.environ.copy() decode_env["VLLM_DISAGG_PREFILL_ROLE"] = "decode" From cf5b84c2a9916618c3b5bb1f304dc0f2ff27a471 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Mon, 16 Sep 2024 01:25:17 +0000 Subject: [PATCH 216/303] adjust path --- tests/kv_transfer/module_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/kv_transfer/module_test.py b/tests/kv_transfer/module_test.py index 10fb19a3128e2..355461919cd7c 100644 --- a/tests/kv_transfer/module_test.py +++ b/tests/kv_transfer/module_test.py @@ -6,6 +6,7 @@ def run_python_script(script_name, timeout): + script_name = f'kv_transfer/{script_name}' try: # Start both processes asynchronously using Popen process0 = subprocess.Popen( From eb751d642b3bbb1a7f215754611f45bba298a54f Mon Sep 17 00:00:00 2001 From: ApostaC Date: Mon, 16 Sep 2024 01:25:32 +0000 Subject: [PATCH 217/303] add disagg prefill test to test pipeline --- .buildkite/test-pipeline.yaml | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 9b0cb6663a55b..da79fd86b767d 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -390,6 +390,18 @@ steps: - pytest -v -s distributed/test_pp_cudagraph.py - pytest -v -s distributed/test_pipeline_parallel.py +- label: Disaggregated Prefill Test # 4min + working_dir: "/vllm-workspace/tests" + num_gpus: 4 + source_file_dependencies: + - vllm/distributed/parallel_state.py + - vllm/distributed/kv_transfer + - vllm/worker/worker_base.py + - vllm/worker/model_runner.py + commands: + - pytest -v -s kv_transfer/module_test.py + - pytest -v -s kv_transfer/disagg_test.py + - label: LoRA Long Context (Distributed) # 11min # This test runs llama 13B, so it is required to run on 4 GPUs. num_gpus: 4 From 1e23e99ecf68ce7ef97cdbf30340a0de8ad0bc9b Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 15 Sep 2024 21:15:00 +0000 Subject: [PATCH 218/303] use new round robin proxy in performance benchmark --- benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh index dde9a80b59b37..715fe56d6c597 100644 --- a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh @@ -19,7 +19,6 @@ kill_gpu_processes() { # kill all processes on GPU. pkill -f pt_main_thread pkill -f python3 - pkill -f round_robin_proxy.sh ps -e | grep pt_main_thread | awk '{print $1}' | xargs kill -9 for port in 8000 8100 8200; do lsof -t -i:$port | xargs -r kill -9; done sleep 1 @@ -61,7 +60,7 @@ launch_chunked_prefill() { --gpu-memory-utilization 0.8 & wait_for_server 8100 wait_for_server 8200 - bash round_robin_proxy.sh & + python3 round_robin_proxy.py & sleep 1 } @@ -149,7 +148,7 @@ main() { mkdir results default_qps=10 - default_output_len=150 + default_output_len=10 export VLLM_LOGGING_LEVEL=DEBUG export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') From b4225f80eb1fa8934ab68fbacabfc01a8e081f25 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 15 Sep 2024 21:19:42 +0000 Subject: [PATCH 219/303] update --- .../disagg_benchmarks/round_robin_proxy.py | 117 ++++++------------ 1 file changed, 40 insertions(+), 77 deletions(-) diff --git a/benchmarks/disagg_benchmarks/round_robin_proxy.py b/benchmarks/disagg_benchmarks/round_robin_proxy.py index 04a30f774670a..8751e24a08d33 100644 --- a/benchmarks/disagg_benchmarks/round_robin_proxy.py +++ b/benchmarks/disagg_benchmarks/round_robin_proxy.py @@ -3,92 +3,55 @@ from aiohttp import web import itertools -class AsyncRoundRobinProxy: - def __init__(self, backend_ports): - self.backend_ports = itertools.cycle(backend_ports) - self.session = None - - async def start(self): - self.session = aiohttp.ClientSession() - - async def stop(self): - if self.session: - await self.session.close() +class RoundRobinProxy: + def __init__(self, target_ports): + self.target_ports = target_ports + self.port_cycle = itertools.cycle(self.target_ports) async def handle_request(self, request): - backend_port = next(self.backend_ports) - print("forwarding to port", backend_port) - backend_url = f"http://localhost:{backend_port}{request.path_qs}" - - try: - async with self.session.request( - method=request.method, - url=backend_url, - headers=request.headers, - data=await request.read() - ) as backend_response: - response = web.StreamResponse( - status=backend_response.status, - headers=backend_response.headers - ) - await response.prepare(request) - - async for chunk in backend_response.content.iter_any(): - await response.write(chunk) - - await response.write_eof() - return response - - except aiohttp.ClientError as e: - return web.Response(text=f"Backend error: {str(e)}", status=502) - -async def run_backend(port): - async def handle(request): - if request.path == '/stream': - response = web.StreamResponse( - status=200, - headers={'Content-Type': 'text/plain'} - ) - await response.prepare(request) - for i in range(10): - await response.write(f"Chunk {i}\n".encode()) - await asyncio.sleep(0.5) # Simulate delay between chunks - return response - else: - return web.Response(text=f"Response from backend on port {port}") - - app = web.Application() - app.router.add_route('*', '/{tail:.*}', handle) - runner = web.AppRunner(app) - await runner.setup() - site = web.TCPSite(runner, 'localhost', port) - await site.start() - print(f"Backend running on http://localhost:{port}") + target_port = next(self.port_cycle) + target_url = f"http://localhost:{target_port}{request.path_qs}" + + async with aiohttp.ClientSession() as session: + try: + # Forward the request + async with session.request( + method=request.method, + url=target_url, + headers=request.headers, + data=request.content, + ) as response: + # Start sending the response + resp = web.StreamResponse( + status=response.status, + headers=response.headers + ) + await resp.prepare(request) + + # Stream the response content + async for chunk in response.content.iter_any(): + await resp.write(chunk) + + await resp.write_eof() + return resp + + except Exception as e: + return web.Response(text=f"Error: {str(e)}", status=500) async def main(): - proxy = AsyncRoundRobinProxy([8100, 8200]) - await proxy.start() - + proxy = RoundRobinProxy([8100, 8200]) app = web.Application() - app.router.add_route('*', '/{tail:.*}', proxy.handle_request) + app.router.add_route('*', '/{path:.*}', proxy.handle_request) runner = web.AppRunner(app) await runner.setup() site = web.TCPSite(runner, 'localhost', 8000) + await site.start() - await asyncio.gather( - site.start(), - run_backend(8100), - run_backend(8200) - ) - - print("Proxy running on http://localhost:8000") - - try: - await asyncio.Future() # Run forever - finally: - await proxy.stop() - await runner.cleanup() + print("Proxy server started on http://localhost:8000") + + # Keep the server running + await asyncio.Event().wait() -if __name__ == "__main__": +if __name__ == '__main__': asyncio.run(main()) \ No newline at end of file From fa4785788f4dfcf8e72f0c31923ba28f1f53c132 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 15 Sep 2024 22:13:26 +0000 Subject: [PATCH 220/303] update benchmarking script --- .../analyze_benchmark_results.py | 48 ------------------- .../disagg_performance_benchmark.sh | 17 +++---- 2 files changed, 9 insertions(+), 56 deletions(-) delete mode 100644 benchmarks/disagg_benchmarks/analyze_benchmark_results.py diff --git a/benchmarks/disagg_benchmarks/analyze_benchmark_results.py b/benchmarks/disagg_benchmarks/analyze_benchmark_results.py deleted file mode 100644 index 4b675c675d25f..0000000000000 --- a/benchmarks/disagg_benchmarks/analyze_benchmark_results.py +++ /dev/null @@ -1,48 +0,0 @@ - -import argparse -import json -import yaml -import os -from pathlib import Path - -def load(path): - - with open(str(path), 'r') as f: - return json.loads(f.read()) - -def main(args): - - results = Path(args.results_folder) - - chunk = load(results / "chunked_prefill_tp4.json") - prefill = load(results / "disagg_prefill_tp4.json") - decode = load(results / "disagg_decode_tp4.json") - - ttft_ratio = chunk["mean_ttft_ms"] / prefill["mean_ttft_ms"] - itl_ratio = chunk["mean_itl_ms"] / decode["mean_itl_ms"] - prefill_decode_ratio = prefill["mean_ttft_ms"] / (decode["mean_itl_ms"] * args.output_len) - - with open(results / args.output_file, 'a') as f: - f.write(yaml.dump([{ - 'qps': args.qps, - 'output_len': args.output_len, - 'prefill_decode_ratio': prefill_decode_ratio, - 'ttft_ratio': ttft_ratio, - 'itl_ratio': itl_ratio, - "chunk_ttft": chunk["mean_ttft_ms"], - "chunk_itl": chunk["mean_itl_ms"], - "disagg_ttft": prefill["mean_ttft_ms"], - "disagg_itl": decode["mean_itl_ms"] - }])) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Analyze benchmark results") - parser.add_argument("--results-folder", required=True, help="Path to the results folder") - parser.add_argument("--output-len", type=int, required=True, help="Target output length") - parser.add_argument("--qps", type=int, required=True, help="Target QPS") - parser.add_argument("--output-file", type=str, default="chunk_vs_disagg.yaml") - - args = parser.parse_args() - main(args) - \ No newline at end of file diff --git a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh index 715fe56d6c597..734679660c233 100644 --- a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh @@ -43,7 +43,7 @@ launch_chunked_prefill() { --model $model \ --port 8100 \ -tp 4 \ - --max-model-len 30000 \ + --max-model-len 10000 \ --disable-log-stats \ --disable-log-requests \ --enable-chunked-prefill \ @@ -53,7 +53,7 @@ launch_chunked_prefill() { --model $model \ --port 8200 \ -tp 4 \ - --max-model-len 30000 \ + --max-model-len 10000 \ --disable-log-stats \ --disable-log-requests \ --enable-chunked-prefill \ @@ -73,7 +73,7 @@ launch_disagg_prefill() { --model $model \ --port 8100 \ -tp 4 \ - --max-model-len 30000 \ + --max-model-len 10000 \ --disable-log-stats \ --disable-log-requests \ --gpu-memory-utilization 0.8 & @@ -82,7 +82,7 @@ launch_disagg_prefill() { --model $model \ --port 8200 \ -tp 4 \ - --max-model-len 30000 \ + --max-model-len 10000 \ --disable-log-stats \ --disable-log-requests \ --gpu-memory-utilization 0.8 & @@ -98,10 +98,10 @@ benchmark() { model="meta-llama/Meta-Llama-3.1-70B-Instruct" dataset_name="sonnet" dataset_path="../sonnet_4x.txt" - num_prompts=400 + num_prompts=200 qps=$1 prefix_len=50 - input_len=2048 + input_len=1024 output_len=$2 tag=$3 @@ -131,7 +131,7 @@ main() { (which jq) || (apt-get -y install jq) (which socat) || (apt-get -y install socat) - pip install quart httpx + pip install quart httpx matplotlib aiohttp cd "$(dirname "$0")" @@ -147,7 +147,6 @@ main() { rm -rf results mkdir results - default_qps=10 default_output_len=10 export VLLM_LOGGING_LEVEL=DEBUG @@ -165,6 +164,8 @@ main() { done kill_gpu_processes + python3 visualize_benchmark_results.py + } From 46f82a4cfa03b86453d465ffc883a26c66e519d3 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 15 Sep 2024 22:18:40 +0000 Subject: [PATCH 221/303] revert changes in model_runner.py --- no change needed for disagg prefill --- vllm/worker/model_runner.py | 39 +++++++++++++------------------------ 1 file changed, 13 insertions(+), 26 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 8c4899a8b7f50..447d303a57fd8 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -14,7 +14,6 @@ import torch.distributed import torch.nn as nn - import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention.backends.abstract import AttentionState @@ -1545,30 +1544,21 @@ def execute_model( model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_start.record() - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(multi_modal_kwargs, - device=self.device), - **seqlen_agnostic_kwargs) - + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), + **seqlen_agnostic_kwargs) + if (self.observability_config is not None and self.observability_config.collect_model_forward_time): model_forward_end.record() - return hidden_or_intermediate_states - - @torch.inference_mode() - def postprocess_model( - self, - model_input, - hidden_or_intermediate_states, - - ): + # Compute the logits in the last pipeline stage. if not get_pp_group().is_last_rank: if (self.is_driver_worker and hidden_or_intermediate_states is not None @@ -1586,7 +1576,7 @@ def postprocess_model( hidden_or_intermediate_states.tensors["model_forward_time"] = ( torch.tensor(model_forward_time + orig_model_forward_time)) return hidden_or_intermediate_states - + logits = self.model.compute_logits(hidden_or_intermediate_states, model_input.sampling_metadata) @@ -1618,7 +1608,6 @@ def postprocess_model( output.model_forward_time = (orig_model_forward_time + model_forward_time) - decode_meta = model_input.attn_metadata.decode_metadata if self.return_hidden_states: # we only need to pass hidden states of most recent token assert model_input.sampling_metadata is not None @@ -1635,9 +1624,7 @@ def postprocess_model( output.hidden_states = hidden_states return [output] - - - + class CUDAGraphRunner: @@ -1808,4 +1795,4 @@ def _get_max_graph_batch_size(max_num_seqs: int) -> int: if padded_size in _BATCH_SIZES_TO_CAPTURE: return padded_size assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1] - return _BATCH_SIZES_TO_CAPTURE[-1] + return _BATCH_SIZES_TO_CAPTURE[-1] \ No newline at end of file From 8d7bb78895c7c42648dbabeff6dfd85318aa3924 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 15 Sep 2024 22:21:02 +0000 Subject: [PATCH 222/303] no I was wrong --- vllm/worker/model_runner.py | 56 ++++++++++++++++++------------------- 1 file changed, 27 insertions(+), 29 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 447d303a57fd8..ab38302b3321a 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -14,6 +14,7 @@ import torch.distributed import torch.nn as nn + import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention.backends.abstract import AttentionState @@ -54,6 +55,7 @@ _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_tensor_dict, _init_sampling_metadata_from_tensor_dict, dump_input_when_exception) +from vllm import _custom_ops as ops if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -1544,21 +1546,30 @@ def execute_model( model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_start.record() - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(multi_modal_kwargs, - device=self.device), - **seqlen_agnostic_kwargs) + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), + **seqlen_agnostic_kwargs) + if (self.observability_config is not None and self.observability_config.collect_model_forward_time): model_forward_end.record() - # Compute the logits in the last pipeline stage. + return hidden_or_intermediate_states + + @torch.inference_mode() + def postprocess_model( + self, + model_input, + hidden_or_intermediate_states, + + ): if not get_pp_group().is_last_rank: if (self.is_driver_worker and hidden_or_intermediate_states is not None @@ -1576,7 +1587,7 @@ def execute_model( hidden_or_intermediate_states.tensors["model_forward_time"] = ( torch.tensor(model_forward_time + orig_model_forward_time)) return hidden_or_intermediate_states - + logits = self.model.compute_logits(hidden_or_intermediate_states, model_input.sampling_metadata) @@ -1591,23 +1602,8 @@ def execute_model( logits=logits, sampling_metadata=model_input.sampling_metadata, ) - if (self.observability_config is not None - and self.observability_config.collect_model_forward_time - and output is not None): - model_forward_end.synchronize() - model_forward_time = model_forward_start.elapsed_time( - model_forward_end) - orig_model_forward_time = 0.0 - if intermediate_tensors is not None: - orig_model_forward_time = intermediate_tensors.tensors.get( - "model_forward_time", torch.tensor(0.0)).item() - # If there are multiple workers, we are still tracking the latency - # from the start time of the driver worker to the end time of the - # driver worker. The model forward time will then end up covering - # the communication time as well. - output.model_forward_time = (orig_model_forward_time + - model_forward_time) + decode_meta = model_input.attn_metadata.decode_metadata if self.return_hidden_states: # we only need to pass hidden states of most recent token assert model_input.sampling_metadata is not None @@ -1624,7 +1620,9 @@ def execute_model( output.hidden_states = hidden_states return [output] - + + + class CUDAGraphRunner: @@ -1795,4 +1793,4 @@ def _get_max_graph_batch_size(max_num_seqs: int) -> int: if padded_size in _BATCH_SIZES_TO_CAPTURE: return padded_size assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1] - return _BATCH_SIZES_TO_CAPTURE[-1] \ No newline at end of file + return _BATCH_SIZES_TO_CAPTURE[-1] From b5f9db5a45cab2a21a9a53b75318fbbc85a28e10 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 15 Sep 2024 22:23:20 +0000 Subject: [PATCH 223/303] update benchmark --- benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh index 734679660c233..1da5669dd1cd0 100644 --- a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh @@ -147,7 +147,7 @@ main() { rm -rf results mkdir results - default_output_len=10 + default_output_len=6 export VLLM_LOGGING_LEVEL=DEBUG export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') From 0fc00918aa47dcd7d00f55940f72ffefcf666c15 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 15 Sep 2024 22:23:48 +0000 Subject: [PATCH 224/303] remove sonnet 4x --- it can be automatically generated via benchmarking script --- benchmarks/sonnet_4x.txt | 2070 -------------------------------------- 1 file changed, 2070 deletions(-) delete mode 100644 benchmarks/sonnet_4x.txt diff --git a/benchmarks/sonnet_4x.txt b/benchmarks/sonnet_4x.txt deleted file mode 100644 index 02f39a9fb14fb..0000000000000 --- a/benchmarks/sonnet_4x.txt +++ /dev/null @@ -1,2070 +0,0 @@ - -FROM fairest creatures we desire increase, -That thereby beauty's rose might never die, -But as the riper should by time decease, -His tender heir might bear his memory: -But thou, contracted to thine own bright eyes, -Feed'st thy light'st flame with self-substantial fuel, -Making a famine where abundance lies, -Thyself thy foe, to thy sweet self too cruel. -Thou that art now the world's fresh ornament -And only herald to the gaudy spring, -Within thine own bud buriest thy content -And, tender churl, makest waste in niggarding. -Pity the world, or else this glutton be, -To eat the world's due, by the grave and thee. -When forty winters shall beseige thy brow, -And dig deep trenches in thy beauty's field, -Thy youth's proud livery, so gazed on now, -Will be a tatter'd weed, of small worth held: -Then being ask'd where all thy beauty lies, -Where all the treasure of thy lusty days, -To say, within thine own deep-sunken eyes, -Were an all-eating shame and thriftless praise. -How much more praise deserved thy beauty's use, -If thou couldst answer 'This fair child of mine -Shall sum my count and make my old excuse,' -Proving his beauty by succession thine! -This were to be new made when thou art old, -And see thy blood warm when thou feel'st it cold. -Look in thy glass, and tell the face thou viewest -Now is the time that face should form another; -Whose fresh repair if now thou not renewest, -Thou dost beguile the world, unbless some mother. -For where is she so fair whose unear'd womb -Disdains the tillage of thy husbandry? -Or who is he so fond will be the tomb -Of his self-love, to stop posterity? -Thou art thy mother's glass, and she in thee -Calls back the lovely April of her prime: -So thou through windows of thine age shall see -Despite of wrinkles this thy golden time. -But if thou live, remember'd not to be, -Die single, and thine image dies with thee. -Unthrifty loveliness, why dost thou spend -Upon thyself thy beauty's legacy? -Nature's bequest gives nothing but doth lend, -And being frank she lends to those are free. -Then, beauteous niggard, why dost thou abuse -The bounteous largess given thee to give? -Profitless usurer, why dost thou use -So great a sum of sums, yet canst not live? -For having traffic with thyself alone, -Thou of thyself thy sweet self dost deceive. -Then how, when nature calls thee to be gone, -What acceptable audit canst thou leave? -Thy unused beauty must be tomb'd with thee, -Which, used, lives th' executor to be. -Those hours, that with gentle work did frame -The lovely gaze where every eye doth dwell, -Will play the tyrants to the very same -And that unfair which fairly doth excel: -For never-resting time leads summer on -To hideous winter and confounds him there; -Sap cheque'd with frost and lusty leaves quite gone, -Beauty o'ersnow'd and bareness every where: -Then, were not summer's distillation left, -A liquid prisoner pent in walls of glass, -Beauty's effect with beauty were bereft, -Nor it nor no remembrance what it was: -But flowers distill'd though they with winter meet, -Leese but their show; their substance still lives sweet. -Then let not winter's ragged hand deface -In thee thy summer, ere thou be distill'd: -Make sweet some vial; treasure thou some place -With beauty's treasure, ere it be self-kill'd. -That use is not forbidden usury, -Which happies those that pay the willing loan; -That's for thyself to breed another thee, -Or ten times happier, be it ten for one; -Ten times thyself were happier than thou art, -If ten of thine ten times refigured thee: -Then what could death do, if thou shouldst depart, -Leaving thee living in posterity? -Be not self-will'd, for thou art much too fair -To be death's conquest and make worms thine heir. -Lo! in the orient when the gracious light -Lifts up his burning head, each under eye -Doth homage to his new-appearing sight, -Serving with looks his sacred majesty; -And having climb'd the steep-up heavenly hill, -Resembling strong youth in his middle age, -yet mortal looks adore his beauty still, -Attending on his golden pilgrimage; -But when from highmost pitch, with weary car, -Like feeble age, he reeleth from the day, -The eyes, 'fore duteous, now converted are -From his low tract and look another way: -So thou, thyself out-going in thy noon, -Unlook'd on diest, unless thou get a son. -Music to hear, why hear'st thou music sadly? -Sweets with sweets war not, joy delights in joy. -Why lovest thou that which thou receivest not gladly, -Or else receivest with pleasure thine annoy? -If the true concord of well-tuned sounds, -By unions married, do offend thine ear, -They do but sweetly chide thee, who confounds -In singleness the parts that thou shouldst bear. -Mark how one string, sweet husband to another, -Strikes each in each by mutual ordering, -Resembling sire and child and happy mother -Who all in one, one pleasing note do sing: -Whose speechless song, being many, seeming one, -Sings this to thee: 'thou single wilt prove none.' -Is it for fear to wet a widow's eye -That thou consumest thyself in single life? -Ah! if thou issueless shalt hap to die. -The world will wail thee, like a makeless wife; -The world will be thy widow and still weep -That thou no form of thee hast left behind, -When every private widow well may keep -By children's eyes her husband's shape in mind. -Look, what an unthrift in the world doth spend -Shifts but his place, for still the world enjoys it; -But beauty's waste hath in the world an end, -And kept unused, the user so destroys it. -No love toward others in that bosom sits -That on himself such murderous shame commits. -For shame! deny that thou bear'st love to any, -Who for thyself art so unprovident. -Grant, if thou wilt, thou art beloved of many, -But that thou none lovest is most evident; -For thou art so possess'd with murderous hate -That 'gainst thyself thou stick'st not to conspire. -Seeking that beauteous roof to ruinate -Which to repair should be thy chief desire. -O, change thy thought, that I may change my mind! -Shall hate be fairer lodged than gentle love? -Be, as thy presence is, gracious and kind, -Or to thyself at least kind-hearted prove: -Make thee another self, for love of me, -That beauty still may live in thine or thee. -As fast as thou shalt wane, so fast thou growest -In one of thine, from that which thou departest; -And that fresh blood which youngly thou bestowest -Thou mayst call thine when thou from youth convertest. -Herein lives wisdom, beauty and increase: -Without this, folly, age and cold decay: -If all were minded so, the times should cease -And threescore year would make the world away. -Let those whom Nature hath not made for store, -Harsh featureless and rude, barrenly perish: -Look, whom she best endow'd she gave the more; -Which bounteous gift thou shouldst in bounty cherish: -She carved thee for her seal, and meant thereby -Thou shouldst print more, not let that copy die. -When I do count the clock that tells the time, -And see the brave day sunk in hideous night; -When I behold the violet past prime, -And sable curls all silver'd o'er with white; -When lofty trees I see barren of leaves -Which erst from heat did canopy the herd, -And summer's green all girded up in sheaves -Borne on the bier with white and bristly beard, -Then of thy beauty do I question make, -That thou among the wastes of time must go, -Since sweets and beauties do themselves forsake -And die as fast as they see others grow; -And nothing 'gainst Time's scythe can make defence -Save breed, to brave him when he takes thee hence. -O, that you were yourself! but, love, you are -No longer yours than you yourself here live: -Against this coming end you should prepare, -And your sweet semblance to some other give. -So should that beauty which you hold in lease -Find no determination: then you were -Yourself again after yourself's decease, -When your sweet issue your sweet form should bear. -Who lets so fair a house fall to decay, -Which husbandry in honour might uphold -Against the stormy gusts of winter's day -And barren rage of death's eternal cold? -O, none but unthrifts! Dear my love, you know -You had a father: let your son say so. -Not from the stars do I my judgment pluck; -And yet methinks I have astronomy, -But not to tell of good or evil luck, -Of plagues, of dearths, or seasons' quality; -Nor can I fortune to brief minutes tell, -Pointing to each his thunder, rain and wind, -Or say with princes if it shall go well, -By oft predict that I in heaven find: -But from thine eyes my knowledge I derive, -And, constant stars, in them I read such art -As truth and beauty shall together thrive, -If from thyself to store thou wouldst convert; -Or else of thee this I prognosticate: -Thy end is truth's and beauty's doom and date. -When I consider every thing that grows -Holds in perfection but a little moment, -That this huge stage presenteth nought but shows -Whereon the stars in secret influence comment; -When I perceive that men as plants increase, -Cheered and cheque'd even by the self-same sky, -Vaunt in their youthful sap, at height decrease, -And wear their brave state out of memory; -Then the conceit of this inconstant stay -Sets you most rich in youth before my sight, -Where wasteful Time debateth with Decay, -To change your day of youth to sullied night; -And all in war with Time for love of you, -As he takes from you, I engraft you new. -But wherefore do not you a mightier way -Make war upon this bloody tyrant, Time? -And fortify yourself in your decay -With means more blessed than my barren rhyme? -Now stand you on the top of happy hours, -And many maiden gardens yet unset -With virtuous wish would bear your living flowers, -Much liker than your painted counterfeit: -So should the lines of life that life repair, -Which this, Time's pencil, or my pupil pen, -Neither in inward worth nor outward fair, -Can make you live yourself in eyes of men. -To give away yourself keeps yourself still, -And you must live, drawn by your own sweet skill. -Who will believe my verse in time to come, -If it were fill'd with your most high deserts? -Though yet, heaven knows, it is but as a tomb -Which hides your life and shows not half your parts. -If I could write the beauty of your eyes -And in fresh numbers number all your graces, -The age to come would say 'This poet lies: -Such heavenly touches ne'er touch'd earthly faces.' -So should my papers yellow'd with their age -Be scorn'd like old men of less truth than tongue, -And your true rights be term'd a poet's rage -And stretched metre of an antique song: -But were some child of yours alive that time, -You should live twice; in it and in my rhyme. -Shall I compare thee to a summer's day? -Thou art more lovely and more temperate: -Rough winds do shake the darling buds of May, -And summer's lease hath all too short a date: -Sometime too hot the eye of heaven shines, -And often is his gold complexion dimm'd; -And every fair from fair sometime declines, -By chance or nature's changing course untrimm'd; -But thy eternal summer shall not fade -Nor lose possession of that fair thou owest; -Nor shall Death brag thou wander'st in his shade, -When in eternal lines to time thou growest: -So long as men can breathe or eyes can see, -So long lives this and this gives life to thee. -Devouring Time, blunt thou the lion's paws, -And make the earth devour her own sweet brood; -Pluck the keen teeth from the fierce tiger's jaws, -And burn the long-lived phoenix in her blood; -Make glad and sorry seasons as thou fleets, -And do whate'er thou wilt, swift-footed Time, -To the wide world and all her fading sweets; -But I forbid thee one most heinous crime: -O, carve not with thy hours my love's fair brow, -Nor draw no lines there with thine antique pen; -Him in thy course untainted do allow -For beauty's pattern to succeeding men. -Yet, do thy worst, old Time: despite thy wrong, -My love shall in my verse ever live young. -A woman's face with Nature's own hand painted -Hast thou, the master-mistress of my passion; -A woman's gentle heart, but not acquainted -With shifting change, as is false women's fashion; -An eye more bright than theirs, less false in rolling, -Gilding the object whereupon it gazeth; -A man in hue, all 'hues' in his controlling, -Much steals men's eyes and women's souls amazeth. -And for a woman wert thou first created; -Till Nature, as she wrought thee, fell a-doting, -And by addition me of thee defeated, -By adding one thing to my purpose nothing. -But since she prick'd thee out for women's pleasure, -Mine be thy love and thy love's use their treasure. -So is it not with me as with that Muse -Stirr'd by a painted beauty to his verse, -Who heaven itself for ornament doth use -And every fair with his fair doth rehearse -Making a couplement of proud compare, -With sun and moon, with earth and sea's rich gems, -With April's first-born flowers, and all things rare -That heaven's air in this huge rondure hems. -O' let me, true in love, but truly write, -And then believe me, my love is as fair -As any mother's child, though not so bright -As those gold candles fix'd in heaven's air: -Let them say more than like of hearsay well; -I will not praise that purpose not to sell. -My glass shall not persuade me I am old, -So long as youth and thou are of one date; -But when in thee time's furrows I behold, -Then look I death my days should expiate. -For all that beauty that doth cover thee -Is but the seemly raiment of my heart, -Which in thy breast doth live, as thine in me: -How can I then be elder than thou art? -O, therefore, love, be of thyself so wary -As I, not for myself, but for thee will; -Bearing thy heart, which I will keep so chary -As tender nurse her babe from faring ill. -Presume not on thy heart when mine is slain; -Thou gavest me thine, not to give back again. -As an unperfect actor on the stage -Who with his fear is put besides his part, -Or some fierce thing replete with too much rage, -Whose strength's abundance weakens his own heart. -So I, for fear of trust, forget to say -The perfect ceremony of love's rite, -And in mine own love's strength seem to decay, -O'ercharged with burden of mine own love's might. -O, let my books be then the eloquence -And dumb presagers of my speaking breast, -Who plead for love and look for recompense -More than that tongue that more hath more express'd. -O, learn to read what silent love hath writ: -To hear with eyes belongs to love's fine wit. -Mine eye hath play'd the painter and hath stell'd -Thy beauty's form in table of my heart; -My body is the frame wherein 'tis held, -And perspective it is the painter's art. -For through the painter must you see his skill, -To find where your true image pictured lies; -Which in my bosom's shop is hanging still, -That hath his windows glazed with thine eyes. -Now see what good turns eyes for eyes have done: -Mine eyes have drawn thy shape, and thine for me -Are windows to my breast, where-through the sun -Delights to peep, to gaze therein on thee; -Yet eyes this cunning want to grace their art; -They draw but what they see, know not the heart. -Let those who are in favour with their stars -Of public honour and proud titles boast, -Whilst I, whom fortune of such triumph bars, -Unlook'd for joy in that I honour most. -Great princes' favourites their fair leaves spread -But as the marigold at the sun's eye, -And in themselves their pride lies buried, -For at a frown they in their glory die. -The painful warrior famoused for fight, -After a thousand victories once foil'd, -Is from the book of honour razed quite, -And all the rest forgot for which he toil'd: -Then happy I, that love and am beloved -Where I may not remove nor be removed. -Lord of my love, to whom in vassalage -Thy merit hath my duty strongly knit, -To thee I send this written embassage, -To witness duty, not to show my wit: -Duty so great, which wit so poor as mine -May make seem bare, in wanting words to show it, -But that I hope some good conceit of thine -In thy soul's thought, all naked, will bestow it; -Till whatsoever star that guides my moving -Points on me graciously with fair aspect -And puts apparel on my tatter'd loving, -To show me worthy of thy sweet respect: -Then may I dare to boast how I do love thee; -Till then not show my head where thou mayst prove me. -Weary with toil, I haste me to my bed, -The dear repose for limbs with travel tired; -But then begins a journey in my head, -To work my mind, when body's work's expired: -For then my thoughts, from far where I abide, -Intend a zealous pilgrimage to thee, -And keep my drooping eyelids open wide, -Looking on darkness which the blind do see -Save that my soul's imaginary sight -Presents thy shadow to my sightless view, -Which, like a jewel hung in ghastly night, -Makes black night beauteous and her old face new. -Lo! thus, by day my limbs, by night my mind, -For thee and for myself no quiet find. -How can I then return in happy plight, -That am debarr'd the benefit of rest? -When day's oppression is not eased by night, -But day by night, and night by day, oppress'd? -And each, though enemies to either's reign, -Do in consent shake hands to torture me; -The one by toil, the other to complain -How far I toil, still farther off from thee. -I tell the day, to please them thou art bright -And dost him grace when clouds do blot the heaven: -So flatter I the swart-complexion'd night, -When sparkling stars twire not thou gild'st the even. -But day doth daily draw my sorrows longer -And night doth nightly make grief's strength seem stronger. -When, in disgrace with fortune and men's eyes, -I all alone beweep my outcast state -And trouble deal heaven with my bootless cries -And look upon myself and curse my fate, -Wishing me like to one more rich in hope, -Featured like him, like him with friends possess'd, -Desiring this man's art and that man's scope, -With what I most enjoy contented least; -Yet in these thoughts myself almost despising, -Haply I think on thee, and then my state, -Like to the lark at break of day arising -From sullen earth, sings hymns at heaven's gate; -For thy sweet love remember'd such wealth brings -That then I scorn to change my state with kings. -When to the sessions of sweet silent thought -I summon up remembrance of things past, -I sigh the lack of many a thing I sought, -And with old woes new wail my dear time's waste: -Then can I drown an eye, unused to flow, -For precious friends hid in death's dateless night, -And weep afresh love's long since cancell'd woe, -And moan the expense of many a vanish'd sight: -Then can I grieve at grievances foregone, -And heavily from woe to woe tell o'er -The sad account of fore-bemoaned moan, -Which I new pay as if not paid before. -But if the while I think on thee, dear friend, -All losses are restored and sorrows end. -Thy bosom is endeared with all hearts, -Which I by lacking have supposed dead, -And there reigns love and all love's loving parts, -And all those friends which I thought buried. -How many a holy and obsequious tear -Hath dear religious love stol'n from mine eye -As interest of the dead, which now appear -But things removed that hidden in thee lie! -Thou art the grave where buried love doth live, -Hung with the trophies of my lovers gone, -Who all their parts of me to thee did give; -That due of many now is thine alone: -Their images I loved I view in thee, -And thou, all they, hast all the all of me. -If thou survive my well-contented day, -When that churl Death my bones with dust shall cover, -And shalt by fortune once more re-survey -These poor rude lines of thy deceased lover, -Compare them with the bettering of the time, -And though they be outstripp'd by every pen, -Reserve them for my love, not for their rhyme, -Exceeded by the height of happier men. -O, then vouchsafe me but this loving thought: -'Had my friend's Muse grown with this growing age, -A dearer birth than this his love had brought, -To march in ranks of better equipage: -But since he died and poets better prove, -Theirs for their style I'll read, his for his love.' -Full many a glorious morning have I seen -Flatter the mountain-tops with sovereign eye, -Kissing with golden face the meadows green, -Gilding pale streams with heavenly alchemy; -Anon permit the basest clouds to ride -With ugly rack on his celestial face, -And from the forlorn world his visage hide, -Stealing unseen to west with this disgrace: -Even so my sun one early morn did shine -With all triumphant splendor on my brow; -But out, alack! he was but one hour mine; -The region cloud hath mask'd him from me now. -Yet him for this my love no whit disdaineth; -Suns of the world may stain when heaven's sun staineth. -Why didst thou promise such a beauteous day, -And make me travel forth without my cloak, -To let base clouds o'ertake me in my way, -Hiding thy bravery in their rotten smoke? -'Tis not enough that through the cloud thou break, -To dry the rain on my storm-beaten face, -For no man well of such a salve can speak -That heals the wound and cures not the disgrace: -Nor can thy shame give physic to my grief; -Though thou repent, yet I have still the loss: -The offender's sorrow lends but weak relief -To him that bears the strong offence's cross. -Ah! but those tears are pearl which thy love sheds, -And they are rich and ransom all ill deeds. -No more be grieved at that which thou hast done: -Roses have thorns, and silver fountains mud; -Clouds and eclipses stain both moon and sun, -And loathsome canker lives in sweetest bud. -All men make faults, and even I in this, -Authorizing thy trespass with compare, -Myself corrupting, salving thy amiss, -Excusing thy sins more than thy sins are; -For to thy sensual fault I bring in sense-- -Thy adverse party is thy advocate-- -And 'gainst myself a lawful plea commence: -Such civil war is in my love and hate -That I an accessary needs must be -To that sweet thief which sourly robs from me. -Let me confess that we two must be twain, -Although our undivided loves are one: -So shall those blots that do with me remain -Without thy help by me be borne alone. -In our two loves there is but one respect, -Though in our lives a separable spite, -Which though it alter not love's sole effect, -Yet doth it steal sweet hours from love's delight. -I may not evermore acknowledge thee, -Lest my bewailed guilt should do thee shame, -Nor thou with public kindness honour me, -Unless thou take that honour from thy name: -But do not so; I love thee in such sort -As, thou being mine, mine is thy good report. -As a decrepit father takes delight -To see his active child do deeds of youth, -So I, made lame by fortune's dearest spite, -Take all my comfort of thy worth and truth. -For whether beauty, birth, or wealth, or wit, -Or any of these all, or all, or more, -Entitled in thy parts do crowned sit, -I make my love engrafted to this store: -So then I am not lame, poor, nor despised, -Whilst that this shadow doth such substance give -That I in thy abundance am sufficed -And by a part of all thy glory live. -Look, what is best, that best I wish in thee: -This wish I have; then ten times happy me!FROM fairest creatures we desire increase, -That thereby beauty's rose might never die, -But as the riper should by time decease, -His tender heir might bear his memory: -But thou, contracted to thine own bright eyes, -Feed'st thy light'st flame with self-substantial fuel, -Making a famine where abundance lies, -Thyself thy foe, to thy sweet self too cruel. -Thou that art now the world's fresh ornament -And only herald to the gaudy spring, -Within thine own bud buriest thy content -And, tender churl, makest waste in niggarding. -Pity the world, or else this glutton be, -To eat the world's due, by the grave and thee. -When forty winters shall beseige thy brow, -And dig deep trenches in thy beauty's field, -Thy youth's proud livery, so gazed on now, -Will be a tatter'd weed, of small worth held: -Then being ask'd where all thy beauty lies, -Where all the treasure of thy lusty days, -To say, within thine own deep-sunken eyes, -Were an all-eating shame and thriftless praise. -How much more praise deserved thy beauty's use, -If thou couldst answer 'This fair child of mine -Shall sum my count and make my old excuse,' -Proving his beauty by succession thine! -This were to be new made when thou art old, -And see thy blood warm when thou feel'st it cold. -Look in thy glass, and tell the face thou viewest -Now is the time that face should form another; -Whose fresh repair if now thou not renewest, -Thou dost beguile the world, unbless some mother. -For where is she so fair whose unear'd womb -Disdains the tillage of thy husbandry? -Or who is he so fond will be the tomb -Of his self-love, to stop posterity? -Thou art thy mother's glass, and she in thee -Calls back the lovely April of her prime: -So thou through windows of thine age shall see -Despite of wrinkles this thy golden time. -But if thou live, remember'd not to be, -Die single, and thine image dies with thee. -Unthrifty loveliness, why dost thou spend -Upon thyself thy beauty's legacy? -Nature's bequest gives nothing but doth lend, -And being frank she lends to those are free. -Then, beauteous niggard, why dost thou abuse -The bounteous largess given thee to give? -Profitless usurer, why dost thou use -So great a sum of sums, yet canst not live? -For having traffic with thyself alone, -Thou of thyself thy sweet self dost deceive. -Then how, when nature calls thee to be gone, -What acceptable audit canst thou leave? -Thy unused beauty must be tomb'd with thee, -Which, used, lives th' executor to be. -Those hours, that with gentle work did frame -The lovely gaze where every eye doth dwell, -Will play the tyrants to the very same -And that unfair which fairly doth excel: -For never-resting time leads summer on -To hideous winter and confounds him there; -Sap cheque'd with frost and lusty leaves quite gone, -Beauty o'ersnow'd and bareness every where: -Then, were not summer's distillation left, -A liquid prisoner pent in walls of glass, -Beauty's effect with beauty were bereft, -Nor it nor no remembrance what it was: -But flowers distill'd though they with winter meet, -Leese but their show; their substance still lives sweet. -Then let not winter's ragged hand deface -In thee thy summer, ere thou be distill'd: -Make sweet some vial; treasure thou some place -With beauty's treasure, ere it be self-kill'd. -That use is not forbidden usury, -Which happies those that pay the willing loan; -That's for thyself to breed another thee, -Or ten times happier, be it ten for one; -Ten times thyself were happier than thou art, -If ten of thine ten times refigured thee: -Then what could death do, if thou shouldst depart, -Leaving thee living in posterity? -Be not self-will'd, for thou art much too fair -To be death's conquest and make worms thine heir. -Lo! in the orient when the gracious light -Lifts up his burning head, each under eye -Doth homage to his new-appearing sight, -Serving with looks his sacred majesty; -And having climb'd the steep-up heavenly hill, -Resembling strong youth in his middle age, -yet mortal looks adore his beauty still, -Attending on his golden pilgrimage; -But when from highmost pitch, with weary car, -Like feeble age, he reeleth from the day, -The eyes, 'fore duteous, now converted are -From his low tract and look another way: -So thou, thyself out-going in thy noon, -Unlook'd on diest, unless thou get a son. -Music to hear, why hear'st thou music sadly? -Sweets with sweets war not, joy delights in joy. -Why lovest thou that which thou receivest not gladly, -Or else receivest with pleasure thine annoy? -If the true concord of well-tuned sounds, -By unions married, do offend thine ear, -They do but sweetly chide thee, who confounds -In singleness the parts that thou shouldst bear. -Mark how one string, sweet husband to another, -Strikes each in each by mutual ordering, -Resembling sire and child and happy mother -Who all in one, one pleasing note do sing: -Whose speechless song, being many, seeming one, -Sings this to thee: 'thou single wilt prove none.' -Is it for fear to wet a widow's eye -That thou consumest thyself in single life? -Ah! if thou issueless shalt hap to die. -The world will wail thee, like a makeless wife; -The world will be thy widow and still weep -That thou no form of thee hast left behind, -When every private widow well may keep -By children's eyes her husband's shape in mind. -Look, what an unthrift in the world doth spend -Shifts but his place, for still the world enjoys it; -But beauty's waste hath in the world an end, -And kept unused, the user so destroys it. -No love toward others in that bosom sits -That on himself such murderous shame commits. -For shame! deny that thou bear'st love to any, -Who for thyself art so unprovident. -Grant, if thou wilt, thou art beloved of many, -But that thou none lovest is most evident; -For thou art so possess'd with murderous hate -That 'gainst thyself thou stick'st not to conspire. -Seeking that beauteous roof to ruinate -Which to repair should be thy chief desire. -O, change thy thought, that I may change my mind! -Shall hate be fairer lodged than gentle love? -Be, as thy presence is, gracious and kind, -Or to thyself at least kind-hearted prove: -Make thee another self, for love of me, -That beauty still may live in thine or thee. -As fast as thou shalt wane, so fast thou growest -In one of thine, from that which thou departest; -And that fresh blood which youngly thou bestowest -Thou mayst call thine when thou from youth convertest. -Herein lives wisdom, beauty and increase: -Without this, folly, age and cold decay: -If all were minded so, the times should cease -And threescore year would make the world away. -Let those whom Nature hath not made for store, -Harsh featureless and rude, barrenly perish: -Look, whom she best endow'd she gave the more; -Which bounteous gift thou shouldst in bounty cherish: -She carved thee for her seal, and meant thereby -Thou shouldst print more, not let that copy die. -When I do count the clock that tells the time, -And see the brave day sunk in hideous night; -When I behold the violet past prime, -And sable curls all silver'd o'er with white; -When lofty trees I see barren of leaves -Which erst from heat did canopy the herd, -And summer's green all girded up in sheaves -Borne on the bier with white and bristly beard, -Then of thy beauty do I question make, -That thou among the wastes of time must go, -Since sweets and beauties do themselves forsake -And die as fast as they see others grow; -And nothing 'gainst Time's scythe can make defence -Save breed, to brave him when he takes thee hence. -O, that you were yourself! but, love, you are -No longer yours than you yourself here live: -Against this coming end you should prepare, -And your sweet semblance to some other give. -So should that beauty which you hold in lease -Find no determination: then you were -Yourself again after yourself's decease, -When your sweet issue your sweet form should bear. -Who lets so fair a house fall to decay, -Which husbandry in honour might uphold -Against the stormy gusts of winter's day -And barren rage of death's eternal cold? -O, none but unthrifts! Dear my love, you know -You had a father: let your son say so. -Not from the stars do I my judgment pluck; -And yet methinks I have astronomy, -But not to tell of good or evil luck, -Of plagues, of dearths, or seasons' quality; -Nor can I fortune to brief minutes tell, -Pointing to each his thunder, rain and wind, -Or say with princes if it shall go well, -By oft predict that I in heaven find: -But from thine eyes my knowledge I derive, -And, constant stars, in them I read such art -As truth and beauty shall together thrive, -If from thyself to store thou wouldst convert; -Or else of thee this I prognosticate: -Thy end is truth's and beauty's doom and date. -When I consider every thing that grows -Holds in perfection but a little moment, -That this huge stage presenteth nought but shows -Whereon the stars in secret influence comment; -When I perceive that men as plants increase, -Cheered and cheque'd even by the self-same sky, -Vaunt in their youthful sap, at height decrease, -And wear their brave state out of memory; -Then the conceit of this inconstant stay -Sets you most rich in youth before my sight, -Where wasteful Time debateth with Decay, -To change your day of youth to sullied night; -And all in war with Time for love of you, -As he takes from you, I engraft you new. -But wherefore do not you a mightier way -Make war upon this bloody tyrant, Time? -And fortify yourself in your decay -With means more blessed than my barren rhyme? -Now stand you on the top of happy hours, -And many maiden gardens yet unset -With virtuous wish would bear your living flowers, -Much liker than your painted counterfeit: -So should the lines of life that life repair, -Which this, Time's pencil, or my pupil pen, -Neither in inward worth nor outward fair, -Can make you live yourself in eyes of men. -To give away yourself keeps yourself still, -And you must live, drawn by your own sweet skill. -Who will believe my verse in time to come, -If it were fill'd with your most high deserts? -Though yet, heaven knows, it is but as a tomb -Which hides your life and shows not half your parts. -If I could write the beauty of your eyes -And in fresh numbers number all your graces, -The age to come would say 'This poet lies: -Such heavenly touches ne'er touch'd earthly faces.' -So should my papers yellow'd with their age -Be scorn'd like old men of less truth than tongue, -And your true rights be term'd a poet's rage -And stretched metre of an antique song: -But were some child of yours alive that time, -You should live twice; in it and in my rhyme. -Shall I compare thee to a summer's day? -Thou art more lovely and more temperate: -Rough winds do shake the darling buds of May, -And summer's lease hath all too short a date: -Sometime too hot the eye of heaven shines, -And often is his gold complexion dimm'd; -And every fair from fair sometime declines, -By chance or nature's changing course untrimm'd; -But thy eternal summer shall not fade -Nor lose possession of that fair thou owest; -Nor shall Death brag thou wander'st in his shade, -When in eternal lines to time thou growest: -So long as men can breathe or eyes can see, -So long lives this and this gives life to thee. -Devouring Time, blunt thou the lion's paws, -And make the earth devour her own sweet brood; -Pluck the keen teeth from the fierce tiger's jaws, -And burn the long-lived phoenix in her blood; -Make glad and sorry seasons as thou fleets, -And do whate'er thou wilt, swift-footed Time, -To the wide world and all her fading sweets; -But I forbid thee one most heinous crime: -O, carve not with thy hours my love's fair brow, -Nor draw no lines there with thine antique pen; -Him in thy course untainted do allow -For beauty's pattern to succeeding men. -Yet, do thy worst, old Time: despite thy wrong, -My love shall in my verse ever live young. -A woman's face with Nature's own hand painted -Hast thou, the master-mistress of my passion; -A woman's gentle heart, but not acquainted -With shifting change, as is false women's fashion; -An eye more bright than theirs, less false in rolling, -Gilding the object whereupon it gazeth; -A man in hue, all 'hues' in his controlling, -Much steals men's eyes and women's souls amazeth. -And for a woman wert thou first created; -Till Nature, as she wrought thee, fell a-doting, -And by addition me of thee defeated, -By adding one thing to my purpose nothing. -But since she prick'd thee out for women's pleasure, -Mine be thy love and thy love's use their treasure. -So is it not with me as with that Muse -Stirr'd by a painted beauty to his verse, -Who heaven itself for ornament doth use -And every fair with his fair doth rehearse -Making a couplement of proud compare, -With sun and moon, with earth and sea's rich gems, -With April's first-born flowers, and all things rare -That heaven's air in this huge rondure hems. -O' let me, true in love, but truly write, -And then believe me, my love is as fair -As any mother's child, though not so bright -As those gold candles fix'd in heaven's air: -Let them say more than like of hearsay well; -I will not praise that purpose not to sell. -My glass shall not persuade me I am old, -So long as youth and thou are of one date; -But when in thee time's furrows I behold, -Then look I death my days should expiate. -For all that beauty that doth cover thee -Is but the seemly raiment of my heart, -Which in thy breast doth live, as thine in me: -How can I then be elder than thou art? -O, therefore, love, be of thyself so wary -As I, not for myself, but for thee will; -Bearing thy heart, which I will keep so chary -As tender nurse her babe from faring ill. -Presume not on thy heart when mine is slain; -Thou gavest me thine, not to give back again. -As an unperfect actor on the stage -Who with his fear is put besides his part, -Or some fierce thing replete with too much rage, -Whose strength's abundance weakens his own heart. -So I, for fear of trust, forget to say -The perfect ceremony of love's rite, -And in mine own love's strength seem to decay, -O'ercharged with burden of mine own love's might. -O, let my books be then the eloquence -And dumb presagers of my speaking breast, -Who plead for love and look for recompense -More than that tongue that more hath more express'd. -O, learn to read what silent love hath writ: -To hear with eyes belongs to love's fine wit. -Mine eye hath play'd the painter and hath stell'd -Thy beauty's form in table of my heart; -My body is the frame wherein 'tis held, -And perspective it is the painter's art. -For through the painter must you see his skill, -To find where your true image pictured lies; -Which in my bosom's shop is hanging still, -That hath his windows glazed with thine eyes. -Now see what good turns eyes for eyes have done: -Mine eyes have drawn thy shape, and thine for me -Are windows to my breast, where-through the sun -Delights to peep, to gaze therein on thee; -Yet eyes this cunning want to grace their art; -They draw but what they see, know not the heart. -Let those who are in favour with their stars -Of public honour and proud titles boast, -Whilst I, whom fortune of such triumph bars, -Unlook'd for joy in that I honour most. -Great princes' favourites their fair leaves spread -But as the marigold at the sun's eye, -And in themselves their pride lies buried, -For at a frown they in their glory die. -The painful warrior famoused for fight, -After a thousand victories once foil'd, -Is from the book of honour razed quite, -And all the rest forgot for which he toil'd: -Then happy I, that love and am beloved -Where I may not remove nor be removed. -Lord of my love, to whom in vassalage -Thy merit hath my duty strongly knit, -To thee I send this written embassage, -To witness duty, not to show my wit: -Duty so great, which wit so poor as mine -May make seem bare, in wanting words to show it, -But that I hope some good conceit of thine -In thy soul's thought, all naked, will bestow it; -Till whatsoever star that guides my moving -Points on me graciously with fair aspect -And puts apparel on my tatter'd loving, -To show me worthy of thy sweet respect: -Then may I dare to boast how I do love thee; -Till then not show my head where thou mayst prove me. -Weary with toil, I haste me to my bed, -The dear repose for limbs with travel tired; -But then begins a journey in my head, -To work my mind, when body's work's expired: -For then my thoughts, from far where I abide, -Intend a zealous pilgrimage to thee, -And keep my drooping eyelids open wide, -Looking on darkness which the blind do see -Save that my soul's imaginary sight -Presents thy shadow to my sightless view, -Which, like a jewel hung in ghastly night, -Makes black night beauteous and her old face new. -Lo! thus, by day my limbs, by night my mind, -For thee and for myself no quiet find. -How can I then return in happy plight, -That am debarr'd the benefit of rest? -When day's oppression is not eased by night, -But day by night, and night by day, oppress'd? -And each, though enemies to either's reign, -Do in consent shake hands to torture me; -The one by toil, the other to complain -How far I toil, still farther off from thee. -I tell the day, to please them thou art bright -And dost him grace when clouds do blot the heaven: -So flatter I the swart-complexion'd night, -When sparkling stars twire not thou gild'st the even. -But day doth daily draw my sorrows longer -And night doth nightly make grief's strength seem stronger. -When, in disgrace with fortune and men's eyes, -I all alone beweep my outcast state -And trouble deal heaven with my bootless cries -And look upon myself and curse my fate, -Wishing me like to one more rich in hope, -Featured like him, like him with friends possess'd, -Desiring this man's art and that man's scope, -With what I most enjoy contented least; -Yet in these thoughts myself almost despising, -Haply I think on thee, and then my state, -Like to the lark at break of day arising -From sullen earth, sings hymns at heaven's gate; -For thy sweet love remember'd such wealth brings -That then I scorn to change my state with kings. -When to the sessions of sweet silent thought -I summon up remembrance of things past, -I sigh the lack of many a thing I sought, -And with old woes new wail my dear time's waste: -Then can I drown an eye, unused to flow, -For precious friends hid in death's dateless night, -And weep afresh love's long since cancell'd woe, -And moan the expense of many a vanish'd sight: -Then can I grieve at grievances foregone, -And heavily from woe to woe tell o'er -The sad account of fore-bemoaned moan, -Which I new pay as if not paid before. -But if the while I think on thee, dear friend, -All losses are restored and sorrows end. -Thy bosom is endeared with all hearts, -Which I by lacking have supposed dead, -And there reigns love and all love's loving parts, -And all those friends which I thought buried. -How many a holy and obsequious tear -Hath dear religious love stol'n from mine eye -As interest of the dead, which now appear -But things removed that hidden in thee lie! -Thou art the grave where buried love doth live, -Hung with the trophies of my lovers gone, -Who all their parts of me to thee did give; -That due of many now is thine alone: -Their images I loved I view in thee, -And thou, all they, hast all the all of me. -If thou survive my well-contented day, -When that churl Death my bones with dust shall cover, -And shalt by fortune once more re-survey -These poor rude lines of thy deceased lover, -Compare them with the bettering of the time, -And though they be outstripp'd by every pen, -Reserve them for my love, not for their rhyme, -Exceeded by the height of happier men. -O, then vouchsafe me but this loving thought: -'Had my friend's Muse grown with this growing age, -A dearer birth than this his love had brought, -To march in ranks of better equipage: -But since he died and poets better prove, -Theirs for their style I'll read, his for his love.' -Full many a glorious morning have I seen -Flatter the mountain-tops with sovereign eye, -Kissing with golden face the meadows green, -Gilding pale streams with heavenly alchemy; -Anon permit the basest clouds to ride -With ugly rack on his celestial face, -And from the forlorn world his visage hide, -Stealing unseen to west with this disgrace: -Even so my sun one early morn did shine -With all triumphant splendor on my brow; -But out, alack! he was but one hour mine; -The region cloud hath mask'd him from me now. -Yet him for this my love no whit disdaineth; -Suns of the world may stain when heaven's sun staineth. -Why didst thou promise such a beauteous day, -And make me travel forth without my cloak, -To let base clouds o'ertake me in my way, -Hiding thy bravery in their rotten smoke? -'Tis not enough that through the cloud thou break, -To dry the rain on my storm-beaten face, -For no man well of such a salve can speak -That heals the wound and cures not the disgrace: -Nor can thy shame give physic to my grief; -Though thou repent, yet I have still the loss: -The offender's sorrow lends but weak relief -To him that bears the strong offence's cross. -Ah! but those tears are pearl which thy love sheds, -And they are rich and ransom all ill deeds. -No more be grieved at that which thou hast done: -Roses have thorns, and silver fountains mud; -Clouds and eclipses stain both moon and sun, -And loathsome canker lives in sweetest bud. -All men make faults, and even I in this, -Authorizing thy trespass with compare, -Myself corrupting, salving thy amiss, -Excusing thy sins more than thy sins are; -For to thy sensual fault I bring in sense-- -Thy adverse party is thy advocate-- -And 'gainst myself a lawful plea commence: -Such civil war is in my love and hate -That I an accessary needs must be -To that sweet thief which sourly robs from me. -Let me confess that we two must be twain, -Although our undivided loves are one: -So shall those blots that do with me remain -Without thy help by me be borne alone. -In our two loves there is but one respect, -Though in our lives a separable spite, -Which though it alter not love's sole effect, -Yet doth it steal sweet hours from love's delight. -I may not evermore acknowledge thee, -Lest my bewailed guilt should do thee shame, -Nor thou with public kindness honour me, -Unless thou take that honour from thy name: -But do not so; I love thee in such sort -As, thou being mine, mine is thy good report. -As a decrepit father takes delight -To see his active child do deeds of youth, -So I, made lame by fortune's dearest spite, -Take all my comfort of thy worth and truth. -For whether beauty, birth, or wealth, or wit, -Or any of these all, or all, or more, -Entitled in thy parts do crowned sit, -I make my love engrafted to this store: -So then I am not lame, poor, nor despised, -Whilst that this shadow doth such substance give -That I in thy abundance am sufficed -And by a part of all thy glory live. -Look, what is best, that best I wish in thee: -This wish I have; then ten times happy me!FROM fairest creatures we desire increase, -That thereby beauty's rose might never die, -But as the riper should by time decease, -His tender heir might bear his memory: -But thou, contracted to thine own bright eyes, -Feed'st thy light'st flame with self-substantial fuel, -Making a famine where abundance lies, -Thyself thy foe, to thy sweet self too cruel. -Thou that art now the world's fresh ornament -And only herald to the gaudy spring, -Within thine own bud buriest thy content -And, tender churl, makest waste in niggarding. -Pity the world, or else this glutton be, -To eat the world's due, by the grave and thee. -When forty winters shall beseige thy brow, -And dig deep trenches in thy beauty's field, -Thy youth's proud livery, so gazed on now, -Will be a tatter'd weed, of small worth held: -Then being ask'd where all thy beauty lies, -Where all the treasure of thy lusty days, -To say, within thine own deep-sunken eyes, -Were an all-eating shame and thriftless praise. -How much more praise deserved thy beauty's use, -If thou couldst answer 'This fair child of mine -Shall sum my count and make my old excuse,' -Proving his beauty by succession thine! -This were to be new made when thou art old, -And see thy blood warm when thou feel'st it cold. -Look in thy glass, and tell the face thou viewest -Now is the time that face should form another; -Whose fresh repair if now thou not renewest, -Thou dost beguile the world, unbless some mother. -For where is she so fair whose unear'd womb -Disdains the tillage of thy husbandry? -Or who is he so fond will be the tomb -Of his self-love, to stop posterity? -Thou art thy mother's glass, and she in thee -Calls back the lovely April of her prime: -So thou through windows of thine age shall see -Despite of wrinkles this thy golden time. -But if thou live, remember'd not to be, -Die single, and thine image dies with thee. -Unthrifty loveliness, why dost thou spend -Upon thyself thy beauty's legacy? -Nature's bequest gives nothing but doth lend, -And being frank she lends to those are free. -Then, beauteous niggard, why dost thou abuse -The bounteous largess given thee to give? -Profitless usurer, why dost thou use -So great a sum of sums, yet canst not live? -For having traffic with thyself alone, -Thou of thyself thy sweet self dost deceive. -Then how, when nature calls thee to be gone, -What acceptable audit canst thou leave? -Thy unused beauty must be tomb'd with thee, -Which, used, lives th' executor to be. -Those hours, that with gentle work did frame -The lovely gaze where every eye doth dwell, -Will play the tyrants to the very same -And that unfair which fairly doth excel: -For never-resting time leads summer on -To hideous winter and confounds him there; -Sap cheque'd with frost and lusty leaves quite gone, -Beauty o'ersnow'd and bareness every where: -Then, were not summer's distillation left, -A liquid prisoner pent in walls of glass, -Beauty's effect with beauty were bereft, -Nor it nor no remembrance what it was: -But flowers distill'd though they with winter meet, -Leese but their show; their substance still lives sweet. -Then let not winter's ragged hand deface -In thee thy summer, ere thou be distill'd: -Make sweet some vial; treasure thou some place -With beauty's treasure, ere it be self-kill'd. -That use is not forbidden usury, -Which happies those that pay the willing loan; -That's for thyself to breed another thee, -Or ten times happier, be it ten for one; -Ten times thyself were happier than thou art, -If ten of thine ten times refigured thee: -Then what could death do, if thou shouldst depart, -Leaving thee living in posterity? -Be not self-will'd, for thou art much too fair -To be death's conquest and make worms thine heir. -Lo! in the orient when the gracious light -Lifts up his burning head, each under eye -Doth homage to his new-appearing sight, -Serving with looks his sacred majesty; -And having climb'd the steep-up heavenly hill, -Resembling strong youth in his middle age, -yet mortal looks adore his beauty still, -Attending on his golden pilgrimage; -But when from highmost pitch, with weary car, -Like feeble age, he reeleth from the day, -The eyes, 'fore duteous, now converted are -From his low tract and look another way: -So thou, thyself out-going in thy noon, -Unlook'd on diest, unless thou get a son. -Music to hear, why hear'st thou music sadly? -Sweets with sweets war not, joy delights in joy. -Why lovest thou that which thou receivest not gladly, -Or else receivest with pleasure thine annoy? -If the true concord of well-tuned sounds, -By unions married, do offend thine ear, -They do but sweetly chide thee, who confounds -In singleness the parts that thou shouldst bear. -Mark how one string, sweet husband to another, -Strikes each in each by mutual ordering, -Resembling sire and child and happy mother -Who all in one, one pleasing note do sing: -Whose speechless song, being many, seeming one, -Sings this to thee: 'thou single wilt prove none.' -Is it for fear to wet a widow's eye -That thou consumest thyself in single life? -Ah! if thou issueless shalt hap to die. -The world will wail thee, like a makeless wife; -The world will be thy widow and still weep -That thou no form of thee hast left behind, -When every private widow well may keep -By children's eyes her husband's shape in mind. -Look, what an unthrift in the world doth spend -Shifts but his place, for still the world enjoys it; -But beauty's waste hath in the world an end, -And kept unused, the user so destroys it. -No love toward others in that bosom sits -That on himself such murderous shame commits. -For shame! deny that thou bear'st love to any, -Who for thyself art so unprovident. -Grant, if thou wilt, thou art beloved of many, -But that thou none lovest is most evident; -For thou art so possess'd with murderous hate -That 'gainst thyself thou stick'st not to conspire. -Seeking that beauteous roof to ruinate -Which to repair should be thy chief desire. -O, change thy thought, that I may change my mind! -Shall hate be fairer lodged than gentle love? -Be, as thy presence is, gracious and kind, -Or to thyself at least kind-hearted prove: -Make thee another self, for love of me, -That beauty still may live in thine or thee. -As fast as thou shalt wane, so fast thou growest -In one of thine, from that which thou departest; -And that fresh blood which youngly thou bestowest -Thou mayst call thine when thou from youth convertest. -Herein lives wisdom, beauty and increase: -Without this, folly, age and cold decay: -If all were minded so, the times should cease -And threescore year would make the world away. -Let those whom Nature hath not made for store, -Harsh featureless and rude, barrenly perish: -Look, whom she best endow'd she gave the more; -Which bounteous gift thou shouldst in bounty cherish: -She carved thee for her seal, and meant thereby -Thou shouldst print more, not let that copy die. -When I do count the clock that tells the time, -And see the brave day sunk in hideous night; -When I behold the violet past prime, -And sable curls all silver'd o'er with white; -When lofty trees I see barren of leaves -Which erst from heat did canopy the herd, -And summer's green all girded up in sheaves -Borne on the bier with white and bristly beard, -Then of thy beauty do I question make, -That thou among the wastes of time must go, -Since sweets and beauties do themselves forsake -And die as fast as they see others grow; -And nothing 'gainst Time's scythe can make defence -Save breed, to brave him when he takes thee hence. -O, that you were yourself! but, love, you are -No longer yours than you yourself here live: -Against this coming end you should prepare, -And your sweet semblance to some other give. -So should that beauty which you hold in lease -Find no determination: then you were -Yourself again after yourself's decease, -When your sweet issue your sweet form should bear. -Who lets so fair a house fall to decay, -Which husbandry in honour might uphold -Against the stormy gusts of winter's day -And barren rage of death's eternal cold? -O, none but unthrifts! Dear my love, you know -You had a father: let your son say so. -Not from the stars do I my judgment pluck; -And yet methinks I have astronomy, -But not to tell of good or evil luck, -Of plagues, of dearths, or seasons' quality; -Nor can I fortune to brief minutes tell, -Pointing to each his thunder, rain and wind, -Or say with princes if it shall go well, -By oft predict that I in heaven find: -But from thine eyes my knowledge I derive, -And, constant stars, in them I read such art -As truth and beauty shall together thrive, -If from thyself to store thou wouldst convert; -Or else of thee this I prognosticate: -Thy end is truth's and beauty's doom and date. -When I consider every thing that grows -Holds in perfection but a little moment, -That this huge stage presenteth nought but shows -Whereon the stars in secret influence comment; -When I perceive that men as plants increase, -Cheered and cheque'd even by the self-same sky, -Vaunt in their youthful sap, at height decrease, -And wear their brave state out of memory; -Then the conceit of this inconstant stay -Sets you most rich in youth before my sight, -Where wasteful Time debateth with Decay, -To change your day of youth to sullied night; -And all in war with Time for love of you, -As he takes from you, I engraft you new. -But wherefore do not you a mightier way -Make war upon this bloody tyrant, Time? -And fortify yourself in your decay -With means more blessed than my barren rhyme? -Now stand you on the top of happy hours, -And many maiden gardens yet unset -With virtuous wish would bear your living flowers, -Much liker than your painted counterfeit: -So should the lines of life that life repair, -Which this, Time's pencil, or my pupil pen, -Neither in inward worth nor outward fair, -Can make you live yourself in eyes of men. -To give away yourself keeps yourself still, -And you must live, drawn by your own sweet skill. -Who will believe my verse in time to come, -If it were fill'd with your most high deserts? -Though yet, heaven knows, it is but as a tomb -Which hides your life and shows not half your parts. -If I could write the beauty of your eyes -And in fresh numbers number all your graces, -The age to come would say 'This poet lies: -Such heavenly touches ne'er touch'd earthly faces.' -So should my papers yellow'd with their age -Be scorn'd like old men of less truth than tongue, -And your true rights be term'd a poet's rage -And stretched metre of an antique song: -But were some child of yours alive that time, -You should live twice; in it and in my rhyme. -Shall I compare thee to a summer's day? -Thou art more lovely and more temperate: -Rough winds do shake the darling buds of May, -And summer's lease hath all too short a date: -Sometime too hot the eye of heaven shines, -And often is his gold complexion dimm'd; -And every fair from fair sometime declines, -By chance or nature's changing course untrimm'd; -But thy eternal summer shall not fade -Nor lose possession of that fair thou owest; -Nor shall Death brag thou wander'st in his shade, -When in eternal lines to time thou growest: -So long as men can breathe or eyes can see, -So long lives this and this gives life to thee. -Devouring Time, blunt thou the lion's paws, -And make the earth devour her own sweet brood; -Pluck the keen teeth from the fierce tiger's jaws, -And burn the long-lived phoenix in her blood; -Make glad and sorry seasons as thou fleets, -And do whate'er thou wilt, swift-footed Time, -To the wide world and all her fading sweets; -But I forbid thee one most heinous crime: -O, carve not with thy hours my love's fair brow, -Nor draw no lines there with thine antique pen; -Him in thy course untainted do allow -For beauty's pattern to succeeding men. -Yet, do thy worst, old Time: despite thy wrong, -My love shall in my verse ever live young. -A woman's face with Nature's own hand painted -Hast thou, the master-mistress of my passion; -A woman's gentle heart, but not acquainted -With shifting change, as is false women's fashion; -An eye more bright than theirs, less false in rolling, -Gilding the object whereupon it gazeth; -A man in hue, all 'hues' in his controlling, -Much steals men's eyes and women's souls amazeth. -And for a woman wert thou first created; -Till Nature, as she wrought thee, fell a-doting, -And by addition me of thee defeated, -By adding one thing to my purpose nothing. -But since she prick'd thee out for women's pleasure, -Mine be thy love and thy love's use their treasure. -So is it not with me as with that Muse -Stirr'd by a painted beauty to his verse, -Who heaven itself for ornament doth use -And every fair with his fair doth rehearse -Making a couplement of proud compare, -With sun and moon, with earth and sea's rich gems, -With April's first-born flowers, and all things rare -That heaven's air in this huge rondure hems. -O' let me, true in love, but truly write, -And then believe me, my love is as fair -As any mother's child, though not so bright -As those gold candles fix'd in heaven's air: -Let them say more than like of hearsay well; -I will not praise that purpose not to sell. -My glass shall not persuade me I am old, -So long as youth and thou are of one date; -But when in thee time's furrows I behold, -Then look I death my days should expiate. -For all that beauty that doth cover thee -Is but the seemly raiment of my heart, -Which in thy breast doth live, as thine in me: -How can I then be elder than thou art? -O, therefore, love, be of thyself so wary -As I, not for myself, but for thee will; -Bearing thy heart, which I will keep so chary -As tender nurse her babe from faring ill. -Presume not on thy heart when mine is slain; -Thou gavest me thine, not to give back again. -As an unperfect actor on the stage -Who with his fear is put besides his part, -Or some fierce thing replete with too much rage, -Whose strength's abundance weakens his own heart. -So I, for fear of trust, forget to say -The perfect ceremony of love's rite, -And in mine own love's strength seem to decay, -O'ercharged with burden of mine own love's might. -O, let my books be then the eloquence -And dumb presagers of my speaking breast, -Who plead for love and look for recompense -More than that tongue that more hath more express'd. -O, learn to read what silent love hath writ: -To hear with eyes belongs to love's fine wit. -Mine eye hath play'd the painter and hath stell'd -Thy beauty's form in table of my heart; -My body is the frame wherein 'tis held, -And perspective it is the painter's art. -For through the painter must you see his skill, -To find where your true image pictured lies; -Which in my bosom's shop is hanging still, -That hath his windows glazed with thine eyes. -Now see what good turns eyes for eyes have done: -Mine eyes have drawn thy shape, and thine for me -Are windows to my breast, where-through the sun -Delights to peep, to gaze therein on thee; -Yet eyes this cunning want to grace their art; -They draw but what they see, know not the heart. -Let those who are in favour with their stars -Of public honour and proud titles boast, -Whilst I, whom fortune of such triumph bars, -Unlook'd for joy in that I honour most. -Great princes' favourites their fair leaves spread -But as the marigold at the sun's eye, -And in themselves their pride lies buried, -For at a frown they in their glory die. -The painful warrior famoused for fight, -After a thousand victories once foil'd, -Is from the book of honour razed quite, -And all the rest forgot for which he toil'd: -Then happy I, that love and am beloved -Where I may not remove nor be removed. -Lord of my love, to whom in vassalage -Thy merit hath my duty strongly knit, -To thee I send this written embassage, -To witness duty, not to show my wit: -Duty so great, which wit so poor as mine -May make seem bare, in wanting words to show it, -But that I hope some good conceit of thine -In thy soul's thought, all naked, will bestow it; -Till whatsoever star that guides my moving -Points on me graciously with fair aspect -And puts apparel on my tatter'd loving, -To show me worthy of thy sweet respect: -Then may I dare to boast how I do love thee; -Till then not show my head where thou mayst prove me. -Weary with toil, I haste me to my bed, -The dear repose for limbs with travel tired; -But then begins a journey in my head, -To work my mind, when body's work's expired: -For then my thoughts, from far where I abide, -Intend a zealous pilgrimage to thee, -And keep my drooping eyelids open wide, -Looking on darkness which the blind do see -Save that my soul's imaginary sight -Presents thy shadow to my sightless view, -Which, like a jewel hung in ghastly night, -Makes black night beauteous and her old face new. -Lo! thus, by day my limbs, by night my mind, -For thee and for myself no quiet find. -How can I then return in happy plight, -That am debarr'd the benefit of rest? -When day's oppression is not eased by night, -But day by night, and night by day, oppress'd? -And each, though enemies to either's reign, -Do in consent shake hands to torture me; -The one by toil, the other to complain -How far I toil, still farther off from thee. -I tell the day, to please them thou art bright -And dost him grace when clouds do blot the heaven: -So flatter I the swart-complexion'd night, -When sparkling stars twire not thou gild'st the even. -But day doth daily draw my sorrows longer -And night doth nightly make grief's strength seem stronger. -When, in disgrace with fortune and men's eyes, -I all alone beweep my outcast state -And trouble deal heaven with my bootless cries -And look upon myself and curse my fate, -Wishing me like to one more rich in hope, -Featured like him, like him with friends possess'd, -Desiring this man's art and that man's scope, -With what I most enjoy contented least; -Yet in these thoughts myself almost despising, -Haply I think on thee, and then my state, -Like to the lark at break of day arising -From sullen earth, sings hymns at heaven's gate; -For thy sweet love remember'd such wealth brings -That then I scorn to change my state with kings. -When to the sessions of sweet silent thought -I summon up remembrance of things past, -I sigh the lack of many a thing I sought, -And with old woes new wail my dear time's waste: -Then can I drown an eye, unused to flow, -For precious friends hid in death's dateless night, -And weep afresh love's long since cancell'd woe, -And moan the expense of many a vanish'd sight: -Then can I grieve at grievances foregone, -And heavily from woe to woe tell o'er -The sad account of fore-bemoaned moan, -Which I new pay as if not paid before. -But if the while I think on thee, dear friend, -All losses are restored and sorrows end. -Thy bosom is endeared with all hearts, -Which I by lacking have supposed dead, -And there reigns love and all love's loving parts, -And all those friends which I thought buried. -How many a holy and obsequious tear -Hath dear religious love stol'n from mine eye -As interest of the dead, which now appear -But things removed that hidden in thee lie! -Thou art the grave where buried love doth live, -Hung with the trophies of my lovers gone, -Who all their parts of me to thee did give; -That due of many now is thine alone: -Their images I loved I view in thee, -And thou, all they, hast all the all of me. -If thou survive my well-contented day, -When that churl Death my bones with dust shall cover, -And shalt by fortune once more re-survey -These poor rude lines of thy deceased lover, -Compare them with the bettering of the time, -And though they be outstripp'd by every pen, -Reserve them for my love, not for their rhyme, -Exceeded by the height of happier men. -O, then vouchsafe me but this loving thought: -'Had my friend's Muse grown with this growing age, -A dearer birth than this his love had brought, -To march in ranks of better equipage: -But since he died and poets better prove, -Theirs for their style I'll read, his for his love.' -Full many a glorious morning have I seen -Flatter the mountain-tops with sovereign eye, -Kissing with golden face the meadows green, -Gilding pale streams with heavenly alchemy; -Anon permit the basest clouds to ride -With ugly rack on his celestial face, -And from the forlorn world his visage hide, -Stealing unseen to west with this disgrace: -Even so my sun one early morn did shine -With all triumphant splendor on my brow; -But out, alack! he was but one hour mine; -The region cloud hath mask'd him from me now. -Yet him for this my love no whit disdaineth; -Suns of the world may stain when heaven's sun staineth. -Why didst thou promise such a beauteous day, -And make me travel forth without my cloak, -To let base clouds o'ertake me in my way, -Hiding thy bravery in their rotten smoke? -'Tis not enough that through the cloud thou break, -To dry the rain on my storm-beaten face, -For no man well of such a salve can speak -That heals the wound and cures not the disgrace: -Nor can thy shame give physic to my grief; -Though thou repent, yet I have still the loss: -The offender's sorrow lends but weak relief -To him that bears the strong offence's cross. -Ah! but those tears are pearl which thy love sheds, -And they are rich and ransom all ill deeds. -No more be grieved at that which thou hast done: -Roses have thorns, and silver fountains mud; -Clouds and eclipses stain both moon and sun, -And loathsome canker lives in sweetest bud. -All men make faults, and even I in this, -Authorizing thy trespass with compare, -Myself corrupting, salving thy amiss, -Excusing thy sins more than thy sins are; -For to thy sensual fault I bring in sense-- -Thy adverse party is thy advocate-- -And 'gainst myself a lawful plea commence: -Such civil war is in my love and hate -That I an accessary needs must be -To that sweet thief which sourly robs from me. -Let me confess that we two must be twain, -Although our undivided loves are one: -So shall those blots that do with me remain -Without thy help by me be borne alone. -In our two loves there is but one respect, -Though in our lives a separable spite, -Which though it alter not love's sole effect, -Yet doth it steal sweet hours from love's delight. -I may not evermore acknowledge thee, -Lest my bewailed guilt should do thee shame, -Nor thou with public kindness honour me, -Unless thou take that honour from thy name: -But do not so; I love thee in such sort -As, thou being mine, mine is thy good report. -As a decrepit father takes delight -To see his active child do deeds of youth, -So I, made lame by fortune's dearest spite, -Take all my comfort of thy worth and truth. -For whether beauty, birth, or wealth, or wit, -Or any of these all, or all, or more, -Entitled in thy parts do crowned sit, -I make my love engrafted to this store: -So then I am not lame, poor, nor despised, -Whilst that this shadow doth such substance give -That I in thy abundance am sufficed -And by a part of all thy glory live. -Look, what is best, that best I wish in thee: -This wish I have; then ten times happy me!FROM fairest creatures we desire increase, -That thereby beauty's rose might never die, -But as the riper should by time decease, -His tender heir might bear his memory: -But thou, contracted to thine own bright eyes, -Feed'st thy light'st flame with self-substantial fuel, -Making a famine where abundance lies, -Thyself thy foe, to thy sweet self too cruel. -Thou that art now the world's fresh ornament -And only herald to the gaudy spring, -Within thine own bud buriest thy content -And, tender churl, makest waste in niggarding. -Pity the world, or else this glutton be, -To eat the world's due, by the grave and thee. -When forty winters shall beseige thy brow, -And dig deep trenches in thy beauty's field, -Thy youth's proud livery, so gazed on now, -Will be a tatter'd weed, of small worth held: -Then being ask'd where all thy beauty lies, -Where all the treasure of thy lusty days, -To say, within thine own deep-sunken eyes, -Were an all-eating shame and thriftless praise. -How much more praise deserved thy beauty's use, -If thou couldst answer 'This fair child of mine -Shall sum my count and make my old excuse,' -Proving his beauty by succession thine! -This were to be new made when thou art old, -And see thy blood warm when thou feel'st it cold. -Look in thy glass, and tell the face thou viewest -Now is the time that face should form another; -Whose fresh repair if now thou not renewest, -Thou dost beguile the world, unbless some mother. -For where is she so fair whose unear'd womb -Disdains the tillage of thy husbandry? -Or who is he so fond will be the tomb -Of his self-love, to stop posterity? -Thou art thy mother's glass, and she in thee -Calls back the lovely April of her prime: -So thou through windows of thine age shall see -Despite of wrinkles this thy golden time. -But if thou live, remember'd not to be, -Die single, and thine image dies with thee. -Unthrifty loveliness, why dost thou spend -Upon thyself thy beauty's legacy? -Nature's bequest gives nothing but doth lend, -And being frank she lends to those are free. -Then, beauteous niggard, why dost thou abuse -The bounteous largess given thee to give? -Profitless usurer, why dost thou use -So great a sum of sums, yet canst not live? -For having traffic with thyself alone, -Thou of thyself thy sweet self dost deceive. -Then how, when nature calls thee to be gone, -What acceptable audit canst thou leave? -Thy unused beauty must be tomb'd with thee, -Which, used, lives th' executor to be. -Those hours, that with gentle work did frame -The lovely gaze where every eye doth dwell, -Will play the tyrants to the very same -And that unfair which fairly doth excel: -For never-resting time leads summer on -To hideous winter and confounds him there; -Sap cheque'd with frost and lusty leaves quite gone, -Beauty o'ersnow'd and bareness every where: -Then, were not summer's distillation left, -A liquid prisoner pent in walls of glass, -Beauty's effect with beauty were bereft, -Nor it nor no remembrance what it was: -But flowers distill'd though they with winter meet, -Leese but their show; their substance still lives sweet. -Then let not winter's ragged hand deface -In thee thy summer, ere thou be distill'd: -Make sweet some vial; treasure thou some place -With beauty's treasure, ere it be self-kill'd. -That use is not forbidden usury, -Which happies those that pay the willing loan; -That's for thyself to breed another thee, -Or ten times happier, be it ten for one; -Ten times thyself were happier than thou art, -If ten of thine ten times refigured thee: -Then what could death do, if thou shouldst depart, -Leaving thee living in posterity? -Be not self-will'd, for thou art much too fair -To be death's conquest and make worms thine heir. -Lo! in the orient when the gracious light -Lifts up his burning head, each under eye -Doth homage to his new-appearing sight, -Serving with looks his sacred majesty; -And having climb'd the steep-up heavenly hill, -Resembling strong youth in his middle age, -yet mortal looks adore his beauty still, -Attending on his golden pilgrimage; -But when from highmost pitch, with weary car, -Like feeble age, he reeleth from the day, -The eyes, 'fore duteous, now converted are -From his low tract and look another way: -So thou, thyself out-going in thy noon, -Unlook'd on diest, unless thou get a son. -Music to hear, why hear'st thou music sadly? -Sweets with sweets war not, joy delights in joy. -Why lovest thou that which thou receivest not gladly, -Or else receivest with pleasure thine annoy? -If the true concord of well-tuned sounds, -By unions married, do offend thine ear, -They do but sweetly chide thee, who confounds -In singleness the parts that thou shouldst bear. -Mark how one string, sweet husband to another, -Strikes each in each by mutual ordering, -Resembling sire and child and happy mother -Who all in one, one pleasing note do sing: -Whose speechless song, being many, seeming one, -Sings this to thee: 'thou single wilt prove none.' -Is it for fear to wet a widow's eye -That thou consumest thyself in single life? -Ah! if thou issueless shalt hap to die. -The world will wail thee, like a makeless wife; -The world will be thy widow and still weep -That thou no form of thee hast left behind, -When every private widow well may keep -By children's eyes her husband's shape in mind. -Look, what an unthrift in the world doth spend -Shifts but his place, for still the world enjoys it; -But beauty's waste hath in the world an end, -And kept unused, the user so destroys it. -No love toward others in that bosom sits -That on himself such murderous shame commits. -For shame! deny that thou bear'st love to any, -Who for thyself art so unprovident. -Grant, if thou wilt, thou art beloved of many, -But that thou none lovest is most evident; -For thou art so possess'd with murderous hate -That 'gainst thyself thou stick'st not to conspire. -Seeking that beauteous roof to ruinate -Which to repair should be thy chief desire. -O, change thy thought, that I may change my mind! -Shall hate be fairer lodged than gentle love? -Be, as thy presence is, gracious and kind, -Or to thyself at least kind-hearted prove: -Make thee another self, for love of me, -That beauty still may live in thine or thee. -As fast as thou shalt wane, so fast thou growest -In one of thine, from that which thou departest; -And that fresh blood which youngly thou bestowest -Thou mayst call thine when thou from youth convertest. -Herein lives wisdom, beauty and increase: -Without this, folly, age and cold decay: -If all were minded so, the times should cease -And threescore year would make the world away. -Let those whom Nature hath not made for store, -Harsh featureless and rude, barrenly perish: -Look, whom she best endow'd she gave the more; -Which bounteous gift thou shouldst in bounty cherish: -She carved thee for her seal, and meant thereby -Thou shouldst print more, not let that copy die. -When I do count the clock that tells the time, -And see the brave day sunk in hideous night; -When I behold the violet past prime, -And sable curls all silver'd o'er with white; -When lofty trees I see barren of leaves -Which erst from heat did canopy the herd, -And summer's green all girded up in sheaves -Borne on the bier with white and bristly beard, -Then of thy beauty do I question make, -That thou among the wastes of time must go, -Since sweets and beauties do themselves forsake -And die as fast as they see others grow; -And nothing 'gainst Time's scythe can make defence -Save breed, to brave him when he takes thee hence. -O, that you were yourself! but, love, you are -No longer yours than you yourself here live: -Against this coming end you should prepare, -And your sweet semblance to some other give. -So should that beauty which you hold in lease -Find no determination: then you were -Yourself again after yourself's decease, -When your sweet issue your sweet form should bear. -Who lets so fair a house fall to decay, -Which husbandry in honour might uphold -Against the stormy gusts of winter's day -And barren rage of death's eternal cold? -O, none but unthrifts! Dear my love, you know -You had a father: let your son say so. -Not from the stars do I my judgment pluck; -And yet methinks I have astronomy, -But not to tell of good or evil luck, -Of plagues, of dearths, or seasons' quality; -Nor can I fortune to brief minutes tell, -Pointing to each his thunder, rain and wind, -Or say with princes if it shall go well, -By oft predict that I in heaven find: -But from thine eyes my knowledge I derive, -And, constant stars, in them I read such art -As truth and beauty shall together thrive, -If from thyself to store thou wouldst convert; -Or else of thee this I prognosticate: -Thy end is truth's and beauty's doom and date. -When I consider every thing that grows -Holds in perfection but a little moment, -That this huge stage presenteth nought but shows -Whereon the stars in secret influence comment; -When I perceive that men as plants increase, -Cheered and cheque'd even by the self-same sky, -Vaunt in their youthful sap, at height decrease, -And wear their brave state out of memory; -Then the conceit of this inconstant stay -Sets you most rich in youth before my sight, -Where wasteful Time debateth with Decay, -To change your day of youth to sullied night; -And all in war with Time for love of you, -As he takes from you, I engraft you new. -But wherefore do not you a mightier way -Make war upon this bloody tyrant, Time? -And fortify yourself in your decay -With means more blessed than my barren rhyme? -Now stand you on the top of happy hours, -And many maiden gardens yet unset -With virtuous wish would bear your living flowers, -Much liker than your painted counterfeit: -So should the lines of life that life repair, -Which this, Time's pencil, or my pupil pen, -Neither in inward worth nor outward fair, -Can make you live yourself in eyes of men. -To give away yourself keeps yourself still, -And you must live, drawn by your own sweet skill. -Who will believe my verse in time to come, -If it were fill'd with your most high deserts? -Though yet, heaven knows, it is but as a tomb -Which hides your life and shows not half your parts. -If I could write the beauty of your eyes -And in fresh numbers number all your graces, -The age to come would say 'This poet lies: -Such heavenly touches ne'er touch'd earthly faces.' -So should my papers yellow'd with their age -Be scorn'd like old men of less truth than tongue, -And your true rights be term'd a poet's rage -And stretched metre of an antique song: -But were some child of yours alive that time, -You should live twice; in it and in my rhyme. -Shall I compare thee to a summer's day? -Thou art more lovely and more temperate: -Rough winds do shake the darling buds of May, -And summer's lease hath all too short a date: -Sometime too hot the eye of heaven shines, -And often is his gold complexion dimm'd; -And every fair from fair sometime declines, -By chance or nature's changing course untrimm'd; -But thy eternal summer shall not fade -Nor lose possession of that fair thou owest; -Nor shall Death brag thou wander'st in his shade, -When in eternal lines to time thou growest: -So long as men can breathe or eyes can see, -So long lives this and this gives life to thee. -Devouring Time, blunt thou the lion's paws, -And make the earth devour her own sweet brood; -Pluck the keen teeth from the fierce tiger's jaws, -And burn the long-lived phoenix in her blood; -Make glad and sorry seasons as thou fleets, -And do whate'er thou wilt, swift-footed Time, -To the wide world and all her fading sweets; -But I forbid thee one most heinous crime: -O, carve not with thy hours my love's fair brow, -Nor draw no lines there with thine antique pen; -Him in thy course untainted do allow -For beauty's pattern to succeeding men. -Yet, do thy worst, old Time: despite thy wrong, -My love shall in my verse ever live young. -A woman's face with Nature's own hand painted -Hast thou, the master-mistress of my passion; -A woman's gentle heart, but not acquainted -With shifting change, as is false women's fashion; -An eye more bright than theirs, less false in rolling, -Gilding the object whereupon it gazeth; -A man in hue, all 'hues' in his controlling, -Much steals men's eyes and women's souls amazeth. -And for a woman wert thou first created; -Till Nature, as she wrought thee, fell a-doting, -And by addition me of thee defeated, -By adding one thing to my purpose nothing. -But since she prick'd thee out for women's pleasure, -Mine be thy love and thy love's use their treasure. -So is it not with me as with that Muse -Stirr'd by a painted beauty to his verse, -Who heaven itself for ornament doth use -And every fair with his fair doth rehearse -Making a couplement of proud compare, -With sun and moon, with earth and sea's rich gems, -With April's first-born flowers, and all things rare -That heaven's air in this huge rondure hems. -O' let me, true in love, but truly write, -And then believe me, my love is as fair -As any mother's child, though not so bright -As those gold candles fix'd in heaven's air: -Let them say more than like of hearsay well; -I will not praise that purpose not to sell. -My glass shall not persuade me I am old, -So long as youth and thou are of one date; -But when in thee time's furrows I behold, -Then look I death my days should expiate. -For all that beauty that doth cover thee -Is but the seemly raiment of my heart, -Which in thy breast doth live, as thine in me: -How can I then be elder than thou art? -O, therefore, love, be of thyself so wary -As I, not for myself, but for thee will; -Bearing thy heart, which I will keep so chary -As tender nurse her babe from faring ill. -Presume not on thy heart when mine is slain; -Thou gavest me thine, not to give back again. -As an unperfect actor on the stage -Who with his fear is put besides his part, -Or some fierce thing replete with too much rage, -Whose strength's abundance weakens his own heart. -So I, for fear of trust, forget to say -The perfect ceremony of love's rite, -And in mine own love's strength seem to decay, -O'ercharged with burden of mine own love's might. -O, let my books be then the eloquence -And dumb presagers of my speaking breast, -Who plead for love and look for recompense -More than that tongue that more hath more express'd. -O, learn to read what silent love hath writ: -To hear with eyes belongs to love's fine wit. -Mine eye hath play'd the painter and hath stell'd -Thy beauty's form in table of my heart; -My body is the frame wherein 'tis held, -And perspective it is the painter's art. -For through the painter must you see his skill, -To find where your true image pictured lies; -Which in my bosom's shop is hanging still, -That hath his windows glazed with thine eyes. -Now see what good turns eyes for eyes have done: -Mine eyes have drawn thy shape, and thine for me -Are windows to my breast, where-through the sun -Delights to peep, to gaze therein on thee; -Yet eyes this cunning want to grace their art; -They draw but what they see, know not the heart. -Let those who are in favour with their stars -Of public honour and proud titles boast, -Whilst I, whom fortune of such triumph bars, -Unlook'd for joy in that I honour most. -Great princes' favourites their fair leaves spread -But as the marigold at the sun's eye, -And in themselves their pride lies buried, -For at a frown they in their glory die. -The painful warrior famoused for fight, -After a thousand victories once foil'd, -Is from the book of honour razed quite, -And all the rest forgot for which he toil'd: -Then happy I, that love and am beloved -Where I may not remove nor be removed. -Lord of my love, to whom in vassalage -Thy merit hath my duty strongly knit, -To thee I send this written embassage, -To witness duty, not to show my wit: -Duty so great, which wit so poor as mine -May make seem bare, in wanting words to show it, -But that I hope some good conceit of thine -In thy soul's thought, all naked, will bestow it; -Till whatsoever star that guides my moving -Points on me graciously with fair aspect -And puts apparel on my tatter'd loving, -To show me worthy of thy sweet respect: -Then may I dare to boast how I do love thee; -Till then not show my head where thou mayst prove me. -Weary with toil, I haste me to my bed, -The dear repose for limbs with travel tired; -But then begins a journey in my head, -To work my mind, when body's work's expired: -For then my thoughts, from far where I abide, -Intend a zealous pilgrimage to thee, -And keep my drooping eyelids open wide, -Looking on darkness which the blind do see -Save that my soul's imaginary sight -Presents thy shadow to my sightless view, -Which, like a jewel hung in ghastly night, -Makes black night beauteous and her old face new. -Lo! thus, by day my limbs, by night my mind, -For thee and for myself no quiet find. -How can I then return in happy plight, -That am debarr'd the benefit of rest? -When day's oppression is not eased by night, -But day by night, and night by day, oppress'd? -And each, though enemies to either's reign, -Do in consent shake hands to torture me; -The one by toil, the other to complain -How far I toil, still farther off from thee. -I tell the day, to please them thou art bright -And dost him grace when clouds do blot the heaven: -So flatter I the swart-complexion'd night, -When sparkling stars twire not thou gild'st the even. -But day doth daily draw my sorrows longer -And night doth nightly make grief's strength seem stronger. -When, in disgrace with fortune and men's eyes, -I all alone beweep my outcast state -And trouble deal heaven with my bootless cries -And look upon myself and curse my fate, -Wishing me like to one more rich in hope, -Featured like him, like him with friends possess'd, -Desiring this man's art and that man's scope, -With what I most enjoy contented least; -Yet in these thoughts myself almost despising, -Haply I think on thee, and then my state, -Like to the lark at break of day arising -From sullen earth, sings hymns at heaven's gate; -For thy sweet love remember'd such wealth brings -That then I scorn to change my state with kings. -When to the sessions of sweet silent thought -I summon up remembrance of things past, -I sigh the lack of many a thing I sought, -And with old woes new wail my dear time's waste: -Then can I drown an eye, unused to flow, -For precious friends hid in death's dateless night, -And weep afresh love's long since cancell'd woe, -And moan the expense of many a vanish'd sight: -Then can I grieve at grievances foregone, -And heavily from woe to woe tell o'er -The sad account of fore-bemoaned moan, -Which I new pay as if not paid before. -But if the while I think on thee, dear friend, -All losses are restored and sorrows end. -Thy bosom is endeared with all hearts, -Which I by lacking have supposed dead, -And there reigns love and all love's loving parts, -And all those friends which I thought buried. -How many a holy and obsequious tear -Hath dear religious love stol'n from mine eye -As interest of the dead, which now appear -But things removed that hidden in thee lie! -Thou art the grave where buried love doth live, -Hung with the trophies of my lovers gone, -Who all their parts of me to thee did give; -That due of many now is thine alone: -Their images I loved I view in thee, -And thou, all they, hast all the all of me. -If thou survive my well-contented day, -When that churl Death my bones with dust shall cover, -And shalt by fortune once more re-survey -These poor rude lines of thy deceased lover, -Compare them with the bettering of the time, -And though they be outstripp'd by every pen, -Reserve them for my love, not for their rhyme, -Exceeded by the height of happier men. -O, then vouchsafe me but this loving thought: -'Had my friend's Muse grown with this growing age, -A dearer birth than this his love had brought, -To march in ranks of better equipage: -But since he died and poets better prove, -Theirs for their style I'll read, his for his love.' -Full many a glorious morning have I seen -Flatter the mountain-tops with sovereign eye, -Kissing with golden face the meadows green, -Gilding pale streams with heavenly alchemy; -Anon permit the basest clouds to ride -With ugly rack on his celestial face, -And from the forlorn world his visage hide, -Stealing unseen to west with this disgrace: -Even so my sun one early morn did shine -With all triumphant splendor on my brow; -But out, alack! he was but one hour mine; -The region cloud hath mask'd him from me now. -Yet him for this my love no whit disdaineth; -Suns of the world may stain when heaven's sun staineth. -Why didst thou promise such a beauteous day, -And make me travel forth without my cloak, -To let base clouds o'ertake me in my way, -Hiding thy bravery in their rotten smoke? -'Tis not enough that through the cloud thou break, -To dry the rain on my storm-beaten face, -For no man well of such a salve can speak -That heals the wound and cures not the disgrace: -Nor can thy shame give physic to my grief; -Though thou repent, yet I have still the loss: -The offender's sorrow lends but weak relief -To him that bears the strong offence's cross. -Ah! but those tears are pearl which thy love sheds, -And they are rich and ransom all ill deeds. -No more be grieved at that which thou hast done: -Roses have thorns, and silver fountains mud; -Clouds and eclipses stain both moon and sun, -And loathsome canker lives in sweetest bud. -All men make faults, and even I in this, -Authorizing thy trespass with compare, -Myself corrupting, salving thy amiss, -Excusing thy sins more than thy sins are; -For to thy sensual fault I bring in sense-- -Thy adverse party is thy advocate-- -And 'gainst myself a lawful plea commence: -Such civil war is in my love and hate -That I an accessary needs must be -To that sweet thief which sourly robs from me. -Let me confess that we two must be twain, -Although our undivided loves are one: -So shall those blots that do with me remain -Without thy help by me be borne alone. -In our two loves there is but one respect, -Though in our lives a separable spite, -Which though it alter not love's sole effect, -Yet doth it steal sweet hours from love's delight. -I may not evermore acknowledge thee, -Lest my bewailed guilt should do thee shame, -Nor thou with public kindness honour me, -Unless thou take that honour from thy name: -But do not so; I love thee in such sort -As, thou being mine, mine is thy good report. -As a decrepit father takes delight -To see his active child do deeds of youth, -So I, made lame by fortune's dearest spite, -Take all my comfort of thy worth and truth. -For whether beauty, birth, or wealth, or wit, -Or any of these all, or all, or more, -Entitled in thy parts do crowned sit, -I make my love engrafted to this store: -So then I am not lame, poor, nor despised, -Whilst that this shadow doth such substance give -That I in thy abundance am sufficed -And by a part of all thy glory live. -Look, what is best, that best I wish in thee: -This wish I have; then ten times happy me! \ No newline at end of file From afd7a29ba1fd7fabb33eae9cb8b5e1961d629cf6 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 15 Sep 2024 22:26:11 +0000 Subject: [PATCH 225/303] revert change in flash attn and flash infer to clean up the diff --- vllm/attention/backends/flash_attn.py | 3 --- vllm/attention/backends/flashinfer.py | 1 - 2 files changed, 4 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 09456ca8d7b61..bf883987bd80b 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -15,9 +15,6 @@ is_block_tables_empty) from vllm.utils import async_tensor_h2d, make_tensor_with_pad -from vllm.distributed import get_disagg_group -import vllm.envs as envs - if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 852c5cd8dc180..4054d337316fe 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -823,4 +823,3 @@ def forward( k_scale=k_scale, v_scale=v_scale) return output.view(num_tokens, hidden_size) - \ No newline at end of file From cbf24b34bbcf74b268beeda26550cfe5ea8639e1 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 15 Sep 2024 22:35:12 +0000 Subject: [PATCH 226/303] update the example --- .../disagg_prefill/disagg_prefill_example.sh | 35 +++++++++++++++---- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/examples/disagg_prefill/disagg_prefill_example.sh b/examples/disagg_prefill/disagg_prefill_example.sh index f57f5fd86d89c..56b6f44c7418a 100644 --- a/examples/disagg_prefill/disagg_prefill_example.sh +++ b/examples/disagg_prefill/disagg_prefill_example.sh @@ -16,7 +16,7 @@ wait_for_server() { } # prefilling instance -VLLM_RPC_PORT=5570 VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0 python3 \ +VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0 python3 \ -m vllm.entrypoints.openai.api_server \ --model meta-llama/Meta-Llama-3.1-8B-Instruct \ --port 8100 \ @@ -24,7 +24,7 @@ VLLM_RPC_PORT=5570 VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0 pytho --gpu-memory-utilization 0.8 & # decoding instance -VLLM_RPC_PORT=5580 VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=1 python3 \ +VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=1 python3 \ -m vllm.entrypoints.openai.api_server \ --model meta-llama/Meta-Llama-3.1-8B-Instruct \ --port 8200 \ @@ -36,18 +36,39 @@ wait_for_server 8100 wait_for_server 8200 # launch a proxy server that opens the service at port 8000 +# the workflow of this proxy: +# - send the request to prefill vLLM instance (port 8100), change max_tokens to 1 +# - after the prefill vLLM finishes prefill, send the request to decode vLLM instance python3 ../../benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py & sleep 1 -# serve an example request -curl http://localhost:8000/v1/completions \ +# serve two example requests +output1=$(curl -s http://localhost:8000/v1/completions \ -H "Content-Type: application/json" \ -d '{ "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", "prompt": "San Francisco is a", "max_tokens": 10, "temperature": 0 -}' +}') -# clean up -ps -e | grep pt_main_thread | awk '{print $1}' | xargs kill -9 \ No newline at end of file +output2=$(curl -s http://localhost:8000/v1/completions \ +-H "Content-Type: application/json" \ +-d '{ +"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", +"prompt": "Santa Clara is a", +"max_tokens": 10, +"temperature": 0 +}') + +# Print the outputs of the curl requests +echo "" +echo "Output of first request: $output1" +echo "Output of second request: $output2" + +echo "Successfully finished 2 test requests!" +echo "" + +# Cleanup commands, suppressing their output +ps -e | grep pt_main_thread | awk '{print $1}' | xargs kill -9 > /dev/null 2>&1 +pkill -f python3 > /dev/null 2>&1 From 4f4ea5053abf40d304cb4281a46e27365da6ec75 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 15 Sep 2024 23:42:30 +0000 Subject: [PATCH 227/303] make format checker happy --- .../disagg_prefill_proxy_server.py | 40 +-- .../disagg_benchmarks/round_robin_proxy.py | 25 +- .../visualize_benchmark_results.py | 39 ++- .../kv_transfer/kv_lookup_buffer/base.py | 19 +- .../simple_kv_lookup_buffer.py | 122 ++++----- vllm/distributed/kv_transfer/kv_pipe/base.py | 16 +- .../kv_pipe/torch_distributed_pipe.py | 76 +++--- vllm/distributed/kv_transfer/vllm_adapter.py | 252 ++++++++++-------- vllm/distributed/parallel_state.py | 33 ++- vllm/executor/gpu_executor.py | 3 +- vllm/executor/multiproc_gpu_executor.py | 3 +- vllm/executor/ray_gpu_executor.py | 2 +- vllm/worker/model_runner.py | 46 ++-- vllm/worker/worker_base.py | 87 +++--- 14 files changed, 386 insertions(+), 377 deletions(-) diff --git a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py index 5750df7735ad1..4058b1c0a3b79 100644 --- a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py +++ b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py @@ -1,28 +1,31 @@ -from quart import Quart, request, Response, jsonify, make_response -import aiohttp -import sys -import traceback import os +import aiohttp +from quart import Quart, make_response, request + AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) app = Quart(__name__) + async def forward_request(url, data): - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" } - async with session.post(url=url, json=data, headers=headers) as response: + async with session.post(url=url, json=data, + headers=headers) as response: if response.status == 200: # if response.headers.get('Transfer-Encoding') == 'chunked': if True: - async for chunk_bytes in response.content.iter_chunked(1024): + async for chunk_bytes in response.content.iter_chunked( + 1024): yield chunk_bytes else: content = await response.read() yield content + @app.route('/v1/completions', methods=['POST']) async def handle_request(): try: @@ -31,25 +34,28 @@ async def handle_request(): prefill_request = original_request_data.copy() # change max_tokens = 1 to let it only do prefill prefill_request['max_tokens'] = 1 - + # finish prefill - async for _ in forward_request('http://localhost:8100/v1/completions', prefill_request): + async for _ in forward_request('http://localhost:8100/v1/completions', + prefill_request): continue - print(f"Prefill done. proceeding to decode.") - # return decode - generator = forward_request('http://localhost:8200/v1/completions', original_request_data) + generator = forward_request('http://localhost:8200/v1/completions', + original_request_data) response = await make_response(generator) response.timeout = None return response - + except Exception as e: - pass - # exc_info = sys.exc_info() - # print(e) - # print("".join(traceback.format_exception(*exc_info))) + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server") + print(e) + print("".join(traceback.format_exception(*exc_info))) + if __name__ == '__main__': app.run(port=8000) diff --git a/benchmarks/disagg_benchmarks/round_robin_proxy.py b/benchmarks/disagg_benchmarks/round_robin_proxy.py index 8751e24a08d33..6eb5f63980070 100644 --- a/benchmarks/disagg_benchmarks/round_robin_proxy.py +++ b/benchmarks/disagg_benchmarks/round_robin_proxy.py @@ -1,9 +1,12 @@ import asyncio +import itertools + import aiohttp from aiohttp import web -import itertools + class RoundRobinProxy: + def __init__(self, target_ports): self.target_ports = target_ports self.port_cycle = itertools.cycle(self.target_ports) @@ -16,16 +19,14 @@ async def handle_request(self, request): try: # Forward the request async with session.request( - method=request.method, - url=target_url, - headers=request.headers, - data=request.content, + method=request.method, + url=target_url, + headers=request.headers, + data=request.content, ) as response: # Start sending the response - resp = web.StreamResponse( - status=response.status, - headers=response.headers - ) + resp = web.StreamResponse(status=response.status, + headers=response.headers) await resp.prepare(request) # Stream the response content @@ -38,6 +39,7 @@ async def handle_request(self, request): except Exception as e: return web.Response(text=f"Error: {str(e)}", status=500) + async def main(): proxy = RoundRobinProxy([8100, 8200]) app = web.Application() @@ -49,9 +51,10 @@ async def main(): await site.start() print("Proxy server started on http://localhost:8000") - + # Keep the server running await asyncio.Event().wait() + if __name__ == '__main__': - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/benchmarks/disagg_benchmarks/visualize_benchmark_results.py b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py index 192f26a1e3cd2..6c5bf5c791dc9 100644 --- a/benchmarks/disagg_benchmarks/visualize_benchmark_results.py +++ b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py @@ -1,40 +1,42 @@ +import json import matplotlib.pyplot as plt -import yaml import pandas as pd -import json - - if __name__ == "__main__": data = [] for name in ['disagg_prefill', 'chunked_prefill']: - for qps in [2,4,6,8]: + for qps in [2, 4, 6, 8]: with open(f"results/{name}-qps-{qps}.json", "r") as f: x = json.load(f) x['name'] = name x['qps'] = qps data.append(x) - + df = pd.DataFrame.from_dict(data) dis_df = df[df['name'] == 'disagg_prefill'] chu_df = df[df['name'] == 'chunked_prefill'] - + plt.style.use('bmh') plt.rcParams['font.size'] = 20 - - for key in ['mean_ttft_ms', - 'median_ttft_ms', - 'p99_ttft_ms', - 'mean_itl_ms', - 'median_itl_ms', - 'p99_itl_ms']: - + for key in [ + 'mean_ttft_ms', 'median_ttft_ms', 'p99_ttft_ms', 'mean_itl_ms', + 'median_itl_ms', 'p99_itl_ms' + ]: + fig, ax = plt.subplots(figsize=(11, 7)) - plt.plot(dis_df['qps'], dis_df[key], label='disagg_prefill', marker='o', linewidth=4) - plt.plot(chu_df['qps'], chu_df[key], label='chunked_prefill', marker='o', linewidth=4) + plt.plot(dis_df['qps'], + dis_df[key], + label='disagg_prefill', + marker='o', + linewidth=4) + plt.plot(chu_df['qps'], + chu_df[key], + label='chunked_prefill', + marker='o', + linewidth=4) ax.legend() ax.set_xlabel('QPS') @@ -42,6 +44,3 @@ ax.set_ylim(bottom=0) fig.savefig(f'results/{key}.png') plt.close(fig) - - - \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py index 733bc82bf53f9..80802f87987ac 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py @@ -1,21 +1,22 @@ - from abc import ABC, abstractmethod -from typing import Optional +from typing import List, Optional + import torch class KVLookupBufferBase(ABC): - + @abstractmethod - def insert(self, - input_tokens: torch.Tensor, - kv: torch.Tensor, roi) -> None: + def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor) -> None: raise NotImplementedError - + @abstractmethod - def drop_select(self, input_tokens, roi) -> Optional[torch.Tensor]: + def drop_select(self, input_tokens: torch.Tensor, + roi: torch.Tensor) -> List[Optional[torch.Tensor]]: raise NotImplementedError - + @abstractmethod def close(self): """ diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py index 6172bf092fb03..9696032002fda 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py @@ -1,22 +1,21 @@ - -from vllm.distributed.kv_transfer.kv_lookup_buffer.base import \ - KVLookupBufferBase -from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase -from typing import Dict, Tuple, List, Optional, Union import threading -import torch -from collections import deque import time +from collections import deque +from typing import Deque, List, Optional, Union + +import torch +from vllm.distributed.kv_transfer.kv_lookup_buffer.base import ( + KVLookupBufferBase) +from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase from vllm.logger import init_logger logger = init_logger(__name__) + class SimpleKVLookupBuffer(KVLookupBufferBase): - - def __init__(self, - signal_pipe: KVPipeBase, - data_pipe: KVPipeBase, + + def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, buffer_size_thresh: int): """ signal_pipe: on CPU @@ -28,72 +27,66 @@ def __init__(self, data_pipe: on device (e.g. GPU) """ - - self.buffer = deque() - + + self.buffer: Deque[List[torch.Tensor]] = deque() + self.buffer_size = 0 self.buffer_size_threshold = buffer_size_thresh self.buffer_lock = threading.Lock() self.signal_pipe = signal_pipe self.data_pipe = data_pipe - self.request_handling_thread = None + self.request_handling_thread: Optional[threading.Thread] = None self.normal_signal = torch.tensor([0]) self.end_signal = None - - def _matches(self, - tokens_roi_sender: List[torch.Tensor], + def _matches(self, tokens_roi_sender: List[torch.Tensor], tokens_roi_recver: List[torch.Tensor]): # tokens_roi_sender: tokens and roi of the producer (in the buffer) # tokens_roi_recver: tokens and roi of the consumer (query) - + tokens_sender = tokens_roi_sender[0] tokens_recver = tokens_roi_recver[0] roi_sender = tokens_roi_sender[1] roi_recver = tokens_roi_recver[1] - + if tokens_recver is None: # consumer sends an empty request # semantics: DROP SELECT * LIMIT 1 # so any of the data in the buffer can be drop-selected return True - # Assuming that roi is a mask on tokens tokens_sender = tokens_sender[roi_sender] tokens_recver = tokens_recver[roi_recver] - - + # simple common prefix matching min_length = min(len(tokens_sender), len(tokens_recver)) - if torch.allclose(tokens_sender[:min_length], tokens_recver[:min_length]): + if torch.allclose(tokens_sender[:min_length], + tokens_recver[:min_length]): return min_length - + return 0 - - def _send_tensor_and_dec_size(self, tensor: Optional[torch.Tensor]) -> None: + def _send_tensor_and_dec_size(self, + tensor: Optional[torch.Tensor]) -> None: assert tensor is not None, "Use self.data_pipe.send(None) instead" self.buffer_size -= tensor.element_size() * tensor.numel() self.data_pipe.send_tensor(tensor) def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]): - + if data == [] or data is None: return 0 if isinstance(data, torch.Tensor): return data.element_size() * data.numel() + else: + raise AssertionError("Unknown data type %s" % type(data)) - assert False, "Unknown data type %s" % type(data) - - def _add_to_buffer(self, - input_tokens: torch.Tensor, - roi: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, + def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, hidden: torch.Tensor): if isinstance(input_tokens, torch.Tensor): @@ -107,21 +100,20 @@ def _add_to_buffer(self, if isinstance(hidden, torch.Tensor): hidden = hidden.clone() - buffer_item = [input_tokens, roi, key, value, hidden] - + with self.buffer_lock: for data in buffer_item: self.buffer_size += self._get_element_size(data) self.buffer.append(buffer_item) - + def _is_end_signal(self, signal): return signal is None - + def drop_select_handler(self): try: - + while True: signal = self.signal_pipe.recv_tensor() if self._is_end_signal(signal): @@ -132,28 +124,29 @@ def drop_select_handler(self): roi = self.data_pipe.recv_tensor() tokens_roi_recver = [input_tokens, roi] - + matched_length = 0 - + # perform input tokens and roi matching with self.buffer_lock: for _ in range(len(self.buffer)): - - temp_length = self._matches(self.buffer[0], tokens_roi_recver) + + temp_length = self._matches(self.buffer[0], + tokens_roi_recver) if temp_length > 0: matched_length = temp_length break # rotate the element we just accessed to the end self.buffer.rotate(-1) - + if matched_length > 0: # need to clone the tensor # in case the tensor is freed before sending finishes matched_item = self.buffer.popleft() for tensor in matched_item: self._send_tensor_and_dec_size(tensor) - + else: # no match, just send None for _ in range(5): @@ -164,60 +157,57 @@ def drop_select_handler(self): raise e logger.debug("Closing drop_select_handler") - - - def drop_select(self, - input_tokens: torch.Tensor, - roi: torch.Tensor): - + + def drop_select(self, input_tokens: torch.Tensor, + roi: torch.Tensor) -> List[Optional[torch.Tensor]]: + assert self.request_handling_thread is None, \ "drop_select should be called by the receiver" - if isinstance(input_tokens, torch.Tensor): input_tokens = input_tokens.clone() if isinstance(roi, torch.Tensor): roi = roi.clone() - + self.signal_pipe.send_tensor(self.normal_signal) self.data_pipe.send_tensor(input_tokens) self.data_pipe.send_tensor(roi) - + input_tokens = self.data_pipe.recv_tensor() roi = self.data_pipe.recv_tensor() key = self.data_pipe.recv_tensor() value = self.data_pipe.recv_tensor() hidden = self.data_pipe.recv_tensor() - + return [input_tokens, roi, key, value, hidden] - def full_handler(self): time.sleep(0.001) - - - def insert(self, input_tokens, roi, key, value, hidden) -> None: + + def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor) -> None: while self.buffer_size > self.buffer_size_threshold: # logger.debug("KV transfer buffer is full. Handling...") self.full_handler() - self._add_to_buffer(input_tokens, roi, key, value, hidden) - + # when calling the insert, the current process is a sender # need to launch the request handler and start listening to request. if self.request_handling_thread is None: self.request_handling_thread = threading.Thread( target=self.drop_select_handler) self.request_handling_thread.start() - - + def close(self): - if hasattr(self, "request_handling_thread") and self.request_handling_thread is not None: + if hasattr(self, "request_handling_thread" + ) and self.request_handling_thread is not None: self.request_handling_thread.join() else: - # TODO: have a explicit close signal and have a explicit way to check if it's requester + # TODO: have a explicit close signal and have a explicit way to + # check if it's requester self.signal_pipe.send_tensor(self.end_signal) diff --git a/vllm/distributed/kv_transfer/kv_pipe/base.py b/vllm/distributed/kv_transfer/kv_pipe/base.py index 7662a5893ceb2..0955b4e838896 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/base.py +++ b/vllm/distributed/kv_transfer/kv_pipe/base.py @@ -1,15 +1,17 @@ - from abc import ABC, abstractmethod +from typing import Optional + +import torch class KVPipeBase(ABC): - - @abstractmethod - def send_tensor(self, tensor): + + @abstractmethod + def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: raise NotImplementedError - - @abstractmethod - def recv_tensor(self): + + @abstractmethod + def recv_tensor(self) -> Optional[torch.Tensor]: raise NotImplementedError @abstractmethod diff --git a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py index 3a6a94bb0e752..911bce88a38f1 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py @@ -1,15 +1,16 @@ -from torch.distributed import Backend -import torch -from typing import List, Optional, Union import threading -from concurrent.futures import ThreadPoolExecutor import time +from concurrent.futures import ThreadPoolExecutor +from typing import List, Optional, Union +import torch +from torch.distributed import Backend + +from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase from vllm.logger import init_logger logger = init_logger(__name__) - # if the tensor is only one-element and only contains NONE_INT # this means that the sended object is None. NONE_INT = -150886311 @@ -42,17 +43,17 @@ class BrokenPipeException(Exception): + def __init__(self, message): self.message = message super().__init__(self.message) -class TorchDistributedPipe: +class TorchDistributedPipe(KVPipeBase): METADATA_LENGTH = 16 MAX_TENSOR_DIMENSIONS = 14 METADATA_DTYPE = torch.int64 - def __init__( self, group_ranks: List[List[int]], @@ -65,8 +66,7 @@ def __init__( for ranks in group_ranks: device_group = torch.distributed.new_group( - ranks, backend=torch_distributed_backend - ) + ranks, backend=torch_distributed_backend) if self.rank in ranks: self.ranks = ranks self.world_size = len(ranks) @@ -78,26 +78,24 @@ def __init__( self.device = self._select_device(torch_distributed_backend) - self.target_rank_for_send = self.ranks[ - (self.rank_in_group + 1) % self.world_size - ] - self.target_rank_for_recv = self.ranks[ - (self.rank_in_group - 1) % self.world_size - ] + self.target_rank_for_send = self.ranks[(self.rank_in_group + 1) % + self.world_size] + self.target_rank_for_recv = self.ranks[(self.rank_in_group - 1) % + self.world_size] # FIXME: why we need this? torch.set_default_device(self.device) - self.transport_thread = None + self.transport_thread: Optional[ThreadPoolExecutor] = None self.buffer_size = 0 self.buffer_size_lock = threading.Lock() self.none_tensor = torch.tensor([NONE_INT], device=self.device) # On-device tensors to be reused for recv - self.rcv_metadata_buffer = torch.zeros( - self.METADATA_LENGTH, dtype=self.METADATA_DTYPE, device=self.device - ) + self.rcv_metadata_buffer = torch.zeros(self.METADATA_LENGTH, + dtype=self.METADATA_DTYPE, + device=self.device) def _select_device(self, backend: Union[str, Backend]): if torch.cuda.is_available() and backend == Backend.NCCL: @@ -129,14 +127,12 @@ def _make_metadata(self, tensor: torch.Tensor) -> torch.Tensor: buffer[0] = DTYPE2INT[tensor.dtype] ndims = len(tensor.shape) buffer[1] = len(tensor.shape) - buffer[2 : 2 + ndims] = torch.tensor( - tensor.shape, dtype=self.METADATA_DTYPE - ) + buffer[2:2 + ndims] = torch.tensor(tensor.shape, + dtype=self.METADATA_DTYPE) return buffer.to(self.device) - def _prepare_recv_buffer( - self, d_metadata_buffer: torch.Tensor - ) -> torch.Tensor: + def _prepare_recv_buffer(self, + d_metadata_buffer: torch.Tensor) -> torch.Tensor: """ Create a buffer to receive the tensor based on the metadata. @@ -149,7 +145,7 @@ def _prepare_recv_buffer( h_buffer = d_metadata_buffer.cpu().numpy() dtype = INT2DTYPE[h_buffer[0]] ndims = h_buffer[1] - shape = tuple(h_buffer[2 : 2 + ndims]) + shape = tuple(h_buffer[2:2 + ndims]) return torch.empty(shape, dtype=dtype, device=self.device) def _send_metadata(self, d_metadata_buffer: torch.Tensor): @@ -174,7 +170,7 @@ def _recv_metadata(self) -> torch.Tensor: race conditions during sending/receiving. Therefore, the metadata buffer can be reused """ - task = torch.distributed.recv( + torch.distributed.recv( self.rcv_metadata_buffer, src=self.target_rank_for_recv, group=self.device_group, @@ -194,9 +190,9 @@ def _send_impl(self, tensor): metadata = self._make_metadata(tensor) self._send_metadata(metadata) - torch.distributed.send( - tensor, dst=self.target_rank_for_send, group=self.device_group - ) + torch.distributed.send(tensor, + dst=self.target_rank_for_send, + group=self.device_group) def _recv_impl(self) -> torch.Tensor: """ @@ -211,9 +207,9 @@ def _recv_impl(self) -> torch.Tensor: d_metadata = self._recv_metadata() buffer = self._prepare_recv_buffer(d_metadata) - torch.distributed.recv( - buffer, src=self.target_rank_for_recv, group=self.device_group - ) + torch.distributed.recv(buffer, + src=self.target_rank_for_recv, + group=self.device_group) return buffer @@ -227,13 +223,9 @@ def send_tensor_wrapper(self, tensor): self.buffer_size = self.buffer_size - tensor_size except Exception as e: logger.error("[rank%d]: Exception when trying to send %s, msg: %s", - torch.distributed.get_rank(), - str(tensor), - str(e)) + torch.distributed.get_rank(), str(tensor), str(e)) import traceback traceback.print_exc() - - def block_if_full(self): """ @@ -268,13 +260,11 @@ def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: with self.buffer_size_lock: self.buffer_size = self.buffer_size + tensor_size - self.transport_thread.submit( self.send_tensor_wrapper, tensor, ) - def recv_tensor(self) -> Optional[torch.Tensor]: """Receives a tensor from the src rank. Blocking.""" if self.transport_thread is None: @@ -300,8 +290,6 @@ def close(self): """ Close the pipe and release the resources. """ - if ( - hasattr(self, "transport_thread") - and self.transport_thread is not None - ): + if (hasattr(self, "transport_thread") + and self.transport_thread is not None): self.transport_thread.shutdown() diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py index 9a6b55cbbe660..03392ec13f10b 100644 --- a/vllm/distributed/kv_transfer/vllm_adapter.py +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -1,59 +1,59 @@ """vLLM distributed KV cache transfer API. These APIs are used in `vllm/worker/worker_base.py`. -Currently supporting TP. The TP between prefill and decode instance needs to be the same. +Currently supporting TP. The TP between prefill and decode instance needs to be +the same. Workflow (disaggregated prefill) - In prefill instance - After prefill, vLLM `insert` its KV caches into a lookup buffer. - - The prefill instance will also open up a thread that listens to `drop_select` request. + - The prefill instance will also open up a thread that listens to + `drop_select` request. - In decode instance - - vLLM first runs `drop_select` to send input tokens and a mask on input tokens (we call it roi, region of interest) to prefill instance + - vLLM first runs `drop_select` to send input tokens and a mask on input + tokens (we call it roi, region of interest) to prefill instance - The prefill instance then respond to `drop_select` request by - Finding a match in current lookup buffer. - Clone and send the matched item out - Delete the matched item in the lookup buffer to free up GPU memory. - The decode vLLM then store the KV cache into paged memory. """ -from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING -from collections import defaultdict, deque -from concurrent.futures import ThreadPoolExecutor -from threading import Lock +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + from copy import deepcopy -import time -import threading import torch -from torch.distributed import Backend, ProcessGroup +from torch.distributed import Backend import vllm.envs as envs -from vllm.logger import init_logger -import vllm.distributed.parallel_state as ps from vllm import _custom_ops as ops +from vllm.distributed.kv_transfer.kv_lookup_buffer.base import ( + KVLookupBufferBase) +from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_kv_lookup_buffer import ( + SimpleKVLookupBuffer) +from vllm.distributed.kv_transfer.kv_pipe.torch_distributed_pipe import ( + TorchDistributedPipe) +from vllm.logger import init_logger from vllm.sequence import IntermediateTensors -from vllm.distributed.kv_transfer.kv_pipe.torch_distributed_pipe import TorchDistributedPipe -from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_kv_lookup_buffer import SimpleKVLookupBuffer - -from copy import deepcopy -assert envs.VLLM_DISAGG_PREFILL_ROLE in [None, "prefill", "decode", "lmcache"], \ +assert envs.VLLM_DISAGG_PREFILL_ROLE in [None, "prefill", "decode", "lmcache"],\ "VLLM_DISAGG_PREFILL_ROLE can only be prefill, decode or lmcache." - # currently the connections are hard-coded. # we only handle 2 cases: # - prefill vLLM --> decode vLLM # - vLLM --> LMCache -IS_DISTRIBUTED_KV_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE in ["prefill", "decode"]) +IS_DISTRIBUTED_KV_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE + in ["prefill", "decode"]) IS_KV_PREFILL_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE == "prefill") IS_KV_DECODE_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE == "decode") IS_LMCACHE_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE == "lmcache") - logger = init_logger(__name__) -import logging - class KV_transfer_agent: """ @@ -70,11 +70,13 @@ def __init__( local_rank: int, torch_distributed_backend: Union[str, Backend], # FIXME(Kuntai): remove this hardcoding - lookup_buffer_size: int = 1e10 - ): - + lookup_buffer_size: int = int(1e10)): + self.lookup_buffer_size = lookup_buffer_size - + + self.send_buffer: Optional[KVLookupBufferBase] = None + self.recv_buffer: Optional[KVLookupBufferBase] = None + if IS_LMCACHE_INSTANCE: # when vLLM is connected with LMCache # it needs to both send and recv KV cache @@ -98,14 +100,12 @@ def __init__( local_rank, "gloo", ) - self.send_buffer = SimpleKVLookupBuffer( - self.send_signal_pipe, - self.send_pipe, - self.lookup_buffer_size) - self.recv_buffer = SimpleKVLookupBuffer( - self.recv_signal_pipe, - self.recv_pipe, - self.lookup_buffer_size) + self.send_buffer = SimpleKVLookupBuffer(self.send_signal_pipe, + self.send_pipe, + self.lookup_buffer_size) + self.recv_buffer = SimpleKVLookupBuffer(self.recv_signal_pipe, + self.recv_pipe, + self.lookup_buffer_size) else: # when performing disaggregated prefill, only 1 pipe is needed # at prefill instance this pipe is used for send KV cache @@ -120,24 +120,25 @@ def __init__( local_rank, "gloo", ) - buffer = SimpleKVLookupBuffer( - self.signal_pipe, - self.pipe, - self.lookup_buffer_size) + buffer = SimpleKVLookupBuffer(self.signal_pipe, self.pipe, + self.lookup_buffer_size) self.send_buffer = buffer self.recv_buffer = buffer - + def send_kv_caches_and_hidden_states( self, model_executable: torch.nn.Module, model_input: "ModelInputForGPUWithSamplingMetadata", kv_caches: List[torch.Tensor], - hidden_or_intermediate_states: Union[torch.Tensor, IntermediateTensors], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], ) -> None: input_tokens_tensor = model_input.input_tokens seq_lens = model_input.attn_metadata.seq_lens slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() + start_layer = model_executable.model.start_layer + end_layer = model_executable.model.end_layer # query_lens contains new KV caches that are added to vLLM. # so we will send them to decode instance @@ -146,13 +147,11 @@ def send_kv_caches_and_hidden_states( start_pos = sum(seq_lens[:idx]) end_pos = start_pos + slen current_tokens = input_tokens_tensor[start_pos:end_pos] - + keys, values = [], [] - - - for l in range(model_executable.model.start_layer, - model_executable.model.end_layer): - kv_cache = kv_caches[l - model_executable.model.start_layer] + + for layer_id in range(start_layer, end_layer): + kv_cache = kv_caches[layer_id - start_layer] _, _, num_heads, head_size = kv_cache[0].shape @@ -163,29 +162,31 @@ def send_kv_caches_and_hidden_states( keys.append(key_cache[current_slot_mapping].unsqueeze(0)) values.append(value_cache[current_slot_mapping].unsqueeze(0)) - + keys = torch.cat(keys, dim=0) values = torch.cat(values, dim=0) - self.send_buffer.insert( - current_tokens, - torch.ones_like(current_tokens, dtype=bool), - keys, - values, - hidden_or_intermediate_states[start_pos:end_pos] - ) - + if self.send_buffer is not None: + self.send_buffer.insert( + current_tokens, torch.ones_like(current_tokens, + dtype=bool), keys, values, + hidden_or_intermediate_states[start_pos:end_pos]) logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) + def destroy(self) -> None: + if self.send_buffer is not None: + self.send_buffer.close() + if self.recv_buffer is not None: + self.recv_buffer.close() def recv_kv_caches_and_hidden_states( - self, - model_executable: torch.nn.Module, + self, model_executable: torch.nn.Module, model_input: "ModelInputForGPUWithSamplingMetadata", kv_caches: List[torch.Tensor] - ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, "ModelInputForGPUWithSamplingMetadata"]: + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: - # When this flag is set to False, it means that + # When this flag is set to False, it means that bypass_model_exec = True # This is disagg decode instance, during prefill state @@ -199,7 +200,7 @@ def recv_kv_caches_and_hidden_states( input_tokens_list = [] num_computed_tokens_list = [] start_pos_list = [] - + # enumerate different requests # FIXME(Kuntai): This impl assumes that all requests are prefill. for idx, slen in enumerate(seq_lens): @@ -211,28 +212,34 @@ def recv_kv_caches_and_hidden_states( input_tokens_list.append(current_tokens) start_pos_list.append(start_pos) - + + if self.recv_buffer is None: + bypass_model_exec = False + break + ret = self.recv_buffer.drop_select( - current_tokens, - torch.ones_like(current_tokens, dtype=bool)) + current_tokens, torch.ones_like(current_tokens, dtype=bool)) if ret[0] is None: # didn't find any match. bypass_model_exec = False num_computed_tokens_list.append(0) continue - + # TODO(Jiayi): change the logic here (need roi) - _, roi, keys, values, hidden = ret - + roi: torch.Tensor = ret[1] + keys: torch.Tensor = ret[2] + values: torch.Tensor = ret[3] + hidden: torch.Tensor = ret[4] + # Jiayi: currently assume roi is a prefix - num_computed_tokens = len(roi) + num_computed_tokens = roi.shape[0] num_computed_tokens_list.append(num_computed_tokens) is_complete = (num_computed_tokens == num_tokens) end_pos = start_pos + num_computed_tokens - + # receive KV cache from disaggregated prefill instance for i in range(model_executable.model.start_layer, - model_executable.model.end_layer): + model_executable.model.end_layer): kv_cache = kv_caches[i - model_executable.model.start_layer] layer = model_executable.model.layers[i] @@ -251,12 +258,13 @@ def recv_kv_caches_and_hidden_states( hidden_or_intermediate_states_for_one_req.append(hidden) - # FIXME(Jiayi): we need to support only skip m out of n reqs in a batch + # FIXME(Jiayi): we need to support only skip m out of n reqs in a batch # same for prefix caching if not bypass_model_exec: # Some of the KV cache is not retrieved # so we need to recompute the hidden state - logger.debug("[rank%d]: KV EMPTY recv DONE.", torch.distributed.get_rank()) + logger.debug("[rank%d]: KV EMPTY recv DONE.", + torch.distributed.get_rank()) return None, bypass_model_exec, None if not is_complete: @@ -268,17 +276,17 @@ def recv_kv_caches_and_hidden_states( slot_mapping, device=kv_cache[0].device, ) - logger.debug("[rank%d]: KV PARTIAL recv DONE.", torch.distributed.get_rank()) + logger.debug("[rank%d]: KV PARTIAL recv DONE.", + torch.distributed.get_rank()) return None, bypass_model_exec, rebuilt_model_input - + # concatenate hidden states from different requests hidden_or_intermediate_states = torch.cat( hidden_or_intermediate_states_for_one_req, dim=0) logger.debug("[rank%d]: KV recv DONE.", torch.distributed.get_rank()) return hidden_or_intermediate_states, bypass_model_exec, model_input - - + def build_partial_prefill_input( self, model_input: "ModelInputForGPUWithSamplingMetadata", @@ -289,70 +297,77 @@ def build_partial_prefill_input( device: torch.device, ) -> "ModelInputForGPUWithSamplingMetadata": rebuilt_input_tokens = [] - rebuilt_input_positions= [] + rebuilt_input_positions = [] rebuilt_query_lens = [] - + rebuilt_num_prefills = 0 rebuilt_num_prefill_tokens = 0 rebuilt_slot_mapping = [] rebuilt_max_query_len = 0 - + rebuilt_block_tables = [] - + rebuilt_query_start_loc = [0] rebuilt_context_lens_tensor = [] rebuilt_selected_token_indices = [] - + # recounting query and context lengths for idx in range(len(input_tokens_list)): token_tensor = input_tokens_list[idx] num_token = len(token_tensor) num_computed_token = num_computed_tokens_list[idx] start_pos = start_pos_list[idx] - + rebuilt_input_tokens.append(token_tensor[num_computed_token:]) # TODO(Jiayi): please check the correctness of next line - rebuilt_input_positions.append(model_input.input_positions[start_pos+num_computed_token:start_pos+num_token]) + rebuilt_input_positions.append( + model_input.input_positions[start_pos + + num_computed_token:start_pos + + num_token]) q_len = num_token - num_computed_token rebuilt_query_lens.append(q_len) - + # Attn metadata-related rebuilt_num_prefills += 1 rebuilt_num_prefill_tokens += q_len - rebuilt_slot_mapping.append(slot_mapping_flat[start_pos+num_computed_token:start_pos+num_token]) + rebuilt_slot_mapping.append( + slot_mapping_flat[start_pos + num_computed_token:start_pos + + num_token]) rebuilt_max_query_len = max(q_len, rebuilt_max_query_len) # TODO(Jiayi): remove hard-code (block_size=16) blk_size = 16 - temp_block_table = [i//blk_size for i in range(start_pos, start_pos+num_token, blk_size)] + temp_block_table = [ + i // blk_size + for i in range(start_pos, start_pos + num_token, blk_size) + ] rebuilt_block_tables.append(temp_block_table) - rebuilt_query_start_loc.append(q_len) #start with 0 + rebuilt_query_start_loc.append(q_len) #start with 0 rebuilt_context_lens_tensor.append(num_computed_token) - + # Sampling metadata related #seq_groups (use rebuilt query lens) - rebuilt_selected_token_indices.append(start_pos+q_len-1) - - + rebuilt_selected_token_indices.append(start_pos + q_len - 1) + # rebuilt attn_metadata rebuilt_attn_metadata = deepcopy(model_input.attn_metadata) rebuilt_attn_metadata.num_prefills = rebuilt_num_prefills rebuilt_attn_metadata.num_prefill_tokens = rebuilt_num_prefill_tokens - rebuilt_attn_metadata.slot_mapping = torch.cat(rebuilt_slot_mapping).to(device) + rebuilt_attn_metadata.slot_mapping = torch.cat( + rebuilt_slot_mapping).to(device) rebuilt_attn_metadata.max_query_len = rebuilt_max_query_len - + rebuilt_attn_metadata.block_tables = torch.tensor( rebuilt_block_tables, - dtype=model_input.attn_metadata.block_tables.dtype - ).to(device) - + dtype=model_input.attn_metadata.block_tables.dtype).to(device) + rebuilt_attn_metadata.query_start_loc = torch.tensor( rebuilt_query_start_loc, dtype=model_input.attn_metadata.query_start_loc.dtype).to(device) rebuilt_attn_metadata.context_lens_tensor = torch.tensor( - rebuilt_context_lens_tensor, + rebuilt_context_lens_tensor, dtype=model_input.attn_metadata.context_lens_tensor.dtype, - ).to(device) - + ).to(device) + rebuilt_attn_metadata._cached_prefill_metadata = None # rebuilt sampling_metadata @@ -362,26 +377,27 @@ def build_partial_prefill_input( rebuilt_sampling_metadata.selected_token_indices = torch.tensor( rebuilt_selected_token_indices, dtype=model_input.sampling_metadata.selected_token_indices.dtype, - ).to(device) - + ).to(device) + # import here to avoid circular import. - from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + from vllm.worker.model_runner import ( + ModelInputForGPUWithSamplingMetadata) rebuilt_model_input = ModelInputForGPUWithSamplingMetadata( - input_tokens = torch.cat(rebuilt_input_tokens).to(device), - input_positions = torch.cat(rebuilt_input_positions).to(device), - seq_lens = model_input.seq_lens, - query_lens = rebuilt_query_lens, - lora_mapping = model_input.lora_mapping, - lora_requests = model_input.lora_requests, - attn_metadata = rebuilt_attn_metadata, - prompt_adapter_mapping = model_input.prompt_adapter_mapping, - prompt_adapter_requests = model_input.prompt_adapter_requests, - multi_modal_kwargs = model_input.multi_modal_kwargs, - request_ids_to_seq_ids = model_input.request_ids_to_seq_ids, - finished_requests_ids = model_input.finished_requests_ids, - virtual_engine = model_input.virtual_engine, - sampling_metadata = rebuilt_sampling_metadata, - is_prompt = model_input.is_prompt, + input_tokens=torch.cat(rebuilt_input_tokens).to(device), + input_positions=torch.cat(rebuilt_input_positions).to(device), + seq_lens=model_input.seq_lens, + query_lens=rebuilt_query_lens, + lora_mapping=model_input.lora_mapping, + lora_requests=model_input.lora_requests, + attn_metadata=rebuilt_attn_metadata, + prompt_adapter_mapping=model_input.prompt_adapter_mapping, + prompt_adapter_requests=model_input.prompt_adapter_requests, + multi_modal_kwargs=model_input.multi_modal_kwargs, + request_ids_to_seq_ids=model_input.request_ids_to_seq_ids, + finished_requests_ids=model_input.finished_requests_ids, + virtual_engine=model_input.virtual_engine, + sampling_metadata=rebuilt_sampling_metadata, + is_prompt=model_input.is_prompt, ) - - return rebuilt_model_input \ No newline at end of file + + return rebuilt_model_input diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 13527110a2232..3615fa6af399c 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -20,29 +20,25 @@ parallelism, you can skip the model parallel initialization and destruction steps. """ -import time import contextlib import pickle -import logging +import time from collections import namedtuple from contextlib import contextmanager, nullcontext from dataclasses import dataclass from multiprocessing import shared_memory from typing import Any, Dict, List, Optional, Tuple, Union from unittest.mock import patch -import queue import torch import torch.distributed from torch.distributed import Backend, ProcessGroup -import vllm.envs as envs -from vllm.logger import init_logger - - # Use this import to check if disagg prefill is enabled. # if enabled, need to adjust distributed group correspondingly. import vllm.distributed.kv_transfer.vllm_adapter as dist_kv +import vllm.envs as envs +from vllm.logger import init_logger @dataclass @@ -865,7 +861,8 @@ def include_decoding_groups_if_disagg_enabled( Extended: [ [0,1], [2,3], [4,5], [6,7] ] Arguments: groups: original distributed group - world_size: the vLLM world size, which is half of torch.distributed.get_world_size() + world_size: the vLLM world size, which is half of + torch.distributed.get_world_size() """ if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: @@ -908,9 +905,8 @@ def init_distributed_environment( # offset global rank by tp * pp (which is world_size) maybe_disagg_rank = rank + world_size - logger.debug( - f"Before: world size {maybe_disagg_world_size}, rank {maybe_disagg_rank}" - ) + logger.debug("Before: world size %d, rank %d", maybe_disagg_world_size, + maybe_disagg_rank) torch.distributed.init_process_group( backend=backend, @@ -974,17 +970,18 @@ def initialize_model_parallel( ranks 8 to 15 belong to the second box. - Disaggregated prefill will also initialize its process group using this function. + Disaggregated prefill will also init its process group using this function. Changes: - vLLM world size: unchanged (tp * pp) - torch.distributed.get_world_size(): - 2 * tp * pp - - Why: torch.distributed package sees 2 vLLM instances (prefill and decode) + - Why: both prefill vLLM and decode vLLM is in the world - Global rank: - [0, tp * pp) for prefill - [tp * pp, 2 * tp * pp) for decode - Parallel groups - - Extend _WORLD, _TP and _PP using `include_decoding_groups_if_disagg_enabled` + - Extend _WORLD, _TP and _PP using + `include_decoding_groups_if_disagg_enabled` - Add a new parallel group `_DISAGG` for disaggregated prefill - [ [0, tp * pp], [1, tp * pp + 1], .. ] - Local rank: unchanged @@ -997,12 +994,14 @@ def initialize_model_parallel( get_world_group().device_group) if dist_kv.IS_DISTRIBUTED_KV_INSTANCE or dist_kv.IS_LMCACHE_INSTANCE: # Disaggregated prefill enabled - # The world_size for this vLLM instance is tp * pp, but torch.distributed contains 2 vLLM instances, its world size is 2 * tp * pp + # The world_size for this vLLM instance is tp * pp, but + # torch.distributed contains 2 vLLM instances, + # its world size is 2 * tp * pp # Adjust the world_size to match. world_size = world_size // 2 - if (world_size - != tensor_model_parallel_size * pipeline_model_parallel_size): + if (world_size != + tensor_model_parallel_size * pipeline_model_parallel_size): raise RuntimeError( f"world_size ({world_size}) is not equal to " f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 679f8394688e8..b774a649d39f5 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -48,7 +48,8 @@ def _get_worker_kwargs( """Return worker init args for a given rank.""" if distributed_init_method is None: distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port(force=dist_kv.IS_DISTRIBUTED_KV_INSTANCE)) + get_ip(), + get_open_port(force=dist_kv.IS_DISTRIBUTED_KV_INSTANCE)) return dict( model_config=self.model_config, parallel_config=self.parallel_config, diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 9448228879453..499e891d98fc0 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -70,7 +70,8 @@ def _init_executor(self) -> None: # Since it only works for single node, we can use the loopback address # 127.0.0.1 for communication. distributed_init_method = get_distributed_init_method( - "127.0.0.1", get_open_port(force=dist_kv.IS_DISTRIBUTED_KV_INSTANCE)) + "127.0.0.1", + get_open_port(force=dist_kv.IS_DISTRIBUTED_KV_INSTANCE)) self.workers: List[ProcessWorkerWrapper] = [] # This is the list of workers that are rank 0 of each TP group EXCEPT diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index c646e8536ba15..0cca5db1677ed 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -6,8 +6,8 @@ import msgspec -import vllm.envs as envs import vllm.distributed.kv_transfer.vllm_adapter as dist_kv +import vllm.envs as envs from vllm.executor.distributed_gpu_executor import ( # yapf: disable DistributedGPUExecutor, DistributedGPUExecutorAsync) from vllm.executor.msgspec_utils import encode_hook diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index ab38302b3321a..b846d1d707db0 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -14,7 +14,6 @@ import torch.distributed import torch.nn as nn - import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention.backends.abstract import AttentionState @@ -55,7 +54,6 @@ _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_tensor_dict, _init_sampling_metadata_from_tensor_dict, dump_input_when_exception) -from vllm import _custom_ops as ops if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -1546,30 +1544,20 @@ def execute_model( model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_start.record() - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(multi_modal_kwargs, - device=self.device), - **seqlen_agnostic_kwargs) - + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), + **seqlen_agnostic_kwargs) + if (self.observability_config is not None and self.observability_config.collect_model_forward_time): model_forward_end.record() - return hidden_or_intermediate_states - - @torch.inference_mode() - def postprocess_model( - self, - model_input, - hidden_or_intermediate_states, - - ): if not get_pp_group().is_last_rank: if (self.is_driver_worker and hidden_or_intermediate_states is not None @@ -1587,7 +1575,16 @@ def postprocess_model( hidden_or_intermediate_states.tensors["model_forward_time"] = ( torch.tensor(model_forward_time + orig_model_forward_time)) return hidden_or_intermediate_states - + + return hidden_or_intermediate_states + + @torch.inference_mode() + def postprocess_model( + self, + model_input: ModelInputForGPUWithSamplingMetadata, + hidden_or_intermediate_states, + ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: + logits = self.model.compute_logits(hidden_or_intermediate_states, model_input.sampling_metadata) @@ -1603,6 +1600,7 @@ def postprocess_model( sampling_metadata=model_input.sampling_metadata, ) + assert model_input.attn_metadata is not None decode_meta = model_input.attn_metadata.decode_metadata if self.return_hidden_states: # we only need to pass hidden states of most recent token @@ -1620,9 +1618,7 @@ def postprocess_model( output.hidden_states = hidden_states return [output] - - - + class CUDAGraphRunner: diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 22577ebf69492..7908fc466eb38 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -7,6 +7,8 @@ import torch +import vllm.distributed.kv_transfer.vllm_adapter as dist_kv +import vllm.distributed.parallel_state as ps from vllm.config import ObservabilityConfig from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group from vllm.logger import init_logger @@ -16,13 +18,11 @@ from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.utils import (enable_trace_function_call_for_thread, update_environment_variables) +from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner from vllm.worker.model_runner_base import (BroadcastableModelInput, ModelRunnerBase, ModelRunnerInputBase) -import vllm.distributed.kv_transfer.vllm_adapter as dist_kv -import vllm.distributed.parallel_state as ps - logger = init_logger(__name__) @@ -223,7 +223,6 @@ def execute_worker(self, worker_input: WorkerInput) -> None: Process an execution request. """ raise NotImplementedError - def _get_worker_input_from_broadcast( self @@ -327,19 +326,14 @@ def execute_model( and self.observability_config.collect_model_execute_time): orig_model_execute_time = intermediate_tensors.tensors.get( "model_execute_time", torch.tensor(0)).item() - - + # for disaggregated prefilling: allow bypassing model execution bypass_model_exec = False - - - # receive KV cache. - # NOTE(kuntai): - # If only a part of KV cache is received, we will adjust model_input - # to avoid prefill on the part of KV caches that are already received. - # This will not happen for disaggregated prefill, but will happen - # when connecting to a KV cache database (like LMCache). + + # receive KV cache from prefill instance, or from LMCache if self.need_recv_kv(model_input, worker_input): + assert isinstance(self.model_runner, GPUModelRunnerBase), \ + "Distributed KV transfer only support GPU modelrunner" hidden_or_intermediate_states, bypass_model_exec, model_input = \ ps.get_disagg_group().recv_kv_caches_and_hidden_states( # model is used to know which layer the current worker @@ -347,11 +341,12 @@ def execute_model( # layers. self.model_runner.model, model_input, - self.kv_cache[worker_input.virtual_engine], + kv_caches=self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None, ) #assert bypass_model_exec - - if not bypass_model_exec: + + if not bypass_model_exec: hidden_or_intermediate_states = self.model_runner.execute_model( model_input=model_input, kv_caches=self.kv_cache[worker_input.virtual_engine] @@ -360,24 +355,31 @@ def execute_model( num_steps=num_steps, **kwargs, ) - + # sending out KV cache if self.need_send_kv(model_input, worker_input): + assert isinstance(self.model_runner, GPUModelRunnerBase), \ + "Distributed KV transfer only support GPU modelrunner" ps.get_disagg_group().send_kv_caches_and_hidden_states( # model is used to know which layer the current worker # is working on, so that we can send KV for only those # layers. self.model_runner.model, model_input, - self.kv_cache[worker_input.virtual_engine], + self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None, hidden_or_intermediate_states, ) - - # Get model output based on hidden state. - output = self.model_runner.postprocess_model( - model_input, - hidden_or_intermediate_states, - ) + + # separating postprocessing steps out from execute_model + # so that disaggregated prefill can completely bypass model forwarding + if isinstance(self.model_runner, ModelRunner): + output = self.model_runner.postprocess_model( + model_input, + hidden_or_intermediate_states, + ) + else: + output = hidden_or_intermediate_states model_execute_time = time.perf_counter() - start_time if not get_pp_group().is_last_rank: @@ -400,38 +402,43 @@ def execute_model( return output def need_recv_kv(self, model_input, worker_input) -> bool: - + + if self.kv_cache is None: + return False + kv_caches = self.kv_cache[worker_input.virtual_engine] prefill_meta = model_input.attn_metadata.prefill_metadata - + # check if the current run is profiling is_profile_run = (kv_caches is None) or (kv_caches[0] is None) # check if the current run is prefill is_prefill_run = prefill_meta is not None # for disaggregated prefilling: allow bypassing model execution - + return all([ - is_prefill_run, - dist_kv.IS_KV_DECODE_INSTANCE or dist_kv.IS_LMCACHE_INSTANCE, - not is_profile_run]) + is_prefill_run, dist_kv.IS_KV_DECODE_INSTANCE + or dist_kv.IS_LMCACHE_INSTANCE, not is_profile_run + ]) - def need_send_kv(self, model_input, worker_input) -> bool: - + + if self.kv_cache is None: + return False + kv_caches = self.kv_cache[worker_input.virtual_engine] prefill_meta = model_input.attn_metadata.prefill_metadata - model_executable = self.model_runner.model - + if not isinstance(self.model_runner, GPUModelRunnerBase): + return False + # check if the current run is profiling is_profile_run = (kv_caches is None) or (kv_caches[0] is None) # check if the current run is prefill is_prefill_run = prefill_meta is not None - + return all([ - is_prefill_run, - dist_kv.IS_KV_PREFILL_INSTANCE or dist_kv.IS_LMCACHE_INSTANCE, - not is_profile_run]) - + is_prefill_run, dist_kv.IS_KV_PREFILL_INSTANCE + or dist_kv.IS_LMCACHE_INSTANCE, not is_profile_run + ]) def _execute_model_spmd( self, From f78a2eb59f30c84060c34f9b0d623e285e45aeee Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 15 Sep 2024 23:48:40 +0000 Subject: [PATCH 228/303] resolve circular import --- vllm/utils.py | 2 +- vllm/worker/worker_base.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 1adab61917265..8e27e1f73f4b4 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -535,7 +535,7 @@ def get_open_port(force: bool = False) -> int: if force and port is not None: # force vLLM to use envs.VLLM_PORT for torch.distributed init # This is because this port will binded by prefill instance - # But both prefill and decode instance need to use this port to + # But both prefill and decode instance need to use this port to # initialize torch.distributed return port while True: diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 7908fc466eb38..d55400a402400 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -18,7 +18,6 @@ from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.utils import (enable_trace_function_call_for_thread, update_environment_variables) -from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner from vllm.worker.model_runner_base import (BroadcastableModelInput, ModelRunnerBase, ModelRunnerInputBase) @@ -332,6 +331,7 @@ def execute_model( # receive KV cache from prefill instance, or from LMCache if self.need_recv_kv(model_input, worker_input): + from vllm.worker.model_runner import GPUModelRunnerBase assert isinstance(self.model_runner, GPUModelRunnerBase), \ "Distributed KV transfer only support GPU modelrunner" hidden_or_intermediate_states, bypass_model_exec, model_input = \ @@ -358,6 +358,7 @@ def execute_model( # sending out KV cache if self.need_send_kv(model_input, worker_input): + from vllm.worker.model_runner import GPUModelRunnerBase assert isinstance(self.model_runner, GPUModelRunnerBase), \ "Distributed KV transfer only support GPU modelrunner" ps.get_disagg_group().send_kv_caches_and_hidden_states( @@ -373,6 +374,7 @@ def execute_model( # separating postprocessing steps out from execute_model # so that disaggregated prefill can completely bypass model forwarding + from vllm.worker.model_runner import ModelRunner if isinstance(self.model_runner, ModelRunner): output = self.model_runner.postprocess_model( model_input, @@ -427,6 +429,7 @@ def need_send_kv(self, model_input, worker_input) -> bool: kv_caches = self.kv_cache[worker_input.virtual_engine] prefill_meta = model_input.attn_metadata.prefill_metadata + from vllm.worker.model_runner import GPUModelRunnerBase if not isinstance(self.model_runner, GPUModelRunnerBase): return False From 44dfa3f7142a455b2f6cfb36bf4c57465106723e Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 15 Sep 2024 23:49:01 +0000 Subject: [PATCH 229/303] fix redundant import --- tests/kv_transfer/test_send_recv.py | 81 ++++++++++++++--------------- 1 file changed, 39 insertions(+), 42 deletions(-) diff --git a/tests/kv_transfer/test_send_recv.py b/tests/kv_transfer/test_send_recv.py index 4bf757d7c8492..994b907e0c899 100644 --- a/tests/kv_transfer/test_send_recv.py +++ b/tests/kv_transfer/test_send_recv.py @@ -1,10 +1,11 @@ +import os +import time +from typing import List -import vllm.distributed.kv_transfer.kv_pipe.torch_distributed_pipe as tdp import torch -import os -import random from tqdm import tqdm -import time + +import vllm.distributed.kv_transfer.kv_pipe.torch_distributed_pipe as tdp def test_run(my_rank, pipe): @@ -35,20 +36,19 @@ def test_run(my_rank, pipe): assert torch.allclose(y, y2) - def stress_test(my_rank, pipe): - + torch.distributed.barrier() - - tensors = [] - - + + tensors: List[torch.Tensor] = [] + for i in tqdm(range(2000)): mean = torch.rand(1).item() std = torch.rand(1).item() - size = torch.randint(900, 1000, (2,)) - x = torch.normal(mean * 1.0, std * 1.0, size=size.tolist()).to(pipe.device) - + size = torch.randint(900, 1000, (2, )) + x = torch.normal(mean * 1.0, std * 1.0, + size=size.tolist()).to(pipe.device) + # 5% probability of sending a None if torch.rand(1).item() < 0.05: tensors.append(None) @@ -59,15 +59,13 @@ def stress_test(my_rank, pipe): tensors.append(x.mean().unsqueeze(0)) tensors.append(x.std().unsqueeze(0)) - - torch.distributed.barrier() - + for i in tqdm(range(2000)): if my_rank == int((i % 10) > 3): - pipe.send_tensor(tensors[3*i]) - pipe.send_tensor(tensors[3*i+1]) - pipe.send_tensor(tensors[3*i+2]) + pipe.send_tensor(tensors[3 * i]) + pipe.send_tensor(tensors[3 * i + 1]) + pipe.send_tensor(tensors[3 * i + 2]) else: x = pipe.recv_tensor() mean = pipe.recv_tensor() @@ -76,34 +74,36 @@ def stress_test(my_rank, pipe): assert mean is None assert std is None else: - assert torch.allclose(x, tensors[3*i]) + assert torch.allclose(x, tensors[3 * i]) assert x.mean() == mean[0] assert x.std() == std[0] torch.distributed.barrier() print("Stress test passed.") - - - + + def latency_test(my_rank, pipe, nelement, ntensor): - + latencies = [] - + torch.distributed.barrier() - + for i in tqdm(range(1000)): - + tensors = [] - + if my_rank == 0: # create tensor - tensors = [torch.rand(nelement).to(pipe.device) for _ in range(ntensor)] - + tensors = [ + torch.rand(nelement).to(pipe.device) for _ in range(ntensor) + ] + torch.distributed.barrier() - + if my_rank == 0: - t = torch.tensor([time.time()], dtype=torch.float64).to(pipe.device) + t = torch.tensor([time.time()], + dtype=torch.float64).to(pipe.device) for tensor in tensors: pipe.send_tensor(tensor) pipe.send_tensor(t) @@ -114,7 +114,7 @@ def latency_test(my_rank, pipe, nelement, ntensor): latencies.append(time.time() - t.item()) torch.distributed.barrier() - + print('Latency test passed.') print('Latency:', torch.tensor(latencies).mean().item() * 1000, 'ms') @@ -123,18 +123,15 @@ def latency_test(my_rank, pipe, nelement, ntensor): my_rank = int(os.environ['RANK']) - - torch.distributed.init_process_group( - init_method="tcp://127.0.0.1:23456", - world_size=2, - rank=my_rank) + torch.distributed.init_process_group(init_method="tcp://127.0.0.1:23456", + world_size=2, + rank=my_rank) print("initialized! My rank is %d" % my_rank) + pipe = tdp.TorchDistributedPipe([[0, 1]], my_rank, "nccl") - pipe = tdp.TorchDistributedPipe([[0,1]], my_rank, "nccl") - - torch.manual_seed(0) + torch.manual_seed(0) test_run(my_rank, pipe) stress_test(my_rank, pipe) - latency_test(my_rank, pipe, 1024*8*128, 80) + latency_test(my_rank, pipe, 1024 * 8 * 128, 80) From 822f3dc82a7950109933a4e31ee7e6d4709129e1 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 15 Sep 2024 23:55:50 +0000 Subject: [PATCH 230/303] rename to a shorter name --- tests/kv_transfer/test_lookup_buffer.py | 80 +++++++++---------- ...e_kv_lookup_buffer.py => simple_buffer.py} | 0 vllm/distributed/kv_transfer/vllm_adapter.py | 5 +- 3 files changed, 42 insertions(+), 43 deletions(-) rename vllm/distributed/kv_transfer/kv_lookup_buffer/{simple_kv_lookup_buffer.py => simple_buffer.py} (100%) diff --git a/tests/kv_transfer/test_lookup_buffer.py b/tests/kv_transfer/test_lookup_buffer.py index ae19d068be9fa..0730f091a34b8 100644 --- a/tests/kv_transfer/test_lookup_buffer.py +++ b/tests/kv_transfer/test_lookup_buffer.py @@ -1,24 +1,25 @@ - -import vllm.distributed.kv_transfer.kv_pipe.torch_distributed_pipe as tdp -import vllm.distributed.kv_transfer.kv_lookup_buffer.simple_kv_lookup_buffer as sklb -import torch import os import random + +import torch from tqdm import tqdm -import time -# TODO: the test depends on a lot of fields in the current implementation. We should have standard interface instead direct field access +import vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer as sklb +import vllm.distributed.kv_transfer.kv_pipe.torch_distributed_pipe as tdp + +# TODO: the test depends on a lot of fields in the current implementation. +# We should have standard interface instead direct field access + def test_run(my_rank, buffer, device): - - # buffer should be empty in the beginning + + # buffer should be empty in the beginning if my_rank == 0: assert buffer.buffer_size == 0 assert len(buffer.buffer) == 0 - # insert - tokens = torch.tensor([1,2,3]).to(device) + tokens = torch.tensor([1, 2, 3]).to(device) roi = (tokens > 0) if my_rank == 0: key = 2.0 * torch.ones([5, 6]).to(device) @@ -27,45 +28,47 @@ def test_run(my_rank, buffer, device): placeholder = torch.tensor([1]).to(device) buffer.insert(tokens, roi, key, value, placeholder) - + torch.distributed.barrier() - + # drop_select if my_rank == 1: tok, roi_, key, value, hidden = buffer.drop_select(tokens, roi) assert torch.allclose(tokens, tok) assert torch.allclose(roi, roi_) - assert torch.allclose(key, 2.0 * torch.ones([5, 6], device = device)) - assert torch.allclose(value, 3.0 * torch.ones([5, 6], device = device)) + assert torch.allclose(key, 2.0 * torch.ones([5, 6], device=device)) + assert torch.allclose(value, 3.0 * torch.ones([5, 6], device=device)) torch.distributed.barrier() - + if my_rank == 0: assert buffer.buffer_size == 0 assert len(buffer.buffer) == 0 - + print("Test run passed!") + def stress_test(my_rank, buf, device): - + torch.distributed.barrier() torch.manual_seed(100) reqs = [ ( - torch.rand(100).to(device), # tokens - torch.ones(100).bool().to(device), # roi - torch.rand(100).to(device), # key - torch.rand(100).to(device), # value - torch.rand(100).to(device), # hidden - ) for i in tqdm(range(200))] + torch.rand(100).to(device), # tokens + torch.ones(100).bool().to(device), # roi + torch.rand(100).to(device), # key + torch.rand(100).to(device), # value + torch.rand(100).to(device), # hidden + ) for i in tqdm(range(200)) + ] random.seed(my_rank) random.shuffle(reqs) - + torch.distributed.barrier() - + n = 0 - + # the buffer size can only store 100 reqs # so the sender will occasionally block to wait for the receiver. for req in tqdm(reqs): @@ -74,7 +77,7 @@ def stress_test(my_rank, buf, device): else: tok, roi, k, v, h = req tok_, roi_, k_, v_, h_ = buf.drop_select(tok, roi) - + if tok_ is None: assert roi_ is None assert k_ is None @@ -89,8 +92,7 @@ def stress_test(my_rank, buf, device): assert torch.allclose(h, h_) print('Rank %d done' % my_rank) torch.distributed.barrier() - - + if my_rank == 0: x = torch.tensor([0]) torch.distributed.recv(x, 1) @@ -103,30 +105,26 @@ def stress_test(my_rank, buf, device): torch.distributed.send(torch.tensor([n]), 0) print("Passed stress test!") - - + if __name__ == "__main__": my_rank = int(os.environ['RANK']) - - torch.distributed.init_process_group( - init_method="tcp://127.0.0.1:23456", - world_size=2, - rank=my_rank) + torch.distributed.init_process_group(init_method="tcp://127.0.0.1:23456", + world_size=2, + rank=my_rank) print("initialized! My rank is %d" % my_rank) - - pipe = tdp.TorchDistributedPipe([[0,1]], my_rank, "nccl") - cpu_pipe = tdp.TorchDistributedPipe([[0,1]], my_rank, "gloo") + pipe = tdp.TorchDistributedPipe([[0, 1]], my_rank, "nccl") + cpu_pipe = tdp.TorchDistributedPipe([[0, 1]], my_rank, "gloo") buffer = sklb.SimpleKVLookupBuffer(cpu_pipe, pipe, 170000) test_run(my_rank, buffer, pipe.device) - + stress_test(my_rank, buffer, pipe.device) - + buffer.close() pipe.close() cpu_pipe.close() diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py similarity index 100% rename from vllm/distributed/kv_transfer/kv_lookup_buffer/simple_kv_lookup_buffer.py rename to vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py index 03392ec13f10b..2edb426c5c8da 100644 --- a/vllm/distributed/kv_transfer/vllm_adapter.py +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -28,12 +28,11 @@ import torch from torch.distributed import Backend +import vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer as sklb import vllm.envs as envs from vllm import _custom_ops as ops from vllm.distributed.kv_transfer.kv_lookup_buffer.base import ( KVLookupBufferBase) -from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_kv_lookup_buffer import ( - SimpleKVLookupBuffer) from vllm.distributed.kv_transfer.kv_pipe.torch_distributed_pipe import ( TorchDistributedPipe) from vllm.logger import init_logger @@ -77,6 +76,8 @@ def __init__( self.send_buffer: Optional[KVLookupBufferBase] = None self.recv_buffer: Optional[KVLookupBufferBase] = None + SimpleKVLookupBuffer = sklb.SimpleKVLookupBuffer + if IS_LMCACHE_INSTANCE: # when vLLM is connected with LMCache # it needs to both send and recv KV cache From 7682269f27d2c23384e6b1ceed6415c7edfe184d Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Mon, 16 Sep 2024 00:06:43 +0000 Subject: [PATCH 231/303] remove unnecessary file --- tests/test_send_recv.sh | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 tests/test_send_recv.sh diff --git a/tests/test_send_recv.sh b/tests/test_send_recv.sh deleted file mode 100644 index e69de29bb2d1d..0000000000000 From b6e5eb35a21964942aba677e5eb6835b35fa09c3 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Mon, 16 Sep 2024 00:40:26 +0000 Subject: [PATCH 232/303] update kv transfer test --- tests/kv_transfer/test_launcher.py | 52 +++++++++++++++++++++++++ tests/kv_transfer/test_lookup_buffer.sh | 3 -- tests/kv_transfer/test_send_recv.py | 10 +++-- tests/kv_transfer/test_send_recv.sh | 3 -- 4 files changed, 58 insertions(+), 10 deletions(-) create mode 100644 tests/kv_transfer/test_launcher.py delete mode 100644 tests/kv_transfer/test_lookup_buffer.sh delete mode 100644 tests/kv_transfer/test_send_recv.sh diff --git a/tests/kv_transfer/test_launcher.py b/tests/kv_transfer/test_launcher.py new file mode 100644 index 0000000000000..5c0aeb04b43fa --- /dev/null +++ b/tests/kv_transfer/test_launcher.py @@ -0,0 +1,52 @@ +import subprocess +import pytest +import sys +import torch + +def run_python_script(script_name, timeout): + try: + # Start both processes asynchronously using Popen + process0 = subprocess.Popen( + [sys.executable, script_name], + env={"RANK": "0"}, # Set the RANK environment variable for process 0 + stdout=sys.stdout, # Pipe stdout to current stdout + stderr=sys.stderr, # Pipe stderr to current stderr + ) + + process1 = subprocess.Popen( + [sys.executable, script_name], + env={"RANK": "1"}, # Set the RANK environment variable for process 1 + stdout=sys.stdout, # Pipe stdout to current stdout + stderr=sys.stderr, # Pipe stderr to current stderr + ) + + # Wait for both processes to complete, with a timeout + process0.wait(timeout=timeout) + process1.wait(timeout=timeout) + + # Check the return status of both processes + if process0.returncode != 0: + pytest.fail(f"Test {script_name} failed for RANK=0 with return code {process0.returncode}") + if process1.returncode != 0: + pytest.fail(f"Test {script_name} failed for RANK=1 with return code {process1.returncode}") + + except subprocess.TimeoutExpired: + # If either process times out, terminate both and fail the test + process0.terminate() + process1.terminate() + pytest.fail(f"Test {script_name} timed out") + except Exception as e: + pytest.fail(f"Test {script_name} failed with error: {str(e)}") + +# Define the test cases using pytest's parametrize +@pytest.mark.parametrize("script_name,timeout", [ + ("test_lookup_buffer.py", 60), # Second test case with a 60-second timeout + ("test_send_recv.py", 120) # First test case with a 120-second timeout +]) +def test_run_python_script(script_name, timeout): + # Check the number of GPUs + if torch.cuda.device_count() < 2: + pytest.skip(f"Skipping test {script_name} because fewer than 2 GPUs are available") + + # Run the test if there are at least 2 GPUs + run_python_script(script_name, timeout) diff --git a/tests/kv_transfer/test_lookup_buffer.sh b/tests/kv_transfer/test_lookup_buffer.sh deleted file mode 100644 index 336b540e70542..0000000000000 --- a/tests/kv_transfer/test_lookup_buffer.sh +++ /dev/null @@ -1,3 +0,0 @@ - -RANK=0 python3 test_lookup_buffer.py & -RANK=1 python3 test_lookup_buffer.py & diff --git a/tests/kv_transfer/test_send_recv.py b/tests/kv_transfer/test_send_recv.py index 994b907e0c899..f6da7f88d5f5c 100644 --- a/tests/kv_transfer/test_send_recv.py +++ b/tests/kv_transfer/test_send_recv.py @@ -42,7 +42,7 @@ def stress_test(my_rank, pipe): tensors: List[torch.Tensor] = [] - for i in tqdm(range(2000)): + for i in tqdm(range(500)): mean = torch.rand(1).item() std = torch.rand(1).item() size = torch.randint(900, 1000, (2, )) @@ -61,7 +61,7 @@ def stress_test(my_rank, pipe): torch.distributed.barrier() - for i in tqdm(range(2000)): + for i in tqdm(range(500)): if my_rank == int((i % 10) > 3): pipe.send_tensor(tensors[3 * i]) pipe.send_tensor(tensors[3 * i + 1]) @@ -89,7 +89,7 @@ def latency_test(my_rank, pipe, nelement, ntensor): torch.distributed.barrier() - for i in tqdm(range(1000)): + for i in tqdm(range(500)): tensors = [] @@ -134,4 +134,6 @@ def latency_test(my_rank, pipe, nelement, ntensor): torch.manual_seed(0) test_run(my_rank, pipe) stress_test(my_rank, pipe) - latency_test(my_rank, pipe, 1024 * 8 * 128, 80) + + # Use this function if you want to test the latency of pipe impl. + # latency_test(my_rank, pipe, 1024 * 8 * 128, 80) diff --git a/tests/kv_transfer/test_send_recv.sh b/tests/kv_transfer/test_send_recv.sh deleted file mode 100644 index 2a478871bd0e7..0000000000000 --- a/tests/kv_transfer/test_send_recv.sh +++ /dev/null @@ -1,3 +0,0 @@ - -RANK=0 python3 test_send_recv.py & -RANK=1 python3 test_send_recv.py & From 58f5080cb53a2491af0645ef4e6cbcbac7c47588 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Mon, 16 Sep 2024 01:12:57 +0000 Subject: [PATCH 233/303] update tests --- tests/kv_transfer/disagg_test.py | 107 ++++++++++++++++++ .../{test_launcher.py => module_test.py} | 0 2 files changed, 107 insertions(+) create mode 100644 tests/kv_transfer/disagg_test.py rename tests/kv_transfer/{test_launcher.py => module_test.py} (100%) diff --git a/tests/kv_transfer/disagg_test.py b/tests/kv_transfer/disagg_test.py new file mode 100644 index 0000000000000..2e8414a9f4642 --- /dev/null +++ b/tests/kv_transfer/disagg_test.py @@ -0,0 +1,107 @@ +import os +import sys +import subprocess +import time +import pytest +import requests +import signal +from subprocess import Popen +import torch + + +# Fixture to set up environment variables and teardown servers after tests +@pytest.fixture(scope="module", autouse=True) +def setup_servers(): + if torch.cuda.device_count() < 4: + pytest.skip("Skipping test: fewer than 4 GPUs available") + + # Set up environment variables + VLLM_HOST_IP = subprocess.check_output("hostname -I | awk '{print $1}'", shell=True).decode().strip() + os.environ["VLLM_HOST_IP"] = VLLM_HOST_IP + os.environ["VLLM_PORT"] = "12345" + + # Start prefill instance + prefill_cmd = [ + sys.executable, "-m", "vllm.entrypoints.openai.api_server", + "-tp", "2", + "--model", "meta-llama/Meta-Llama-3.1-8B-Instruct", + "--port", "8100", + "--gpu-memory-utilization", "0.8" + ] + prefill_env = os.environ.copy() + prefill_env["VLLM_DISAGG_PREFILL_ROLE"] = "prefill" + prefill_env["CUDA_VISIBLE_DEVICES"] = "0,1" + prefill_proc = Popen(prefill_cmd, env=prefill_env) + + # Start decode instance + decode_cmd = [ + sys.executable, "-m", "vllm.entrypoints.openai.api_server", + "-tp", "2", + "--model", "meta-llama/Meta-Llama-3.1-8B-Instruct", + "--port", "8200", + "--gpu-memory-utilization", "0.8" + ] + decode_env = os.environ.copy() + decode_env["VLLM_DISAGG_PREFILL_ROLE"] = "decode" + decode_env["CUDA_VISIBLE_DEVICES"] = "2,3" + decode_proc = Popen(decode_cmd, env=decode_env) + + # Wait for servers to be ready + assert wait_for_server(8100), "Prefill server did not start in time" + assert wait_for_server(8200), "Decode server did not start in time" + + # Yield to the test function and handle teardown after tests + yield + + # Cleanup: kill the processes + prefill_proc.terminate() + decode_proc.terminate() + + # Additional cleanup if needed + prefill_proc.wait() + decode_proc.wait() + +# Helper function to wait for server +def wait_for_server(port, timeout=120): + start_time = time.time() + while time.time() - start_time < timeout: + try: + response = requests.get(f"http://localhost:{port}/v1/completions") + if response.status_code in [200, 405]: + return True + except requests.ConnectionError: + time.sleep(1) + return False + +# Test function to send curl requests and validate responses +@pytest.mark.parametrize("prompt", [ + "San Francisco is a", + "Santa Clara is a" +]) +def test_disaggregated_prefilling(prompt): + # Send to prefill + response = requests.post( + "http://localhost:8100/v1/completions", + headers={"Content-Type": "application/json"}, + json={ + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "prompt": prompt, + "max_tokens": 1, + "temperature": 0 + } + ) + assert response.status_code == 200 + + # Send to decode + response = requests.post( + "http://localhost:8200/v1/completions", + headers={"Content-Type": "application/json"}, + json={ + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "prompt": prompt, + "max_tokens": 10, + "temperature": 0 + } + ) + assert response.status_code == 200 + \ No newline at end of file diff --git a/tests/kv_transfer/test_launcher.py b/tests/kv_transfer/module_test.py similarity index 100% rename from tests/kv_transfer/test_launcher.py rename to tests/kv_transfer/module_test.py From 8f0538c5053a7c59c8c22da646d394fd7b3bf01e Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Mon, 16 Sep 2024 01:14:18 +0000 Subject: [PATCH 234/303] make fmt checker happy --- tests/kv_transfer/disagg_test.py | 67 +++++++++++++---------------- tests/kv_transfer/module_test.py | 35 +++++++++------ tests/kv_transfer/test_send_recv.py | 2 +- 3 files changed, 54 insertions(+), 50 deletions(-) diff --git a/tests/kv_transfer/disagg_test.py b/tests/kv_transfer/disagg_test.py index 2e8414a9f4642..fa6a527574cf4 100644 --- a/tests/kv_transfer/disagg_test.py +++ b/tests/kv_transfer/disagg_test.py @@ -1,11 +1,11 @@ import os -import sys import subprocess +import sys import time +from subprocess import Popen + import pytest import requests -import signal -from subprocess import Popen import torch @@ -16,16 +16,15 @@ def setup_servers(): pytest.skip("Skipping test: fewer than 4 GPUs available") # Set up environment variables - VLLM_HOST_IP = subprocess.check_output("hostname -I | awk '{print $1}'", shell=True).decode().strip() + VLLM_HOST_IP = subprocess.check_output("hostname -I | awk '{print $1}'", + shell=True).decode().strip() os.environ["VLLM_HOST_IP"] = VLLM_HOST_IP os.environ["VLLM_PORT"] = "12345" # Start prefill instance prefill_cmd = [ - sys.executable, "-m", "vllm.entrypoints.openai.api_server", - "-tp", "2", - "--model", "meta-llama/Meta-Llama-3.1-8B-Instruct", - "--port", "8100", + sys.executable, "-m", "vllm.entrypoints.openai.api_server", "-tp", "2", + "--model", "meta-llama/Meta-Llama-3.1-8B-Instruct", "--port", "8100", "--gpu-memory-utilization", "0.8" ] prefill_env = os.environ.copy() @@ -35,10 +34,8 @@ def setup_servers(): # Start decode instance decode_cmd = [ - sys.executable, "-m", "vllm.entrypoints.openai.api_server", - "-tp", "2", - "--model", "meta-llama/Meta-Llama-3.1-8B-Instruct", - "--port", "8200", + sys.executable, "-m", "vllm.entrypoints.openai.api_server", "-tp", "2", + "--model", "meta-llama/Meta-Llama-3.1-8B-Instruct", "--port", "8200", "--gpu-memory-utilization", "0.8" ] decode_env = os.environ.copy() @@ -61,6 +58,7 @@ def setup_servers(): prefill_proc.wait() decode_proc.wait() + # Helper function to wait for server def wait_for_server(port, timeout=120): start_time = time.time() @@ -73,35 +71,30 @@ def wait_for_server(port, timeout=120): time.sleep(1) return False + # Test function to send curl requests and validate responses -@pytest.mark.parametrize("prompt", [ - "San Francisco is a", - "Santa Clara is a" -]) +@pytest.mark.parametrize("prompt", ["San Francisco is a", "Santa Clara is a"]) def test_disaggregated_prefilling(prompt): # Send to prefill - response = requests.post( - "http://localhost:8100/v1/completions", - headers={"Content-Type": "application/json"}, - json={ - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", - "prompt": prompt, - "max_tokens": 1, - "temperature": 0 - } - ) + response = requests.post("http://localhost:8100/v1/completions", + headers={"Content-Type": "application/json"}, + json={ + "model": + "meta-llama/Meta-Llama-3.1-8B-Instruct", + "prompt": prompt, + "max_tokens": 1, + "temperature": 0 + }) assert response.status_code == 200 # Send to decode - response = requests.post( - "http://localhost:8200/v1/completions", - headers={"Content-Type": "application/json"}, - json={ - "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", - "prompt": prompt, - "max_tokens": 10, - "temperature": 0 - } - ) + response = requests.post("http://localhost:8200/v1/completions", + headers={"Content-Type": "application/json"}, + json={ + "model": + "meta-llama/Meta-Llama-3.1-8B-Instruct", + "prompt": prompt, + "max_tokens": 10, + "temperature": 0 + }) assert response.status_code == 200 - \ No newline at end of file diff --git a/tests/kv_transfer/module_test.py b/tests/kv_transfer/module_test.py index 5c0aeb04b43fa..10fb19a3128e2 100644 --- a/tests/kv_transfer/module_test.py +++ b/tests/kv_transfer/module_test.py @@ -1,21 +1,25 @@ import subprocess -import pytest import sys + +import pytest import torch + def run_python_script(script_name, timeout): try: # Start both processes asynchronously using Popen process0 = subprocess.Popen( [sys.executable, script_name], - env={"RANK": "0"}, # Set the RANK environment variable for process 0 + env={"RANK": + "0"}, # Set the RANK environment variable for process 0 stdout=sys.stdout, # Pipe stdout to current stdout stderr=sys.stderr, # Pipe stderr to current stderr ) - + process1 = subprocess.Popen( [sys.executable, script_name], - env={"RANK": "1"}, # Set the RANK environment variable for process 1 + env={"RANK": + "1"}, # Set the RANK environment variable for process 1 stdout=sys.stdout, # Pipe stdout to current stdout stderr=sys.stderr, # Pipe stderr to current stderr ) @@ -26,9 +30,11 @@ def run_python_script(script_name, timeout): # Check the return status of both processes if process0.returncode != 0: - pytest.fail(f"Test {script_name} failed for RANK=0 with return code {process0.returncode}") + pytest.fail( + f"Test {script_name} failed for RANK=0, {process0.returncode}") if process1.returncode != 0: - pytest.fail(f"Test {script_name} failed for RANK=1 with return code {process1.returncode}") + pytest.fail( + f"Test {script_name} failed for RANK=1, {process1.returncode}") except subprocess.TimeoutExpired: # If either process times out, terminate both and fail the test @@ -38,15 +44,20 @@ def run_python_script(script_name, timeout): except Exception as e: pytest.fail(f"Test {script_name} failed with error: {str(e)}") + # Define the test cases using pytest's parametrize -@pytest.mark.parametrize("script_name,timeout", [ - ("test_lookup_buffer.py", 60), # Second test case with a 60-second timeout - ("test_send_recv.py", 120) # First test case with a 120-second timeout -]) +@pytest.mark.parametrize( + "script_name,timeout", + [ + ("test_lookup_buffer.py", + 60), # Second test case with a 60-second timeout + ("test_send_recv.py", 120) # First test case with a 120-second timeout + ]) def test_run_python_script(script_name, timeout): # Check the number of GPUs if torch.cuda.device_count() < 2: - pytest.skip(f"Skipping test {script_name} because fewer than 2 GPUs are available") - + pytest.skip( + f"Skipping test {script_name} because <2 GPUs are available") + # Run the test if there are at least 2 GPUs run_python_script(script_name, timeout) diff --git a/tests/kv_transfer/test_send_recv.py b/tests/kv_transfer/test_send_recv.py index f6da7f88d5f5c..ff771f34c0325 100644 --- a/tests/kv_transfer/test_send_recv.py +++ b/tests/kv_transfer/test_send_recv.py @@ -134,6 +134,6 @@ def latency_test(my_rank, pipe, nelement, ntensor): torch.manual_seed(0) test_run(my_rank, pipe) stress_test(my_rank, pipe) - + # Use this function if you want to test the latency of pipe impl. # latency_test(my_rank, pipe, 1024 * 8 * 128, 80) From dda1f312ee49ff32d843a141818b82d9f295cbea Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Mon, 16 Sep 2024 01:18:19 +0000 Subject: [PATCH 235/303] constraint the model length --- tests/kv_transfer/disagg_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/kv_transfer/disagg_test.py b/tests/kv_transfer/disagg_test.py index fa6a527574cf4..fffd9ab6f42a7 100644 --- a/tests/kv_transfer/disagg_test.py +++ b/tests/kv_transfer/disagg_test.py @@ -25,7 +25,7 @@ def setup_servers(): prefill_cmd = [ sys.executable, "-m", "vllm.entrypoints.openai.api_server", "-tp", "2", "--model", "meta-llama/Meta-Llama-3.1-8B-Instruct", "--port", "8100", - "--gpu-memory-utilization", "0.8" + "--gpu-memory-utilization", "0.8", "--max-model-len", "1000", ] prefill_env = os.environ.copy() prefill_env["VLLM_DISAGG_PREFILL_ROLE"] = "prefill" @@ -36,7 +36,7 @@ def setup_servers(): decode_cmd = [ sys.executable, "-m", "vllm.entrypoints.openai.api_server", "-tp", "2", "--model", "meta-llama/Meta-Llama-3.1-8B-Instruct", "--port", "8200", - "--gpu-memory-utilization", "0.8" + "--gpu-memory-utilization", "0.8", "--max-model-len", "1000", ] decode_env = os.environ.copy() decode_env["VLLM_DISAGG_PREFILL_ROLE"] = "decode" From 85d72fa6d4db8987dbb6c1eceda48eecaf6bb90b Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Mon, 16 Sep 2024 01:25:17 +0000 Subject: [PATCH 236/303] adjust path --- tests/kv_transfer/module_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/kv_transfer/module_test.py b/tests/kv_transfer/module_test.py index 10fb19a3128e2..355461919cd7c 100644 --- a/tests/kv_transfer/module_test.py +++ b/tests/kv_transfer/module_test.py @@ -6,6 +6,7 @@ def run_python_script(script_name, timeout): + script_name = f'kv_transfer/{script_name}' try: # Start both processes asynchronously using Popen process0 = subprocess.Popen( From 60ede08c67dc2bbfa42bfbcd2496a3bd47b10f75 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Mon, 16 Sep 2024 01:25:32 +0000 Subject: [PATCH 237/303] add disagg prefill test to test pipeline --- .buildkite/test-pipeline.yaml | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 9b0cb6663a55b..da79fd86b767d 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -390,6 +390,18 @@ steps: - pytest -v -s distributed/test_pp_cudagraph.py - pytest -v -s distributed/test_pipeline_parallel.py +- label: Disaggregated Prefill Test # 4min + working_dir: "/vllm-workspace/tests" + num_gpus: 4 + source_file_dependencies: + - vllm/distributed/parallel_state.py + - vllm/distributed/kv_transfer + - vllm/worker/worker_base.py + - vllm/worker/model_runner.py + commands: + - pytest -v -s kv_transfer/module_test.py + - pytest -v -s kv_transfer/disagg_test.py + - label: LoRA Long Context (Distributed) # 11min # This test runs llama 13B, so it is required to run on 4 GPUs. num_gpus: 4 From 0df75665d81b7c842d36c58b6f42272cbe05d4b3 Mon Sep 17 00:00:00 2001 From: YaoJiayi <120040070@link.cuhk.edu.cn> Date: Mon, 16 Sep 2024 07:45:25 -0500 Subject: [PATCH 238/303] bugfix --- .../kv_lookup_buffer/simple_buffer.py | 3 +++ .../kv_pipe/torch_distributed_pipe.py | 5 ++-- vllm/distributed/kv_transfer/vllm_adapter.py | 3 +++ vllm/distributed/parallel_state.py | 24 +++++++++++++------ 4 files changed, 26 insertions(+), 9 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py index 9696032002fda..bd9a122bdf404 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py @@ -169,8 +169,11 @@ def drop_select(self, input_tokens: torch.Tensor, if isinstance(roi, torch.Tensor): roi = roi.clone() + logger.debug(f"Sending signal {self.normal_signal}") self.signal_pipe.send_tensor(self.normal_signal) + logger.debug(f"Sending input tokens") self.data_pipe.send_tensor(input_tokens) + logger.debug(f"Sending roi") self.data_pipe.send_tensor(roi) input_tokens = self.data_pipe.recv_tensor() diff --git a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py index 911bce88a38f1..d4080f4739cf2 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py @@ -189,10 +189,11 @@ def _send_impl(self, tensor): metadata = self._make_metadata(tensor) self._send_metadata(metadata) - - torch.distributed.send(tensor, + logger.debug(f"Sent meta {metadata}") + torch.distributed.send(tensor.to(self.device), dst=self.target_rank_for_send, group=self.device_group) + logger.debug(f"Sent tensor {tensor}") def _recv_impl(self) -> torch.Tensor: """ diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py index 2edb426c5c8da..4dc6e163abb7d 100644 --- a/vllm/distributed/kv_transfer/vllm_adapter.py +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -107,6 +107,7 @@ def __init__( self.recv_buffer = SimpleKVLookupBuffer(self.recv_signal_pipe, self.recv_pipe, self.lookup_buffer_size) + self.tensor_device = 'cpu' else: # when performing disaggregated prefill, only 1 pipe is needed # at prefill instance this pipe is used for send KV cache @@ -125,6 +126,8 @@ def __init__( self.lookup_buffer_size) self.send_buffer = buffer self.recv_buffer = buffer + + self.tensor_device = 'cuda' def send_kv_caches_and_hidden_states( self, diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 3615fa6af399c..81322793f4f18 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1056,13 +1056,23 @@ def initialize_model_parallel( # decode global rank: i + world_size group_ranks.append([i, i + world_size]) logger.debug("Distributed group is %s", str(group_ranks)) - _DISAGG = dist_kv.KV_transfer_agent( - group_ranks=group_ranks, - local_rank=get_world_group().local_rank, - torch_distributed_backend=backend, - ) - logger.debug("_DISAGG initialized for rank %d", - torch.distributed.get_rank()) + if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: + _DISAGG = dist_kv.KV_transfer_agent( + group_ranks=group_ranks, + local_rank=get_world_group().local_rank, + torch_distributed_backend=backend, + ) + logger.debug("_DISAGG initialized for rank %d", + torch.distributed.get_rank()) + elif dist_kv.IS_LMCACHE_INSTANCE: + _DISAGG = dist_kv.KV_transfer_agent( + group_ranks=group_ranks, + local_rank=get_world_group().local_rank, + torch_distributed_backend="gloo", + ) + logger.debug("_DISAGG (LMC) initialized for rank %d", + torch.distributed.get_rank()) + def ensure_model_parallel_initialized( From 73c1683f4534009df47c11acca84d91ab0ccb743 Mon Sep 17 00:00:00 2001 From: YaoJiayi <120040070@link.cuhk.edu.cn> Date: Mon, 16 Sep 2024 10:19:25 -0500 Subject: [PATCH 239/303] bugfix --- .../kv_transfer/kv_lookup_buffer/simple_buffer.py | 3 --- .../kv_transfer/kv_pipe/torch_distributed_pipe.py | 4 ++-- vllm/distributed/kv_transfer/vllm_adapter.py | 8 ++++---- vllm/worker/worker_base.py | 2 ++ 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py index bd9a122bdf404..9696032002fda 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py @@ -169,11 +169,8 @@ def drop_select(self, input_tokens: torch.Tensor, if isinstance(roi, torch.Tensor): roi = roi.clone() - logger.debug(f"Sending signal {self.normal_signal}") self.signal_pipe.send_tensor(self.normal_signal) - logger.debug(f"Sending input tokens") self.data_pipe.send_tensor(input_tokens) - logger.debug(f"Sending roi") self.data_pipe.send_tensor(roi) input_tokens = self.data_pipe.recv_tensor() diff --git a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py index d4080f4739cf2..c2c5cbbe95b0a 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py @@ -189,11 +189,11 @@ def _send_impl(self, tensor): metadata = self._make_metadata(tensor) self._send_metadata(metadata) - logger.debug(f"Sent meta {metadata}") + #logger.debug(f"Sent meta {metadata}") torch.distributed.send(tensor.to(self.device), dst=self.target_rank_for_send, group=self.device_group) - logger.debug(f"Sent tensor {tensor}") + #logger.debug(f"Sent tensor {tensor}") def _recv_impl(self) -> torch.Tensor: """ diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py index 4dc6e163abb7d..caebc15f09b2d 100644 --- a/vllm/distributed/kv_transfer/vllm_adapter.py +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -250,8 +250,8 @@ def recv_kv_caches_and_hidden_states( key_cache, value_cache = kv_cache[0], kv_cache[1] ops.reshape_and_cache_flash( - keys[i], - values[i], + keys[i - model_executable.model.start_layer].to(key_cache.device), + values[i - model_executable.model.start_layer].to(value_cache.device), key_cache, value_cache, slot_mapping[start_pos:end_pos], @@ -269,7 +269,7 @@ def recv_kv_caches_and_hidden_states( # so we need to recompute the hidden state logger.debug("[rank%d]: KV EMPTY recv DONE.", torch.distributed.get_rank()) - return None, bypass_model_exec, None + return None, bypass_model_exec, model_input if not is_complete: rebuilt_model_input = self.build_partial_prefill_input( @@ -282,7 +282,7 @@ def recv_kv_caches_and_hidden_states( ) logger.debug("[rank%d]: KV PARTIAL recv DONE.", torch.distributed.get_rank()) - return None, bypass_model_exec, rebuilt_model_input + return None, False, rebuilt_model_input # concatenate hidden states from different requests hidden_or_intermediate_states = torch.cat( diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index d55400a402400..ae2fb65cc455e 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -334,6 +334,7 @@ def execute_model( from vllm.worker.model_runner import GPUModelRunnerBase assert isinstance(self.model_runner, GPUModelRunnerBase), \ "Distributed KV transfer only support GPU modelrunner" + logger.debug("Receiving KV caches") hidden_or_intermediate_states, bypass_model_exec, model_input = \ ps.get_disagg_group().recv_kv_caches_and_hidden_states( # model is used to know which layer the current worker @@ -361,6 +362,7 @@ def execute_model( from vllm.worker.model_runner import GPUModelRunnerBase assert isinstance(self.model_runner, GPUModelRunnerBase), \ "Distributed KV transfer only support GPU modelrunner" + logger.debug("Sending KV caches") ps.get_disagg_group().send_kv_caches_and_hidden_states( # model is used to know which layer the current worker # is working on, so that we can send KV for only those From 70bec94bfed89a8b8d9eae207f6e06f6ea2c6447 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Thu, 19 Sep 2024 00:51:29 +0000 Subject: [PATCH 240/303] rename the environment variable to KV producer and KV consumer, for more clarity --- vllm/distributed/kv_transfer/vllm_adapter.py | 80 +- vllm/distributed/parallel_state.py | 36 +- vllm/envs.py | 6 +- vllm/worker/model_runner.py | 128 ++- vllm/worker/worker.py | 869 +++++++++---------- 5 files changed, 603 insertions(+), 516 deletions(-) diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py index caebc15f09b2d..17c2e52b1174b 100644 --- a/vllm/distributed/kv_transfer/vllm_adapter.py +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -1,5 +1,5 @@ """vLLM distributed KV cache transfer API. -These APIs are used in `vllm/worker/worker_base.py`. +These APIs are used in `vllm/worker/model_runner.py`. Currently supporting TP. The TP between prefill and decode instance needs to be the same. @@ -38,20 +38,22 @@ from vllm.logger import init_logger from vllm.sequence import IntermediateTensors -assert envs.VLLM_DISAGG_PREFILL_ROLE in [None, "prefill", "decode", "lmcache"],\ - "VLLM_DISAGG_PREFILL_ROLE can only be prefill, decode or lmcache." +logger = init_logger(__name__) -# currently the connections are hard-coded. -# we only handle 2 cases: -# - prefill vLLM --> decode vLLM -# - vLLM --> LMCache -IS_DISTRIBUTED_KV_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE - in ["prefill", "decode"]) -IS_KV_PREFILL_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE == "prefill") -IS_KV_DECODE_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE == "decode") -IS_LMCACHE_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE == "lmcache") +# check VLLM_DISTRIBUTERD_KV_ROLE and set corresponding flags +assert envs.VLLM_DISTRIBUTERD_KV_ROLE in [None, "producer", "consumer", "both"],\ + "VLLM_DISTRIBUTERD_KV_ROLE can only be producer, consumer or both." +IS_DISTRIBUTED_KV_INSTANCE: bool = (envs.VLLM_DISTRIBUTERD_KV_ROLE + in ["producer", "consumer", "both"]) +IS_KV_PRODUCER: bool = (envs.VLLM_DISTRIBUTERD_KV_ROLE in ["producer", "both"]) +IS_KV_CONSUMER: bool = (envs.VLLM_DISTRIBUTERD_KV_ROLE == ["consumer", "both"]) + +# When the current instance is both KV producer and KV consumer, +# it is likely connected to a KV storage service on CPU/disk +# so the communication backend needs to be "gloo" for that case. +DISTRIBUTED_BACKEND: str = "gloo" if (IS_KV_PRODUCER and IS_KV_CONSUMER) else "nccl" +DISTRIBUTED_DEVICE: str = "cpu" if (IS_KV_PRODUCER and IS_KV_CONSUMER) else "cuda" -logger = init_logger(__name__) class KV_transfer_agent: @@ -67,7 +69,7 @@ def __init__( self, group_ranks: List[List[int]], local_rank: int, - torch_distributed_backend: Union[str, Backend], + torch_distributed_backend: Union[str, Backend] = DISTRIBUTED_BACKEND, # FIXME(Kuntai): remove this hardcoding lookup_buffer_size: int = int(1e10)): @@ -78,13 +80,15 @@ def __init__( SimpleKVLookupBuffer = sklb.SimpleKVLookupBuffer - if IS_LMCACHE_INSTANCE: - # when vLLM is connected with LMCache - # it needs to both send and recv KV cache + # In disaggregated prefill, the prefill vLLM only uses send pipe + # and the decode vLLM only uses recv pipe + # In remote KV cache store, vLLM will use both send pipe and recv pipe + # So we build both send pipe and recv pipe for simplicity. + if IS_KV_PRODUCER: self.send_pipe = TorchDistributedPipe( group_ranks, local_rank, - torch_distributed_backend, + DISTRIBUTED_BACKEND, ) self.send_signal_pipe = TorchDistributedPipe( group_ranks, @@ -94,7 +98,7 @@ def __init__( self.recv_pipe = TorchDistributedPipe( group_ranks, local_rank, - torch_distributed_backend, + DISTRIBUTED_BACKEND, ) self.recv_signal_pipe = TorchDistributedPipe( group_ranks, @@ -107,27 +111,39 @@ def __init__( self.recv_buffer = SimpleKVLookupBuffer(self.recv_signal_pipe, self.recv_pipe, self.lookup_buffer_size) - self.tensor_device = 'cpu' + self.tensor_device = DISTRIBUTED_DEVICE else: - # when performing disaggregated prefill, only 1 pipe is needed - # at prefill instance this pipe is used for send KV cache - # at decode instance this pipe is used for recv KV cache - self.pipe = TorchDistributedPipe( + + # the current vLLM instance is KV consumer, so it needs to connect + # its recv pipe to the send pipe of KV producder + + self.recv_pipe = TorchDistributedPipe( group_ranks, local_rank, - torch_distributed_backend, + DISTRIBUTED_BACKEND, ) - self.signal_pipe = TorchDistributedPipe( + self.recv_signal_pipe = TorchDistributedPipe( group_ranks, local_rank, "gloo", ) - buffer = SimpleKVLookupBuffer(self.signal_pipe, self.pipe, - self.lookup_buffer_size) - self.send_buffer = buffer - self.recv_buffer = buffer - - self.tensor_device = 'cuda' + self.send_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + DISTRIBUTED_BACKEND, + ) + self.send_signal_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + "gloo", + ) + self.send_buffer = SimpleKVLookupBuffer(self.send_signal_pipe, + self.send_pipe, + self.lookup_buffer_size) + self.recv_buffer = SimpleKVLookupBuffer(self.recv_signal_pipe, + self.recv_pipe, + self.lookup_buffer_size) + self.tensor_device = DISTRIBUTED_DEVICE def send_kv_caches_and_hidden_states( self, diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index b619fadb37c6a..54f7968908de0 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -981,10 +981,10 @@ def init_distributed_environment( # this backend is used for WORLD maybe_disagg_world_size = world_size maybe_disagg_rank = rank - if dist_kv.IS_DISTRIBUTED_KV_INSTANCE or dist_kv.IS_LMCACHE_INSTANCE: + if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: maybe_disagg_world_size = world_size * 2 logger.debug("Disaggregated prefill enabled.") - if dist_kv.IS_KV_PREFILL_INSTANCE or dist_kv.IS_LMCACHE_INSTANCE: + if dist_kv.IS_KV_PRODUCER: # for prefill, the ranks are [0, world_size) maybe_disagg_rank = rank else: @@ -1016,7 +1016,7 @@ def init_distributed_environment( if _WORLD is None: ranks = [[i for i in range(world_size)]] # offset the distributed group - if dist_kv.IS_DISTRIBUTED_KV_INSTANCE or dist_kv.IS_LMCACHE_INSTANCE: + if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: ranks = include_decoding_groups_if_disagg_enabled( ranks, world_size) @@ -1079,9 +1079,9 @@ def initialize_model_parallel( world_size: int = torch.distributed.get_world_size() backend = backend or torch.distributed.get_backend( get_world_group().device_group) - if dist_kv.IS_DISTRIBUTED_KV_INSTANCE or dist_kv.IS_LMCACHE_INSTANCE: + if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: # Disaggregated prefill enabled - # The world_size for this vLLM instance is tp * pp, but + # This vLLM instance thinks its word size is tp * pp, but # torch.distributed contains 2 vLLM instances, # its world size is 2 * tp * pp # Adjust the world_size to match. @@ -1135,8 +1135,8 @@ def initialize_model_parallel( group_name="pp") logger.debug("_PP initialized for rank %d", torch.distributed.get_rank()) - # TODO(Jiayi): perhaps we need to separate lmcache and disagg - if dist_kv.IS_DISTRIBUTED_KV_INSTANCE or dist_kv.IS_LMCACHE_INSTANCE: + + if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: global _DISAGG logger.debug("Disaggregated prefill enabled, create _DISAGG group") group_ranks = [] @@ -1145,22 +1145,12 @@ def initialize_model_parallel( # decode global rank: i + world_size group_ranks.append([i, i + world_size]) logger.debug("Distributed group is %s", str(group_ranks)) - if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: - _DISAGG = dist_kv.KV_transfer_agent( - group_ranks=group_ranks, - local_rank=get_world_group().local_rank, - torch_distributed_backend=backend, - ) - logger.debug("_DISAGG initialized for rank %d", - torch.distributed.get_rank()) - elif dist_kv.IS_LMCACHE_INSTANCE: - _DISAGG = dist_kv.KV_transfer_agent( - group_ranks=group_ranks, - local_rank=get_world_group().local_rank, - torch_distributed_backend="gloo", - ) - logger.debug("_DISAGG (LMC) initialized for rank %d", - torch.distributed.get_rank()) + _DISAGG = dist_kv.KV_transfer_agent( + group_ranks=group_ranks, + local_rank=get_world_group().local_rank, + ) + logger.debug("_DISAGG initialized for rank %d", + torch.distributed.get_rank()) def ensure_model_parallel_initialized( diff --git a/vllm/envs.py b/vllm/envs.py index 65c67104349e5..407dd942cc9f1 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -367,9 +367,9 @@ def get_default_config_root(): lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")), # Specify the role of current vllm instance - # Value can be "prefill", "decode". - "VLLM_DISAGG_PREFILL_ROLE": - lambda: os.getenv("VLLM_DISAGG_PREFILL_ROLE", None), + # Value can be "producer", "consumer" or "both". + "VLLM_DISTRIBUTERD_KV_ROLE": + lambda: os.getenv("VLLM_DISTRIBUTERD_KV_ROLE", None), # If set, vllm will skip the deprecation warnings. "VLLM_NO_DEPRECATION_WARNING": diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 7293ffe86c0f8..78ffc40082fc8 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -22,7 +22,7 @@ ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) from vllm.core.scheduler import SchedulerOutputs -from vllm.distributed import get_pp_group +from vllm.distributed import get_pp_group, get_disagg_group from vllm.distributed.parallel_state import graph_capture from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger @@ -55,6 +55,8 @@ _init_attn_metadata_from_tensor_dict, _init_sampling_metadata_from_tensor_dict, dump_input_when_exception) +import vllm.distributed.kv_transfer.vllm_adapter as dist_kv + if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -1576,6 +1578,24 @@ def execute_model( else: model_executable = self.model + # Receive KV cache in distributed KV cache transfer setting + # In disagg prefill setting, it will also recv hidden states and bypass + # model forwarding + # In KV cache database setting, it will change the model input so that + # we can skip prefilling on tokens that successfully received KV caches + # NOTE: The receive operation is blocking + bypass_model_exec = False + if self.recv_kv_needed(model_input, kv_caches): + hidden_or_intermediate_states, bypass_model_exec, model_input = \ + get_disagg_group().recv_kv_caches_and_hidden_states( + # model is used to know which layer the current worker + # is working on, so that we can receive KV for only those + # layers. + model_executable, + model_input, + kv_caches=kv_caches + ) + multi_modal_kwargs = model_input.multi_modal_kwargs or {} seqlen_agnostic_kwargs = { "finished_requests_ids": model_input.finished_requests_ids, @@ -1587,20 +1607,36 @@ def execute_model( model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_start.record() - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(multi_modal_kwargs, - device=self.device), - **seqlen_agnostic_kwargs) + if not bypass_model_exec: + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), + **seqlen_agnostic_kwargs) if (self.observability_config is not None and self.observability_config.collect_model_forward_time): model_forward_end.record() + + # Sending KV cache in distributed KV cache transfer setting + # NOTE: the send operation is non-blocking + if self.send_kv_needed(model_input, kv_caches): + logger.debug("Sending KV caches") + get_disagg_group().send_kv_caches_and_hidden_states( + # model is used to know which layer the current worker + # is working on, so that we can send KV for only those + # layers. + model_executable, + model_input, + kv_caches, + hidden_or_intermediate_states, + ) + # Compute the logits in the last pipeline stage. if not get_pp_group().is_last_rank: if (self.is_driver_worker and hidden_or_intermediate_states is not None @@ -1619,15 +1655,6 @@ def execute_model( torch.tensor(model_forward_time + orig_model_forward_time)) return hidden_or_intermediate_states - return hidden_or_intermediate_states - - @torch.inference_mode() - def postprocess_model( - self, - model_input: ModelInputForGPUWithSamplingMetadata, - hidden_or_intermediate_states, - ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: - logits = self.model.compute_logits(hidden_or_intermediate_states, model_input.sampling_metadata) @@ -1642,9 +1669,23 @@ def postprocess_model( logits=logits, sampling_metadata=model_input.sampling_metadata, ) + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time + and output is not None): + model_forward_end.synchronize() + model_forward_time = model_forward_start.elapsed_time( + model_forward_end) + orig_model_forward_time = 0.0 + if intermediate_tensors is not None: + orig_model_forward_time = intermediate_tensors.tensors.get( + "model_forward_time", torch.tensor(0.0)).item() + # If there are multiple workers, we are still tracking the latency + # from the start time of the driver worker to the end time of the + # driver worker. The model forward time will then end up covering + # the communication time as well. + output.model_forward_time = (orig_model_forward_time + + model_forward_time) - assert model_input.attn_metadata is not None - decode_meta = model_input.attn_metadata.decode_metadata if self.return_hidden_states: # we only need to pass hidden states of most recent token assert model_input.sampling_metadata is not None @@ -1661,6 +1702,49 @@ def postprocess_model( output.hidden_states = hidden_states return [output] + + + def recv_kv_needed(self, model_input, kv_caches) -> bool: + """ + Need to receive KV when + 1. current vLLM instance is KV cache *consumer* + 2. this batch is not a profiling run + 3. this batch is a prefill run + """ + + prefill_meta = model_input.attn_metadata.prefill_metadata + + # check if the current run is profiling + is_profile_run = (kv_caches is None) or (kv_caches[0] is None) + # check if the current run is prefill + is_prefill_run = prefill_meta is not None + + return all([ + dist_kv.IS_KV_DECODE_INSTANCE or dist_kv.IS_LMCACHE_INSTANCE, + not is_profile_run, + is_prefill_run, + ]) + + def send_kv_needed(self, model_input, kv_caches) -> bool: + """ + Need to receive KV when + 1. current vLLM instance is KV cache *producer* + 2. this batch is not a profiling run + 3. this batch is a prefill run + """ + + prefill_meta = model_input.attn_metadata.prefill_metadata + + # check if the current run is profiling + is_profile_run = (kv_caches is None) or (kv_caches[0] is None) + # check if the current run is prefill + is_prefill_run = prefill_meta is not None + + return all([ + dist_kv.IS_KV_PREFILL_INSTANCE or dist_kv.IS_LMCACHE_INSTANCE, + not is_profile_run, + is_prefill_run + ]) class CUDAGraphRunner: @@ -1844,4 +1928,4 @@ def _get_max_graph_batch_size(max_num_seqs: int) -> int: if padded_size in _BATCH_SIZES_TO_CAPTURE: return padded_size assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1] - return _BATCH_SIZES_TO_CAPTURE[-1] + return _BATCH_SIZES_TO_CAPTURE[-1] \ No newline at end of file diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 3851843afc960..abc6f98b5f30a 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -1,488 +1,485 @@ -"""A GPU worker class.""" -import gc +import dataclasses +import importlib import os -from typing import Dict, List, Optional, Set, Tuple, Type, Union +import time +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union import torch -import torch.distributed - -import vllm.envs as envs -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig) -from vllm.distributed import (ensure_model_parallel_initialized, - init_distributed_environment, - set_custom_all_reduce) + +from vllm.config import ObservabilityConfig +from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.model_executor import set_random_seed from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.platforms import current_platform -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, - SequenceGroupMetadata, SequenceGroupMetadataDelta) -from vllm.worker.cache_engine import CacheEngine -from vllm.worker.embedding_model_runner import EmbeddingModelRunner -from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner -from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner -from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput +from vllm.sequence import ExecuteModelRequest, IntermediateTensors +from vllm.utils import (enable_trace_function_call_for_thread, + update_environment_variables) +from vllm.worker.model_runner_base import (BroadcastableModelInput, + ModelRunnerBase, + ModelRunnerInputBase) logger = init_logger(__name__) -class Worker(LocalOrDistributedWorkerBase): - """A worker class that executes (a partition of) the model on a GPU. - - Each worker is associated with a single GPU. The worker is responsible for - maintaining the KV cache and executing the model on the GPU. In case of - distributed inference, each worker is assigned a partition of the model. +class WorkerBase(ABC): + """Worker interface that allows vLLM to cleanly separate implementations for + different hardware. Also abstracts control plane communication, e.g., to + communicate request metadata to other workers. """ - def __init__( - self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - lora_config: Optional[LoRAConfig] = None, - speculative_config: Optional[SpeculativeConfig] = None, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, - is_driver_worker: bool = False, - model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None, - observability_config: Optional[ObservabilityConfig] = None, - ) -> None: - self.model_config = model_config - self.parallel_config = parallel_config - self.parallel_config.rank = rank - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.local_rank = local_rank - self.rank = rank - self.distributed_init_method = distributed_init_method - self.lora_config = lora_config - self.load_config = load_config - self.prompt_adapter_config = prompt_adapter_config - self.is_driver_worker = is_driver_worker - if parallel_config and is_driver_worker: - assert rank % parallel_config.tensor_parallel_size == 0, \ - "Driver worker should be rank 0 of tensor parallel group." - if self.model_config.trust_remote_code: - # note: lazy import to avoid importing torch before initializing - from vllm.utils import init_cached_hf_modules - init_cached_hf_modules() - self.observability_config = observability_config - - # Return hidden states from target model if the draft model is an - # mlp_speculator - speculative_args = {} if speculative_config is None \ - or (speculative_config.draft_model_config.model == - model_config.model) \ - or (speculative_config.draft_model_config.hf_config.model_type - not in ["medusa", "mlp_speculator", "eagle"]) \ - else {"return_hidden_states": True} - - ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner - if model_runner_cls is not None: - ModelRunnerClass = model_runner_cls - elif self._is_embedding_model(): - ModelRunnerClass = EmbeddingModelRunner - elif self._is_encoder_decoder_model(): - ModelRunnerClass = EncoderDecoderModelRunner - self.model_runner: GPUModelRunnerBase = ModelRunnerClass( - model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config=load_config, - lora_config=self.lora_config, - kv_cache_dtype=self.cache_config.cache_dtype, - is_driver_worker=is_driver_worker, - prompt_adapter_config=prompt_adapter_config, - observability_config=observability_config, - **speculative_args, - ) - # Uninitialized cache engine. Will be initialized by - # initialize_cache. - self.cache_engine: List[CacheEngine] - # Initialize gpu_cache as embedding models don't initialize kv_caches - self.gpu_cache: Optional[List[List[torch.Tensor]]] = None - self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {} - - # Torch profiler. Enabled and configured through env vars: - # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace - if envs.VLLM_TORCH_PROFILER_DIR: - torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR - logger.info("Profiling enabled. Traces will be saved to: %s", - torch_profiler_trace_dir) - self.profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - with_stack=True, - on_trace_ready=torch.profiler.tensorboard_trace_handler( - torch_profiler_trace_dir, use_gzip=True)) - else: - self.profiler = None + @abstractmethod + def init_device(self) -> None: + """Initialize device state, such as loading the model or other on-device + memory allocations. + """ + raise NotImplementedError - def start_profile(self): - if self.profiler is None: - raise RuntimeError("Profiler is not enabled.") - self.profiler.start() + @abstractmethod + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available blocks for the GPU KV cache and + swappable CPU KV cache. + + The implementation may run profiling or other heuristics to determine + the size of caches. - def stop_profile(self): - if self.profiler is None: - raise RuntimeError("Profiler is not enabled.") - self.profiler.stop() + Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks + are blocks that are "active" on the device and can be appended to. + num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be + appended to. + """ + raise NotImplementedError - def _is_encoder_decoder_model(self): - return self.model_config.is_encoder_decoder_model + @abstractmethod + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache with the given size in blocks. + """ + raise NotImplementedError - def _is_embedding_model(self): - return self.model_config.is_embedding_model + @current_platform.inference_mode() + def start_worker_execution_loop(self) -> None: + """Execute model loop in parallel worker. - def init_device(self) -> None: - if self.device_config.device.type == "cuda": - # torch.distributed.all_reduce does not free the input tensor until - # the synchronization point. This causes the memory usage to grow - # as the number of all_reduce calls increases. This env var disables - # this behavior. - # Related issue: - # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 - os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - - # This env var set by Ray causes exceptions with graph building. - os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) - self.device = torch.device(f"cuda:{self.local_rank}") - torch.cuda.set_device(self.device) - - _check_if_gpu_supports_dtype(self.model_config.dtype) - gc.collect() - torch.cuda.empty_cache() - self.init_gpu_memory = torch.cuda.mem_get_info()[0] - else: - raise RuntimeError( - f"Not support device type: {self.device_config.device}") - # Initialize the distributed environment. - init_worker_distributed_environment(self.parallel_config, self.rank, - self.distributed_init_method, - self.local_rank) - # Set random seed. - set_random_seed(self.model_config.seed) - - def load_model(self): - self.model_runner.load_model() - - def save_sharded_state( - self, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None, - ) -> None: - self.model_runner.save_sharded_state( - path, - pattern=pattern, - max_size=max_size, - ) + You can stop the loop by executing a driver worker with an empty output. + See `stop_remote_worker_execution_loop` for more details. + """ + while True: + output = self.execute_model(execute_model_req=None) + if output is None: + return None - def save_tensorized_model( + @abstractmethod + def execute_model( self, - tensorizer_config: TensorizerConfig, - ) -> None: - self.model_runner.save_tensorized_model( - tensorizer_config=tensorizer_config, ) + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> Optional[List[SamplerOutput]]: + raise NotImplementedError - @torch.inference_mode() - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Profiles the peak memory usage of the model to determine how many - KV blocks may be allocated without OOMs. + @abstractmethod + def get_cache_block_size_bytes(self) -> int: + """Return the size of a single cache block, in bytes. Used in + speculative decoding. + """ + raise NotImplementedError - The engine will first conduct a profiling of the existing memory usage. - Then, it calculate the maximum possible number of GPU and CPU blocks - that can be allocated with the remaining free memory. + @abstractmethod + def add_lora(self, lora_request: LoRARequest) -> bool: + raise NotImplementedError - .. tip:: - You may limit the usage of GPU memory - by adjusting the `gpu_memory_utilization` parameter. - """ - # Profile the memory usage of the model and get the maximum number of - # cache blocks that can be allocated with the remaining free memory. - torch.cuda.empty_cache() - - # Execute a forward pass with dummy inputs to profile the memory usage - # of the model. - self.model_runner.profile_run() - - # Calculate the number of blocks that can be allocated with the - # profiled peak memory. - torch.cuda.synchronize() - free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() - # NOTE(woosuk): Here we assume that the other processes using the same - # GPU did not change their memory usage during the profiling. - peak_memory = self.init_gpu_memory - free_gpu_memory - assert peak_memory > 0, ( - "Error in memory profiling. " - f"Initial free memory {self.init_gpu_memory}, current free memory" - f" {free_gpu_memory}. This happens when the GPU memory was " - "not properly cleaned up before initializing the vLLM instance.") - - cache_block_size = self.get_cache_block_size_bytes() - num_gpu_blocks = int( - (total_gpu_memory * self.cache_config.gpu_memory_utilization - - peak_memory) // cache_block_size) - num_cpu_blocks = int(self.cache_config.swap_space_bytes // - cache_block_size) - num_gpu_blocks = max(num_gpu_blocks, 0) - num_cpu_blocks = max(num_cpu_blocks, 0) - if self.model_runner.lora_manager: - self.model_runner.remove_all_loras() - gc.collect() - torch.cuda.empty_cache() - return num_gpu_blocks, num_cpu_blocks + @abstractmethod + def remove_lora(self, lora_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def pin_lora(self, lora_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def list_loras(self) -> Set[int]: + raise NotImplementedError - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - """Allocate GPU and CPU KV cache with the specified number of blocks. - This also warms up the model, which may record CUDA graphs. +class LoraNotSupportedWorkerBase(WorkerBase): + """Partial implementation of WorkerBase that raises exceptions when LoRA + methods are invoked. + """ + + def add_lora(self, lora_request: LoRARequest) -> bool: + raise ValueError(f"{type(self)} does not support LoRA") + + def remove_lora(self, lora_id: int) -> bool: + raise ValueError(f"{type(self)} does not support LoRA") + + def pin_lora(self, lora_id: int) -> bool: + return ValueError( + f"{type(self)} does not support LoRA") # type: ignore + + def list_loras(self) -> Set[int]: + raise ValueError(f"{type(self)} does not support LoRA") + + +@dataclasses.dataclass(frozen=True) +class WorkerInput: + """Local inputs to each worker. May contain device-specific data. These + fields should be broadcastable to other workers. + """ + + num_seq_groups: Optional[int] = None + blocks_to_swap_in: Optional[torch.Tensor] = None + blocks_to_swap_out: Optional[torch.Tensor] = None + blocks_to_copy: Optional[torch.Tensor] = None + virtual_engine: int = 0 + num_steps: int = 1 + + @classmethod + def from_broadcasted_tensor_dict( + cls: Type["WorkerInput"], + tensor_dict: Dict[str, Any], + ) -> "WorkerInput": + """ + Pop fields from the given tensor_dict and populate a new instance of + WorkerInput. """ - raise_if_cache_size_invalid(num_gpu_blocks, - self.cache_config.block_size, - self.model_config.max_model_len) - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - self._init_cache_engine() - self._warm_up_model() - - def _init_cache_engine(self): - assert self.cache_config.num_gpu_blocks is not None - self.cache_engine = [ - CacheEngine(self.cache_config, self.model_config, - self.parallel_config, self.device_config) - for _ in range(self.parallel_config.pipeline_parallel_size) - ] - self.gpu_cache = [ - self.cache_engine[ve].gpu_cache - for ve in range(self.parallel_config.pipeline_parallel_size) - ] - - def _warm_up_model(self) -> None: - if not self.model_config.enforce_eager: - self.model_runner.capture_model(self.gpu_cache) - # Reset the seed to ensure that the random state is not affected by - # the model initialization and profiling. - set_random_seed(self.model_config.seed) + return cls( + num_seq_groups=tensor_dict.pop("num_seq_groups"), + blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"), + blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"), + blocks_to_copy=tensor_dict.pop("blocks_to_copy"), + virtual_engine=tensor_dict["virtual_engine"], + num_steps=tensor_dict.pop("num_steps"), + ) + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + """ + Extract broadcastable fields. + """ + tensor_dict = { + "num_seq_groups": self.num_seq_groups, + "blocks_to_swap_in": self.blocks_to_swap_in, + "blocks_to_swap_out": self.blocks_to_swap_out, + "blocks_to_copy": self.blocks_to_copy, + "virtual_engine": self.virtual_engine, + "num_steps": self.num_steps, + } + + return tensor_dict + + +class LocalOrDistributedWorkerBase(WorkerBase): + """ + Partial implementation of WorkerBase that has a default `execute_model` + definition to perform metadata transfer between workers when in distributed + mode. Subclasses of this interface should use model runners that inherit + from ModelRunnerBase, and should only need to implement worker-local logic. + If custom control plane logic is needed to transfer metadata, or if the + model runner cannot inherit from ModelRunnerBase, use WorkerBase instead. + """ + is_driver_worker: bool + model_runner: ModelRunnerBase + observability_config: Optional[ObservabilityConfig] = None @property + @abstractmethod def do_metadata_broadcast(self) -> bool: - return self.parallel_config.tensor_parallel_size > 1 + """ + Used by the default `execute_model` to check whether broadcast is + needed to transfer request inputs from the driver worker to other + workers in the TP group. If WorkerBase subclass only supports + single-worker execution, then this method should return False. + """ + raise NotImplementedError @property + @abstractmethod def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: - return self.gpu_cache + """ + Gets the list of kv caches to pass to the worker's model runner. Each + element in the list is a kv cache corresponding to a particular virtual + engine (PP stream). Used by the default `execute_model`. If the worker's + model runner does not follow the ModelRunnerBase interface, then inherit + from WorkerBase instead. + """ + raise NotImplementedError - @torch.inference_mode() + @abstractmethod def prepare_worker_input( self, execute_model_req: ExecuteModelRequest) -> WorkerInput: - virtual_engine = execute_model_req.virtual_engine - num_steps = execute_model_req.num_steps - num_seq_groups = len(execute_model_req.seq_group_metadata_list) - # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. - # they contain parameters to launch cudamemcpyasync. - blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in, - device="cpu", - dtype=torch.int64).view(-1, 2) - blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out, - device="cpu", - dtype=torch.int64).view(-1, 2) - # `blocks_to_copy` is a gpu tensor. The src and tgt of - # blocks to copy are in the same device, and `blocks_to_copy` - # can be used directly within cuda kernels. - blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, - device=self.device, - dtype=torch.int64).view(-1, 2) - - return WorkerInput( - num_seq_groups=num_seq_groups, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - virtual_engine=virtual_engine, - num_steps=num_steps, - ) + """ + Prepare the inputs to WorkerBase.execute_worker from an execution + request. This method may move data to the worker's local device. It is + not allowed to communicate with other workers or devices. + """ + raise NotImplementedError - @torch.inference_mode() + @abstractmethod def execute_worker(self, worker_input: WorkerInput) -> None: - virtual_engine = worker_input.virtual_engine - # Issue cache operations. - if (worker_input.blocks_to_swap_in is not None - and worker_input.blocks_to_swap_in.numel() > 0): - self.cache_engine[virtual_engine].swap_in( - worker_input.blocks_to_swap_in) - if (worker_input.blocks_to_swap_out is not None - and worker_input.blocks_to_swap_out.numel() > 0): - self.cache_engine[virtual_engine].swap_out( - worker_input.blocks_to_swap_out) - if (worker_input.blocks_to_copy is not None - and worker_input.blocks_to_copy.numel() > 0): - self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy) - - def _get_cached_seq_group_metadata( - self, - seq_group_metadata_list: List[Union[SequenceGroupMetadata, - SequenceGroupMetadataDelta]], - finished_request_ids: List[str]) -> List[SequenceGroupMetadata]: - """Return a list of cached Sequence Group Metadata after updating its - state. - - It is used because scheduler only sends delta to workers to reduce - the data payload size. The function also cleans up cache based on - a given `finished_request_ids`. """ - new_seq_group_metadata_list = [] - for metadata_or_delta in seq_group_metadata_list: - request_id = metadata_or_delta.request_id - if request_id not in self._seq_group_metadata_cache: - # The first prefill. - assert isinstance(metadata_or_delta, SequenceGroupMetadata) - self._seq_group_metadata_cache[request_id] = metadata_or_delta - else: - # The first prefill is already cached. - if isinstance(metadata_or_delta, SequenceGroupMetadataDelta): - self._seq_group_metadata_cache[request_id].apply_delta( - metadata_or_delta) - else: - # If metadata snapshot is sent again, it is - # preempted. Reset the cache because we need to start - # from scratch. - assert isinstance(metadata_or_delta, SequenceGroupMetadata) - self._seq_group_metadata_cache[ - request_id] = metadata_or_delta - - new_seq_group_metadata_list.append( - self._seq_group_metadata_cache[request_id]) - - # Clean up finished ids - for finished_id in finished_request_ids: - del self._seq_group_metadata_cache[finished_id] - - return new_seq_group_metadata_list - - def _execute_model_spmd( - self, - execute_model_req: ExecuteModelRequest, - intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> Optional[List[SamplerOutput]]: - if execute_model_req is not None: - new_seq_group_metadata_list = self._get_cached_seq_group_metadata( + Process an execution request. + """ + raise NotImplementedError + + def _get_worker_input_from_broadcast( + self + ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[ + str, torch.Tensor]]]: + """ Get the worker input from the broadcasted tensor dict. """ + assert self.do_metadata_broadcast + assert not self.is_driver_worker + broadcast_data = broadcast_tensor_dict(src=0) + if not broadcast_data: + return None + + worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data) + model_input = ( + self.model_runner.make_model_input_from_broadcasted_tensor_dict( + broadcast_data)) + + kwargs = extract_previous_hidden_states(broadcast_data) + + return model_input, worker_input, kwargs + + def _get_driver_input_and_broadcast( + self, execute_model_req: ExecuteModelRequest + ) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]: + """ Get the driver input and broadcast it to other workers. """ + assert self.is_driver_worker + + worker_input: WorkerInput = self.prepare_worker_input( + execute_model_req=execute_model_req) + model_input: ModelRunnerInputBase = ( + self.model_runner.prepare_model_input( execute_model_req.seq_group_metadata_list, - execute_model_req.finished_requests_ids) + execute_model_req.virtual_engine, + execute_model_req.finished_requests_ids)) - execute_model_req.seq_group_metadata_list = ( - new_seq_group_metadata_list) - output = super()._execute_model_spmd(execute_model_req, - intermediate_tensors) - return output + kwargs = extract_previous_hidden_states(execute_model_req) - def add_lora(self, lora_request: LoRARequest) -> bool: - return self.model_runner.add_lora(lora_request) + if self.do_metadata_broadcast: + broadcast_data = worker_input.as_broadcastable_tensor_dict() + broadcast_data.update(model_input.as_broadcastable_tensor_dict()) + broadcast_data.update(kwargs) + broadcast_tensor_dict(broadcast_data, src=0) - def remove_lora(self, lora_id: int) -> bool: - return self.model_runner.remove_lora(lora_id) + if execute_model_req.async_callback: + model_input = dataclasses.replace( # type: ignore + model_input, + async_callback=execute_model_req.async_callback) - def pin_lora(self, lora_id: int) -> bool: - return self.model_runner.pin_lora(lora_id) + return model_input, worker_input, kwargs - def list_loras(self) -> Set[int]: - return self.model_runner.list_loras() + def prepare_input( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[ + str, torch.Tensor]]]: + """ + Prepare the inputs to ModelRunner and workers. + """ + if self.is_driver_worker: + if execute_model_req is None: + if self.do_metadata_broadcast: + # This signals that there's no more requests to process for + # now. All workers are running infinite loop with + # broadcast_tensor_dict, and it stops the loop when the + # driver broadcasts an empty input. Send an empty input to + # notify all other workers to stop their execution loop. + broadcast_tensor_dict({}, src=0) + return None + return self._get_driver_input_and_broadcast(execute_model_req) + else: + return self._get_worker_input_from_broadcast() + + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None, + ) -> Optional[List[SamplerOutput]]: + """Executes at least one model step on the given sequences, unless no + sequences are provided.""" + start_time = time.perf_counter() + + inputs = self.prepare_input(execute_model_req) + if inputs is None: + return None + + model_input, worker_input, kwargs = inputs + num_steps = worker_input.num_steps + + self.execute_worker(worker_input) + + # If there is no input, we don't need to execute the model. + if worker_input.num_seq_groups == 0: + return [] + + intermediate_tensors = None + orig_model_execute_time = 0.0 + if not get_pp_group().is_first_rank: + intermediate_tensors = IntermediateTensors( + get_pp_group().recv_tensor_dict( + all_gather_group=get_tp_group())) + if (self.observability_config is not None + and self.observability_config.collect_model_execute_time): + orig_model_execute_time = intermediate_tensors.tensors.get( + "model_execute_time", torch.tensor(0)).item() + + output = self.model_runner.execute_model( + model_input=model_input, + kv_caches=self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None, + intermediate_tensors=intermediate_tensors, + num_steps=num_steps, + **kwargs, + ) - def add_prompt_adapter( - self, prompt_adapter_request: PromptAdapterRequest) -> bool: - return self.model_runner.add_prompt_adapter(prompt_adapter_request) + model_execute_time = time.perf_counter() - start_time + if not get_pp_group().is_last_rank: + # output is IntermediateTensors + if (self.observability_config is not None + and self.observability_config.collect_model_execute_time): + output.tensors["model_execute_time"] = torch.tensor( + model_execute_time + orig_model_execute_time) + get_pp_group().send_tensor_dict(output.tensors, + all_gather_group=get_tp_group()) + return [None] + if (self.observability_config is not None + and self.observability_config.collect_model_execute_time + and output is not None): + for o in output: + o.model_execute_time = (orig_model_execute_time + + model_execute_time) + + # output is List[SamplerOutput] + return output - def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: - return self.model_runner.remove_lora(prompt_adapter_id) + def _execute_model_spmd( + self, + execute_model_req: ExecuteModelRequest, + intermediate_tensors: Optional[IntermediateTensors] = None + ) -> Optional[List[SamplerOutput]]: + """ + Execute model in Single Program Multiple Data (SPMD) fashion. + All workers take the same request, prepare the input and + execute the model. + """ + assert execute_model_req is not None, ( + "_execute_model_spmd() requires each worker to take in an " + "ExecuteModelRequest") + worker_input: WorkerInput = self.prepare_worker_input( + execute_model_req=execute_model_req) + model_input: ModelRunnerInputBase = ( + self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list)) + + self.execute_worker(worker_input) + + # If there is no input, we don't need to execute the model. + if worker_input.num_seq_groups == 0: + return [] + + kwargs = extract_previous_hidden_states(execute_model_req) + + return self.model_runner.execute_model( + model_input=model_input, + kv_caches=self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None, + intermediate_tensors=intermediate_tensors, + **kwargs, + ) - def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: - return self.model_runner.pin_prompt_adapter(prompt_adapter_id) - def list_prompt_adapters(self) -> Set[int]: - return self.model_runner.list_prompt_adapters() +class WorkerWrapperBase: + """ + The whole point of this class is to lazily initialize the worker. + We first instantiate the WorkerWrapper, which remembers the worker module + and class name. Then, when we call `update_environment_variables`, and the + real initialization happens in `init_worker`. + + If worker_class_fn is specified, it will be executed to get the worker + class. + Otherwise, the worker class will be obtained by dynamically importing it + using worker_module_name and worker_class_name. + """ - @property - def max_model_len(self) -> int: - return self.model_config.max_model_len + def __init__( + self, + worker_module_name: str, + worker_class_name: str, + trust_remote_code: bool = False, + worker_class_fn: Optional[Callable[[], + Type[WorkerBase]]] = None) -> None: + self.worker_module_name = worker_module_name + self.worker_class_name = worker_class_name + self.worker_class_fn = worker_class_fn + self.worker: Optional[WorkerBase] = None + if trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() - @property - def vocab_size(self) -> int: - return self.model_runner.vocab_size + @staticmethod + def update_environment_variables(envs: Dict[str, str]) -> None: + key = 'CUDA_VISIBLE_DEVICES' + if key in envs and key in os.environ: + # overwriting CUDA_VISIBLE_DEVICES is desired behavior + # suppress the warning in `update_environment_variables` + del os.environ[key] + update_environment_variables(envs) - def get_cache_block_size_bytes(self) -> int: - """Get the size of the KV cache block size in bytes. + def init_worker(self, *args, **kwargs): + """ + Here we inject some common logic before initializing the worker. + Arguments are passed to the worker class constructor. """ - return CacheEngine.get_cache_block_size(self.cache_config, - self.model_config, - self.parallel_config) - - -def init_worker_distributed_environment( - parallel_config: ParallelConfig, - rank: int, - distributed_init_method: Optional[str] = None, - local_rank: int = -1, -) -> None: - """Initialize the distributed environment.""" - set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) - - init_distributed_environment(parallel_config.world_size, rank, - distributed_init_method, local_rank) - - ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size) - - -def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): - # Check if the GPU supports the dtype. - if torch_dtype == torch.bfloat16: # noqa: SIM102 - if not current_platform.has_device_capability(80): - capability = current_platform.get_device_capability() - gpu_name = current_platform.get_device_name() - - if capability is None: - compute_str = "does not have a compute capability" - else: - version_str = capability.as_version_str() - compute_str = f"has compute capability {version_str}" - - raise ValueError( - "Bfloat16 is only supported on GPUs with compute capability " - f"of at least 8.0. Your {gpu_name} GPU {compute_str}. " - "You can use float16 instead by explicitly setting the" - "`dtype` flag in CLI, for example: --dtype=half.") - - -def raise_if_cache_size_invalid(num_gpu_blocks, block_size, - max_model_len) -> None: - if num_gpu_blocks <= 0: - raise ValueError("No available memory for the cache blocks. " - "Try increasing `gpu_memory_utilization` when " - "initializing the engine.") - max_seq_len = block_size * num_gpu_blocks - if max_model_len > max_seq_len: - raise ValueError( - f"The model's max seq len ({max_model_len}) " - "is larger than the maximum number of tokens that can be " - f"stored in KV cache ({max_seq_len}). Try increasing " - "`gpu_memory_utilization` or decreasing `max_model_len` when " - "initializing the engine.") + enable_trace_function_call_for_thread() + + # see https://github.com/NVIDIA/nccl/issues/1234 + os.environ['NCCL_CUMEM_ENABLE'] = '0' + + from vllm.plugins import load_general_plugins + load_general_plugins() + + if self.worker_class_fn: + worker_class = self.worker_class_fn() + else: + mod = importlib.import_module(self.worker_module_name) + worker_class = getattr(mod, self.worker_class_name) + + self.worker = worker_class(*args, **kwargs) + assert self.worker is not None + + def execute_method(self, method, *args, **kwargs): + try: + target = self if self.worker is None else self.worker + executor = getattr(target, method) + return executor(*args, **kwargs) + except Exception as e: + # if the driver worker also execute methods, + # exceptions in the rest worker may cause deadlock in rpc like ray + # see https://github.com/vllm-project/vllm/issues/3455 + # print the error and inform the user to solve the error + msg = (f"Error executing method {method}. " + "This might cause deadlock in distributed execution.") + logger.exception(msg) + raise e + + +def extract_previous_hidden_states( + data: Union[ExecuteModelRequest, Dict[str, torch.Tensor]]) -> \ + Dict[str, torch.Tensor]: + """If data contains previous_hidden_states, extract it. This returns a dict + which can be used directly as additional kwargs in any following + execute_model calls. This is used in draft models like EAGLE.""" + output = {} + + # When called from non-driver worker, data is dict but when called from + # driver worker, data is ExecuteModelRequest. + if isinstance(data, dict): + if "previous_hidden_states" in data: + output["previous_hidden_states"] = data["previous_hidden_states"] + elif data.previous_hidden_states is not None: + output["previous_hidden_states"] = data.previous_hidden_states\ + .hidden_states + + return output \ No newline at end of file From e787e42debbef0947011ba4b61620b5ad29e1735 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Thu, 19 Sep 2024 00:54:37 +0000 Subject: [PATCH 241/303] revert worker to vllm main --- vllm/worker/worker.py | 869 +++++++++++++++++++------------------ vllm/worker/worker_base.py | 110 +---- 2 files changed, 445 insertions(+), 534 deletions(-) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index abc6f98b5f30a..a12e026b2bdf5 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -1,485 +1,488 @@ -import dataclasses -import importlib +"""A GPU worker class.""" +import gc import os -import time -from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Dict, List, Optional, Set, Tuple, Type, Union import torch - -from vllm.config import ObservabilityConfig -from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group +import torch.distributed + +import vllm.envs as envs +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig) +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment, + set_custom_all_reduce) from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.model_executor import set_random_seed from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.platforms import current_platform -from vllm.sequence import ExecuteModelRequest, IntermediateTensors -from vllm.utils import (enable_trace_function_call_for_thread, - update_environment_variables) -from vllm.worker.model_runner_base import (BroadcastableModelInput, - ModelRunnerBase, - ModelRunnerInputBase) +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, + SequenceGroupMetadata, SequenceGroupMetadataDelta) +from vllm.worker.cache_engine import CacheEngine +from vllm.worker.embedding_model_runner import EmbeddingModelRunner +from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner +from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner +from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput logger = init_logger(__name__) -class WorkerBase(ABC): - """Worker interface that allows vLLM to cleanly separate implementations for - different hardware. Also abstracts control plane communication, e.g., to - communicate request metadata to other workers. - """ - - @abstractmethod - def init_device(self) -> None: - """Initialize device state, such as loading the model or other on-device - memory allocations. - """ - raise NotImplementedError - - @abstractmethod - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of available blocks for the GPU KV cache and - swappable CPU KV cache. - - The implementation may run profiling or other heuristics to determine - the size of caches. - - Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks - are blocks that are "active" on the device and can be appended to. - num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be - appended to. - """ - raise NotImplementedError - - @abstractmethod - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - """Initialize the KV cache with the given size in blocks. - """ - raise NotImplementedError - - @current_platform.inference_mode() - def start_worker_execution_loop(self) -> None: - """Execute model loop in parallel worker. +class Worker(LocalOrDistributedWorkerBase): + """A worker class that executes (a partition of) the model on a GPU. - You can stop the loop by executing a driver worker with an empty output. - See `stop_remote_worker_execution_loop` for more details. - """ - while True: - output = self.execute_model(execute_model_req=None) - if output is None: - return None + Each worker is associated with a single GPU. The worker is responsible for + maintaining the KV cache and executing the model on the GPU. In case of + distributed inference, each worker is assigned a partition of the model. + """ - @abstractmethod - def execute_model( + def __init__( self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> Optional[List[SamplerOutput]]: - raise NotImplementedError - - @abstractmethod - def get_cache_block_size_bytes(self) -> int: - """Return the size of a single cache block, in bytes. Used in - speculative decoding. - """ - raise NotImplementedError - - @abstractmethod - def add_lora(self, lora_request: LoRARequest) -> bool: - raise NotImplementedError - - @abstractmethod - def remove_lora(self, lora_id: int) -> bool: - raise NotImplementedError - - @abstractmethod - def pin_lora(self, lora_id: int) -> bool: - raise NotImplementedError - - @abstractmethod - def list_loras(self) -> Set[int]: - raise NotImplementedError - + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + lora_config: Optional[LoRAConfig] = None, + speculative_config: Optional[SpeculativeConfig] = None, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + is_driver_worker: bool = False, + model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None, + observability_config: Optional[ObservabilityConfig] = None, + ) -> None: + self.model_config = model_config + self.parallel_config = parallel_config + self.parallel_config.rank = rank + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.lora_config = lora_config + self.load_config = load_config + self.prompt_adapter_config = prompt_adapter_config + self.is_driver_worker = is_driver_worker + if parallel_config and is_driver_worker: + assert rank % parallel_config.tensor_parallel_size == 0, \ + "Driver worker should be rank 0 of tensor parallel group." + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + self.observability_config = observability_config + + # Return hidden states from target model if the draft model is an + # mlp_speculator + speculative_args = {} if speculative_config is None \ + or (speculative_config.draft_model_config.model == + model_config.model) \ + or (speculative_config.draft_model_config.hf_config.model_type + not in ["medusa", "mlp_speculator", "eagle"]) \ + else {"return_hidden_states": True} + + ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner + if model_runner_cls is not None: + ModelRunnerClass = model_runner_cls + elif self._is_embedding_model(): + ModelRunnerClass = EmbeddingModelRunner + elif self._is_encoder_decoder_model(): + ModelRunnerClass = EncoderDecoderModelRunner + self.model_runner: GPUModelRunnerBase = ModelRunnerClass( + model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config=load_config, + lora_config=self.lora_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=is_driver_worker, + prompt_adapter_config=prompt_adapter_config, + observability_config=observability_config, + **speculative_args, + ) + # Uninitialized cache engine. Will be initialized by + # initialize_cache. + self.cache_engine: List[CacheEngine] + # Initialize gpu_cache as embedding models don't initialize kv_caches + self.gpu_cache: Optional[List[List[torch.Tensor]]] = None + self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {} + + # Torch profiler. Enabled and configured through env vars: + # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace + if envs.VLLM_TORCH_PROFILER_DIR: + torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + logger.info("Profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir) + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + with_stack=True, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + torch_profiler_trace_dir, use_gzip=True)) + else: + self.profiler = None -class LoraNotSupportedWorkerBase(WorkerBase): - """Partial implementation of WorkerBase that raises exceptions when LoRA - methods are invoked. - """ + def start_profile(self): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + self.profiler.start() - def add_lora(self, lora_request: LoRARequest) -> bool: - raise ValueError(f"{type(self)} does not support LoRA") + def stop_profile(self): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + self.profiler.stop() - def remove_lora(self, lora_id: int) -> bool: - raise ValueError(f"{type(self)} does not support LoRA") + def _is_encoder_decoder_model(self): + return self.model_config.is_encoder_decoder_model - def pin_lora(self, lora_id: int) -> bool: - return ValueError( - f"{type(self)} does not support LoRA") # type: ignore + def _is_embedding_model(self): + return self.model_config.is_embedding_model - def list_loras(self) -> Set[int]: - raise ValueError(f"{type(self)} does not support LoRA") + def init_device(self) -> None: + if self.device_config.device.type == "cuda": + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # This env var set by Ray causes exceptions with graph building. + os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) + self.device = torch.device(f"cuda:{self.local_rank}") + torch.cuda.set_device(self.device) + + _check_if_gpu_supports_dtype(self.model_config.dtype) + gc.collect() + torch.cuda.empty_cache() + self.init_gpu_memory = torch.cuda.mem_get_info()[0] + else: + raise RuntimeError( + f"Not support device type: {self.device_config.device}") + # Initialize the distributed environment. + init_worker_distributed_environment(self.parallel_config, self.rank, + self.distributed_init_method, + self.local_rank) + # Set random seed. + set_random_seed(self.model_config.seed) + + def load_model(self): + self.model_runner.load_model() + + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + self.model_runner.save_sharded_state( + path, + pattern=pattern, + max_size=max_size, + ) + def save_tensorized_model( + self, + tensorizer_config: TensorizerConfig, + ) -> None: + self.model_runner.save_tensorized_model( + tensorizer_config=tensorizer_config, ) -@dataclasses.dataclass(frozen=True) -class WorkerInput: - """Local inputs to each worker. May contain device-specific data. These - fields should be broadcastable to other workers. - """ + @torch.inference_mode() + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Profiles the peak memory usage of the model to determine how many + KV blocks may be allocated without OOMs. - num_seq_groups: Optional[int] = None - blocks_to_swap_in: Optional[torch.Tensor] = None - blocks_to_swap_out: Optional[torch.Tensor] = None - blocks_to_copy: Optional[torch.Tensor] = None - virtual_engine: int = 0 - num_steps: int = 1 - - @classmethod - def from_broadcasted_tensor_dict( - cls: Type["WorkerInput"], - tensor_dict: Dict[str, Any], - ) -> "WorkerInput": - """ - Pop fields from the given tensor_dict and populate a new instance of - WorkerInput. - """ - return cls( - num_seq_groups=tensor_dict.pop("num_seq_groups"), - blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"), - blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"), - blocks_to_copy=tensor_dict.pop("blocks_to_copy"), - virtual_engine=tensor_dict["virtual_engine"], - num_steps=tensor_dict.pop("num_steps"), - ) + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. - def as_broadcastable_tensor_dict( - self) -> Dict[str, Union[int, torch.Tensor]]: - """ - Extract broadcastable fields. + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. """ - tensor_dict = { - "num_seq_groups": self.num_seq_groups, - "blocks_to_swap_in": self.blocks_to_swap_in, - "blocks_to_swap_out": self.blocks_to_swap_out, - "blocks_to_copy": self.blocks_to_copy, - "virtual_engine": self.virtual_engine, - "num_steps": self.num_steps, - } - - return tensor_dict + # Profile the memory usage of the model and get the maximum number of + # cache blocks that can be allocated with the remaining free memory. + torch.cuda.empty_cache() + + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + self.model_runner.profile_run() + + # Calculate the number of blocks that can be allocated with the + # profiled peak memory. + torch.cuda.synchronize() + free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + # NOTE(woosuk): Here we assume that the other processes using the same + # GPU did not change their memory usage during the profiling. + peak_memory = self.init_gpu_memory - free_gpu_memory + assert peak_memory > 0, ( + "Error in memory profiling. " + f"Initial free memory {self.init_gpu_memory}, current free memory" + f" {free_gpu_memory}. This happens when the GPU memory was " + "not properly cleaned up before initializing the vLLM instance.") + + cache_block_size = self.get_cache_block_size_bytes() + num_gpu_blocks = int( + (total_gpu_memory * self.cache_config.gpu_memory_utilization - + peak_memory) // cache_block_size) + num_cpu_blocks = int(self.cache_config.swap_space_bytes // + cache_block_size) + num_gpu_blocks = max(num_gpu_blocks, 0) + num_cpu_blocks = max(num_cpu_blocks, 0) + if self.model_runner.lora_manager: + self.model_runner.remove_all_loras() + gc.collect() + torch.cuda.empty_cache() + return num_gpu_blocks, num_cpu_blocks + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Allocate GPU and CPU KV cache with the specified number of blocks. -class LocalOrDistributedWorkerBase(WorkerBase): - """ - Partial implementation of WorkerBase that has a default `execute_model` - definition to perform metadata transfer between workers when in distributed - mode. Subclasses of this interface should use model runners that inherit - from ModelRunnerBase, and should only need to implement worker-local logic. - If custom control plane logic is needed to transfer metadata, or if the - model runner cannot inherit from ModelRunnerBase, use WorkerBase instead. - """ - is_driver_worker: bool - model_runner: ModelRunnerBase - observability_config: Optional[ObservabilityConfig] = None + This also warms up the model, which may record CUDA graphs. + """ + raise_if_cache_size_invalid(num_gpu_blocks, + self.cache_config.block_size, + self.model_config.max_model_len) + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + self._init_cache_engine() + self._warm_up_model() + + def _init_cache_engine(self): + assert self.cache_config.num_gpu_blocks is not None + self.cache_engine = [ + CacheEngine(self.cache_config, self.model_config, + self.parallel_config, self.device_config) + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + self.gpu_cache = [ + self.cache_engine[ve].gpu_cache + for ve in range(self.parallel_config.pipeline_parallel_size) + ] + + def _warm_up_model(self) -> None: + if not self.model_config.enforce_eager: + self.model_runner.capture_model(self.gpu_cache) + # Reset the seed to ensure that the random state is not affected by + # the model initialization and profiling. + set_random_seed(self.model_config.seed) @property - @abstractmethod def do_metadata_broadcast(self) -> bool: - """ - Used by the default `execute_model` to check whether broadcast is - needed to transfer request inputs from the driver worker to other - workers in the TP group. If WorkerBase subclass only supports - single-worker execution, then this method should return False. - """ - raise NotImplementedError + return self.parallel_config.tensor_parallel_size > 1 @property - @abstractmethod def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: - """ - Gets the list of kv caches to pass to the worker's model runner. Each - element in the list is a kv cache corresponding to a particular virtual - engine (PP stream). Used by the default `execute_model`. If the worker's - model runner does not follow the ModelRunnerBase interface, then inherit - from WorkerBase instead. - """ - raise NotImplementedError + return self.gpu_cache - @abstractmethod + @torch.inference_mode() def prepare_worker_input( self, execute_model_req: ExecuteModelRequest) -> WorkerInput: - """ - Prepare the inputs to WorkerBase.execute_worker from an execution - request. This method may move data to the worker's local device. It is - not allowed to communicate with other workers or devices. - """ - raise NotImplementedError + virtual_engine = execute_model_req.virtual_engine + num_steps = execute_model_req.num_steps + num_seq_groups = len(execute_model_req.seq_group_metadata_list) + # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. + # they contain parameters to launch cudamemcpyasync. + blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in, + device="cpu", + dtype=torch.int64).view(-1, 2) + blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out, + device="cpu", + dtype=torch.int64).view(-1, 2) + # `blocks_to_copy` is a gpu tensor. The src and tgt of + # blocks to copy are in the same device, and `blocks_to_copy` + # can be used directly within cuda kernels. + blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, + device=self.device, + dtype=torch.int64).view(-1, 2) + + return WorkerInput( + num_seq_groups=num_seq_groups, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + virtual_engine=virtual_engine, + num_steps=num_steps, + ) - @abstractmethod + @torch.inference_mode() def execute_worker(self, worker_input: WorkerInput) -> None: + virtual_engine = worker_input.virtual_engine + # Issue cache operations. + if (worker_input.blocks_to_swap_in is not None + and worker_input.blocks_to_swap_in.numel() > 0): + self.cache_engine[virtual_engine].swap_in( + worker_input.blocks_to_swap_in) + if (worker_input.blocks_to_swap_out is not None + and worker_input.blocks_to_swap_out.numel() > 0): + self.cache_engine[virtual_engine].swap_out( + worker_input.blocks_to_swap_out) + if (worker_input.blocks_to_copy is not None + and worker_input.blocks_to_copy.numel() > 0): + self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy) + + def _get_cached_seq_group_metadata( + self, + seq_group_metadata_list: List[Union[SequenceGroupMetadata, + SequenceGroupMetadataDelta]], + finished_request_ids: List[str]) -> List[SequenceGroupMetadata]: + """Return a list of cached Sequence Group Metadata after updating its + state. + + It is used because scheduler only sends delta to workers to reduce + the data payload size. The function also cleans up cache based on + a given `finished_request_ids`. """ - Process an execution request. - """ - raise NotImplementedError - - def _get_worker_input_from_broadcast( - self - ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[ - str, torch.Tensor]]]: - """ Get the worker input from the broadcasted tensor dict. """ - assert self.do_metadata_broadcast - assert not self.is_driver_worker - broadcast_data = broadcast_tensor_dict(src=0) - if not broadcast_data: - return None - - worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data) - model_input = ( - self.model_runner.make_model_input_from_broadcasted_tensor_dict( - broadcast_data)) - - kwargs = extract_previous_hidden_states(broadcast_data) - - return model_input, worker_input, kwargs - - def _get_driver_input_and_broadcast( - self, execute_model_req: ExecuteModelRequest - ) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]: - """ Get the driver input and broadcast it to other workers. """ - assert self.is_driver_worker - - worker_input: WorkerInput = self.prepare_worker_input( - execute_model_req=execute_model_req) - model_input: ModelRunnerInputBase = ( - self.model_runner.prepare_model_input( - execute_model_req.seq_group_metadata_list, - execute_model_req.virtual_engine, - execute_model_req.finished_requests_ids)) - - kwargs = extract_previous_hidden_states(execute_model_req) + new_seq_group_metadata_list = [] + for metadata_or_delta in seq_group_metadata_list: + request_id = metadata_or_delta.request_id + if request_id not in self._seq_group_metadata_cache: + # The first prefill. + assert isinstance(metadata_or_delta, SequenceGroupMetadata) + self._seq_group_metadata_cache[request_id] = metadata_or_delta + else: + # The first prefill is already cached. + if isinstance(metadata_or_delta, SequenceGroupMetadataDelta): + self._seq_group_metadata_cache[request_id].apply_delta( + metadata_or_delta) + else: + # If metadata snapshot is sent again, it is + # preempted. Reset the cache because we need to start + # from scratch. + assert isinstance(metadata_or_delta, SequenceGroupMetadata) + self._seq_group_metadata_cache[ + request_id] = metadata_or_delta + + new_seq_group_metadata_list.append( + self._seq_group_metadata_cache[request_id]) + + # Clean up finished ids + for finished_id in finished_request_ids: + del self._seq_group_metadata_cache[finished_id] + + return new_seq_group_metadata_list - if self.do_metadata_broadcast: - broadcast_data = worker_input.as_broadcastable_tensor_dict() - broadcast_data.update(model_input.as_broadcastable_tensor_dict()) - broadcast_data.update(kwargs) - broadcast_tensor_dict(broadcast_data, src=0) - - if execute_model_req.async_callback: - model_input = dataclasses.replace( # type: ignore - model_input, - async_callback=execute_model_req.async_callback) - - return model_input, worker_input, kwargs - - def prepare_input( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[ - str, torch.Tensor]]]: - """ - Prepare the inputs to ModelRunner and workers. - """ - if self.is_driver_worker: - if execute_model_req is None: - if self.do_metadata_broadcast: - # This signals that there's no more requests to process for - # now. All workers are running infinite loop with - # broadcast_tensor_dict, and it stops the loop when the - # driver broadcasts an empty input. Send an empty input to - # notify all other workers to stop their execution loop. - broadcast_tensor_dict({}, src=0) - return None - return self._get_driver_input_and_broadcast(execute_model_req) - else: - return self._get_worker_input_from_broadcast() - - def execute_model( + def _execute_model_spmd( self, - execute_model_req: Optional[ExecuteModelRequest] = None, + execute_model_req: ExecuteModelRequest, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Optional[List[SamplerOutput]]: - """Executes at least one model step on the given sequences, unless no - sequences are provided.""" - start_time = time.perf_counter() - - inputs = self.prepare_input(execute_model_req) - if inputs is None: - return None - - model_input, worker_input, kwargs = inputs - num_steps = worker_input.num_steps - - self.execute_worker(worker_input) - - # If there is no input, we don't need to execute the model. - if worker_input.num_seq_groups == 0: - return [] - - intermediate_tensors = None - orig_model_execute_time = 0.0 - if not get_pp_group().is_first_rank: - intermediate_tensors = IntermediateTensors( - get_pp_group().recv_tensor_dict( - all_gather_group=get_tp_group())) - if (self.observability_config is not None - and self.observability_config.collect_model_execute_time): - orig_model_execute_time = intermediate_tensors.tensors.get( - "model_execute_time", torch.tensor(0)).item() - - output = self.model_runner.execute_model( - model_input=model_input, - kv_caches=self.kv_cache[worker_input.virtual_engine] - if self.kv_cache is not None else None, - intermediate_tensors=intermediate_tensors, - num_steps=num_steps, - **kwargs, - ) + if execute_model_req is not None: + new_seq_group_metadata_list = self._get_cached_seq_group_metadata( + execute_model_req.seq_group_metadata_list, + execute_model_req.finished_requests_ids) - model_execute_time = time.perf_counter() - start_time - if not get_pp_group().is_last_rank: - # output is IntermediateTensors - if (self.observability_config is not None - and self.observability_config.collect_model_execute_time): - output.tensors["model_execute_time"] = torch.tensor( - model_execute_time + orig_model_execute_time) - get_pp_group().send_tensor_dict(output.tensors, - all_gather_group=get_tp_group()) - return [None] - if (self.observability_config is not None - and self.observability_config.collect_model_execute_time - and output is not None): - for o in output: - o.model_execute_time = (orig_model_execute_time + - model_execute_time) - - # output is List[SamplerOutput] + execute_model_req.seq_group_metadata_list = ( + new_seq_group_metadata_list) + output = super()._execute_model_spmd(execute_model_req, + intermediate_tensors) return output - def _execute_model_spmd( - self, - execute_model_req: ExecuteModelRequest, - intermediate_tensors: Optional[IntermediateTensors] = None - ) -> Optional[List[SamplerOutput]]: - """ - Execute model in Single Program Multiple Data (SPMD) fashion. - All workers take the same request, prepare the input and - execute the model. - """ - assert execute_model_req is not None, ( - "_execute_model_spmd() requires each worker to take in an " - "ExecuteModelRequest") - worker_input: WorkerInput = self.prepare_worker_input( - execute_model_req=execute_model_req) - model_input: ModelRunnerInputBase = ( - self.model_runner.prepare_model_input( - execute_model_req.seq_group_metadata_list)) - - self.execute_worker(worker_input) - - # If there is no input, we don't need to execute the model. - if worker_input.num_seq_groups == 0: - return [] - - kwargs = extract_previous_hidden_states(execute_model_req) - - return self.model_runner.execute_model( - model_input=model_input, - kv_caches=self.kv_cache[worker_input.virtual_engine] - if self.kv_cache is not None else None, - intermediate_tensors=intermediate_tensors, - **kwargs, - ) + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_runner.add_lora(lora_request) + def remove_lora(self, lora_id: int) -> bool: + return self.model_runner.remove_lora(lora_id) -class WorkerWrapperBase: - """ - The whole point of this class is to lazily initialize the worker. - We first instantiate the WorkerWrapper, which remembers the worker module - and class name. Then, when we call `update_environment_variables`, and the - real initialization happens in `init_worker`. - - If worker_class_fn is specified, it will be executed to get the worker - class. - Otherwise, the worker class will be obtained by dynamically importing it - using worker_module_name and worker_class_name. - """ + def pin_lora(self, lora_id: int) -> bool: + return self.model_runner.pin_lora(lora_id) - def __init__( - self, - worker_module_name: str, - worker_class_name: str, - trust_remote_code: bool = False, - worker_class_fn: Optional[Callable[[], - Type[WorkerBase]]] = None) -> None: - self.worker_module_name = worker_module_name - self.worker_class_name = worker_class_name - self.worker_class_fn = worker_class_fn - self.worker: Optional[WorkerBase] = None - if trust_remote_code: - # note: lazy import to avoid importing torch before initializing - from vllm.utils import init_cached_hf_modules - init_cached_hf_modules() + def list_loras(self) -> Set[int]: + return self.model_runner.list_loras() - @staticmethod - def update_environment_variables(envs: Dict[str, str]) -> None: - key = 'CUDA_VISIBLE_DEVICES' - if key in envs and key in os.environ: - # overwriting CUDA_VISIBLE_DEVICES is desired behavior - # suppress the warning in `update_environment_variables` - del os.environ[key] - update_environment_variables(envs) + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + return self.model_runner.add_prompt_adapter(prompt_adapter_request) - def init_worker(self, *args, **kwargs): - """ - Here we inject some common logic before initializing the worker. - Arguments are passed to the worker class constructor. - """ - enable_trace_function_call_for_thread() + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return self.model_runner.remove_lora(prompt_adapter_id) - # see https://github.com/NVIDIA/nccl/issues/1234 - os.environ['NCCL_CUMEM_ENABLE'] = '0' + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return self.model_runner.pin_prompt_adapter(prompt_adapter_id) - from vllm.plugins import load_general_plugins - load_general_plugins() + def list_prompt_adapters(self) -> Set[int]: + return self.model_runner.list_prompt_adapters() - if self.worker_class_fn: - worker_class = self.worker_class_fn() - else: - mod = importlib.import_module(self.worker_module_name) - worker_class = getattr(mod, self.worker_class_name) - - self.worker = worker_class(*args, **kwargs) - assert self.worker is not None - - def execute_method(self, method, *args, **kwargs): - try: - target = self if self.worker is None else self.worker - executor = getattr(target, method) - return executor(*args, **kwargs) - except Exception as e: - # if the driver worker also execute methods, - # exceptions in the rest worker may cause deadlock in rpc like ray - # see https://github.com/vllm-project/vllm/issues/3455 - # print the error and inform the user to solve the error - msg = (f"Error executing method {method}. " - "This might cause deadlock in distributed execution.") - logger.exception(msg) - raise e - - -def extract_previous_hidden_states( - data: Union[ExecuteModelRequest, Dict[str, torch.Tensor]]) -> \ - Dict[str, torch.Tensor]: - """If data contains previous_hidden_states, extract it. This returns a dict - which can be used directly as additional kwargs in any following - execute_model calls. This is used in draft models like EAGLE.""" - output = {} - - # When called from non-driver worker, data is dict but when called from - # driver worker, data is ExecuteModelRequest. - if isinstance(data, dict): - if "previous_hidden_states" in data: - output["previous_hidden_states"] = data["previous_hidden_states"] - elif data.previous_hidden_states is not None: - output["previous_hidden_states"] = data.previous_hidden_states\ - .hidden_states - - return output \ No newline at end of file + @property + def max_model_len(self) -> int: + return self.model_config.max_model_len + + @property + def vocab_size(self) -> int: + return self.model_runner.vocab_size + + def get_cache_block_size_bytes(self) -> int: + """Get the size of the KV cache block size in bytes. + """ + return CacheEngine.get_cache_block_size(self.cache_config, + self.model_config, + self.parallel_config) + + +def init_worker_distributed_environment( + parallel_config: ParallelConfig, + rank: int, + distributed_init_method: Optional[str] = None, + local_rank: int = -1, +) -> None: + """Initialize the distributed environment.""" + set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) + + init_distributed_environment(parallel_config.world_size, rank, + distributed_init_method, local_rank) + + ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) + + +def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): + # Check if the GPU supports the dtype. + if torch_dtype == torch.bfloat16: # noqa: SIM102 + if not current_platform.has_device_capability(80): + capability = current_platform.get_device_capability() + gpu_name = current_platform.get_device_name() + + if capability is None: + compute_str = "does not have a compute capability" + else: + version_str = capability.as_version_str() + compute_str = f"has compute capability {version_str}" + + raise ValueError( + "Bfloat16 is only supported on GPUs with compute capability " + f"of at least 8.0. Your {gpu_name} GPU {compute_str}. " + "You can use float16 instead by explicitly setting the" + "`dtype` flag in CLI, for example: --dtype=half.") + + +def raise_if_cache_size_invalid(num_gpu_blocks, block_size, + max_model_len) -> None: + if num_gpu_blocks <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine.") + max_seq_len = block_size * num_gpu_blocks + if max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`gpu_memory_utilization` or decreasing `max_model_len` when " + "initializing the engine.") \ No newline at end of file diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index ae2fb65cc455e..abc6f98b5f30a 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -7,8 +7,6 @@ import torch -import vllm.distributed.kv_transfer.vllm_adapter as dist_kv -import vllm.distributed.parallel_state as ps from vllm.config import ObservabilityConfig from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group from vllm.logger import init_logger @@ -326,64 +324,14 @@ def execute_model( orig_model_execute_time = intermediate_tensors.tensors.get( "model_execute_time", torch.tensor(0)).item() - # for disaggregated prefilling: allow bypassing model execution - bypass_model_exec = False - - # receive KV cache from prefill instance, or from LMCache - if self.need_recv_kv(model_input, worker_input): - from vllm.worker.model_runner import GPUModelRunnerBase - assert isinstance(self.model_runner, GPUModelRunnerBase), \ - "Distributed KV transfer only support GPU modelrunner" - logger.debug("Receiving KV caches") - hidden_or_intermediate_states, bypass_model_exec, model_input = \ - ps.get_disagg_group().recv_kv_caches_and_hidden_states( - # model is used to know which layer the current worker - # is working on, so that we can receive KV for only those - # layers. - self.model_runner.model, - model_input, - kv_caches=self.kv_cache[worker_input.virtual_engine] - if self.kv_cache is not None else None, - ) - #assert bypass_model_exec - - if not bypass_model_exec: - hidden_or_intermediate_states = self.model_runner.execute_model( - model_input=model_input, - kv_caches=self.kv_cache[worker_input.virtual_engine] - if self.kv_cache is not None else None, - intermediate_tensors=intermediate_tensors, - num_steps=num_steps, - **kwargs, - ) - - # sending out KV cache - if self.need_send_kv(model_input, worker_input): - from vllm.worker.model_runner import GPUModelRunnerBase - assert isinstance(self.model_runner, GPUModelRunnerBase), \ - "Distributed KV transfer only support GPU modelrunner" - logger.debug("Sending KV caches") - ps.get_disagg_group().send_kv_caches_and_hidden_states( - # model is used to know which layer the current worker - # is working on, so that we can send KV for only those - # layers. - self.model_runner.model, - model_input, - self.kv_cache[worker_input.virtual_engine] - if self.kv_cache is not None else None, - hidden_or_intermediate_states, - ) - - # separating postprocessing steps out from execute_model - # so that disaggregated prefill can completely bypass model forwarding - from vllm.worker.model_runner import ModelRunner - if isinstance(self.model_runner, ModelRunner): - output = self.model_runner.postprocess_model( - model_input, - hidden_or_intermediate_states, - ) - else: - output = hidden_or_intermediate_states + output = self.model_runner.execute_model( + model_input=model_input, + kv_caches=self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None, + intermediate_tensors=intermediate_tensors, + num_steps=num_steps, + **kwargs, + ) model_execute_time = time.perf_counter() - start_time if not get_pp_group().is_last_rank: @@ -405,46 +353,6 @@ def execute_model( # output is List[SamplerOutput] return output - def need_recv_kv(self, model_input, worker_input) -> bool: - - if self.kv_cache is None: - return False - - kv_caches = self.kv_cache[worker_input.virtual_engine] - prefill_meta = model_input.attn_metadata.prefill_metadata - - # check if the current run is profiling - is_profile_run = (kv_caches is None) or (kv_caches[0] is None) - # check if the current run is prefill - is_prefill_run = prefill_meta is not None - # for disaggregated prefilling: allow bypassing model execution - - return all([ - is_prefill_run, dist_kv.IS_KV_DECODE_INSTANCE - or dist_kv.IS_LMCACHE_INSTANCE, not is_profile_run - ]) - - def need_send_kv(self, model_input, worker_input) -> bool: - - if self.kv_cache is None: - return False - - kv_caches = self.kv_cache[worker_input.virtual_engine] - prefill_meta = model_input.attn_metadata.prefill_metadata - from vllm.worker.model_runner import GPUModelRunnerBase - if not isinstance(self.model_runner, GPUModelRunnerBase): - return False - - # check if the current run is profiling - is_profile_run = (kv_caches is None) or (kv_caches[0] is None) - # check if the current run is prefill - is_prefill_run = prefill_meta is not None - - return all([ - is_prefill_run, dist_kv.IS_KV_PREFILL_INSTANCE - or dist_kv.IS_LMCACHE_INSTANCE, not is_profile_run - ]) - def _execute_model_spmd( self, execute_model_req: ExecuteModelRequest, @@ -574,4 +482,4 @@ def extract_previous_hidden_states( output["previous_hidden_states"] = data.previous_hidden_states\ .hidden_states - return output + return output \ No newline at end of file From 9874b42f5373188f6efeb498a088c50fdee22232 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Thu, 19 Sep 2024 00:57:20 +0000 Subject: [PATCH 242/303] bug fix --- vllm/worker/model_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 78ffc40082fc8..a6f91667fedf9 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1720,7 +1720,7 @@ def recv_kv_needed(self, model_input, kv_caches) -> bool: is_prefill_run = prefill_meta is not None return all([ - dist_kv.IS_KV_DECODE_INSTANCE or dist_kv.IS_LMCACHE_INSTANCE, + dist_kv.IS_KV_CONSUMER, not is_profile_run, is_prefill_run, ]) @@ -1741,7 +1741,7 @@ def send_kv_needed(self, model_input, kv_caches) -> bool: is_prefill_run = prefill_meta is not None return all([ - dist_kv.IS_KV_PREFILL_INSTANCE or dist_kv.IS_LMCACHE_INSTANCE, + dist_kv.IS_KV_PRODUCER, not is_profile_run, is_prefill_run ]) From 5950ad530ee03c1a8a00a71cad2970979c8d66d9 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Thu, 19 Sep 2024 03:17:49 +0000 Subject: [PATCH 243/303] fix typo: Distributerd -> Distributed --- examples/disagg_prefill/disagg_prefill_example.sh | 9 +++++---- vllm/distributed/kv_transfer/vllm_adapter.py | 8 ++++---- vllm/envs.py | 4 ++-- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/examples/disagg_prefill/disagg_prefill_example.sh b/examples/disagg_prefill/disagg_prefill_example.sh index 56b6f44c7418a..57cebb3775a09 100644 --- a/examples/disagg_prefill/disagg_prefill_example.sh +++ b/examples/disagg_prefill/disagg_prefill_example.sh @@ -5,6 +5,7 @@ export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') export VLLM_PORT=12345 +export VLLM_LOGGING_LEVEL=DEBUG # a function that waits vLLM server to start wait_for_server() { @@ -15,16 +16,16 @@ wait_for_server() { done" && return 0 || return 1 } -# prefilling instance -VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0 python3 \ +# prefilling instance, which is the KV producer +VLLM_DISTRIBUTED_KV_ROLE=producer CUDA_VISIBLE_DEVICES=0 python3 \ -m vllm.entrypoints.openai.api_server \ --model meta-llama/Meta-Llama-3.1-8B-Instruct \ --port 8100 \ --max-model-len 10000 \ --gpu-memory-utilization 0.8 & -# decoding instance -VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=1 python3 \ +# decoding instance, which is the KV consumer +VLLM_DISTRIBUTED_KV_ROLE=consumer CUDA_VISIBLE_DEVICES=1 python3 \ -m vllm.entrypoints.openai.api_server \ --model meta-llama/Meta-Llama-3.1-8B-Instruct \ --port 8200 \ diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py index 17c2e52b1174b..1a918ff0ae32e 100644 --- a/vllm/distributed/kv_transfer/vllm_adapter.py +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -41,12 +41,12 @@ logger = init_logger(__name__) # check VLLM_DISTRIBUTERD_KV_ROLE and set corresponding flags -assert envs.VLLM_DISTRIBUTERD_KV_ROLE in [None, "producer", "consumer", "both"],\ +assert envs.VLLM_DISTRIBUTED_KV_ROLE in [None, "producer", "consumer", "both"],\ "VLLM_DISTRIBUTERD_KV_ROLE can only be producer, consumer or both." -IS_DISTRIBUTED_KV_INSTANCE: bool = (envs.VLLM_DISTRIBUTERD_KV_ROLE +IS_DISTRIBUTED_KV_INSTANCE: bool = (envs.VLLM_DISTRIBUTED_KV_ROLE in ["producer", "consumer", "both"]) -IS_KV_PRODUCER: bool = (envs.VLLM_DISTRIBUTERD_KV_ROLE in ["producer", "both"]) -IS_KV_CONSUMER: bool = (envs.VLLM_DISTRIBUTERD_KV_ROLE == ["consumer", "both"]) +IS_KV_PRODUCER: bool = (envs.VLLM_DISTRIBUTED_KV_ROLE in ["producer", "both"]) +IS_KV_CONSUMER: bool = (envs.VLLM_DISTRIBUTED_KV_ROLE in ["consumer", "both"]) # When the current instance is both KV producer and KV consumer, # it is likely connected to a KV storage service on CPU/disk diff --git a/vllm/envs.py b/vllm/envs.py index 407dd942cc9f1..cc2c8a11e5af5 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -368,8 +368,8 @@ def get_default_config_root(): # Specify the role of current vllm instance # Value can be "producer", "consumer" or "both". - "VLLM_DISTRIBUTERD_KV_ROLE": - lambda: os.getenv("VLLM_DISTRIBUTERD_KV_ROLE", None), + "VLLM_DISTRIBUTED_KV_ROLE": + lambda: os.getenv("VLLM_DISTRIBUTED_KV_ROLE", None), # If set, vllm will skip the deprecation warnings. "VLLM_NO_DEPRECATION_WARNING": From c116684042e205b789c6e259ef3dd3a6301806b0 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Thu, 19 Sep 2024 03:26:39 +0000 Subject: [PATCH 244/303] remove the debug flag in example -- user don't need it --- examples/disagg_prefill/disagg_prefill_example.sh | 9 ++++++++- vllm/distributed/parallel_state.py | 7 ++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/examples/disagg_prefill/disagg_prefill_example.sh b/examples/disagg_prefill/disagg_prefill_example.sh index 57cebb3775a09..efec87855dbee 100644 --- a/examples/disagg_prefill/disagg_prefill_example.sh +++ b/examples/disagg_prefill/disagg_prefill_example.sh @@ -5,7 +5,14 @@ export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') export VLLM_PORT=12345 -export VLLM_LOGGING_LEVEL=DEBUG + +# install quart first -- required for disagg prefill proxy serve +if python3 -c "import quart" &> /dev/null; then + echo "Quart is already installed." +else + echo "Quart is not installed. Installing..." + python3 -m pip install quart +fi # a function that waits vLLM server to start wait_for_server() { diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 54f7968908de0..2c6f61fca8f4c 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -986,14 +986,15 @@ def init_distributed_environment( logger.debug("Disaggregated prefill enabled.") if dist_kv.IS_KV_PRODUCER: # for prefill, the ranks are [0, world_size) + logger.debug("rank %d is KV producer.", rank) maybe_disagg_rank = rank else: # this is decode instance. # offset global rank by tp * pp (which is world_size) maybe_disagg_rank = rank + world_size - - logger.debug("Before: world size %d, rank %d", maybe_disagg_world_size, - maybe_disagg_rank) + logger.debug("rank %d is KV producer, adjust it to %d", + rank, + maybe_disagg_rank) torch.distributed.init_process_group( backend=backend, From 44e8875b83a016c77ae2f3643a5270be2875d1c4 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Thu, 19 Sep 2024 03:33:11 +0000 Subject: [PATCH 245/303] fix typo --- benchmarks/benchmark_serving.py | 8 +++---- .../disagg_overhead_benchmark.sh | 22 +++---------------- vllm/distributed/parallel_state.py | 2 +- 3 files changed, 8 insertions(+), 24 deletions(-) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index a2a049ed5671c..f9d719a16008f 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -136,9 +136,9 @@ def sample_sonnet_requests( prefix_len: int, tokenizer: PreTrainedTokenizerBase, ) -> List[Tuple[str, str, int, int, None]]: - assert input_len >= prefix_len, ( - "'args.sonnet-input-len' must be greater than or equal to " - "'args.prefix-input-len'.") + assert ( + input_len > prefix_len + ), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'." # Load the dataset. with open(dataset_path) as f: @@ -963,4 +963,4 @@ def main(args: argparse.Namespace): ) args = parser.parse_args() - main(args) + main(args) \ No newline at end of file diff --git a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh index 36116172ab7c2..dec00c2c9fe00 100644 --- a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh @@ -52,31 +52,15 @@ benchmark() { input_len=2048 output_len=$2 - # large model - # VLLM_RPC_PORT=5570 VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ - # -m vllm.entrypoints.openai.api_server \ - # --model $model \ - # --port 8100 \ - # -tp 4 \ - # --max-model-len 30000 \ - # --gpu-memory-utilization 0.8 & - # VLLM_RPC_PORT=5580 VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ - # -m vllm.entrypoints.openai.api_server \ - # --model $model \ - # --port 8200 \ - # -tp 4 \ - # --max-model-len 30000 \ - # --gpu-memory-utilization 0.8 & - - VLLM_RPC_PORT=5570 VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0 python3 \ + + VLLM_DISTRIBUTED_KV_ROLE=producer CUDA_VISIBLE_DEVICES=0 python3 \ -m vllm.entrypoints.openai.api_server \ --model meta-llama/Meta-Llama-3.1-8B-Instruct \ --port 8100 \ --max-model-len 10000 \ --gpu-memory-utilization 0.8 & -# decoding instance -VLLM_RPC_PORT=5580 VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=1 python3 \ + VLLM_DISTRIBUTED_KV_ROLE=consumer CUDA_VISIBLE_DEVICES=1 python3 \ -m vllm.entrypoints.openai.api_server \ --model meta-llama/Meta-Llama-3.1-8B-Instruct \ --port 8200 \ diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 2c6f61fca8f4c..64faa5b81263c 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -992,7 +992,7 @@ def init_distributed_environment( # this is decode instance. # offset global rank by tp * pp (which is world_size) maybe_disagg_rank = rank + world_size - logger.debug("rank %d is KV producer, adjust it to %d", + logger.debug("rank %d is KV consumer, adjust it to %d", rank, maybe_disagg_rank) From 181928fd72b6b297995913e3f7c67ddb79fab7fb Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Thu, 19 Sep 2024 03:41:41 +0000 Subject: [PATCH 246/303] fixing benchmark_serving.py --- benchmarks/benchmark_serving.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index f9d719a16008f..a407a263120bb 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -626,9 +626,9 @@ def main(args: argparse.Namespace): prefix_len=args.sonnet_prefix_len, tokenizer=tokenizer, ) - input_requests = [(prompt, prompt_len, output_len) + input_requests = [(prompt, prompt_len, output_len, None) for prompt, prompt_formatted, prompt_len, - output_len in input_requests] + output_len, _ in input_requests] else: assert ( tokenizer.chat_template or tokenizer.default_chat_template @@ -641,9 +641,9 @@ def main(args: argparse.Namespace): prefix_len=args.sonnet_prefix_len, tokenizer=tokenizer, ) - input_requests = [(prompt_formatted, prompt_len, output_len) + input_requests = [(prompt_formatted, prompt_len, output_len, None) for prompt, prompt_formatted, prompt_len, - output_len in input_requests] + output_len, _ in input_requests] elif args.dataset_name == "hf": input_requests = sample_hf_requests( From c17d18daeafb91726e73cb8be19e34c00a247fce Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Thu, 19 Sep 2024 04:03:19 +0000 Subject: [PATCH 247/303] fix the example --- .../disagg_benchmarks/disagg_performance_benchmark.sh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh index 1da5669dd1cd0..0e6875363f4d3 100644 --- a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh @@ -38,7 +38,7 @@ wait_for_server() { launch_chunked_prefill() { model="meta-llama/Meta-Llama-3.1-70B-Instruct" # disagg prefill - VLLM_RPC_PORT=5570 CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ + CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ -m vllm.entrypoints.openai.api_server \ --model $model \ --port 8100 \ @@ -48,7 +48,7 @@ launch_chunked_prefill() { --disable-log-requests \ --enable-chunked-prefill \ --gpu-memory-utilization 0.8 & - VLLM_RPC_PORT=5580 CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ + CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ -m vllm.entrypoints.openai.api_server \ --model $model \ --port 8200 \ @@ -68,7 +68,7 @@ launch_chunked_prefill() { launch_disagg_prefill() { model="meta-llama/Meta-Llama-3.1-70B-Instruct" # disagg prefill - VLLM_PORT=12345 VLLM_RPC_PORT=5570 VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ + VLLM_PORT=12345 VLLM_DISTRIBUTED_KV_ROLE=producer CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ -m vllm.entrypoints.openai.api_server \ --model $model \ --port 8100 \ @@ -77,7 +77,7 @@ launch_disagg_prefill() { --disable-log-stats \ --disable-log-requests \ --gpu-memory-utilization 0.8 & - VLLM_PORT=12345 VLLM_RPC_PORT=5580 VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ + VLLM_PORT=12345 VLLM_DISTRIBUTED_KV_ROLE=consumer CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ -m vllm.entrypoints.openai.api_server \ --model $model \ --port 8200 \ From 0b00876f064a148609150e79d75f1cfaed6c5ab6 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Thu, 19 Sep 2024 05:23:47 +0000 Subject: [PATCH 248/303] update build partial prefill input --- .../disagg_prefill_example.sh | 0 vllm/distributed/kv_transfer/vllm_adapter.py | 237 +++++++++--------- 2 files changed, 123 insertions(+), 114 deletions(-) rename examples/{disagg_prefill => distributed_kv}/disagg_prefill_example.sh (100%) diff --git a/examples/disagg_prefill/disagg_prefill_example.sh b/examples/distributed_kv/disagg_prefill_example.sh similarity index 100% rename from examples/disagg_prefill/disagg_prefill_example.sh rename to examples/distributed_kv/disagg_prefill_example.sh diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py index 1a918ff0ae32e..9b9a035185ad9 100644 --- a/vllm/distributed/kv_transfer/vllm_adapter.py +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -52,6 +52,7 @@ # it is likely connected to a KV storage service on CPU/disk # so the communication backend needs to be "gloo" for that case. DISTRIBUTED_BACKEND: str = "gloo" if (IS_KV_PRODUCER and IS_KV_CONSUMER) else "nccl" +# corresponding device DISTRIBUTED_DEVICE: str = "cpu" if (IS_KV_PRODUCER and IS_KV_CONSUMER) else "cuda" @@ -288,7 +289,7 @@ def recv_kv_caches_and_hidden_states( return None, bypass_model_exec, model_input if not is_complete: - rebuilt_model_input = self.build_partial_prefill_input( + rebuilt_model_input = build_partial_prefill_input( model_input, input_tokens_list, num_computed_tokens_list, @@ -307,117 +308,125 @@ def recv_kv_caches_and_hidden_states( logger.debug("[rank%d]: KV recv DONE.", torch.distributed.get_rank()) return hidden_or_intermediate_states, bypass_model_exec, model_input - def build_partial_prefill_input( - self, - model_input: "ModelInputForGPUWithSamplingMetadata", - input_tokens_list: List[torch.Tensor], - num_computed_tokens_list: List[int], - start_pos_list: List[int], - slot_mapping_flat: torch.Tensor, - device: torch.device, - ) -> "ModelInputForGPUWithSamplingMetadata": - rebuilt_input_tokens = [] - rebuilt_input_positions = [] - rebuilt_query_lens = [] - - rebuilt_num_prefills = 0 - rebuilt_num_prefill_tokens = 0 - rebuilt_slot_mapping = [] - rebuilt_max_query_len = 0 - - rebuilt_block_tables = [] - - rebuilt_query_start_loc = [0] - rebuilt_context_lens_tensor = [] - rebuilt_selected_token_indices = [] - - # recounting query and context lengths - for idx in range(len(input_tokens_list)): - token_tensor = input_tokens_list[idx] - num_token = len(token_tensor) - num_computed_token = num_computed_tokens_list[idx] - start_pos = start_pos_list[idx] - - rebuilt_input_tokens.append(token_tensor[num_computed_token:]) - # TODO(Jiayi): please check the correctness of next line - rebuilt_input_positions.append( - model_input.input_positions[start_pos + - num_computed_token:start_pos + - num_token]) - q_len = num_token - num_computed_token - rebuilt_query_lens.append(q_len) - - # Attn metadata-related - rebuilt_num_prefills += 1 - rebuilt_num_prefill_tokens += q_len - rebuilt_slot_mapping.append( - slot_mapping_flat[start_pos + num_computed_token:start_pos + - num_token]) - rebuilt_max_query_len = max(q_len, rebuilt_max_query_len) - # TODO(Jiayi): remove hard-code (block_size=16) - blk_size = 16 - temp_block_table = [ - i // blk_size - for i in range(start_pos, start_pos + num_token, blk_size) - ] - rebuilt_block_tables.append(temp_block_table) - rebuilt_query_start_loc.append(q_len) #start with 0 - rebuilt_context_lens_tensor.append(num_computed_token) - - # Sampling metadata related - #seq_groups (use rebuilt query lens) - rebuilt_selected_token_indices.append(start_pos + q_len - 1) - - # rebuilt attn_metadata - rebuilt_attn_metadata = deepcopy(model_input.attn_metadata) - rebuilt_attn_metadata.num_prefills = rebuilt_num_prefills - rebuilt_attn_metadata.num_prefill_tokens = rebuilt_num_prefill_tokens - rebuilt_attn_metadata.slot_mapping = torch.cat( - rebuilt_slot_mapping).to(device) - rebuilt_attn_metadata.max_query_len = rebuilt_max_query_len - - rebuilt_attn_metadata.block_tables = torch.tensor( - rebuilt_block_tables, - dtype=model_input.attn_metadata.block_tables.dtype).to(device) - - rebuilt_attn_metadata.query_start_loc = torch.tensor( - rebuilt_query_start_loc, - dtype=model_input.attn_metadata.query_start_loc.dtype).to(device) - rebuilt_attn_metadata.context_lens_tensor = torch.tensor( - rebuilt_context_lens_tensor, - dtype=model_input.attn_metadata.context_lens_tensor.dtype, - ).to(device) - - rebuilt_attn_metadata._cached_prefill_metadata = None - - # rebuilt sampling_metadata - rebuilt_sampling_metadata = deepcopy(model_input.sampling_metadata) - for idx, q_len in enumerate(rebuilt_query_lens): + + + +def build_partial_prefill_input( + model_input: "ModelInputForGPUWithSamplingMetadata", + input_tokens_list: List[torch.Tensor], + num_computed_tokens_list: List[int], + start_pos_list: List[int], + slot_mapping_flat: torch.Tensor, + device: torch.device, +) -> "ModelInputForGPUWithSamplingMetadata": + """ + Helper function to rebuild the model input for the current request. + Goal: avoid running redundant prefill on those tokens that already has KV + caches received. + """ + rebuilt_input_tokens = [] + rebuilt_input_positions = [] + rebuilt_query_lens = [] + + rebuilt_num_prefills = 0 + rebuilt_num_prefill_tokens = 0 + rebuilt_slot_mapping = [] + rebuilt_max_query_len = 0 + + rebuilt_block_tables = [] + + rebuilt_query_start_loc = [0] + rebuilt_context_lens_tensor = [] + rebuilt_selected_token_indices = [] + + # recounting query and context lengths + for idx in range(len(input_tokens_list)): + token_tensor = input_tokens_list[idx] + num_token = len(token_tensor) + num_computed_token = num_computed_tokens_list[idx] + start_pos = start_pos_list[idx] + + rebuilt_input_tokens.append(token_tensor[num_computed_token:]) + # TODO(Jiayi): please check the correctness of next line + rebuilt_input_positions.append( + model_input.input_positions[start_pos + + num_computed_token : start_pos + + num_token]) + q_len = num_token - num_computed_token + rebuilt_query_lens.append(q_len) + + # Attn metadata-related + rebuilt_num_prefills += 1 + rebuilt_num_prefill_tokens += q_len + new_slot_mapping = slot_mapping_flat[start_pos + num_computed_token : start_pos + num_token] + rebuilt_slot_mapping.append(new_slot_mapping) + rebuilt_max_query_len = max(q_len, rebuilt_max_query_len) + # TODO(Jiayi): remove hard-code (block_size=16) + blk_size = 16 + temp_block_table = [ + slot_mapping_flat[i] // blk_size + for i in range(start_pos, start_pos + num_token, blk_size) + ] + rebuilt_block_tables.append(temp_block_table) + rebuilt_query_start_loc.append(rebuilt_num_prefill_tokens) #start with 0 + rebuilt_context_lens_tensor.append(num_computed_token) + + # Sampling metadata related + #seq_groups (use rebuilt query lens) + rebuilt_selected_token_indices.append(rebuilt_num_prefill_tokens - 1) + + # rebuilt attn_metadata + rebuilt_attn_metadata = deepcopy(model_input.attn_metadata) + rebuilt_attn_metadata.num_prefills = rebuilt_num_prefills + rebuilt_attn_metadata.num_prefill_tokens = rebuilt_num_prefill_tokens + rebuilt_attn_metadata.slot_mapping = torch.cat( + rebuilt_slot_mapping).to(device) + rebuilt_attn_metadata.max_query_len = rebuilt_max_query_len + + rebuilt_attn_metadata.block_tables = torch.tensor( + rebuilt_block_tables, + dtype=model_input.attn_metadata.block_tables.dtype).to(device) + + rebuilt_attn_metadata.query_start_loc = torch.tensor( + rebuilt_query_start_loc, + dtype=model_input.attn_metadata.query_start_loc.dtype).to(device) + rebuilt_attn_metadata.context_lens_tensor = torch.tensor( + rebuilt_context_lens_tensor, + dtype=model_input.attn_metadata.context_lens_tensor.dtype, + ).to(device) + + rebuilt_attn_metadata._cached_prefill_metadata = None + + # rebuilt sampling_metadata + rebuilt_sampling_metadata = deepcopy(model_input.sampling_metadata) + for idx, q_len in enumerate(rebuilt_query_lens): + if rebuilt_sampling_metadata.seq_groups is not None: rebuilt_sampling_metadata.seq_groups[idx].query_len = q_len - rebuilt_sampling_metadata.selected_token_indices = torch.tensor( - rebuilt_selected_token_indices, - dtype=model_input.sampling_metadata.selected_token_indices.dtype, - ).to(device) - - # import here to avoid circular import. - from vllm.worker.model_runner import ( - ModelInputForGPUWithSamplingMetadata) - rebuilt_model_input = ModelInputForGPUWithSamplingMetadata( - input_tokens=torch.cat(rebuilt_input_tokens).to(device), - input_positions=torch.cat(rebuilt_input_positions).to(device), - seq_lens=model_input.seq_lens, - query_lens=rebuilt_query_lens, - lora_mapping=model_input.lora_mapping, - lora_requests=model_input.lora_requests, - attn_metadata=rebuilt_attn_metadata, - prompt_adapter_mapping=model_input.prompt_adapter_mapping, - prompt_adapter_requests=model_input.prompt_adapter_requests, - multi_modal_kwargs=model_input.multi_modal_kwargs, - request_ids_to_seq_ids=model_input.request_ids_to_seq_ids, - finished_requests_ids=model_input.finished_requests_ids, - virtual_engine=model_input.virtual_engine, - sampling_metadata=rebuilt_sampling_metadata, - is_prompt=model_input.is_prompt, - ) - - return rebuilt_model_input + + rebuilt_sampling_metadata.selected_token_indices = torch.tensor( + rebuilt_selected_token_indices, + dtype=model_input.sampling_metadata.selected_token_indices.dtype, + ).to(device) + + # import here to avoid circular import. + from vllm.worker.model_runner import ( + ModelInputForGPUWithSamplingMetadata) + rebuilt_model_input = ModelInputForGPUWithSamplingMetadata( + input_tokens=torch.cat(rebuilt_input_tokens).to(device), + input_positions=torch.cat(rebuilt_input_positions).to(device), + seq_lens=model_input.seq_lens, + query_lens=rebuilt_query_lens, + lora_mapping=model_input.lora_mapping, + lora_requests=model_input.lora_requests, + attn_metadata=rebuilt_attn_metadata, + prompt_adapter_mapping=model_input.prompt_adapter_mapping, + prompt_adapter_requests=model_input.prompt_adapter_requests, + multi_modal_kwargs=model_input.multi_modal_kwargs, + request_ids_to_seq_ids=model_input.request_ids_to_seq_ids, + finished_requests_ids=model_input.finished_requests_ids, + virtual_engine=model_input.virtual_engine, + sampling_metadata=rebuilt_sampling_metadata, + is_prompt=model_input.is_prompt, + ) + + return rebuilt_model_input From 94a5086ae10d52d3d2d0d8704854aa07d6a789a5 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Fri, 20 Sep 2024 01:06:11 +0000 Subject: [PATCH 249/303] bug fix for LMCache -- adjust vLLM's rebuild input, and merge the logic to reduce random exit branch --- vllm/distributed/kv_transfer/vllm_adapter.py | 55 ++++++++++++-------- vllm/distributed/parallel_state.py | 3 +- 2 files changed, 35 insertions(+), 23 deletions(-) diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py index 9b9a035185ad9..f901f9e22abcb 100644 --- a/vllm/distributed/kv_transfer/vllm_adapter.py +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -86,6 +86,7 @@ def __init__( # In remote KV cache store, vLLM will use both send pipe and recv pipe # So we build both send pipe and recv pipe for simplicity. if IS_KV_PRODUCER: + self.send_pipe = TorchDistributedPipe( group_ranks, local_rank, @@ -207,11 +208,12 @@ def recv_kv_caches_and_hidden_states( ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, "ModelInputForGPUWithSamplingMetadata"]: - # When this flag is set to False, it means that + # When this flag is set to False, it means that at least for one + # request its corresponding KV cache or hidden state is missing. + # In this case we need to do prefilling to recompute missing KV cache + # and hidden states. bypass_model_exec = True - # This is disagg decode instance, during prefill state - # Need to receive KV from the prefill instance input_tokens_tensor = model_input.input_tokens seq_lens = model_input.attn_metadata.seq_lens slot_mapping = model_input.attn_metadata.slot_mapping.flatten() @@ -231,6 +233,7 @@ def recv_kv_caches_and_hidden_states( current_tokens = input_tokens_tensor[start_pos:end_pos] num_tokens = slen + # collecting data for rebuilding the input input_tokens_list.append(current_tokens) start_pos_list.append(start_pos) @@ -246,19 +249,26 @@ def recv_kv_caches_and_hidden_states( num_computed_tokens_list.append(0) continue - # TODO(Jiayi): change the logic here (need roi) roi: torch.Tensor = ret[1] keys: torch.Tensor = ret[2] values: torch.Tensor = ret[3] hidden: torch.Tensor = ret[4] - # Jiayi: currently assume roi is a prefix num_computed_tokens = roi.shape[0] num_computed_tokens_list.append(num_computed_tokens) - is_complete = (num_computed_tokens == num_tokens) + + # check if both KV cache and the hidden states are received + # If not, need to redo the forwarding to compute missing states + if not all([ + (num_computed_tokens == num_tokens), + hidden is not None + ]): + bypass_model_exec = False + + # update the end position based on how many tokens are cached. end_pos = start_pos + num_computed_tokens - # receive KV cache from disaggregated prefill instance + # put received KV caches into paged memory for i in range(model_executable.model.start_layer, model_executable.model.end_layer): @@ -279,33 +289,30 @@ def recv_kv_caches_and_hidden_states( hidden_or_intermediate_states_for_one_req.append(hidden) - # FIXME(Jiayi): we need to support only skip m out of n reqs in a batch - # same for prefix caching - if not bypass_model_exec: + if bypass_model_exec == False: # Some of the KV cache is not retrieved - # so we need to recompute the hidden state - logger.debug("[rank%d]: KV EMPTY recv DONE.", + # so we need to adjust model_input and redo the forwarding. + logger.debug("[rank%d]: Failed to receive all KVs and hidden " + "states, redo model forwarding.", torch.distributed.get_rank()) - return None, bypass_model_exec, model_input - - if not is_complete: rebuilt_model_input = build_partial_prefill_input( model_input, input_tokens_list, num_computed_tokens_list, start_pos_list, slot_mapping, - device=kv_cache[0].device, + device=input_tokens_tensor.device, ) - logger.debug("[rank%d]: KV PARTIAL recv DONE.", - torch.distributed.get_rank()) - return None, False, rebuilt_model_input + model_input = rebuilt_model_input + hidden_or_intermediate_states = None - # concatenate hidden states from different requests - hidden_or_intermediate_states = torch.cat( + else: + logger.debug("[rank%d]: Successfully received all KVs and hidden " + "states, skip model forwarding.", + torch.distributed.get_rank()) + hidden_or_intermediate_states = torch.cat( hidden_or_intermediate_states_for_one_req, dim=0) - logger.debug("[rank%d]: KV recv DONE.", torch.distributed.get_rank()) return hidden_or_intermediate_states, bypass_model_exec, model_input @@ -344,6 +351,10 @@ def build_partial_prefill_input( token_tensor = input_tokens_list[idx] num_token = len(token_tensor) num_computed_token = num_computed_tokens_list[idx] + # currently attention kernel cannot handle the case where there is 0 + # query token. + if num_computed_token == num_token: + num_computed_token -= 1 start_pos = start_pos_list[idx] rebuilt_input_tokens.append(token_tensor[num_computed_token:]) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 64faa5b81263c..ab24ef09090cb 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -983,7 +983,7 @@ def init_distributed_environment( maybe_disagg_rank = rank if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: maybe_disagg_world_size = world_size * 2 - logger.debug("Disaggregated prefill enabled.") + logger.debug("Distributed KV transfer enabled.") if dist_kv.IS_KV_PRODUCER: # for prefill, the ranks are [0, world_size) logger.debug("rank %d is KV producer.", rank) @@ -996,6 +996,7 @@ def init_distributed_environment( rank, maybe_disagg_rank) + torch.distributed.init_process_group( backend=backend, init_method=distributed_init_method, From 8099fb3a132d1b0fc1beac411189ed019721e692 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Fri, 20 Sep 2024 01:07:33 +0000 Subject: [PATCH 250/303] make format checker happy --- vllm/distributed/kv_transfer/vllm_adapter.py | 71 ++++++++++---------- vllm/distributed/parallel_state.py | 9 +-- 2 files changed, 39 insertions(+), 41 deletions(-) diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py index f901f9e22abcb..8e57b5171e048 100644 --- a/vllm/distributed/kv_transfer/vllm_adapter.py +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -51,10 +51,11 @@ # When the current instance is both KV producer and KV consumer, # it is likely connected to a KV storage service on CPU/disk # so the communication backend needs to be "gloo" for that case. -DISTRIBUTED_BACKEND: str = "gloo" if (IS_KV_PRODUCER and IS_KV_CONSUMER) else "nccl" +DISTRIBUTED_BACKEND: str = "gloo" if (IS_KV_PRODUCER + and IS_KV_CONSUMER) else "nccl" # corresponding device -DISTRIBUTED_DEVICE: str = "cpu" if (IS_KV_PRODUCER and IS_KV_CONSUMER) else "cuda" - +DISTRIBUTED_DEVICE: str = "cpu" if (IS_KV_PRODUCER + and IS_KV_CONSUMER) else "cuda" class KV_transfer_agent: @@ -86,7 +87,7 @@ def __init__( # In remote KV cache store, vLLM will use both send pipe and recv pipe # So we build both send pipe and recv pipe for simplicity. if IS_KV_PRODUCER: - + self.send_pipe = TorchDistributedPipe( group_ranks, local_rank, @@ -115,10 +116,10 @@ def __init__( self.lookup_buffer_size) self.tensor_device = DISTRIBUTED_DEVICE else: - - # the current vLLM instance is KV consumer, so it needs to connect + + # the current vLLM instance is KV consumer, so it needs to connect # its recv pipe to the send pipe of KV producder - + self.recv_pipe = TorchDistributedPipe( group_ranks, local_rank, @@ -208,10 +209,10 @@ def recv_kv_caches_and_hidden_states( ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, "ModelInputForGPUWithSamplingMetadata"]: - # When this flag is set to False, it means that at least for one + # When this flag is set to False, it means that at least for one # request its corresponding KV cache or hidden state is missing. - # In this case we need to do prefilling to recompute missing KV cache - # and hidden states. + # In this case we need to do prefilling to recompute missing KV cache + # and hidden states. bypass_model_exec = True input_tokens_tensor = model_input.input_tokens @@ -259,12 +260,10 @@ def recv_kv_caches_and_hidden_states( # check if both KV cache and the hidden states are received # If not, need to redo the forwarding to compute missing states - if not all([ - (num_computed_tokens == num_tokens), - hidden is not None - ]): + if not all([(num_computed_tokens == num_tokens), hidden is not None + ]): bypass_model_exec = False - + # update the end position based on how many tokens are cached. end_pos = start_pos + num_computed_tokens @@ -277,8 +276,10 @@ def recv_kv_caches_and_hidden_states( key_cache, value_cache = kv_cache[0], kv_cache[1] ops.reshape_and_cache_flash( - keys[i - model_executable.model.start_layer].to(key_cache.device), - values[i - model_executable.model.start_layer].to(value_cache.device), + keys[i - model_executable.model.start_layer].to( + key_cache.device), + values[i - model_executable.model.start_layer].to( + value_cache.device), key_cache, value_cache, slot_mapping[start_pos:end_pos], @@ -289,12 +290,12 @@ def recv_kv_caches_and_hidden_states( hidden_or_intermediate_states_for_one_req.append(hidden) - if bypass_model_exec == False: + if not bypass_model_exec: # Some of the KV cache is not retrieved # so we need to adjust model_input and redo the forwarding. - logger.debug("[rank%d]: Failed to receive all KVs and hidden " - "states, redo model forwarding.", - torch.distributed.get_rank()) + logger.debug( + "[rank%d]: Failed to receive all KVs and hidden " + "states, redo model forwarding.", torch.distributed.get_rank()) rebuilt_model_input = build_partial_prefill_input( model_input, input_tokens_list, @@ -307,17 +308,15 @@ def recv_kv_caches_and_hidden_states( hidden_or_intermediate_states = None else: - logger.debug("[rank%d]: Successfully received all KVs and hidden " - "states, skip model forwarding.", - torch.distributed.get_rank()) + logger.debug( + "[rank%d]: Successfully received all KVs and hidden " + "states, skip model forwarding.", torch.distributed.get_rank()) hidden_or_intermediate_states = torch.cat( - hidden_or_intermediate_states_for_one_req, dim=0) + hidden_or_intermediate_states_for_one_req, dim=0) return hidden_or_intermediate_states, bypass_model_exec, model_input - - def build_partial_prefill_input( model_input: "ModelInputForGPUWithSamplingMetadata", input_tokens_list: List[torch.Tensor], @@ -351,7 +350,7 @@ def build_partial_prefill_input( token_tensor = input_tokens_list[idx] num_token = len(token_tensor) num_computed_token = num_computed_tokens_list[idx] - # currently attention kernel cannot handle the case where there is 0 + # currently attention kernel cannot handle the case where there is 0 # query token. if num_computed_token == num_token: num_computed_token -= 1 @@ -361,7 +360,7 @@ def build_partial_prefill_input( # TODO(Jiayi): please check the correctness of next line rebuilt_input_positions.append( model_input.input_positions[start_pos + - num_computed_token : start_pos + + num_computed_token:start_pos + num_token]) q_len = num_token - num_computed_token rebuilt_query_lens.append(q_len) @@ -369,7 +368,9 @@ def build_partial_prefill_input( # Attn metadata-related rebuilt_num_prefills += 1 rebuilt_num_prefill_tokens += q_len - new_slot_mapping = slot_mapping_flat[start_pos + num_computed_token : start_pos + num_token] + new_slot_mapping = slot_mapping_flat[start_pos + + num_computed_token:start_pos + + num_token] rebuilt_slot_mapping.append(new_slot_mapping) rebuilt_max_query_len = max(q_len, rebuilt_max_query_len) # TODO(Jiayi): remove hard-code (block_size=16) @@ -379,7 +380,8 @@ def build_partial_prefill_input( for i in range(start_pos, start_pos + num_token, blk_size) ] rebuilt_block_tables.append(temp_block_table) - rebuilt_query_start_loc.append(rebuilt_num_prefill_tokens) #start with 0 + rebuilt_query_start_loc.append( + rebuilt_num_prefill_tokens) #start with 0 rebuilt_context_lens_tensor.append(num_computed_token) # Sampling metadata related @@ -390,8 +392,8 @@ def build_partial_prefill_input( rebuilt_attn_metadata = deepcopy(model_input.attn_metadata) rebuilt_attn_metadata.num_prefills = rebuilt_num_prefills rebuilt_attn_metadata.num_prefill_tokens = rebuilt_num_prefill_tokens - rebuilt_attn_metadata.slot_mapping = torch.cat( - rebuilt_slot_mapping).to(device) + rebuilt_attn_metadata.slot_mapping = torch.cat(rebuilt_slot_mapping).to( + device) rebuilt_attn_metadata.max_query_len = rebuilt_max_query_len rebuilt_attn_metadata.block_tables = torch.tensor( @@ -420,8 +422,7 @@ def build_partial_prefill_input( ).to(device) # import here to avoid circular import. - from vllm.worker.model_runner import ( - ModelInputForGPUWithSamplingMetadata) + from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata) rebuilt_model_input = ModelInputForGPUWithSamplingMetadata( input_tokens=torch.cat(rebuilt_input_tokens).to(device), input_positions=torch.cat(rebuilt_input_positions).to(device), diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index ab24ef09090cb..c4f2b8529dd26 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -992,11 +992,9 @@ def init_distributed_environment( # this is decode instance. # offset global rank by tp * pp (which is world_size) maybe_disagg_rank = rank + world_size - logger.debug("rank %d is KV consumer, adjust it to %d", - rank, + logger.debug("rank %d is KV consumer, adjust it to %d", rank, maybe_disagg_rank) - torch.distributed.init_process_group( backend=backend, init_method=distributed_init_method, @@ -1137,7 +1135,6 @@ def initialize_model_parallel( group_name="pp") logger.debug("_PP initialized for rank %d", torch.distributed.get_rank()) - if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: global _DISAGG logger.debug("Disaggregated prefill enabled, create _DISAGG group") @@ -1152,8 +1149,8 @@ def initialize_model_parallel( local_rank=get_world_group().local_rank, ) logger.debug("_DISAGG initialized for rank %d", - torch.distributed.get_rank()) - + torch.distributed.get_rank()) + def ensure_model_parallel_initialized( tensor_model_parallel_size: int, From 603864e302650e01ae46634da51bcc7252e08cb5 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Fri, 20 Sep 2024 01:12:10 +0000 Subject: [PATCH 251/303] make ruff and yapf happy, also fix test bug --- tests/kv_transfer/disagg_test.py | 36 +++++++++++++++----- vllm/distributed/kv_transfer/vllm_adapter.py | 2 +- vllm/worker/model_runner.py | 23 +++++-------- 3 files changed, 38 insertions(+), 23 deletions(-) diff --git a/tests/kv_transfer/disagg_test.py b/tests/kv_transfer/disagg_test.py index fffd9ab6f42a7..96203b2dc65a0 100644 --- a/tests/kv_transfer/disagg_test.py +++ b/tests/kv_transfer/disagg_test.py @@ -23,23 +23,43 @@ def setup_servers(): # Start prefill instance prefill_cmd = [ - sys.executable, "-m", "vllm.entrypoints.openai.api_server", "-tp", "2", - "--model", "meta-llama/Meta-Llama-3.1-8B-Instruct", "--port", "8100", - "--gpu-memory-utilization", "0.8", "--max-model-len", "1000", + sys.executable, + "-m", + "vllm.entrypoints.openai.api_server", + "-tp", + "2", + "--model", + "meta-llama/Meta-Llama-3.1-8B-Instruct", + "--port", + "8100", + "--gpu-memory-utilization", + "0.8", + "--max-model-len", + "1000", ] prefill_env = os.environ.copy() - prefill_env["VLLM_DISAGG_PREFILL_ROLE"] = "prefill" + prefill_env["VLLM_DISTRIBUTED_KV_ROLE"] = "producer" prefill_env["CUDA_VISIBLE_DEVICES"] = "0,1" prefill_proc = Popen(prefill_cmd, env=prefill_env) # Start decode instance decode_cmd = [ - sys.executable, "-m", "vllm.entrypoints.openai.api_server", "-tp", "2", - "--model", "meta-llama/Meta-Llama-3.1-8B-Instruct", "--port", "8200", - "--gpu-memory-utilization", "0.8", "--max-model-len", "1000", + sys.executable, + "-m", + "vllm.entrypoints.openai.api_server", + "-tp", + "2", + "--model", + "meta-llama/Meta-Llama-3.1-8B-Instruct", + "--port", + "8200", + "--gpu-memory-utilization", + "0.8", + "--max-model-len", + "1000", ] decode_env = os.environ.copy() - decode_env["VLLM_DISAGG_PREFILL_ROLE"] = "decode" + decode_env["VLLM_DISTRIBUTED_KV_ROLE"] = "consumer" decode_env["CUDA_VISIBLE_DEVICES"] = "2,3" decode_proc = Popen(decode_cmd, env=decode_env) diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py index 8e57b5171e048..7516e7c5ff307 100644 --- a/vllm/distributed/kv_transfer/vllm_adapter.py +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -422,7 +422,7 @@ def build_partial_prefill_input( ).to(device) # import here to avoid circular import. - from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata) + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata rebuilt_model_input = ModelInputForGPUWithSamplingMetadata( input_tokens=torch.cat(rebuilt_input_tokens).to(device), input_positions=torch.cat(rebuilt_input_positions).to(device), diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index a6f91667fedf9..f1dbc7c1803ee 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -14,6 +14,7 @@ import torch.distributed import torch.nn as nn +import vllm.distributed.kv_transfer.vllm_adapter as dist_kv import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention.backends.abstract import AttentionState @@ -22,7 +23,7 @@ ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) from vllm.core.scheduler import SchedulerOutputs -from vllm.distributed import get_pp_group, get_disagg_group +from vllm.distributed import get_disagg_group, get_pp_group from vllm.distributed.parallel_state import graph_capture from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger @@ -55,8 +56,6 @@ _init_attn_metadata_from_tensor_dict, _init_sampling_metadata_from_tensor_dict, dump_input_when_exception) -import vllm.distributed.kv_transfer.vllm_adapter as dist_kv - if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -1615,13 +1614,13 @@ def execute_model( attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, **MultiModalInputs.as_kwargs(multi_modal_kwargs, - device=self.device), + device=self.device), **seqlen_agnostic_kwargs) if (self.observability_config is not None and self.observability_config.collect_model_forward_time): model_forward_end.record() - + # Sending KV cache in distributed KV cache transfer setting # NOTE: the send operation is non-blocking if self.send_kv_needed(model_input, kv_caches): @@ -1702,8 +1701,7 @@ def execute_model( output.hidden_states = hidden_states return [output] - - + def recv_kv_needed(self, model_input, kv_caches) -> bool: """ Need to receive KV when @@ -1722,7 +1720,7 @@ def recv_kv_needed(self, model_input, kv_caches) -> bool: return all([ dist_kv.IS_KV_CONSUMER, not is_profile_run, - is_prefill_run, + is_prefill_run, ]) def send_kv_needed(self, model_input, kv_caches) -> bool: @@ -1740,11 +1738,8 @@ def send_kv_needed(self, model_input, kv_caches) -> bool: # check if the current run is prefill is_prefill_run = prefill_meta is not None - return all([ - dist_kv.IS_KV_PRODUCER, - not is_profile_run, - is_prefill_run - ]) + return all( + [dist_kv.IS_KV_PRODUCER, not is_profile_run, is_prefill_run]) class CUDAGraphRunner: @@ -1928,4 +1923,4 @@ def _get_max_graph_batch_size(max_num_seqs: int) -> int: if padded_size in _BATCH_SIZES_TO_CAPTURE: return padded_size assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1] - return _BATCH_SIZES_TO_CAPTURE[-1] \ No newline at end of file + return _BATCH_SIZES_TO_CAPTURE[-1] From 1d7a1c99ff45246024013049cae54dfd4c624502 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Fri, 20 Sep 2024 01:13:25 +0000 Subject: [PATCH 252/303] remove empty file --- tests/random_send_recv.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 tests/random_send_recv.py diff --git a/tests/random_send_recv.py b/tests/random_send_recv.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 From 10ad09c0213fba08896f841c2ac109c83fa88ff4 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Fri, 20 Sep 2024 04:33:24 +0000 Subject: [PATCH 253/303] fix bug when world_size == -1 --- vllm/distributed/parallel_state.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index c4f2b8529dd26..efeb6253e56b5 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -979,6 +979,8 @@ def init_distributed_environment( "distributed_init_method must be provided when initializing " "distributed environment") # this backend is used for WORLD + + # offset world size and rank in disaggregated prefill scenario maybe_disagg_world_size = world_size maybe_disagg_rank = rank if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: @@ -994,6 +996,7 @@ def init_distributed_environment( maybe_disagg_rank = rank + world_size logger.debug("rank %d is KV consumer, adjust it to %d", rank, maybe_disagg_rank) + torch.distributed.init_process_group( backend=backend, @@ -1014,11 +1017,16 @@ def init_distributed_environment( global _WORLD if _WORLD is None: - ranks = [[i for i in range(world_size)]] - # offset the distributed group + # in single node the world size can be -1 + # need to infer the world size from torch.distributed.get_world_size() + torch_dist_world_size = torch.distributed.get_world_size() if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: - ranks = include_decoding_groups_if_disagg_enabled( - ranks, world_size) + # two vLLM instances in the world + # so this vLLM instance's world size is half of the world size + torch_dist_world_size = torch_dist_world_size // 2 + ranks = [[i for i in range(torch_dist_world_size)]] + ranks = include_decoding_groups_if_disagg_enabled( + ranks, world_size) _WORLD = init_world_group(ranks, local_rank, backend) logger.debug("_WORLD initialized for rank %d", From 38e3a5759cffff52110d42401e1b728b65aad184 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Fri, 20 Sep 2024 04:34:43 +0000 Subject: [PATCH 254/303] adjust comments --- vllm/distributed/parallel_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index efeb6253e56b5..a44487561141e 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1017,7 +1017,7 @@ def init_distributed_environment( global _WORLD if _WORLD is None: - # in single node the world size can be -1 + # in single node single process the world size can be -1 # need to infer the world size from torch.distributed.get_world_size() torch_dist_world_size = torch.distributed.get_world_size() if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: From e2bd481b7795b4b398231d516c947abab9f4e5bd Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Fri, 20 Sep 2024 04:37:51 +0000 Subject: [PATCH 255/303] make yapf and ruff happy --- vllm/distributed/parallel_state.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index a44487561141e..fcef0f53b0e2d 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -996,7 +996,6 @@ def init_distributed_environment( maybe_disagg_rank = rank + world_size logger.debug("rank %d is KV consumer, adjust it to %d", rank, maybe_disagg_rank) - torch.distributed.init_process_group( backend=backend, @@ -1025,8 +1024,7 @@ def init_distributed_environment( # so this vLLM instance's world size is half of the world size torch_dist_world_size = torch_dist_world_size // 2 ranks = [[i for i in range(torch_dist_world_size)]] - ranks = include_decoding_groups_if_disagg_enabled( - ranks, world_size) + ranks = include_decoding_groups_if_disagg_enabled(ranks, world_size) _WORLD = init_world_group(ranks, local_rank, backend) logger.debug("_WORLD initialized for rank %d", From 49793376dc4242c5e2974d45e4e8fa787a2a2048 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Fri, 20 Sep 2024 17:10:55 +0000 Subject: [PATCH 256/303] relaunch CI --- vllm/distributed/parallel_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index fcef0f53b0e2d..4630046d1073a 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1021,7 +1021,7 @@ def init_distributed_environment( torch_dist_world_size = torch.distributed.get_world_size() if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: # two vLLM instances in the world - # so this vLLM instance's world size is half of the world size + # so this vLLM instance's world size is half of torch's world size torch_dist_world_size = torch_dist_world_size // 2 ranks = [[i for i in range(torch_dist_world_size)]] ranks = include_decoding_groups_if_disagg_enabled(ranks, world_size) From a2007dc17c2c7ffc022913e3fc7c1dee4a8b57b7 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Tue, 24 Sep 2024 06:49:58 +0000 Subject: [PATCH 257/303] change get_open_port so that it is easier to understand --- vllm/utils.py | 33 ++++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 3efd7e51b2902..17a0a47ef8d9d 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -34,6 +34,7 @@ import vllm.envs as envs from vllm.logger import enable_trace_function_call, init_logger from vllm.platforms import current_platform +import vllm.distributed.kv_transfer.vllm_adapter as dist_kv logger = init_logger(__name__) @@ -545,16 +546,34 @@ def get_open_zmq_ipc_path() -> str: def get_open_port(force: bool = False) -> int: port = envs.VLLM_PORT - if port is not None: - if force and port is not None: - # force vLLM to use envs.VLLM_PORT for torch.distributed init - # This is because this port will binded by prefill instance - # But both prefill and decode instance need to use this port to - # initialize torch.distributed + + if force: + # This flag will only be True in disaggregated prefill scenario + # and it has to be set so that vLLM can connect prefill vLLM instance + # and decode vLLM instance. + assert port is not None, "Please set VLLM_PORT in order to use " + "disaggregated prefill and distributed KV cache transfer." + + # For prefill vLLM instance (KV producer) this port must be empty. + # For decode vLLM instance this port can be non-empty. + if dist_kv.IS_KV_PRODUCER: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", port)) + return port + except OSError as e: + logger.error("Port %d must be empty so that prefill vLLM " + "instance can use this port to initialize " + "distributed KV communication with decode " + "vLLM instance.", port) + raise e + else: return port + + if port is not None: while True: try: - logger.error('Trying port %d', port) + logger.debug('Trying port %d', port) with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("", port)) return port From ce434f5ed0d3af2d7de1280ef47dee0c08bd1f41 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Tue, 24 Sep 2024 06:52:22 +0000 Subject: [PATCH 258/303] adjust comment --- vllm/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 17a0a47ef8d9d..0c87bfc65f10d 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -551,8 +551,8 @@ def get_open_port(force: bool = False) -> int: # This flag will only be True in disaggregated prefill scenario # and it has to be set so that vLLM can connect prefill vLLM instance # and decode vLLM instance. - assert port is not None, "Please set VLLM_PORT in order to use " - "disaggregated prefill and distributed KV cache transfer." + assert port is not None, "Please set environment variable VLLM_PORT in" + " order to use disaggregated prefill and distributed KV cache transfer" # For prefill vLLM instance (KV producer) this port must be empty. # For decode vLLM instance this port can be non-empty. @@ -573,7 +573,7 @@ def get_open_port(force: bool = False) -> int: if port is not None: while True: try: - logger.debug('Trying port %d', port) + logger.info('Trying port %d', port) with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("", port)) return port From f224c71f0b2de554a199d334ab0186895b566c2b Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Tue, 24 Sep 2024 06:54:25 +0000 Subject: [PATCH 259/303] make format checker happy --- vllm/utils.py | 17 +++++++++-------- vllm/worker/worker.py | 2 +- vllm/worker/worker_base.py | 2 +- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 0c87bfc65f10d..d2ba2272575e7 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -31,10 +31,10 @@ from packaging.version import Version from typing_extensions import ParamSpec, TypeIs, assert_never +import vllm.distributed.kv_transfer.vllm_adapter as dist_kv import vllm.envs as envs from vllm.logger import enable_trace_function_call, init_logger from vllm.platforms import current_platform -import vllm.distributed.kv_transfer.vllm_adapter as dist_kv logger = init_logger(__name__) @@ -546,14 +546,14 @@ def get_open_zmq_ipc_path() -> str: def get_open_port(force: bool = False) -> int: port = envs.VLLM_PORT - + if force: # This flag will only be True in disaggregated prefill scenario # and it has to be set so that vLLM can connect prefill vLLM instance # and decode vLLM instance. assert port is not None, "Please set environment variable VLLM_PORT in" " order to use disaggregated prefill and distributed KV cache transfer" - + # For prefill vLLM instance (KV producer) this port must be empty. # For decode vLLM instance this port can be non-empty. if dist_kv.IS_KV_PRODUCER: @@ -562,14 +562,15 @@ def get_open_port(force: bool = False) -> int: s.bind(("", port)) return port except OSError as e: - logger.error("Port %d must be empty so that prefill vLLM " - "instance can use this port to initialize " - "distributed KV communication with decode " - "vLLM instance.", port) + logger.error( + "Port %d must be empty so that prefill vLLM " + "instance can use this port to initialize " + "distributed KV communication with decode " + "vLLM instance.", port) raise e else: return port - + if port is not None: while True: try: diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index a12e026b2bdf5..3851843afc960 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -485,4 +485,4 @@ def raise_if_cache_size_invalid(num_gpu_blocks, block_size, "is larger than the maximum number of tokens that can be " f"stored in KV cache ({max_seq_len}). Try increasing " "`gpu_memory_utilization` or decreasing `max_model_len` when " - "initializing the engine.") \ No newline at end of file + "initializing the engine.") diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index abc6f98b5f30a..6ba4f272315ce 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -482,4 +482,4 @@ def extract_previous_hidden_states( output["previous_hidden_states"] = data.previous_hidden_states\ .hidden_states - return output \ No newline at end of file + return output From 5d9b007b3b26eb445cfe1689dd1fb179aa5fb737 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Tue, 24 Sep 2024 07:01:42 +0000 Subject: [PATCH 260/303] adjust model runner docstring --- vllm/worker/model_runner.py | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index f1dbc7c1803ee..da33cda71dff0 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1584,7 +1584,7 @@ def execute_model( # we can skip prefilling on tokens that successfully received KV caches # NOTE: The receive operation is blocking bypass_model_exec = False - if self.recv_kv_needed(model_input, kv_caches): + if self.need_recv_kv(model_input, kv_caches): hidden_or_intermediate_states, bypass_model_exec, model_input = \ get_disagg_group().recv_kv_caches_and_hidden_states( # model is used to know which layer the current worker @@ -1623,11 +1623,10 @@ def execute_model( # Sending KV cache in distributed KV cache transfer setting # NOTE: the send operation is non-blocking - if self.send_kv_needed(model_input, kv_caches): - logger.debug("Sending KV caches") + if self.need_send_kv(model_input, kv_caches): get_disagg_group().send_kv_caches_and_hidden_states( - # model is used to know which layer the current worker - # is working on, so that we can send KV for only those + # model_executable is used to know which layer the current + # worker is working on, so that we can send KV for only those # layers. model_executable, model_input, @@ -1702,12 +1701,16 @@ def execute_model( return [output] - def recv_kv_needed(self, model_input, kv_caches) -> bool: - """ - Need to receive KV when - 1. current vLLM instance is KV cache *consumer* + def need_recv_kv(self, model_input, kv_caches) -> bool: + """Check if we need to receive kv-cache from the other worker. + We need to receive KV when + 1. current vLLM instance is KV cache consumer/decode vLLM instance 2. this batch is not a profiling run 3. this batch is a prefill run + + Args: + model_input: input to the model executable + kv_caches: vLLM's paged memory """ prefill_meta = model_input.attn_metadata.prefill_metadata @@ -1723,12 +1726,16 @@ def recv_kv_needed(self, model_input, kv_caches) -> bool: is_prefill_run, ]) - def send_kv_needed(self, model_input, kv_caches) -> bool: - """ - Need to receive KV when - 1. current vLLM instance is KV cache *producer* + def need_send_kv(self, model_input, kv_caches) -> bool: + """Check if we need to send kv-cache from the other worker. + We need to send KV when + 1. current vLLM instance is KV cache producer/prefill vLLM instance 2. this batch is not a profiling run 3. this batch is a prefill run + + Args: + model_input: input to the model executable + kv_caches: vLLM's paged memory """ prefill_meta = model_input.attn_metadata.prefill_metadata From 6255dca7d649953b51f3f68e4157045c21997604 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Tue, 24 Sep 2024 07:04:11 +0000 Subject: [PATCH 261/303] make format checker happy --- vllm/worker/model_runner.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index da33cda71dff0..eba5ddd14090d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1625,7 +1625,7 @@ def execute_model( # NOTE: the send operation is non-blocking if self.need_send_kv(model_input, kv_caches): get_disagg_group().send_kv_caches_and_hidden_states( - # model_executable is used to know which layer the current + # model_executable is used to know which layer the current # worker is working on, so that we can send KV for only those # layers. model_executable, @@ -1720,14 +1720,11 @@ def need_recv_kv(self, model_input, kv_caches) -> bool: # check if the current run is prefill is_prefill_run = prefill_meta is not None - return all([ - dist_kv.IS_KV_CONSUMER, - not is_profile_run, - is_prefill_run, - ]) + return dist_kv.IS_KV_CONSUMER and ( + not is_profile_run) and is_prefill_run def need_send_kv(self, model_input, kv_caches) -> bool: - """Check if we need to send kv-cache from the other worker. + """Check if we need to send kv-cache to the other worker. We need to send KV when 1. current vLLM instance is KV cache producer/prefill vLLM instance 2. this batch is not a profiling run @@ -1745,8 +1742,8 @@ def need_send_kv(self, model_input, kv_caches) -> bool: # check if the current run is prefill is_prefill_run = prefill_meta is not None - return all( - [dist_kv.IS_KV_PRODUCER, not is_profile_run, is_prefill_run]) + return dist_kv.IS_KV_PRODUCER and ( + not is_profile_run) and is_prefill_run class CUDAGraphRunner: From 71ae27592b35ecf4d48d44ab12eb8295ed2a6f1e Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Tue, 24 Sep 2024 18:38:29 +0000 Subject: [PATCH 262/303] change data == [] to not data (thanks Cody) --- vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py index 9696032002fda..96f8f14561e77 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py @@ -78,7 +78,7 @@ def _send_tensor_and_dec_size(self, def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]): - if data == [] or data is None: + if not data: return 0 if isinstance(data, torch.Tensor): return data.element_size() * data.numel() From 80164ea3f33b17a5daaa9aa8202f561f1d829867 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Tue, 24 Sep 2024 18:47:18 +0000 Subject: [PATCH 263/303] fix misleading to available --- vllm/utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index d2ba2272575e7..a423eb355370e 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -549,14 +549,15 @@ def get_open_port(force: bool = False) -> int: if force: # This flag will only be True in disaggregated prefill scenario - # and it has to be set so that vLLM can connect prefill vLLM instance - # and decode vLLM instance. + # and VLLM_PORT must be set so that vLLM can connect prefill vLLM + # instance and decode vLLM instance. assert port is not None, "Please set environment variable VLLM_PORT in" " order to use disaggregated prefill and distributed KV cache transfer" - # For prefill vLLM instance (KV producer) this port must be empty. - # For decode vLLM instance this port can be non-empty. + # For prefill vLLM instance (KV producer), `port` must be available. + # For decode vLLM instance `port` can be not available. if dist_kv.IS_KV_PRODUCER: + # `port` must be available. try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("", port)) @@ -569,6 +570,7 @@ def get_open_port(force: bool = False) -> int: "vLLM instance.", port) raise e else: + # `port` can be not available return port if port is not None: From 52c2d1084c7607f637d1c3e5cac6b8f02bd46490 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Tue, 24 Sep 2024 18:49:50 +0000 Subject: [PATCH 264/303] add new line and run format checker --- benchmarks/benchmark_serving.py | 2 +- vllm/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index a407a263120bb..d0e1cb41a68bc 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -963,4 +963,4 @@ def main(args: argparse.Namespace): ) args = parser.parse_args() - main(args) \ No newline at end of file + main(args) diff --git a/vllm/utils.py b/vllm/utils.py index a423eb355370e..55de17e242b21 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -549,7 +549,7 @@ def get_open_port(force: bool = False) -> int: if force: # This flag will only be True in disaggregated prefill scenario - # and VLLM_PORT must be set so that vLLM can connect prefill vLLM + # and VLLM_PORT must be set so that vLLM can connect prefill vLLM # instance and decode vLLM instance. assert port is not None, "Please set environment variable VLLM_PORT in" " order to use disaggregated prefill and distributed KV cache transfer" From 09478ef416d99a75c96ec8b2547518aef7d05771 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Tue, 24 Sep 2024 19:46:34 +0000 Subject: [PATCH 265/303] add docstring for lookup buffer --- .../kv_transfer/kv_lookup_buffer/base.py | 93 ++++++++++++++++++- .../kv_lookup_buffer/simple_buffer.py | 18 +++- 2 files changed, 102 insertions(+), 9 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py index 80802f87987ac..4bacde2884341 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py @@ -1,3 +1,12 @@ +""" +This file contains a new class `KVLookupBufferBase` that allows developers to +think of KV cache operations as inserting new KV cache entries (`insert`) +into the lookup buffer and querying existing KV caches (`drop_select`) +from the lookup buffer. + +All distributed communications are abstracted behind this class. +""" + from abc import ABC, abstractmethod from typing import List, Optional @@ -5,21 +14,97 @@ class KVLookupBufferBase(ABC): + """ + Abstract base class for a lookup buffer. + + This class provides an abstraction for a key-value (KV) cache lookup buffer. + + The key of the lookup buffer: + - input_tokens: token IDs of the request + - roi: a binary mask on top of input_tokens. + - Purpose of roi: Since KV cache may only be available for a subset of + tokens in the input (for example, when vLLM is connected to an external + KV cache service), roi specifies the subset of tokens that the KV cache + is associated with. + - NOTE: roi can be further extended to describe which part of KV the + current process is holding (each process may only hold a part of KV + due to TP and PP). This is not implemented for now. + + The value of the lookup buffer: + - key: the key tensor in the KV cache + - value: the value tensor in the KV cache + - hidden: the final hidden state generated by model forwarding. This allows + vLLM to bypass further model forwarding by transmitting the hidden state. + """ @abstractmethod def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, key: torch.Tensor, value: torch.Tensor, hidden: torch.Tensor) -> None: + """Insert into the lookup buffer. + + The functionality is similar to the following python statement + ``` + buffer[input_tokens, roi] = [key, value, hidden] + ``` + + FIXME: in the future, we should only have two arguments, key and value, + where key is a tensor dict and value is a tensor dict. + + FIXME: we should transmit both sampler outputs and the hidden states. + + Args: + input_tokens (torch.Tensor): token IDs. + roi (torch.Tensor): A binary mask on top of the input tokens + key (torch.Tensor): The key tensor in the KV cache. + value (torch.Tensor): The value tensor in the KV cache. + hidden (torch.Tensor): The final hidden state tensor generated + during model forwarding to bypass model + forwarding. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ raise NotImplementedError @abstractmethod - def drop_select(self, input_tokens: torch.Tensor, - roi: torch.Tensor) -> List[Optional[torch.Tensor]]: + def drop_select( + self, input_tokens: Optional[torch.Tensor], + roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: + """ + Select and *drop* KV cache entries from the lookup buffer. + + The functionality is similar to the following python statements + ``` + ret = buffer.pop(input_tokens, roi) + return ret + ``` + + If `input_tokens` and `roi` is `None`, it means selecting any of the + KV caches in the buffer, return, and remove it from the buffer, useful + when offloading KV cache to KV cache storage service. + + Args: + input_tokens (torch.Tensor): token IDs. + roi (torch.Tensor): A binary mask on top of the input tokens + + Returns: + List[Optional[torch.Tensor]]: A list of tensors. Can be None. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ raise NotImplementedError @abstractmethod - def close(self): + def close(self) -> None: """ - Close the buffer, release resources. + Close the buffer and release resources. + + This method is responsible for cleaning up resources related to the + lookup buffer when it is no longer needed. + + Raises: + NotImplementedError: This method must be implemented in subclasses. """ raise NotImplementedError diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py index 96f8f14561e77..33f9b7c440240 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py @@ -57,7 +57,7 @@ def _matches(self, tokens_roi_sender: List[torch.Tensor], # so any of the data in the buffer can be drop-selected return True - # Assuming that roi is a mask on tokens + # Assuming that roi is a binary mask on tokens tokens_sender = tokens_sender[roi_sender] tokens_recver = tokens_recver[roi_recver] @@ -128,6 +128,9 @@ def drop_select_handler(self): matched_length = 0 # perform input tokens and roi matching + # FIXME: this matching is O(n), ideally it should be O(1) + # but this buffer size won't (and shouldn't) be too large so + # the fix is not urgent. with self.buffer_lock: for _ in range(len(self.buffer)): @@ -158,11 +161,13 @@ def drop_select_handler(self): logger.debug("Closing drop_select_handler") - def drop_select(self, input_tokens: torch.Tensor, - roi: torch.Tensor) -> List[Optional[torch.Tensor]]: + def drop_select( + self, input_tokens: Optional[torch.Tensor], + roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: assert self.request_handling_thread is None, \ - "drop_select should be called by the receiver" + "drop_select should be called by the KV cache consumer "\ + "(e.g. the decode vLLM instance)" if isinstance(input_tokens, torch.Tensor): input_tokens = input_tokens.clone() @@ -188,8 +193,11 @@ def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, key: torch.Tensor, value: torch.Tensor, hidden: torch.Tensor) -> None: + if self.buffer_size > self.buffer_size_threshold: + # log outside the while loop to avoid this message being logged + # repeatedly. + logger.debug("KV transfer buffer is full. Handling...") while self.buffer_size > self.buffer_size_threshold: - # logger.debug("KV transfer buffer is full. Handling...") self.full_handler() self._add_to_buffer(input_tokens, roi, key, value, hidden) From 06cb15c7b13c84fd201f99dafb866e8669c1494a Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Tue, 24 Sep 2024 19:47:09 +0000 Subject: [PATCH 266/303] align docstring syntax --- vllm/distributed/kv_transfer/kv_lookup_buffer/base.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py index 4bacde2884341..bad119a1aa929 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py @@ -71,8 +71,7 @@ def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, def drop_select( self, input_tokens: Optional[torch.Tensor], roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: - """ - Select and *drop* KV cache entries from the lookup buffer. + """Select and *drop* KV cache entries from the lookup buffer. The functionality is similar to the following python statements ``` @@ -98,8 +97,7 @@ def drop_select( @abstractmethod def close(self) -> None: - """ - Close the buffer and release resources. + """Close the buffer and release resources. This method is responsible for cleaning up resources related to the lookup buffer when it is no longer needed. From 7c11a392c012bc795a444ff547df103cb45cbbe6 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Tue, 24 Sep 2024 19:58:56 +0000 Subject: [PATCH 267/303] add docstring for abstract classes --- .../kv_lookup_buffer/simple_buffer.py | 2 +- vllm/distributed/kv_transfer/kv_pipe/base.py | 47 ++++++++++++++++++- 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py index 33f9b7c440240..a3a9e8a2846b5 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py @@ -194,7 +194,7 @@ def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, hidden: torch.Tensor) -> None: if self.buffer_size > self.buffer_size_threshold: - # log outside the while loop to avoid this message being logged + # log outside the while loop to avoid this message being logged # repeatedly. logger.debug("KV transfer buffer is full. Handling...") while self.buffer_size > self.buffer_size_threshold: diff --git a/vllm/distributed/kv_transfer/kv_pipe/base.py b/vllm/distributed/kv_transfer/kv_pipe/base.py index 0955b4e838896..79e235b48fd72 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/base.py +++ b/vllm/distributed/kv_transfer/kv_pipe/base.py @@ -1,3 +1,13 @@ +""" +This file defines +`KVPipeBase` +that provides an abstraction for sending and receiving tensors, or None, via +distributed communications. + +All distributed communications for disagg prefill & KV cache storage should be +handled by `KVPipeBase`. +""" + from abc import ABC, abstractmethod from typing import Optional @@ -5,15 +15,50 @@ class KVPipeBase(ABC): + """ + This class provides an interface for sending and receiving tensors, or + None, by distributed communications. + """ @abstractmethod def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + """Send a tensor, or None, via the pipe. + + Need to support sending None -- important for error handling. + + TODO: add a `key` argument so that we can use traditional + key-value database as the distributed communication mechanism behind + the pipe. + + Args: + tensor (Optional[torch.Tensor]): The tensor to be sent. Can be None. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ raise NotImplementedError @abstractmethod def recv_tensor(self) -> Optional[torch.Tensor]: + """Receive a tensor (can be None) from the pipeline. + + Returns: + Optional[torch.Tensor]: The tensor received from the pipeline. Can + be None. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ raise NotImplementedError @abstractmethod - def close(self): + def close(self) -> None: + """Close the pipeline and release resources. + + This method is responsible for closing the communication pipeline + and releasing any resources associated with it. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ raise NotImplementedError From 37bac34d2e59c2a374a2d557ba38b026c68d5b72 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Tue, 24 Sep 2024 20:03:23 +0000 Subject: [PATCH 268/303] put assertion at the end of the function --- .../distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py index a3a9e8a2846b5..8dfa61780ddaf 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py @@ -82,8 +82,8 @@ def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]): return 0 if isinstance(data, torch.Tensor): return data.element_size() * data.numel() - else: - raise AssertionError("Unknown data type %s" % type(data)) + + raise AssertionError("Unknown data type %s" % type(data)) def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor, key: torch.Tensor, value: torch.Tensor, From 111abb463df1825aeab0c8632429d297f1deeb08 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Tue, 24 Sep 2024 20:17:22 +0000 Subject: [PATCH 269/303] add fp8 support to pipe --- .../kv_transfer/kv_pipe/torch_distributed_pipe.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py index c2c5cbbe95b0a..f58643d316a07 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py @@ -15,13 +15,15 @@ # this means that the sended object is None. NONE_INT = -150886311 -# Mapping tensor dtype to a int, used for tensor metadata transmission +# Mapping tensor dtype to INT64, used for tensor metadata transmission FLOAT16_INT = -543205003776624 INT64_INT = -375623078607432 BOOL_INT = -28035262008646 BFLOAT16_INT = -452084912267662 FLOAT32_INT = -1049557997456592 FLOAT64_INT = -452201007054137 +FLOAT8_E4M3FN_INT = -1066697177659525 +FLOAT8_E5M2_INT = -618182574682355 DTYPE2INT = { torch.float16: FLOAT16_INT, @@ -30,6 +32,8 @@ torch.bfloat16: BFLOAT16_INT, torch.float32: FLOAT32_INT, torch.float64: FLOAT64_INT, + torch.float8_e4m3fn: FLOAT8_E4M3FN_INT, + torch.float8_e5m2: FLOAT8_E5M2_INT, } INT2DTYPE = { @@ -39,6 +43,8 @@ BFLOAT16_INT: torch.bfloat16, FLOAT32_INT: torch.float32, FLOAT64_INT: torch.float64, + FLOAT8_E4M3FN_INT: torch.float8_e4m3fn, + FLOAT8_E5M2_INT: torch.float8_e5m2, } From 394afaa21f0636f84d5b39239f726a2edf0c4816 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Tue, 24 Sep 2024 20:24:34 +0000 Subject: [PATCH 270/303] adjust docstrings --- .../kv_pipe/torch_distributed_pipe.py | 37 ++++++------------- 1 file changed, 12 insertions(+), 25 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py index f58643d316a07..3fe3fa289c662 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py @@ -89,9 +89,6 @@ def __init__( self.target_rank_for_recv = self.ranks[(self.rank_in_group - 1) % self.world_size] - # FIXME: why we need this? - torch.set_default_device(self.device) - self.transport_thread: Optional[ThreadPoolExecutor] = None self.buffer_size = 0 self.buffer_size_lock = threading.Lock() @@ -110,8 +107,7 @@ def _select_device(self, backend: Union[str, Backend]): return "cpu" def _make_metadata(self, tensor: torch.Tensor) -> torch.Tensor: - """ - Create the metadata on based on the input tensor, and move it to GPU. + """Create the metadata on based on the input tensor, and move it to GPU. The metadata's length is `TorchDistributedPipe.METADATA_LENGTH`. Currently, the metadata is a int64 tensor and it includes dtype, number @@ -129,7 +125,9 @@ def _make_metadata(self, tensor: torch.Tensor) -> torch.Tensor: Returns: - metadata: the metadata tensor, on self.device """ - buffer = torch.empty(self.METADATA_LENGTH, dtype=self.METADATA_DTYPE) + buffer = torch.empty(self.METADATA_LENGTH, + dtype=self.METADATA_DTYPE, + device="cpu") buffer[0] = DTYPE2INT[tensor.dtype] ndims = len(tensor.shape) buffer[1] = len(tensor.shape) @@ -139,8 +137,7 @@ def _make_metadata(self, tensor: torch.Tensor) -> torch.Tensor: def _prepare_recv_buffer(self, d_metadata_buffer: torch.Tensor) -> torch.Tensor: - """ - Create a buffer to receive the tensor based on the metadata. + """Create a buffer to receive the tensor based on the metadata. Parameters: - d_metadata_buffer: the metadata tensor on self.device @@ -155,8 +152,7 @@ def _prepare_recv_buffer(self, return torch.empty(shape, dtype=dtype, device=self.device) def _send_metadata(self, d_metadata_buffer: torch.Tensor): - """ - Send the metadata buffer to the target rank. + """Send the metadata buffer to the target rank. """ torch.distributed.send( d_metadata_buffer, @@ -165,8 +161,7 @@ def _send_metadata(self, d_metadata_buffer: torch.Tensor): ) def _recv_metadata(self) -> torch.Tensor: - """ - Receive the metadata buffer from the target rank. + """Receive the metadata buffer from the target rank. Returns: - metadata_buffer: the metadata buffer tensor, on self.device @@ -195,11 +190,9 @@ def _send_impl(self, tensor): metadata = self._make_metadata(tensor) self._send_metadata(metadata) - #logger.debug(f"Sent meta {metadata}") torch.distributed.send(tensor.to(self.device), dst=self.target_rank_for_send, group=self.device_group) - #logger.debug(f"Sent tensor {tensor}") def _recv_impl(self) -> torch.Tensor: """ @@ -235,17 +228,14 @@ def send_tensor_wrapper(self, tensor): traceback.print_exc() def block_if_full(self): - """ - Block the current thread if the buffer size is larger than 1e9. - """ + """Block the current thread if the buffer size is larger than 1e9.""" # TODO: replace this 1e9 with a configurable parameter or a constant while self.buffer_size > 1e9: logger.debug("KV cache transfer pipe is full. Waiting...") time.sleep(0.05) def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: - """ - Sends a tensor to the destination rank in a non-blocking way. + """Sends a tensor to the destination rank in a non-blocking way. Flow: send tensor dim -- send tensor shape -- send tensor data """ @@ -254,9 +244,8 @@ def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: if tensor is None: tensor = self.none_tensor - tensor_size = 0 - else: - tensor_size = tensor.element_size() * tensor.numel() + + tensor_size = tensor.element_size() * tensor.numel() assert ( 0 < len(tensor.shape) < self.MAX_TENSOR_DIMENSIONS @@ -294,9 +283,7 @@ def recv_tensor(self) -> Optional[torch.Tensor]: return tensor def close(self): - """ - Close the pipe and release the resources. - """ + """Close the pipe and release the resources.""" if (hasattr(self, "transport_thread") and self.transport_thread is not None): self.transport_thread.shutdown() From 76019f1f42c8142b682f6c277f2282decc4c965c Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Tue, 24 Sep 2024 21:49:05 +0000 Subject: [PATCH 271/303] bug fix: check isinstance(torch.Tensor) before checking NOne --- .../kv_transfer/kv_lookup_buffer/simple_buffer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py index 8dfa61780ddaf..41c2fba31fbea 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py @@ -77,11 +77,13 @@ def _send_tensor_and_dec_size(self, self.data_pipe.send_tensor(tensor) def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]): - - if not data: - return 0 + if isinstance(data, torch.Tensor): return data.element_size() * data.numel() + if not data: + # cannot perform `not data` on a tensor + # so this check needs to go after the check above + return 0 raise AssertionError("Unknown data type %s" % type(data)) From 93ec62b8556e279d2c050bdc1c3247831bd39466 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Tue, 24 Sep 2024 22:16:03 +0000 Subject: [PATCH 272/303] make format check happy --- vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py index 41c2fba31fbea..eb052e2e41e11 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py @@ -77,7 +77,7 @@ def _send_tensor_and_dec_size(self, self.data_pipe.send_tensor(tensor) def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]): - + if isinstance(data, torch.Tensor): return data.element_size() * data.numel() if not data: From c5bdf64887725a5719bfb304419318fdc3f49ef0 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Thu, 10 Oct 2024 09:35:48 -0700 Subject: [PATCH 273/303] Adjust to latest changes of `kv_caches`: it is now always a tensor. --- vllm/worker/model_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 91c352d9babe2..b7459dc838a52 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1774,7 +1774,7 @@ def need_recv_kv(self, model_input, kv_caches) -> bool: prefill_meta = model_input.attn_metadata.prefill_metadata # check if the current run is profiling - is_profile_run = (kv_caches is None) or (kv_caches[0] is None) + is_profile_run = (kv_caches.numel() == 0) # check if the current run is prefill is_prefill_run = prefill_meta is not None @@ -1796,7 +1796,7 @@ def need_send_kv(self, model_input, kv_caches) -> bool: prefill_meta = model_input.attn_metadata.prefill_metadata # check if the current run is profiling - is_profile_run = (kv_caches is None) or (kv_caches[0] is None) + is_profile_run = (kv_caches.numel() == 0) # check if the current run is prefill is_prefill_run = prefill_meta is not None From 596eb642e3d80823c0c3c7efe2fb7c65ae48e391 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Thu, 10 Oct 2024 09:42:29 -0700 Subject: [PATCH 274/303] debug --- vllm/worker/model_runner.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b7459dc838a52..1949f3b581788 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1774,6 +1774,7 @@ def need_recv_kv(self, model_input, kv_caches) -> bool: prefill_meta = model_input.attn_metadata.prefill_metadata # check if the current run is profiling + print(kv_caches) is_profile_run = (kv_caches.numel() == 0) # check if the current run is prefill is_prefill_run = prefill_meta is not None @@ -1796,6 +1797,7 @@ def need_send_kv(self, model_input, kv_caches) -> bool: prefill_meta = model_input.attn_metadata.prefill_metadata # check if the current run is profiling + print(kv_caches) is_profile_run = (kv_caches.numel() == 0) # check if the current run is prefill is_prefill_run = prefill_meta is not None From 683bd9cff790adcbeb3fe8ccbc1b07a33810e518 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Thu, 10 Oct 2024 09:44:33 -0700 Subject: [PATCH 275/303] bug fix: kv_caches will be list of torch.tensor([]) in profile run. --- vllm/worker/model_runner.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 1949f3b581788..852c040413496 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1774,8 +1774,7 @@ def need_recv_kv(self, model_input, kv_caches) -> bool: prefill_meta = model_input.attn_metadata.prefill_metadata # check if the current run is profiling - print(kv_caches) - is_profile_run = (kv_caches.numel() == 0) + is_profile_run = (kv_caches[0].numel() == 0) # check if the current run is prefill is_prefill_run = prefill_meta is not None @@ -1797,8 +1796,7 @@ def need_send_kv(self, model_input, kv_caches) -> bool: prefill_meta = model_input.attn_metadata.prefill_metadata # check if the current run is profiling - print(kv_caches) - is_profile_run = (kv_caches.numel() == 0) + is_profile_run = (kv_caches[0].numel() == 0) # check if the current run is prefill is_prefill_run = prefill_meta is not None From 521daba7eead630d1b0c6b5bf4f8ba83580fb22a Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Thu, 10 Oct 2024 12:59:40 -0700 Subject: [PATCH 276/303] Relax server start timeout limit --- tests/kv_transfer/disagg_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kv_transfer/disagg_test.py b/tests/kv_transfer/disagg_test.py index 96203b2dc65a0..3dfacbdc5fe84 100644 --- a/tests/kv_transfer/disagg_test.py +++ b/tests/kv_transfer/disagg_test.py @@ -80,7 +80,7 @@ def setup_servers(): # Helper function to wait for server -def wait_for_server(port, timeout=120): +def wait_for_server(port, timeout=240): start_time = time.time() while time.time() - start_time < timeout: try: From 7efdf6085dd3f9f37676f91748c4e9d010e3356b Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Fri, 8 Nov 2024 20:35:12 +0000 Subject: [PATCH 277/303] Adjust folder format --- vllm/config.py | 43 + .../__init__.py | 0 .../base.py | 40 +- .../kv_connector/lmcache_connector.py | 31 + .../torch_distributed_connector.py | 743 ++++++++++++++++++ .../kv_lookup_buffer/simple_buffer.py | 223 ------ 6 files changed, 848 insertions(+), 232 deletions(-) rename vllm/distributed/kv_transfer/{kv_lookup_buffer => kv_connector}/__init__.py (100%) rename vllm/distributed/kv_transfer/{kv_lookup_buffer => kv_connector}/base.py (77%) create mode 100644 vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py create mode 100644 vllm/distributed/kv_transfer/kv_connector/torch_distributed_connector.py delete mode 100644 vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py diff --git a/vllm/config.py b/vllm/config.py index 91bbbfec4b7b3..5284d0a09a6fe 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1968,6 +1968,48 @@ def __post_init__(self): "OpenTelemetry is not available. Unable to configure " "'otlp_traces_endpoint'. Ensure OpenTelemetry packages are " f"installed. Original error:\n{otel_import_error_traceback}") + + +@dataclass +class KVTransferConfig: + """Configuration for KV cache transfer between vLLM instances + + To distinguish between other configs, all the configs here are prefixed + with "kv" + """ + + # the connector to use for kv cache transfer + # vLLM natively supports TorchDistributedConnector and also include + # third-party connector, but does not provide official support, please + # reach out to the code owner for those connectors for support. + kv_connector: Optional[str] = None + + # the buffer size parameter, used to configure the buffer size of receiving + # KV cache + kv_buffer_size: Optional[float] = None + kv_transfer_role: Optional[str] = None + kv_init_method: Optional[str] = None + kv_xpyd: Optional[str] = None + kv_rank: Optional[int] = None + + def __post_init__(self): + if self.kv_connector is None and not all([ + self.kv_buffer_size is None, + self.kv_transfer_role is None, + self.kv_init_method is None, + self.kv_xpyd is None, + self.kv_xypd_rank is None, + ]): + raise ValueError("Please specify kv_connector before configuring " + "variables with prefix `kv_`") + + assert self.kv_connector in [ + None, + 'TorchDistributedConnector', + 'LMCacheConnector' + ], f"Existing kv connectors are `TorchDistributedConnector` and "\ + f"`LMCacheConnector`. Got {self.kv_connector}" + @dataclass @@ -1988,6 +2030,7 @@ class VllmConfig: observability_config: Optional[ObservabilityConfig] = None prompt_adapter_config: Optional[PromptAdapterConfig] = None quant_config: Optional[QuantizationConfig] = None + kv_transfer_config: Optional[KVTransferConfig] = None @staticmethod def _get_quantization_config( diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/__init__.py b/vllm/distributed/kv_transfer/kv_connector/__init__.py similarity index 100% rename from vllm/distributed/kv_transfer/kv_lookup_buffer/__init__.py rename to vllm/distributed/kv_transfer/kv_connector/__init__.py diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py b/vllm/distributed/kv_transfer/kv_connector/base.py similarity index 77% rename from vllm/distributed/kv_transfer/kv_lookup_buffer/base.py rename to vllm/distributed/kv_transfer/kv_connector/base.py index bad119a1aa929..51f0633f8d5bc 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/base.py @@ -13,9 +13,9 @@ import torch -class KVLookupBufferBase(ABC): +class KVConnectorBase(ABC): """ - Abstract base class for a lookup buffer. + Abstract base class for a KV connector. This class provides an abstraction for a key-value (KV) cache lookup buffer. @@ -36,16 +36,20 @@ class KVLookupBufferBase(ABC): - hidden: the final hidden state generated by model forwarding. This allows vLLM to bypass further model forwarding by transmitting the hidden state. """ + + @abstractmethod + def init(self, config: vLLMConfig): + raise NotImplementedError @abstractmethod def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, key: torch.Tensor, value: torch.Tensor, hidden: torch.Tensor) -> None: - """Insert into the lookup buffer. + """Insert into the lookup buffer, similar to SQL insert The functionality is similar to the following python statement ``` - buffer[input_tokens, roi] = [key, value, hidden] + connector[input_tokens, roi] = [key, value, hidden] ``` FIXME: in the future, we should only have two arguments, key and value, @@ -68,15 +72,14 @@ def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, raise NotImplementedError @abstractmethod - def drop_select( + def select( self, input_tokens: Optional[torch.Tensor], roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: - """Select and *drop* KV cache entries from the lookup buffer. + """Select KV cache entries from the connector. The functionality is similar to the following python statements ``` - ret = buffer.pop(input_tokens, roi) - return ret + return connector[input_tokens, roi] ``` If `input_tokens` and `roi` is `None`, it means selecting any of the @@ -100,9 +103,28 @@ def close(self) -> None: """Close the buffer and release resources. This method is responsible for cleaning up resources related to the - lookup buffer when it is no longer needed. + connector when it is no longer needed. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + + @abstractmethod + def rebuild_model_input( + self, + model_input: "ModelInputForGPUWithSamplingMetadata", + input_tokens_list: List[torch.Tensor], + num_computed_tokens_list: List[int], + start_pos_list: List[int], + slot_mapping_flat: torch.Tensor, + device: torch.device, + ) -> "ModelInputForGPUWithSamplingMetadata": + """Rebuild the model input based on how many KV caches are received Raises: NotImplementedError: This method must be implemented in subclasses. """ raise NotImplementedError + diff --git a/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py new file mode 100644 index 0000000000000..5fa45fb0b337f --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py @@ -0,0 +1,31 @@ +""" + This file implements a simple torch distributed connector by 2 classes: + - `TorchDistributedPipe`: a tensor transmission pipe between P/D instance, + using `torch.distributed` + - `TorchDistributedConnector`: a torch distributed connector between P/D + instance, implemented on top of `TorchDistributedPipe` +""" +import threading +import time +from collections import deque +from concurrent.futures import ThreadPoolExecutor +from typing import Deque, List, Optional, Union + +import torch +from torch.distributed import Backend + +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.logger import init_logger + +logger = init_logger(__name__) + +try: + import lmcache +except ModuleNotFoundError as e: + logger.error("LMcache not installed, please install LMCache.") + raise e + + +class LMCacheConnector(KVConnectorBase): + + pass \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/kv_connector/torch_distributed_connector.py b/vllm/distributed/kv_transfer/kv_connector/torch_distributed_connector.py new file mode 100644 index 0000000000000..d7e386f935f03 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/torch_distributed_connector.py @@ -0,0 +1,743 @@ +""" + This file implements a simple torch distributed connector by 2 classes: + - `TorchDistributedPipe`: a tensor transmission pipe between P/D instance, + using `torch.distributed` + - `TorchDistributedConnector`: a torch distributed connector between P/D + instance, implemented on top of `TorchDistributedPipe` +""" +import threading +import time +from collections import deque +from concurrent.futures import ThreadPoolExecutor +from typing import Deque, List, Optional, Union + +import torch +from torch.distributed import Backend + +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.logger import init_logger + + + +logger = init_logger(__name__) + +# if the tensor is only one-element and only contains NONE_INT +# this means that the sended object is None. +NONE_INT = -150886311 + +# Mapping tensor dtype to INT64, used for tensor metadata transmission +FLOAT16_INT = -543205003776624 +INT64_INT = -375623078607432 +BOOL_INT = -28035262008646 +BFLOAT16_INT = -452084912267662 +FLOAT32_INT = -1049557997456592 +FLOAT64_INT = -452201007054137 +FLOAT8_E4M3FN_INT = -1066697177659525 +FLOAT8_E5M2_INT = -618182574682355 + +DTYPE2INT = { + torch.float16: FLOAT16_INT, + torch.int64: INT64_INT, + torch.bool: BOOL_INT, + torch.bfloat16: BFLOAT16_INT, + torch.float32: FLOAT32_INT, + torch.float64: FLOAT64_INT, + torch.float8_e4m3fn: FLOAT8_E4M3FN_INT, + torch.float8_e5m2: FLOAT8_E5M2_INT, +} + +INT2DTYPE = { + FLOAT16_INT: torch.float16, + INT64_INT: torch.int64, + BOOL_INT: torch.bool, + BFLOAT16_INT: torch.bfloat16, + FLOAT32_INT: torch.float32, + FLOAT64_INT: torch.float64, + FLOAT8_E4M3FN_INT: torch.float8_e4m3fn, + FLOAT8_E5M2_INT: torch.float8_e5m2, +} + + +class BrokenPipeException(Exception): + + def __init__(self, message): + self.message = message + super().__init__(self.message) + + +class TorchDistributedPipe: + + METADATA_LENGTH = 16 + MAX_TENSOR_DIMENSIONS = 14 + METADATA_DTYPE = torch.int64 + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + ): + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend) + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + + assert self.device_group is not None + assert self.rank_in_group <= 1 + + self.device = self._select_device(torch_distributed_backend) + + self.target_rank_for_send = self.ranks[(self.rank_in_group + 1) % + self.world_size] + self.target_rank_for_recv = self.ranks[(self.rank_in_group - 1) % + self.world_size] + + self.transport_thread: Optional[ThreadPoolExecutor] = None + self.buffer_size = 0 + self.buffer_size_lock = threading.Lock() + + self.none_tensor = torch.tensor([NONE_INT], device=self.device) + + # On-device tensors to be reused for recv + self.rcv_metadata_buffer = torch.zeros(self.METADATA_LENGTH, + dtype=self.METADATA_DTYPE, + device=self.device) + + def _select_device(self, backend: Union[str, Backend]): + if torch.cuda.is_available() and backend == Backend.NCCL: + return torch.device(f"cuda:{self.local_rank}") + else: + return "cpu" + + def _make_metadata(self, tensor: torch.Tensor) -> torch.Tensor: + """Create the metadata on based on the input tensor, and move it to GPU. + The metadata's length is `TorchDistributedPipe.METADATA_LENGTH`. + + Currently, the metadata is a int64 tensor and it includes dtype, number + of dimensions, and the shape information of the input tensor. + + + The information follows the layout below: + - metadata[0] -- dtype + - metadata[1] -- number of dimensions + - metadata[2 : 2+ndims] -- the shape of the input tensor + + Parameters: + - tensor: the input tensor + + Returns: + - metadata: the metadata tensor, on self.device + """ + buffer = torch.empty(self.METADATA_LENGTH, + dtype=self.METADATA_DTYPE, + device="cpu") + buffer[0] = DTYPE2INT[tensor.dtype] + ndims = len(tensor.shape) + buffer[1] = len(tensor.shape) + buffer[2:2 + ndims] = torch.tensor(tensor.shape, + dtype=self.METADATA_DTYPE) + return buffer.to(self.device) + + def _prepare_recv_buffer(self, + d_metadata_buffer: torch.Tensor) -> torch.Tensor: + """Create a buffer to receive the tensor based on the metadata. + + Parameters: + - d_metadata_buffer: the metadata tensor on self.device + + Returns: + - buffer: the buffer tensor to receive the tensor, on self.device + """ + h_buffer = d_metadata_buffer.cpu().numpy() + dtype = INT2DTYPE[h_buffer[0]] + ndims = h_buffer[1] + shape = tuple(h_buffer[2:2 + ndims]) + return torch.empty(shape, dtype=dtype, device=self.device) + + def _send_metadata(self, d_metadata_buffer: torch.Tensor): + """Send the metadata buffer to the target rank. + """ + torch.distributed.send( + d_metadata_buffer, + dst=self.target_rank_for_send, + group=self.device_group, + ) + + def _recv_metadata(self) -> torch.Tensor: + """Receive the metadata buffer from the target rank. + + Returns: + - metadata_buffer: the metadata buffer tensor, on self.device + + Note: + The current implementation uses the assumption that there is no + race conditions during sending/receiving. Therefore, the metadata + buffer can be reused + """ + torch.distributed.recv( + self.rcv_metadata_buffer, + src=self.target_rank_for_recv, + group=self.device_group, + ) + + return self.rcv_metadata_buffer + + def _send_impl(self, tensor): + """ + The actual implementation of sending the tensor to the target rank. + This function will first send the metadata, and then send the tensor. + + Parameters: + - tensor: the input tensor to be sent + """ + + metadata = self._make_metadata(tensor) + self._send_metadata(metadata) + torch.distributed.send(tensor.to(self.device), + dst=self.target_rank_for_send, + group=self.device_group) + + def _recv_impl(self) -> torch.Tensor: + """ + The actual implementation of receiving the tensor from the target rank. + This function will first receive the metadata, then receive the tensor. + + This function will block if there is no tensor to receive. + + Returns: + - buffer: the received tensor, on self.device + """ + d_metadata = self._recv_metadata() + buffer = self._prepare_recv_buffer(d_metadata) + + torch.distributed.recv(buffer, + src=self.target_rank_for_recv, + group=self.device_group) + + return buffer + + def send_tensor_wrapper(self, tensor): + try: + """Wrapper for send_tensor_dict""" + tensor_size = tensor.element_size() * tensor.numel() + self._send_impl(tensor) + + with self.buffer_size_lock: + self.buffer_size = self.buffer_size - tensor_size + except Exception as e: + logger.error("[rank%d]: Exception when trying to send %s, msg: %s", + torch.distributed.get_rank(), str(tensor), str(e)) + import traceback + traceback.print_exc() + + def block_if_full(self): + """Block the current thread if the buffer size is larger than 1e9.""" + # TODO: replace this 1e9 with a configurable parameter or a constant + while self.buffer_size > 1e9: + logger.debug("KV cache transfer pipe is full. Waiting...") + time.sleep(0.05) + + def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + """Sends a tensor to the destination rank in a non-blocking way. + Flow: send tensor dim -- send tensor shape -- send tensor data + """ + + if self.transport_thread is None: + self.transport_thread = ThreadPoolExecutor(max_workers=1) + + if tensor is None: + tensor = self.none_tensor + + tensor_size = tensor.element_size() * tensor.numel() + + assert ( + 0 < len(tensor.shape) < self.MAX_TENSOR_DIMENSIONS + ), f"Only support dimensions within 1-{self.MAX_TENSOR_DIMENSIONS}" + + self.block_if_full() + + with self.buffer_size_lock: + self.buffer_size = self.buffer_size + tensor_size + + self.transport_thread.submit( + self.send_tensor_wrapper, + tensor, + ) + + def recv_tensor(self) -> Optional[torch.Tensor]: + """Receives a tensor from the src rank. Blocking.""" + if self.transport_thread is None: + self.transport_thread = ThreadPoolExecutor(max_workers=1) + + future = self.transport_thread.submit(self._recv_impl) + + try: + tensor = future.result() + except Exception as e: + # the underlying pipe is likely broken + logger.error("Encountering exception in KV receiving thread") + logger.error("%s", e) + # fault tolerance: if the pipe is broken, return None + return None + + if tensor.numel() == 1 and tensor.item() == NONE_INT: + return None + else: + return tensor + + def close(self): + """Close the pipe and release the resources.""" + if (hasattr(self, "transport_thread") + and self.transport_thread is not None): + self.transport_thread.shutdown() + + +class TorchDistributedBuffer: + + def __init__(self, + signal_pipe: TorchDistributedPipe, + data_pipe: TorchDistributedPipe, + buffer_size_thresh: int): + """ + signal_pipe: on CPU + + NOTE: on-device recv will block all threads in the process, making the + KV cache producer unable to listen to new request while transmitting + KV cache. Luckily CPU recv only blocks the current thread so we use + CPU recv to listen to new request. + + data_pipe: on device (e.g. GPU) + """ + + self.buffer: Deque[List[torch.Tensor]] = deque() + + self.buffer_size = 0 + self.buffer_size_threshold = buffer_size_thresh + self.buffer_lock = threading.Lock() + self.signal_pipe = signal_pipe + self.data_pipe = data_pipe + self.request_handling_thread: Optional[threading.Thread] = None + + self.normal_signal = torch.tensor([0]) + self.end_signal = None + + def _matches(self, tokens_roi_sender: List[torch.Tensor], + tokens_roi_recver: List[torch.Tensor]): + + # tokens_roi_sender: tokens and roi of the producer (in the buffer) + # tokens_roi_recver: tokens and roi of the consumer (query) + + tokens_sender = tokens_roi_sender[0] + tokens_recver = tokens_roi_recver[0] + roi_sender = tokens_roi_sender[1] + roi_recver = tokens_roi_recver[1] + + if tokens_recver is None: + # consumer sends an empty request + # semantics: DROP SELECT * LIMIT 1 + # so any of the data in the buffer can be drop-selected + return True + + # Assuming that roi is a binary mask on tokens + tokens_sender = tokens_sender[roi_sender] + tokens_recver = tokens_recver[roi_recver] + + # simple common prefix matching + min_length = min(len(tokens_sender), len(tokens_recver)) + if torch.allclose(tokens_sender[:min_length], + tokens_recver[:min_length]): + return min_length + + return 0 + + def _send_tensor_and_dec_size(self, + tensor: Optional[torch.Tensor]) -> None: + + assert tensor is not None, "Use self.data_pipe.send(None) instead" + self.buffer_size -= tensor.element_size() * tensor.numel() + self.data_pipe.send_tensor(tensor) + + def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]): + + if isinstance(data, torch.Tensor): + return data.element_size() * data.numel() + if not data: + # cannot perform `not data` on a tensor + # so this check needs to go after the check above + return 0 + + raise AssertionError("Unknown data type %s" % type(data)) + + def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor): + + if isinstance(input_tokens, torch.Tensor): + input_tokens = input_tokens.clone() + if isinstance(roi, torch.Tensor): + roi = roi.clone() + if isinstance(key, torch.Tensor): + key = key.clone() + if isinstance(value, torch.Tensor): + value = value.clone() + if isinstance(hidden, torch.Tensor): + hidden = hidden.clone() + + buffer_item = [input_tokens, roi, key, value, hidden] + + with self.buffer_lock: + for data in buffer_item: + self.buffer_size += self._get_element_size(data) + self.buffer.append(buffer_item) + + def _is_end_signal(self, signal): + return signal is None + + def drop_select_handler(self): + + try: + + while True: + signal = self.signal_pipe.recv_tensor() + if self._is_end_signal(signal): + logger.info("Received end signal!") + break + + input_tokens = self.data_pipe.recv_tensor() + + roi = self.data_pipe.recv_tensor() + tokens_roi_recver = [input_tokens, roi] + + matched_length = 0 + + # perform input tokens and roi matching + # FIXME: this matching is O(n), ideally it should be O(1) + # but this buffer size won't (and shouldn't) be too large so + # the fix is not urgent. + with self.buffer_lock: + + for _ in range(len(self.buffer)): + + temp_length = self._matches(self.buffer[0], + tokens_roi_recver) + if temp_length > 0: + matched_length = temp_length + break + # rotate the element we just accessed to the end + self.buffer.rotate(-1) + + if matched_length > 0: + # need to clone the tensor + # in case the tensor is freed before sending finishes + matched_item = self.buffer.popleft() + for tensor in matched_item: + self._send_tensor_and_dec_size(tensor) + + else: + # no match, just send None + for _ in range(5): + self.data_pipe.send_tensor(None) + + except RuntimeError as e: + if 'Connection closed by peer' not in str(e): + raise e + + logger.debug("Closing drop_select_handler") + + def drop_select( + self, input_tokens: Optional[torch.Tensor], + roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: + + assert self.request_handling_thread is None, \ + "drop_select should be called by the KV cache consumer "\ + "(e.g. the decode vLLM instance)" + + if isinstance(input_tokens, torch.Tensor): + input_tokens = input_tokens.clone() + if isinstance(roi, torch.Tensor): + roi = roi.clone() + + self.signal_pipe.send_tensor(self.normal_signal) + self.data_pipe.send_tensor(input_tokens) + self.data_pipe.send_tensor(roi) + + input_tokens = self.data_pipe.recv_tensor() + roi = self.data_pipe.recv_tensor() + key = self.data_pipe.recv_tensor() + value = self.data_pipe.recv_tensor() + hidden = self.data_pipe.recv_tensor() + + return [input_tokens, roi, key, value, hidden] + + def full_handler(self): + time.sleep(0.001) + + def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor) -> None: + + if self.buffer_size > self.buffer_size_threshold: + # log outside the while loop to avoid this message being logged + # repeatedly. + logger.debug("KV transfer buffer is full. Handling...") + while self.buffer_size > self.buffer_size_threshold: + self.full_handler() + + self._add_to_buffer(input_tokens, roi, key, value, hidden) + + # when calling the insert, the current process is a sender + # need to launch the request handler and start listening to request. + if self.request_handling_thread is None: + self.request_handling_thread = threading.Thread( + target=self.drop_select_handler) + self.request_handling_thread.start() + + def close(self): + + if hasattr(self, "request_handling_thread" + ) and self.request_handling_thread is not None: + self.request_handling_thread.join() + + else: + # TODO: have a explicit close signal and have a explicit way to + # check if it's requester + self.signal_pipe.send_tensor(self.end_signal) + + +class TorchDistributedConnector(KVConnectorBase): + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + # FIXME(Kuntai): remove this hardcoding + lookup_buffer_size: int): + + self.lookup_buffer_size = lookup_buffer_size + + self.send_buffer: Optional[TorchDistributedBuffer] = None + self.recv_buffer: Optional[TorchDistributedBuffer] = None + + SimpleKVLookupBuffer = sklb.SimpleKVLookupBuffer + + # In disaggregated prefill, the prefill vLLM only uses send pipe + # and the decode vLLM only uses recv pipe + # In remote KV cache store, vLLM will use both send pipe and recv pipe + # So we build both send pipe and recv pipe for simplicity. + if IS_KV_PRODUCER: + + self.send_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + DISTRIBUTED_BACKEND, + ) + self.send_signal_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + "gloo", + ) + self.recv_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + DISTRIBUTED_BACKEND, + ) + self.recv_signal_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + "gloo", + ) + self.send_buffer = SimpleKVLookupBuffer(self.send_signal_pipe, + self.send_pipe, + self.lookup_buffer_size) + self.recv_buffer = SimpleKVLookupBuffer(self.recv_signal_pipe, + self.recv_pipe, + self.lookup_buffer_size) + self.tensor_device = DISTRIBUTED_DEVICE + else: + + # the current vLLM instance is KV consumer, so it needs to connect + # its recv pipe to the send pipe of KV producder + + self.recv_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + DISTRIBUTED_BACKEND, + ) + self.recv_signal_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + "gloo", + ) + self.send_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + DISTRIBUTED_BACKEND, + ) + self.send_signal_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + "gloo", + ) + self.send_buffer = SimpleKVLookupBuffer(self.send_signal_pipe, + self.send_pipe, + self.lookup_buffer_size) + self.recv_buffer = SimpleKVLookupBuffer(self.recv_signal_pipe, + self.recv_pipe, + self.lookup_buffer_size) + self.tensor_device = DISTRIBUTED_DEVICE + + + def select( + self, input_tokens: Optional[torch.Tensor], + roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: + return self.send_buffer.drop_select(input, roi) + + def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor) -> None: + return self.recv_buffer.insert( + input_tokens, + roi, + key, + value, + hidden + ) + + + + def build_partial_prefill_input( + self, + model_input: "ModelInputForGPUWithSamplingMetadata", + input_tokens_list: List[torch.Tensor], + num_computed_tokens_list: List[int], + start_pos_list: List[int], + slot_mapping_flat: torch.Tensor, + device: torch.device, + ) -> "ModelInputForGPUWithSamplingMetadata": + """ + Helper function to rebuild the model input for the current request. + Goal: avoid running redundant prefill on those tokens that already has + KV caches received. + """ + rebuilt_input_tokens = [] + rebuilt_input_positions = [] + rebuilt_query_lens = [] + + rebuilt_num_prefills = 0 + rebuilt_num_prefill_tokens = 0 + rebuilt_slot_mapping = [] + rebuilt_max_query_len = 0 + + rebuilt_block_tables = [] + + rebuilt_query_start_loc = [0] + rebuilt_context_lens_tensor = [] + rebuilt_selected_token_indices = [] + + # recounting query and context lengths + for idx in range(len(input_tokens_list)): + token_tensor = input_tokens_list[idx] + num_token = len(token_tensor) + num_computed_token = num_computed_tokens_list[idx] + # currently attention kernel cannot handle the case where there is 0 + # query token. + if num_computed_token == num_token: + num_computed_token -= 1 + start_pos = start_pos_list[idx] + + rebuilt_input_tokens.append(token_tensor[num_computed_token:]) + # TODO(Jiayi): please check the correctness of next line + rebuilt_input_positions.append( + model_input.input_positions[start_pos + + num_computed_token:start_pos + + num_token]) + q_len = num_token - num_computed_token + rebuilt_query_lens.append(q_len) + + # Attn metadata-related + rebuilt_num_prefills += 1 + rebuilt_num_prefill_tokens += q_len + new_slot_mapping = slot_mapping_flat[start_pos + + num_computed_token:start_pos + + num_token] + rebuilt_slot_mapping.append(new_slot_mapping) + rebuilt_max_query_len = max(q_len, rebuilt_max_query_len) + # TODO(Jiayi): remove hard-code (block_size=16) + blk_size = 16 + temp_block_table = [ + slot_mapping_flat[i] // blk_size + for i in range(start_pos, start_pos + num_token, blk_size) + ] + rebuilt_block_tables.append(temp_block_table) + rebuilt_query_start_loc.append( + rebuilt_num_prefill_tokens) #start with 0 + rebuilt_context_lens_tensor.append(num_computed_token) + + # Sampling metadata related + #seq_groups (use rebuilt query lens) + rebuilt_selected_token_indices.append(rebuilt_num_prefill_tokens - 1) + + # rebuilt attn_metadata + rebuilt_attn_metadata = deepcopy(model_input.attn_metadata) + rebuilt_attn_metadata.num_prefills = rebuilt_num_prefills + rebuilt_attn_metadata.num_prefill_tokens = rebuilt_num_prefill_tokens + rebuilt_attn_metadata.slot_mapping = torch.cat(rebuilt_slot_mapping).to( + device) + rebuilt_attn_metadata.max_query_len = rebuilt_max_query_len + + rebuilt_attn_metadata.block_tables = torch.tensor( + rebuilt_block_tables, + dtype=model_input.attn_metadata.block_tables.dtype).to(device) + + rebuilt_attn_metadata.query_start_loc = torch.tensor( + rebuilt_query_start_loc, + dtype=model_input.attn_metadata.query_start_loc.dtype).to(device) + rebuilt_attn_metadata.context_lens_tensor = torch.tensor( + rebuilt_context_lens_tensor, + dtype=model_input.attn_metadata.context_lens_tensor.dtype, + ).to(device) + + rebuilt_attn_metadata._cached_prefill_metadata = None + + # rebuilt sampling_metadata + rebuilt_sampling_metadata = deepcopy(model_input.sampling_metadata) + for idx, q_len in enumerate(rebuilt_query_lens): + if rebuilt_sampling_metadata.seq_groups is not None: + rebuilt_sampling_metadata.seq_groups[idx].query_len = q_len + + rebuilt_sampling_metadata.selected_token_indices = torch.tensor( + rebuilt_selected_token_indices, + dtype=model_input.sampling_metadata.selected_token_indices.dtype, + ).to(device) + + # import here to avoid circular import. + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + rebuilt_model_input = ModelInputForGPUWithSamplingMetadata( + input_tokens=torch.cat(rebuilt_input_tokens).to(device), + input_positions=torch.cat(rebuilt_input_positions).to(device), + seq_lens=model_input.seq_lens, + query_lens=rebuilt_query_lens, + lora_mapping=model_input.lora_mapping, + lora_requests=model_input.lora_requests, + attn_metadata=rebuilt_attn_metadata, + prompt_adapter_mapping=model_input.prompt_adapter_mapping, + prompt_adapter_requests=model_input.prompt_adapter_requests, + multi_modal_kwargs=model_input.multi_modal_kwargs, + request_ids_to_seq_ids=model_input.request_ids_to_seq_ids, + finished_requests_ids=model_input.finished_requests_ids, + virtual_engine=model_input.virtual_engine, + sampling_metadata=rebuilt_sampling_metadata, + is_prompt=model_input.is_prompt, + ) + + return rebuilt_model_input + diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py deleted file mode 100644 index eb052e2e41e11..0000000000000 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +++ /dev/null @@ -1,223 +0,0 @@ -import threading -import time -from collections import deque -from typing import Deque, List, Optional, Union - -import torch - -from vllm.distributed.kv_transfer.kv_lookup_buffer.base import ( - KVLookupBufferBase) -from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase -from vllm.logger import init_logger - -logger = init_logger(__name__) - - -class SimpleKVLookupBuffer(KVLookupBufferBase): - - def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, - buffer_size_thresh: int): - """ - signal_pipe: on CPU - - NOTE: on-device recv will block all threads in the process, making the - KV cache producer unable to listen to new request while transmitting - KV cache. Luckily CPU recv only blocks the current thread so we use - CPU recv to listen to new request. - - data_pipe: on device (e.g. GPU) - """ - - self.buffer: Deque[List[torch.Tensor]] = deque() - - self.buffer_size = 0 - self.buffer_size_threshold = buffer_size_thresh - self.buffer_lock = threading.Lock() - self.signal_pipe = signal_pipe - self.data_pipe = data_pipe - self.request_handling_thread: Optional[threading.Thread] = None - - self.normal_signal = torch.tensor([0]) - self.end_signal = None - - def _matches(self, tokens_roi_sender: List[torch.Tensor], - tokens_roi_recver: List[torch.Tensor]): - - # tokens_roi_sender: tokens and roi of the producer (in the buffer) - # tokens_roi_recver: tokens and roi of the consumer (query) - - tokens_sender = tokens_roi_sender[0] - tokens_recver = tokens_roi_recver[0] - roi_sender = tokens_roi_sender[1] - roi_recver = tokens_roi_recver[1] - - if tokens_recver is None: - # consumer sends an empty request - # semantics: DROP SELECT * LIMIT 1 - # so any of the data in the buffer can be drop-selected - return True - - # Assuming that roi is a binary mask on tokens - tokens_sender = tokens_sender[roi_sender] - tokens_recver = tokens_recver[roi_recver] - - # simple common prefix matching - min_length = min(len(tokens_sender), len(tokens_recver)) - if torch.allclose(tokens_sender[:min_length], - tokens_recver[:min_length]): - return min_length - - return 0 - - def _send_tensor_and_dec_size(self, - tensor: Optional[torch.Tensor]) -> None: - - assert tensor is not None, "Use self.data_pipe.send(None) instead" - self.buffer_size -= tensor.element_size() * tensor.numel() - self.data_pipe.send_tensor(tensor) - - def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]): - - if isinstance(data, torch.Tensor): - return data.element_size() * data.numel() - if not data: - # cannot perform `not data` on a tensor - # so this check needs to go after the check above - return 0 - - raise AssertionError("Unknown data type %s" % type(data)) - - def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor, - key: torch.Tensor, value: torch.Tensor, - hidden: torch.Tensor): - - if isinstance(input_tokens, torch.Tensor): - input_tokens = input_tokens.clone() - if isinstance(roi, torch.Tensor): - roi = roi.clone() - if isinstance(key, torch.Tensor): - key = key.clone() - if isinstance(value, torch.Tensor): - value = value.clone() - if isinstance(hidden, torch.Tensor): - hidden = hidden.clone() - - buffer_item = [input_tokens, roi, key, value, hidden] - - with self.buffer_lock: - for data in buffer_item: - self.buffer_size += self._get_element_size(data) - self.buffer.append(buffer_item) - - def _is_end_signal(self, signal): - return signal is None - - def drop_select_handler(self): - - try: - - while True: - signal = self.signal_pipe.recv_tensor() - if self._is_end_signal(signal): - logger.info("Received end signal!") - break - - input_tokens = self.data_pipe.recv_tensor() - - roi = self.data_pipe.recv_tensor() - tokens_roi_recver = [input_tokens, roi] - - matched_length = 0 - - # perform input tokens and roi matching - # FIXME: this matching is O(n), ideally it should be O(1) - # but this buffer size won't (and shouldn't) be too large so - # the fix is not urgent. - with self.buffer_lock: - - for _ in range(len(self.buffer)): - - temp_length = self._matches(self.buffer[0], - tokens_roi_recver) - if temp_length > 0: - matched_length = temp_length - break - # rotate the element we just accessed to the end - self.buffer.rotate(-1) - - if matched_length > 0: - # need to clone the tensor - # in case the tensor is freed before sending finishes - matched_item = self.buffer.popleft() - for tensor in matched_item: - self._send_tensor_and_dec_size(tensor) - - else: - # no match, just send None - for _ in range(5): - self.data_pipe.send_tensor(None) - - except RuntimeError as e: - if 'Connection closed by peer' not in str(e): - raise e - - logger.debug("Closing drop_select_handler") - - def drop_select( - self, input_tokens: Optional[torch.Tensor], - roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: - - assert self.request_handling_thread is None, \ - "drop_select should be called by the KV cache consumer "\ - "(e.g. the decode vLLM instance)" - - if isinstance(input_tokens, torch.Tensor): - input_tokens = input_tokens.clone() - if isinstance(roi, torch.Tensor): - roi = roi.clone() - - self.signal_pipe.send_tensor(self.normal_signal) - self.data_pipe.send_tensor(input_tokens) - self.data_pipe.send_tensor(roi) - - input_tokens = self.data_pipe.recv_tensor() - roi = self.data_pipe.recv_tensor() - key = self.data_pipe.recv_tensor() - value = self.data_pipe.recv_tensor() - hidden = self.data_pipe.recv_tensor() - - return [input_tokens, roi, key, value, hidden] - - def full_handler(self): - time.sleep(0.001) - - def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, - key: torch.Tensor, value: torch.Tensor, - hidden: torch.Tensor) -> None: - - if self.buffer_size > self.buffer_size_threshold: - # log outside the while loop to avoid this message being logged - # repeatedly. - logger.debug("KV transfer buffer is full. Handling...") - while self.buffer_size > self.buffer_size_threshold: - self.full_handler() - - self._add_to_buffer(input_tokens, roi, key, value, hidden) - - # when calling the insert, the current process is a sender - # need to launch the request handler and start listening to request. - if self.request_handling_thread is None: - self.request_handling_thread = threading.Thread( - target=self.drop_select_handler) - self.request_handling_thread.start() - - def close(self): - - if hasattr(self, "request_handling_thread" - ) and self.request_handling_thread is not None: - self.request_handling_thread.join() - - else: - # TODO: have a explicit close signal and have a explicit way to - # check if it's requester - self.signal_pipe.send_tensor(self.end_signal) From 1c608e69d08d9dee16fe96c26800a1222f1aa13b Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sat, 9 Nov 2024 05:55:59 +0000 Subject: [PATCH 278/303] config fix --- vllm/config.py | 24 +- .../kv_transfer/kv_connector/base.py | 4 +- .../torch_distributed_connector.py | 66 ++-- .../kv_transfer/kv_pipe/__init__.py | 0 vllm/distributed/kv_transfer/kv_pipe/base.py | 64 ---- .../kv_pipe/torch_distributed_pipe.py | 289 ------------------ vllm/engine/arg_utils.py | 44 ++- 7 files changed, 101 insertions(+), 390 deletions(-) delete mode 100644 vllm/distributed/kv_transfer/kv_pipe/__init__.py delete mode 100644 vllm/distributed/kv_transfer/kv_pipe/base.py delete mode 100644 vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py diff --git a/vllm/config.py b/vllm/config.py index 5284d0a09a6fe..e369707ec9660 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1988,27 +1988,35 @@ class KVTransferConfig: # KV cache kv_buffer_size: Optional[float] = None kv_transfer_role: Optional[str] = None - kv_init_method: Optional[str] = None - kv_xpyd: Optional[str] = None - kv_rank: Optional[int] = None + kv_device: Optional[str] = None def __post_init__(self): if self.kv_connector is None and not all([ self.kv_buffer_size is None, self.kv_transfer_role is None, - self.kv_init_method is None, - self.kv_xpyd is None, - self.kv_xypd_rank is None, + self.kv_device is None, ]): raise ValueError("Please specify kv_connector before configuring " "variables with prefix `kv_`") assert self.kv_connector in [ None, - 'TorchDistributedConnector', - 'LMCacheConnector' + "TorchDistributedConnector", + "LMCacheConnector", ], f"Existing kv connectors are `TorchDistributedConnector` and "\ f"`LMCacheConnector`. Got {self.kv_connector}" + + @property + def is_distributed_kv_instance(self) -> bool: + return self.kv_transfer_role in ["kv_producer", "kv_consumer", "kv_both"] + + @property + def is_kv_producer(self) -> bool: + return self.kv_transfer_role in ["kv_producer", "kv_both"] + + @property + def is_kv_consumer(self) -> bool: + return self.kv_transfer_role in ["kv_consumer", "kv_both"] diff --git a/vllm/distributed/kv_transfer/kv_connector/base.py b/vllm/distributed/kv_transfer/kv_connector/base.py index 51f0633f8d5bc..e0e26c0b1ae8d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/base.py @@ -10,6 +10,8 @@ from abc import ABC, abstractmethod from typing import List, Optional +from vllm.config import KVTransferConfig + import torch @@ -38,7 +40,7 @@ class KVConnectorBase(ABC): """ @abstractmethod - def init(self, config: vLLMConfig): + def init(self, config: KVTransferConfig): raise NotImplementedError @abstractmethod diff --git a/vllm/distributed/kv_transfer/kv_connector/torch_distributed_connector.py b/vllm/distributed/kv_transfer/kv_connector/torch_distributed_connector.py index d7e386f935f03..0977793b014ce 100644 --- a/vllm/distributed/kv_transfer/kv_connector/torch_distributed_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/torch_distributed_connector.py @@ -16,11 +16,16 @@ from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase from vllm.logger import init_logger +from vllm.config import KVTransferConfig logger = init_logger(__name__) + + +# magic constants to transmit tensors + # if the tensor is only one-element and only contains NONE_INT # this means that the sended object is None. NONE_INT = -150886311 @@ -76,10 +81,12 @@ def __init__( group_ranks: List[List[int]], local_rank: int, torch_distributed_backend: Union[str, Backend], + buffer_size_thresh: float, ): self.rank = torch.distributed.get_rank() self.local_rank = local_rank self.device_group = None + self.buffer_size_thresh = buffer_size_thresh for ranks in group_ranks: device_group = torch.distributed.new_group( @@ -241,7 +248,7 @@ def send_tensor_wrapper(self, tensor): def block_if_full(self): """Block the current thread if the buffer size is larger than 1e9.""" # TODO: replace this 1e9 with a configurable parameter or a constant - while self.buffer_size > 1e9: + while self.buffer_size > self.buffer_size_thresh: logger.debug("KV cache transfer pipe is full. Waiting...") time.sleep(0.05) @@ -305,7 +312,7 @@ class TorchDistributedBuffer: def __init__(self, signal_pipe: TorchDistributedPipe, data_pipe: TorchDistributedPipe, - buffer_size_thresh: int): + buffer_size_thresh: float): """ signal_pipe: on CPU @@ -518,50 +525,50 @@ def __init__( self, group_ranks: List[List[int]], local_rank: int, - torch_distributed_backend: Union[str, Backend], - # FIXME(Kuntai): remove this hardcoding - lookup_buffer_size: int): + config: KVTransferConfig, + ): - self.lookup_buffer_size = lookup_buffer_size + self.lookup_buffer_size = self.kv_buffer_size self.send_buffer: Optional[TorchDistributedBuffer] = None self.recv_buffer: Optional[TorchDistributedBuffer] = None - - SimpleKVLookupBuffer = sklb.SimpleKVLookupBuffer + + device2backend = { + "cpu": "gloo", + "gpu": "nccl", + } # In disaggregated prefill, the prefill vLLM only uses send pipe # and the decode vLLM only uses recv pipe # In remote KV cache store, vLLM will use both send pipe and recv pipe # So we build both send pipe and recv pipe for simplicity. - if IS_KV_PRODUCER: + if config.is_kv_producer: self.send_pipe = TorchDistributedPipe( group_ranks, local_rank, - DISTRIBUTED_BACKEND, + device2backend[config.kv_device], + self.kv_buffer_size, ) self.send_signal_pipe = TorchDistributedPipe( group_ranks, local_rank, "gloo", + self.kv_buffer_size, ) self.recv_pipe = TorchDistributedPipe( group_ranks, local_rank, - DISTRIBUTED_BACKEND, + device2backend[config.kv_device], + self.kv_buffer_size, ) self.recv_signal_pipe = TorchDistributedPipe( group_ranks, local_rank, "gloo", + self.kv_buffer_size ) - self.send_buffer = SimpleKVLookupBuffer(self.send_signal_pipe, - self.send_pipe, - self.lookup_buffer_size) - self.recv_buffer = SimpleKVLookupBuffer(self.recv_signal_pipe, - self.recv_pipe, - self.lookup_buffer_size) - self.tensor_device = DISTRIBUTED_DEVICE + else: # the current vLLM instance is KV consumer, so it needs to connect @@ -570,30 +577,35 @@ def __init__( self.recv_pipe = TorchDistributedPipe( group_ranks, local_rank, - DISTRIBUTED_BACKEND, + device2backend[config.kv_device], + self.kv_buffer_size, ) self.recv_signal_pipe = TorchDistributedPipe( group_ranks, local_rank, "gloo", + self.kv_buffer_size, ) self.send_pipe = TorchDistributedPipe( group_ranks, local_rank, - DISTRIBUTED_BACKEND, + device2backend[config.kv_device], + self.kv_buffer_size, ) self.send_signal_pipe = TorchDistributedPipe( group_ranks, local_rank, "gloo", + self.kv_buffer_size ) - self.send_buffer = SimpleKVLookupBuffer(self.send_signal_pipe, - self.send_pipe, - self.lookup_buffer_size) - self.recv_buffer = SimpleKVLookupBuffer(self.recv_signal_pipe, - self.recv_pipe, - self.lookup_buffer_size) - self.tensor_device = DISTRIBUTED_DEVICE + + self.send_buffer = TorchDistributedBuffer(self.send_signal_pipe, + self.send_pipe, + self.lookup_buffer_size) + self.recv_buffer = TorchDistributedBuffer(self.recv_signal_pipe, + self.recv_pipe, + self.lookup_buffer_size) + self.tensor_device = config.kv_device def select( diff --git a/vllm/distributed/kv_transfer/kv_pipe/__init__.py b/vllm/distributed/kv_transfer/kv_pipe/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/vllm/distributed/kv_transfer/kv_pipe/base.py b/vllm/distributed/kv_transfer/kv_pipe/base.py deleted file mode 100644 index 79e235b48fd72..0000000000000 --- a/vllm/distributed/kv_transfer/kv_pipe/base.py +++ /dev/null @@ -1,64 +0,0 @@ -""" -This file defines -`KVPipeBase` -that provides an abstraction for sending and receiving tensors, or None, via -distributed communications. - -All distributed communications for disagg prefill & KV cache storage should be -handled by `KVPipeBase`. -""" - -from abc import ABC, abstractmethod -from typing import Optional - -import torch - - -class KVPipeBase(ABC): - """ - This class provides an interface for sending and receiving tensors, or - None, by distributed communications. - """ - - @abstractmethod - def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: - """Send a tensor, or None, via the pipe. - - Need to support sending None -- important for error handling. - - TODO: add a `key` argument so that we can use traditional - key-value database as the distributed communication mechanism behind - the pipe. - - Args: - tensor (Optional[torch.Tensor]): The tensor to be sent. Can be None. - - Raises: - NotImplementedError: This method must be implemented in subclasses. - """ - raise NotImplementedError - - @abstractmethod - def recv_tensor(self) -> Optional[torch.Tensor]: - """Receive a tensor (can be None) from the pipeline. - - Returns: - Optional[torch.Tensor]: The tensor received from the pipeline. Can - be None. - - Raises: - NotImplementedError: This method must be implemented in subclasses. - """ - raise NotImplementedError - - @abstractmethod - def close(self) -> None: - """Close the pipeline and release resources. - - This method is responsible for closing the communication pipeline - and releasing any resources associated with it. - - Raises: - NotImplementedError: This method must be implemented in subclasses. - """ - raise NotImplementedError diff --git a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py deleted file mode 100644 index 3fe3fa289c662..0000000000000 --- a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py +++ /dev/null @@ -1,289 +0,0 @@ -import threading -import time -from concurrent.futures import ThreadPoolExecutor -from typing import List, Optional, Union - -import torch -from torch.distributed import Backend - -from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase -from vllm.logger import init_logger - -logger = init_logger(__name__) - -# if the tensor is only one-element and only contains NONE_INT -# this means that the sended object is None. -NONE_INT = -150886311 - -# Mapping tensor dtype to INT64, used for tensor metadata transmission -FLOAT16_INT = -543205003776624 -INT64_INT = -375623078607432 -BOOL_INT = -28035262008646 -BFLOAT16_INT = -452084912267662 -FLOAT32_INT = -1049557997456592 -FLOAT64_INT = -452201007054137 -FLOAT8_E4M3FN_INT = -1066697177659525 -FLOAT8_E5M2_INT = -618182574682355 - -DTYPE2INT = { - torch.float16: FLOAT16_INT, - torch.int64: INT64_INT, - torch.bool: BOOL_INT, - torch.bfloat16: BFLOAT16_INT, - torch.float32: FLOAT32_INT, - torch.float64: FLOAT64_INT, - torch.float8_e4m3fn: FLOAT8_E4M3FN_INT, - torch.float8_e5m2: FLOAT8_E5M2_INT, -} - -INT2DTYPE = { - FLOAT16_INT: torch.float16, - INT64_INT: torch.int64, - BOOL_INT: torch.bool, - BFLOAT16_INT: torch.bfloat16, - FLOAT32_INT: torch.float32, - FLOAT64_INT: torch.float64, - FLOAT8_E4M3FN_INT: torch.float8_e4m3fn, - FLOAT8_E5M2_INT: torch.float8_e5m2, -} - - -class BrokenPipeException(Exception): - - def __init__(self, message): - self.message = message - super().__init__(self.message) - - -class TorchDistributedPipe(KVPipeBase): - METADATA_LENGTH = 16 - MAX_TENSOR_DIMENSIONS = 14 - METADATA_DTYPE = torch.int64 - - def __init__( - self, - group_ranks: List[List[int]], - local_rank: int, - torch_distributed_backend: Union[str, Backend], - ): - self.rank = torch.distributed.get_rank() - self.local_rank = local_rank - self.device_group = None - - for ranks in group_ranks: - device_group = torch.distributed.new_group( - ranks, backend=torch_distributed_backend) - if self.rank in ranks: - self.ranks = ranks - self.world_size = len(ranks) - self.rank_in_group = ranks.index(self.rank) - self.device_group = device_group - - assert self.device_group is not None - assert self.rank_in_group <= 1 - - self.device = self._select_device(torch_distributed_backend) - - self.target_rank_for_send = self.ranks[(self.rank_in_group + 1) % - self.world_size] - self.target_rank_for_recv = self.ranks[(self.rank_in_group - 1) % - self.world_size] - - self.transport_thread: Optional[ThreadPoolExecutor] = None - self.buffer_size = 0 - self.buffer_size_lock = threading.Lock() - - self.none_tensor = torch.tensor([NONE_INT], device=self.device) - - # On-device tensors to be reused for recv - self.rcv_metadata_buffer = torch.zeros(self.METADATA_LENGTH, - dtype=self.METADATA_DTYPE, - device=self.device) - - def _select_device(self, backend: Union[str, Backend]): - if torch.cuda.is_available() and backend == Backend.NCCL: - return torch.device(f"cuda:{self.local_rank}") - else: - return "cpu" - - def _make_metadata(self, tensor: torch.Tensor) -> torch.Tensor: - """Create the metadata on based on the input tensor, and move it to GPU. - The metadata's length is `TorchDistributedPipe.METADATA_LENGTH`. - - Currently, the metadata is a int64 tensor and it includes dtype, number - of dimensions, and the shape information of the input tensor. - - - The information follows the layout below: - - metadata[0] -- dtype - - metadata[1] -- number of dimensions - - metadata[2 : 2+ndims] -- the shape of the input tensor - - Parameters: - - tensor: the input tensor - - Returns: - - metadata: the metadata tensor, on self.device - """ - buffer = torch.empty(self.METADATA_LENGTH, - dtype=self.METADATA_DTYPE, - device="cpu") - buffer[0] = DTYPE2INT[tensor.dtype] - ndims = len(tensor.shape) - buffer[1] = len(tensor.shape) - buffer[2:2 + ndims] = torch.tensor(tensor.shape, - dtype=self.METADATA_DTYPE) - return buffer.to(self.device) - - def _prepare_recv_buffer(self, - d_metadata_buffer: torch.Tensor) -> torch.Tensor: - """Create a buffer to receive the tensor based on the metadata. - - Parameters: - - d_metadata_buffer: the metadata tensor on self.device - - Returns: - - buffer: the buffer tensor to receive the tensor, on self.device - """ - h_buffer = d_metadata_buffer.cpu().numpy() - dtype = INT2DTYPE[h_buffer[0]] - ndims = h_buffer[1] - shape = tuple(h_buffer[2:2 + ndims]) - return torch.empty(shape, dtype=dtype, device=self.device) - - def _send_metadata(self, d_metadata_buffer: torch.Tensor): - """Send the metadata buffer to the target rank. - """ - torch.distributed.send( - d_metadata_buffer, - dst=self.target_rank_for_send, - group=self.device_group, - ) - - def _recv_metadata(self) -> torch.Tensor: - """Receive the metadata buffer from the target rank. - - Returns: - - metadata_buffer: the metadata buffer tensor, on self.device - - Note: - The current implementation uses the assumption that there is no - race conditions during sending/receiving. Therefore, the metadata - buffer can be reused - """ - torch.distributed.recv( - self.rcv_metadata_buffer, - src=self.target_rank_for_recv, - group=self.device_group, - ) - - return self.rcv_metadata_buffer - - def _send_impl(self, tensor): - """ - The actual implementation of sending the tensor to the target rank. - This function will first send the metadata, and then send the tensor. - - Parameters: - - tensor: the input tensor to be sent - """ - - metadata = self._make_metadata(tensor) - self._send_metadata(metadata) - torch.distributed.send(tensor.to(self.device), - dst=self.target_rank_for_send, - group=self.device_group) - - def _recv_impl(self) -> torch.Tensor: - """ - The actual implementation of receiving the tensor from the target rank. - This function will first receive the metadata, then receive the tensor. - - This function will block if there is no tensor to receive. - - Returns: - - buffer: the received tensor, on self.device - """ - d_metadata = self._recv_metadata() - buffer = self._prepare_recv_buffer(d_metadata) - - torch.distributed.recv(buffer, - src=self.target_rank_for_recv, - group=self.device_group) - - return buffer - - def send_tensor_wrapper(self, tensor): - try: - """Wrapper for send_tensor_dict""" - tensor_size = tensor.element_size() * tensor.numel() - self._send_impl(tensor) - - with self.buffer_size_lock: - self.buffer_size = self.buffer_size - tensor_size - except Exception as e: - logger.error("[rank%d]: Exception when trying to send %s, msg: %s", - torch.distributed.get_rank(), str(tensor), str(e)) - import traceback - traceback.print_exc() - - def block_if_full(self): - """Block the current thread if the buffer size is larger than 1e9.""" - # TODO: replace this 1e9 with a configurable parameter or a constant - while self.buffer_size > 1e9: - logger.debug("KV cache transfer pipe is full. Waiting...") - time.sleep(0.05) - - def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: - """Sends a tensor to the destination rank in a non-blocking way. - Flow: send tensor dim -- send tensor shape -- send tensor data - """ - - if self.transport_thread is None: - self.transport_thread = ThreadPoolExecutor(max_workers=1) - - if tensor is None: - tensor = self.none_tensor - - tensor_size = tensor.element_size() * tensor.numel() - - assert ( - 0 < len(tensor.shape) < self.MAX_TENSOR_DIMENSIONS - ), f"Only support dimensions within 1-{self.MAX_TENSOR_DIMENSIONS}" - - self.block_if_full() - - with self.buffer_size_lock: - self.buffer_size = self.buffer_size + tensor_size - - self.transport_thread.submit( - self.send_tensor_wrapper, - tensor, - ) - - def recv_tensor(self) -> Optional[torch.Tensor]: - """Receives a tensor from the src rank. Blocking.""" - if self.transport_thread is None: - self.transport_thread = ThreadPoolExecutor(max_workers=1) - - future = self.transport_thread.submit(self._recv_impl) - - try: - tensor = future.result() - except Exception as e: - # the underlying pipe is likely broken - logger.error("Encountering exception in KV receiving thread") - logger.error("%s", e) - # fault tolerance: if the pipe is broken, return None - return None - - if tensor.numel() == 1 and tensor.item() == NONE_INT: - return None - else: - return tensor - - def close(self): - """Close the pipe and release the resources.""" - if (hasattr(self, "transport_thread") - and self.transport_thread is not None): - self.transport_thread.shutdown() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b556c0eed3776..3e82e2092e88d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -13,7 +13,7 @@ ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig, SpeculativeConfig, TaskOption, TokenizerPoolConfig, - VllmConfig) + KVTransferConfig, VllmConfig) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS @@ -913,6 +913,40 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "such as the token IDs of good_token and bad_token in " "the math-shepherd-mistral-7b-prm model.") + parser.add_argument( + '--kv-connector', + type=str, + default=None, + choices=["TorchDistributedConnector", "LMCacheConnector"], + help="The KV connector for vLLM to transmit KV caches between vLLM" + " instances.") + + parser.add_argument( + '--kv-buffer-size', + type=float, + default=None, + help="The buffer size for TorchDistributedConnector. Measured in " + "number of bytes. Recommended value: 1e9 (about 1GB)." + ) + + parser.add_argument( + '--kv-transfer-role', + type=str, + default=None, + choices=["kv_producer", "kv_consumer", "both"], + help="Whether this vLLM instance produces KV caches, consume KV " + "caches, or both." + ) + + parser.add_argument( + '--kv-device', + type=str, + default=None, + choices=["CPU", "GPU"], + help="The device used by kv connector to buffer the KV cache. Can " + "be CPU or GPU. Recommended value: CPU." + ) + return parser @classmethod @@ -1180,6 +1214,13 @@ def create_engine_config(self) -> VllmConfig: or "all" in detailed_trace_modules, ) + kv_transfer_config = KVTransferConfig( + kv_connector=self.kv_connector, + kv_buffer_size=self.kv_buffer_size, + kv_transfer_role=self.kv_transfer_role, + kv_device=self.kv_device, + ) + return VllmConfig( model_config=model_config, cache_config=cache_config, @@ -1192,6 +1233,7 @@ def create_engine_config(self) -> VllmConfig: decoding_config=decoding_config, observability_config=observability_config, prompt_adapter_config=prompt_adapter_config, + kv_transfer_config=kv_transfer_config, ) From 303ff859b0628de90c70788982e34a0afda893d3 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 10 Nov 2024 02:17:59 +0000 Subject: [PATCH 279/303] misc fixes --- vllm/config.py | 119 ++++---- .../kv_transfer/kv_connector/__init__.py | 18 ++ .../torch_distributed_connector.py | 15 +- vllm/distributed/kv_transfer/vllm_adapter.py | 270 +++--------------- vllm/distributed/parallel_state.py | 27 -- vllm/engine/arg_utils.py | 45 ++- 6 files changed, 159 insertions(+), 335 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index e369707ec9660..d63baa0cc66a4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -907,6 +907,7 @@ class ParallelConfig: """Configuration for the distributed execution. Args: + kv_disagg_parallel_size: Number of kv disagg groups. pipeline_parallel_size: Number of pipeline parallel groups. tensor_parallel_size: Number of tensor parallel groups. worker_use_ray: Deprecated, use distributed_executor_backend instead. @@ -924,10 +925,18 @@ class ParallelConfig: workers, either "ray" or "mp" (multiprocessing). If either pipeline_parallel_size or tensor_parallel_size is greater than 1, will default to "ray" if Ray is installed or "mp" otherwise. + kv_connector: The connector to use for kv cache transfer, value can be + None, "TorchDistributedConnector" or "LMCacheConnector". + kv_buffer_device: The buffer device to use for kv cache transfer. + kv_buffer_size: The buffer size to use for kv cache transfer. + kv_disagg_role: The role of the kv disagg worker, can be "kv_producer", + "kv_consumer", "kv_both" or None. + kv_disagg_rank: The rank of the kv disagg worker. """ def __init__( self, + kv_disagg_parallel_size: int, pipeline_parallel_size: int, tensor_parallel_size: int, worker_use_ray: Optional[bool] = None, @@ -938,7 +947,13 @@ def __init__( placement_group: Optional["PlacementGroup"] = None, distributed_executor_backend: Optional[Union[ str, Type["ExecutorBase"]]] = None, + kv_connector: Optional[str] = None, + kv_buffer_device: Optional[str] = None, + kv_buffer_size: Optional[float] = None, + kv_disagg_role: Optional[str] = None, + kv_disagg_rank: int = 0, ) -> None: + self.kv_disagg_parallel_size = kv_disagg_parallel_size self.pipeline_parallel_size = pipeline_parallel_size self.tensor_parallel_size = tensor_parallel_size self.distributed_executor_backend = distributed_executor_backend @@ -947,7 +962,13 @@ def __init__( self.tokenizer_pool_config = tokenizer_pool_config self.ray_workers_use_nsight = ray_workers_use_nsight self.placement_group = placement_group - self.world_size = pipeline_parallel_size * self.tensor_parallel_size + self.world_size = kv_disagg_parallel_size * pipeline_parallel_size *\ + tensor_parallel_size + self.kv_connector = kv_connector + self.kv_buffer_device = kv_buffer_device + self.kv_buffer_size = kv_buffer_size + self.kv_disagg_role = kv_disagg_role + self.kv_disagg_rank = kv_disagg_rank if worker_use_ray: if self.distributed_executor_backend is None: @@ -1008,6 +1029,19 @@ def use_ray(self) -> bool: isinstance(self.distributed_executor_backend, type) and self.distributed_executor_backend.uses_ray) + @property + def is_distributed_kv_instance(self) -> bool: + return self.kv_transfer_role in ["kv_producer", "kv_consumer", "kv_both"] + + @property + def is_kv_producer(self) -> bool: + return self.kv_transfer_role in ["kv_producer", "kv_both"] + + @property + def is_kv_consumer(self) -> bool: + return self.kv_transfer_role in ["kv_consumer", "kv_both"] + + def _verify_args(self) -> None: # Lazy import to avoid circular import from vllm.executor.executor_base import ExecutorBase @@ -1032,6 +1066,38 @@ def _verify_args(self) -> None: raise ValueError("Unable to use nsight profiling unless workers " "run with Ray.") + # A series of checks for P/D disaggregation (and future disaggregation) + if self.kv_connector is None and not all([ + self.kv_disagg_parallel_size == 1, + self.kv_disagg_rank == 0, + self.kv_buffer_size is None, + self.kv_disagg_role is None, + self.kv_buffer_device is None, + ]): + raise ValueError("Please specify kv_connector before configuring " + "variables with prefix `kv_`") + + if self.kv_connector not in [None, + "TorchDistributedConnector", + "LMCacheConnector"]: + raise ValueError(f"Unsupported kv_connector: {self.kv_connector}. " + f"Supported connectors are " + f"`TorchDistributedConnector` and " + f"`LMCacheConnector`") + + if self.kv_disagg_role not in [None, + "kv_producer", + "kv_consumer", + "kv_both"]: + raise ValueError(f"Unsupported kv_disagg_role: {self.kv_disagg_role}. " + f"Supported roles are `kv_producer`, `kv_consumer`, " + f"and `kv_both`") + + if self.kv_connector is not None and self.kv_disagg_role is None: + raise ValueError("Please specify kv_disagg_role when kv_connector " + "is set, supported roles are `kv_producer`, " + "`kv_consumer`, and `kv_both`") + class SchedulerConfig: """Scheduler configuration. @@ -1969,56 +2035,6 @@ def __post_init__(self): "'otlp_traces_endpoint'. Ensure OpenTelemetry packages are " f"installed. Original error:\n{otel_import_error_traceback}") - -@dataclass -class KVTransferConfig: - """Configuration for KV cache transfer between vLLM instances - - To distinguish between other configs, all the configs here are prefixed - with "kv" - """ - - # the connector to use for kv cache transfer - # vLLM natively supports TorchDistributedConnector and also include - # third-party connector, but does not provide official support, please - # reach out to the code owner for those connectors for support. - kv_connector: Optional[str] = None - - # the buffer size parameter, used to configure the buffer size of receiving - # KV cache - kv_buffer_size: Optional[float] = None - kv_transfer_role: Optional[str] = None - kv_device: Optional[str] = None - - def __post_init__(self): - if self.kv_connector is None and not all([ - self.kv_buffer_size is None, - self.kv_transfer_role is None, - self.kv_device is None, - ]): - raise ValueError("Please specify kv_connector before configuring " - "variables with prefix `kv_`") - - assert self.kv_connector in [ - None, - "TorchDistributedConnector", - "LMCacheConnector", - ], f"Existing kv connectors are `TorchDistributedConnector` and "\ - f"`LMCacheConnector`. Got {self.kv_connector}" - - @property - def is_distributed_kv_instance(self) -> bool: - return self.kv_transfer_role in ["kv_producer", "kv_consumer", "kv_both"] - - @property - def is_kv_producer(self) -> bool: - return self.kv_transfer_role in ["kv_producer", "kv_both"] - - @property - def is_kv_consumer(self) -> bool: - return self.kv_transfer_role in ["kv_consumer", "kv_both"] - - @dataclass class VllmConfig: @@ -2038,7 +2054,6 @@ class VllmConfig: observability_config: Optional[ObservabilityConfig] = None prompt_adapter_config: Optional[PromptAdapterConfig] = None quant_config: Optional[QuantizationConfig] = None - kv_transfer_config: Optional[KVTransferConfig] = None @staticmethod def _get_quantization_config( diff --git a/vllm/distributed/kv_transfer/kv_connector/__init__.py b/vllm/distributed/kv_transfer/kv_connector/__init__.py index e69de29bb2d1d..7d0f202f7e1d8 100644 --- a/vllm/distributed/kv_transfer/kv_connector/__init__.py +++ b/vllm/distributed/kv_transfer/kv_connector/__init__.py @@ -0,0 +1,18 @@ + +from .base import KVConnectorBase +from vllm.config import ParallelConfig + +class KVConnectorFactory: + + @staticmethod + def create_connector( + config: ParallelConfig + ) -> KVConnectorBase: + if config.kv_connector == 'LMCacheConnector': + from .lmcache_connector import LMCacheConnector + return LMCacheConnector(config) + elif config.kv_connector == 'TorchDistributedConnector': + from .torch_distributed_connector import TorchDistributedConnector + return TorchDistributedConnector(config) + else: + raise ValueError(f"Unsupported connector type: {connector_type}") \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/kv_connector/torch_distributed_connector.py b/vllm/distributed/kv_transfer/kv_connector/torch_distributed_connector.py index 0977793b014ce..fe6787cfe88f4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/torch_distributed_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/torch_distributed_connector.py @@ -1,15 +1,18 @@ """ - This file implements a simple torch distributed connector by 2 classes: - - `TorchDistributedPipe`: a tensor transmission pipe between P/D instance, - using `torch.distributed` + This file implements a simple torch distributed connector by 3 classes: + - `TorchDistributedPipe`: a tensor transmission pipe between vllm instances, + using `torch.distributed` + - `TorchDistributedBuffer`: a buffer to store tensors, implemented on top + of `TorchDistributedPipe` - `TorchDistributedConnector`: a torch distributed connector between P/D - instance, implemented on top of `TorchDistributedPipe` + instance, implemented on top of `TorchDistributedBuffer` """ import threading import time from collections import deque from concurrent.futures import ThreadPoolExecutor from typing import Deque, List, Optional, Union +from copy import deepcopy import torch from torch.distributed import Backend @@ -24,8 +27,6 @@ -# magic constants to transmit tensors - # if the tensor is only one-element and only contains NONE_INT # this means that the sended object is None. NONE_INT = -150886311 @@ -611,11 +612,13 @@ def __init__( def select( self, input_tokens: Optional[torch.Tensor], roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: + return self.send_buffer.drop_select(input, roi) def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, key: torch.Tensor, value: torch.Tensor, hidden: torch.Tensor) -> None: + return self.recv_buffer.insert( input_tokens, roi, diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py index 7516e7c5ff307..e59443e51b3ea 100644 --- a/vllm/distributed/kv_transfer/vllm_adapter.py +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -28,34 +28,28 @@ import torch from torch.distributed import Backend -import vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer as sklb import vllm.envs as envs from vllm import _custom_ops as ops -from vllm.distributed.kv_transfer.kv_lookup_buffer.base import ( - KVLookupBufferBase) -from vllm.distributed.kv_transfer.kv_pipe.torch_distributed_pipe import ( - TorchDistributedPipe) +from vllm.distributed.kv_transfer.kv_connector import KVConnectorFactory from vllm.logger import init_logger from vllm.sequence import IntermediateTensors +from vllm.config import ParallelConfig logger = init_logger(__name__) -# check VLLM_DISTRIBUTERD_KV_ROLE and set corresponding flags -assert envs.VLLM_DISTRIBUTED_KV_ROLE in [None, "producer", "consumer", "both"],\ - "VLLM_DISTRIBUTERD_KV_ROLE can only be producer, consumer or both." -IS_DISTRIBUTED_KV_INSTANCE: bool = (envs.VLLM_DISTRIBUTED_KV_ROLE - in ["producer", "consumer", "both"]) -IS_KV_PRODUCER: bool = (envs.VLLM_DISTRIBUTED_KV_ROLE in ["producer", "both"]) -IS_KV_CONSUMER: bool = (envs.VLLM_DISTRIBUTED_KV_ROLE in ["consumer", "both"]) -# When the current instance is both KV producer and KV consumer, -# it is likely connected to a KV storage service on CPU/disk -# so the communication backend needs to be "gloo" for that case. -DISTRIBUTED_BACKEND: str = "gloo" if (IS_KV_PRODUCER - and IS_KV_CONSUMER) else "nccl" -# corresponding device -DISTRIBUTED_DEVICE: str = "cpu" if (IS_KV_PRODUCER - and IS_KV_CONSUMER) else "cuda" +# several flags used for indicating the role of current vLLM worker +IS_DISTRIBUTED_KV_INSTANCE: Optional[bool] = None +IS_KV_PRODUCER: Optional[bool] = None +IS_KV_CONSUMER: Optional[bool] = None + + +def set_kv_transfer_attribute(config: ParallelConfig): + global IS_DISTRIBUTED_KV_INSTANCE, IS_KV_PRODUCER, IS_KV_CONSUMER + + IS_DISTRIBUTED_KV_INSTANCE = config.is_distributed_kv_instance + IS_KV_PRODUCER = config.is_kv_producer + IS_KV_CONSUMER = config.is_kv_consumer class KV_transfer_agent: @@ -71,83 +65,16 @@ def __init__( self, group_ranks: List[List[int]], local_rank: int, - torch_distributed_backend: Union[str, Backend] = DISTRIBUTED_BACKEND, - # FIXME(Kuntai): remove this hardcoding - lookup_buffer_size: int = int(1e10)): - - self.lookup_buffer_size = lookup_buffer_size + config: ParallelConfig, + ): - self.send_buffer: Optional[KVLookupBufferBase] = None - self.recv_buffer: Optional[KVLookupBufferBase] = None + assert self.config.is_distributed_kv_instance, "KV cache transfer "\ + "agent should only be used when kv_connector is set." - SimpleKVLookupBuffer = sklb.SimpleKVLookupBuffer - - # In disaggregated prefill, the prefill vLLM only uses send pipe - # and the decode vLLM only uses recv pipe - # In remote KV cache store, vLLM will use both send pipe and recv pipe - # So we build both send pipe and recv pipe for simplicity. - if IS_KV_PRODUCER: - - self.send_pipe = TorchDistributedPipe( - group_ranks, - local_rank, - DISTRIBUTED_BACKEND, - ) - self.send_signal_pipe = TorchDistributedPipe( - group_ranks, - local_rank, - "gloo", - ) - self.recv_pipe = TorchDistributedPipe( - group_ranks, - local_rank, - DISTRIBUTED_BACKEND, - ) - self.recv_signal_pipe = TorchDistributedPipe( - group_ranks, - local_rank, - "gloo", - ) - self.send_buffer = SimpleKVLookupBuffer(self.send_signal_pipe, - self.send_pipe, - self.lookup_buffer_size) - self.recv_buffer = SimpleKVLookupBuffer(self.recv_signal_pipe, - self.recv_pipe, - self.lookup_buffer_size) - self.tensor_device = DISTRIBUTED_DEVICE - else: - - # the current vLLM instance is KV consumer, so it needs to connect - # its recv pipe to the send pipe of KV producder - - self.recv_pipe = TorchDistributedPipe( - group_ranks, - local_rank, - DISTRIBUTED_BACKEND, - ) - self.recv_signal_pipe = TorchDistributedPipe( - group_ranks, - local_rank, - "gloo", - ) - self.send_pipe = TorchDistributedPipe( - group_ranks, - local_rank, - DISTRIBUTED_BACKEND, - ) - self.send_signal_pipe = TorchDistributedPipe( - group_ranks, - local_rank, - "gloo", - ) - self.send_buffer = SimpleKVLookupBuffer(self.send_signal_pipe, - self.send_pipe, - self.lookup_buffer_size) - self.recv_buffer = SimpleKVLookupBuffer(self.recv_signal_pipe, - self.recv_pipe, - self.lookup_buffer_size) - self.tensor_device = DISTRIBUTED_DEVICE + self.connector = KVConnectorFactory.create_connector(config) + + def send_kv_caches_and_hidden_states( self, model_executable: torch.nn.Module, @@ -188,19 +115,16 @@ def send_kv_caches_and_hidden_states( keys = torch.cat(keys, dim=0) values = torch.cat(values, dim=0) - if self.send_buffer is not None: - self.send_buffer.insert( - current_tokens, torch.ones_like(current_tokens, - dtype=bool), keys, values, - hidden_or_intermediate_states[start_pos:end_pos]) + + self.connector.insert( + current_tokens, torch.ones_like(current_tokens, + dtype=bool), keys, values, + hidden_or_intermediate_states[start_pos:end_pos]) logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) - def destroy(self) -> None: - if self.send_buffer is not None: - self.send_buffer.close() - if self.recv_buffer is not None: - self.recv_buffer.close() + def close(self) -> None: + self.connector.close() def recv_kv_caches_and_hidden_states( self, model_executable: torch.nn.Module, @@ -238,11 +162,7 @@ def recv_kv_caches_and_hidden_states( input_tokens_list.append(current_tokens) start_pos_list.append(start_pos) - if self.recv_buffer is None: - bypass_model_exec = False - break - - ret = self.recv_buffer.drop_select( + ret = self.connector.select( current_tokens, torch.ones_like(current_tokens, dtype=bool)) if ret[0] is None: # didn't find any match. @@ -296,7 +216,10 @@ def recv_kv_caches_and_hidden_states( logger.debug( "[rank%d]: Failed to receive all KVs and hidden " "states, redo model forwarding.", torch.distributed.get_rank()) - rebuilt_model_input = build_partial_prefill_input( + + # allow the connector to mutate the model input + # useful for injecting memory movement / computation requests + rebuilt_model_input = self.connector.build_partial_prefill_input( model_input, input_tokens_list, num_computed_tokens_list, @@ -315,130 +238,3 @@ def recv_kv_caches_and_hidden_states( hidden_or_intermediate_states_for_one_req, dim=0) return hidden_or_intermediate_states, bypass_model_exec, model_input - - -def build_partial_prefill_input( - model_input: "ModelInputForGPUWithSamplingMetadata", - input_tokens_list: List[torch.Tensor], - num_computed_tokens_list: List[int], - start_pos_list: List[int], - slot_mapping_flat: torch.Tensor, - device: torch.device, -) -> "ModelInputForGPUWithSamplingMetadata": - """ - Helper function to rebuild the model input for the current request. - Goal: avoid running redundant prefill on those tokens that already has KV - caches received. - """ - rebuilt_input_tokens = [] - rebuilt_input_positions = [] - rebuilt_query_lens = [] - - rebuilt_num_prefills = 0 - rebuilt_num_prefill_tokens = 0 - rebuilt_slot_mapping = [] - rebuilt_max_query_len = 0 - - rebuilt_block_tables = [] - - rebuilt_query_start_loc = [0] - rebuilt_context_lens_tensor = [] - rebuilt_selected_token_indices = [] - - # recounting query and context lengths - for idx in range(len(input_tokens_list)): - token_tensor = input_tokens_list[idx] - num_token = len(token_tensor) - num_computed_token = num_computed_tokens_list[idx] - # currently attention kernel cannot handle the case where there is 0 - # query token. - if num_computed_token == num_token: - num_computed_token -= 1 - start_pos = start_pos_list[idx] - - rebuilt_input_tokens.append(token_tensor[num_computed_token:]) - # TODO(Jiayi): please check the correctness of next line - rebuilt_input_positions.append( - model_input.input_positions[start_pos + - num_computed_token:start_pos + - num_token]) - q_len = num_token - num_computed_token - rebuilt_query_lens.append(q_len) - - # Attn metadata-related - rebuilt_num_prefills += 1 - rebuilt_num_prefill_tokens += q_len - new_slot_mapping = slot_mapping_flat[start_pos + - num_computed_token:start_pos + - num_token] - rebuilt_slot_mapping.append(new_slot_mapping) - rebuilt_max_query_len = max(q_len, rebuilt_max_query_len) - # TODO(Jiayi): remove hard-code (block_size=16) - blk_size = 16 - temp_block_table = [ - slot_mapping_flat[i] // blk_size - for i in range(start_pos, start_pos + num_token, blk_size) - ] - rebuilt_block_tables.append(temp_block_table) - rebuilt_query_start_loc.append( - rebuilt_num_prefill_tokens) #start with 0 - rebuilt_context_lens_tensor.append(num_computed_token) - - # Sampling metadata related - #seq_groups (use rebuilt query lens) - rebuilt_selected_token_indices.append(rebuilt_num_prefill_tokens - 1) - - # rebuilt attn_metadata - rebuilt_attn_metadata = deepcopy(model_input.attn_metadata) - rebuilt_attn_metadata.num_prefills = rebuilt_num_prefills - rebuilt_attn_metadata.num_prefill_tokens = rebuilt_num_prefill_tokens - rebuilt_attn_metadata.slot_mapping = torch.cat(rebuilt_slot_mapping).to( - device) - rebuilt_attn_metadata.max_query_len = rebuilt_max_query_len - - rebuilt_attn_metadata.block_tables = torch.tensor( - rebuilt_block_tables, - dtype=model_input.attn_metadata.block_tables.dtype).to(device) - - rebuilt_attn_metadata.query_start_loc = torch.tensor( - rebuilt_query_start_loc, - dtype=model_input.attn_metadata.query_start_loc.dtype).to(device) - rebuilt_attn_metadata.context_lens_tensor = torch.tensor( - rebuilt_context_lens_tensor, - dtype=model_input.attn_metadata.context_lens_tensor.dtype, - ).to(device) - - rebuilt_attn_metadata._cached_prefill_metadata = None - - # rebuilt sampling_metadata - rebuilt_sampling_metadata = deepcopy(model_input.sampling_metadata) - for idx, q_len in enumerate(rebuilt_query_lens): - if rebuilt_sampling_metadata.seq_groups is not None: - rebuilt_sampling_metadata.seq_groups[idx].query_len = q_len - - rebuilt_sampling_metadata.selected_token_indices = torch.tensor( - rebuilt_selected_token_indices, - dtype=model_input.sampling_metadata.selected_token_indices.dtype, - ).to(device) - - # import here to avoid circular import. - from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata - rebuilt_model_input = ModelInputForGPUWithSamplingMetadata( - input_tokens=torch.cat(rebuilt_input_tokens).to(device), - input_positions=torch.cat(rebuilt_input_positions).to(device), - seq_lens=model_input.seq_lens, - query_lens=rebuilt_query_lens, - lora_mapping=model_input.lora_mapping, - lora_requests=model_input.lora_requests, - attn_metadata=rebuilt_attn_metadata, - prompt_adapter_mapping=model_input.prompt_adapter_mapping, - prompt_adapter_requests=model_input.prompt_adapter_requests, - multi_modal_kwargs=model_input.multi_modal_kwargs, - request_ids_to_seq_ids=model_input.request_ids_to_seq_ids, - finished_requests_ids=model_input.finished_requests_ids, - virtual_engine=model_input.virtual_engine, - sampling_metadata=rebuilt_sampling_metadata, - is_prompt=model_input.is_prompt, - ) - - return rebuilt_model_input diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 68d44f10b211b..1a82a9e87cb10 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -990,33 +990,6 @@ def set_custom_all_reduce(enable: bool): _ENABLE_CUSTOM_ALL_REDUCE = enable -def include_decoding_groups_if_disagg_enabled( - groups: List[List[int]], - world_size: int, -) -> List[List[int]]: - """ - Include the distributed group for decode - Only for disaggregated prefill - - Example: - Original group: [ [0,1], [2,3] ], world_size = 4 - Extended: [ [0,1], [2,3], [4,5], [6,7] ] - Arguments: - groups: original distributed group - world_size: the vLLM world size, which is half of - torch.distributed.get_world_size() - """ - - if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: - new_groups = [] - for group in groups: - new_groups.append([rank for rank in group]) - for group in groups: - new_groups.append([rank + world_size for rank in group]) - return new_groups - else: - return groups - def init_distributed_environment( world_size: int = -1, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 3e82e2092e88d..6d113d0616b27 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -13,7 +13,7 @@ ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig, SpeculativeConfig, TaskOption, TokenizerPoolConfig, - KVTransferConfig, VllmConfig) + VllmConfig) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS @@ -22,6 +22,7 @@ maybe_register_config_serialize_by_value) from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import FlexibleArgumentParser, StoreBoolean +import vllm.distributed.kv_transfer.vllm_adapter as dist_kv if TYPE_CHECKING: from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup @@ -109,6 +110,8 @@ class EngineArgs: # notice. distributed_executor_backend: Optional[Union[str, Type[ExecutorBase]]] = None + # number of P/D disaggregation (or other disaggregation) workers + kv_disagg_parapllel_size: int = 1 pipeline_parallel_size: int = 1 tensor_parallel_size: int = 1 max_parallel_loading_workers: Optional[int] = None @@ -196,6 +199,13 @@ class EngineArgs: pooling_step_tag_id: Optional[int] = None pooling_returned_token_ids: Optional[List[int]] = None + # P/D disaggregation coonfiguration + kv_connector: Optional[str] = None + kv_buffer_size: Optional[int] = None + kv_buffer_device: Optional[str] = None + kv_disagg_role: Optional[str] = None + kv_disagg_device: Optional[str] = None + def __post_init__(self): if not self.tokenizer: self.tokenizer = self.model @@ -913,6 +923,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "such as the token IDs of good_token and bad_token in " "the math-shepherd-mistral-7b-prm model.") + parser.add_argument( + '--kv-disagg-parallel-size', + '-kdp', + type=int, + default=1 + ) + parser.add_argument( '--kv-connector', type=str, @@ -930,16 +947,16 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: ) parser.add_argument( - '--kv-transfer-role', + '--kv-disagg-role', type=str, default=None, choices=["kv_producer", "kv_consumer", "both"], - help="Whether this vLLM instance produces KV caches, consume KV " - "caches, or both." + help="Whether this vLLM instance produces, consumes KV cache, or " + "both. Choices are 'kv_producer', 'kv_consumer', and 'both'." ) parser.add_argument( - '--kv-device', + '--kv-buffer-device', type=str, default=None, choices=["CPU", "GPU"], @@ -1054,6 +1071,7 @@ def create_engine_config(self) -> VllmConfig: cpu_offload_gb=self.cpu_offload_gb, ) parallel_config = ParallelConfig( + kv_disagg_parallel_size=self.kv_disagg_parallel_size, pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, worker_use_ray=self.worker_use_ray, @@ -1065,7 +1083,15 @@ def create_engine_config(self) -> VllmConfig: self.tokenizer_pool_extra_config, ), ray_workers_use_nsight=self.ray_workers_use_nsight, - distributed_executor_backend=self.distributed_executor_backend) + distributed_executor_backend=self.distributed_executor_backend, + kv_connector=self.kv_connector, + kv_buffer_size=self.kv_buffer_size, + kv_buffer_device=self.kv_buffer_device, + kv_disagg_role=self.kv_transfer_role, + kv_disagg_rank=self.kv_disagg_rank, + ) + # set the kv cache transfer condition check variables + dist_kv.set_kv_transfer_attribute(parallel_config) max_model_len = model_config.max_model_len use_long_context = max_model_len > 32768 @@ -1214,12 +1240,6 @@ def create_engine_config(self) -> VllmConfig: or "all" in detailed_trace_modules, ) - kv_transfer_config = KVTransferConfig( - kv_connector=self.kv_connector, - kv_buffer_size=self.kv_buffer_size, - kv_transfer_role=self.kv_transfer_role, - kv_device=self.kv_device, - ) return VllmConfig( model_config=model_config, @@ -1233,7 +1253,6 @@ def create_engine_config(self) -> VllmConfig: decoding_config=decoding_config, observability_config=observability_config, prompt_adapter_config=prompt_adapter_config, - kv_transfer_config=kv_transfer_config, ) From 0f172d50f0111e4cb164bd00e0cf3d02fdad7439 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Wed, 13 Nov 2024 16:52:52 +0000 Subject: [PATCH 280/303] stage changes --- .../pynccl_connector/lookup_buffer.py | 241 ++++++ .../pynccl_connector/pynccl_connector.py | 270 +++++++ .../pynccl_connector/pynccl_pipe.py | 307 +++++++ .../torch_distributed_connector.py | 758 ------------------ vllm/distributed/parallel_state.py | 24 +- vllm/executor/multiproc_gpu_executor.py | 2 +- 6 files changed, 828 insertions(+), 774 deletions(-) create mode 100644 vllm/distributed/kv_transfer/kv_connector/pynccl_connector/lookup_buffer.py create mode 100644 vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_connector.py create mode 100644 vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py delete mode 100644 vllm/distributed/kv_transfer/kv_connector/torch_distributed_connector.py diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/lookup_buffer.py b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/lookup_buffer.py new file mode 100644 index 0000000000000..965d6552d722c --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/lookup_buffer.py @@ -0,0 +1,241 @@ +""" + This file implements a simple torch distributed connector by 3 classes: + - `TorchDistributedPipe`: a tensor transmission pipe between vllm instances, + using `torch.distributed` + - `TorchDistributedBuffer`: a buffer to store tensors, implemented on top + of `TorchDistributedPipe` + - `TorchDistributedConnector`: a torch distributed connector between P/D + instance, implemented on top of `TorchDistributedBuffer` +""" +import threading +import time +from collections import deque +from concurrent.futures import ThreadPoolExecutor +from typing import Deque, List, Optional, Union +from copy import deepcopy + +import torch +from torch.distributed import Backend + +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.kv_transfer.kv_connector.pynccl_connector.pynccl_pipe \ + import PyncclPipe +from vllm.logger import init_logger +from vllm.config import KVTransferConfig + + + +logger = init_logger(__name__) + + +class LookupBuffer: + + def __init__(self, + signal_pipe: PyncclPipe, + data_pipe: PyncclPipe, + buffer_size_thresh: float): + """ + signal_pipe: on CPU + + NOTE: on-device recv will block all threads in the process, making the + KV cache producer unable to listen to new request while transmitting + KV cache. Luckily CPU recv only blocks the current thread so we use + CPU recv to listen to new request. + + data_pipe: on device (e.g. GPU) + """ + + self.buffer: Deque[List[torch.Tensor]] = deque() + + self.buffer_size = 0 + self.buffer_size_threshold = buffer_size_thresh + self.buffer_lock = threading.Lock() + self.signal_pipe = signal_pipe + self.data_pipe = data_pipe + self.request_handling_thread: Optional[threading.Thread] = None + + self.normal_signal = torch.tensor([0]) + self.end_signal = None + + def _matches(self, tokens_roi_sender: List[torch.Tensor], + tokens_roi_recver: List[torch.Tensor]): + + # tokens_roi_sender: tokens and roi of the producer (in the buffer) + # tokens_roi_recver: tokens and roi of the consumer (query) + + tokens_sender = tokens_roi_sender[0] + tokens_recver = tokens_roi_recver[0] + roi_sender = tokens_roi_sender[1] + roi_recver = tokens_roi_recver[1] + + if tokens_recver is None: + # consumer sends an empty request + # semantics: DROP SELECT * LIMIT 1 + # so any of the data in the buffer can be drop-selected + return True + + # Assuming that roi is a binary mask on tokens + tokens_sender = tokens_sender[roi_sender] + tokens_recver = tokens_recver[roi_recver] + + # simple common prefix matching + min_length = min(len(tokens_sender), len(tokens_recver)) + if torch.allclose(tokens_sender[:min_length], + tokens_recver[:min_length]): + return min_length + + return 0 + + def _send_tensor_and_dec_size(self, + tensor: Optional[torch.Tensor]) -> None: + + assert tensor is not None, "Use self.data_pipe.send(None) instead" + self.buffer_size -= tensor.element_size() * tensor.numel() + self.data_pipe.send_tensor(tensor) + + def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]): + + if isinstance(data, torch.Tensor): + return data.element_size() * data.numel() + if not data: + # cannot perform `not data` on a tensor + # so this check needs to go after the check above + return 0 + + raise AssertionError("Unknown data type %s" % type(data)) + + def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor): + + if isinstance(input_tokens, torch.Tensor): + input_tokens = input_tokens.clone() + if isinstance(roi, torch.Tensor): + roi = roi.clone() + if isinstance(key, torch.Tensor): + key = key.clone() + if isinstance(value, torch.Tensor): + value = value.clone() + if isinstance(hidden, torch.Tensor): + hidden = hidden.clone() + + buffer_item = [input_tokens, roi, key, value, hidden] + + with self.buffer_lock: + for data in buffer_item: + self.buffer_size += self._get_element_size(data) + self.buffer.append(buffer_item) + + def _is_end_signal(self, signal): + return signal is None + + def drop_select_handler(self): + + try: + + while True: + signal = self.signal_pipe.recv_tensor() + if self._is_end_signal(signal): + logger.info("Received end signal!") + break + + input_tokens = self.data_pipe.recv_tensor() + + roi = self.data_pipe.recv_tensor() + tokens_roi_recver = [input_tokens, roi] + + matched_length = 0 + + # perform input tokens and roi matching + # FIXME: this matching is O(n), ideally it should be O(1) + # but this buffer size won't (and shouldn't) be too large so + # the fix is not urgent. + with self.buffer_lock: + + for _ in range(len(self.buffer)): + + temp_length = self._matches(self.buffer[0], + tokens_roi_recver) + if temp_length > 0: + matched_length = temp_length + break + # rotate the element we just accessed to the end + self.buffer.rotate(-1) + + if matched_length > 0: + # need to clone the tensor + # in case the tensor is freed before sending finishes + matched_item = self.buffer.popleft() + for tensor in matched_item: + self._send_tensor_and_dec_size(tensor) + + else: + # no match, just send None + for _ in range(5): + self.data_pipe.send_tensor(None) + + except RuntimeError as e: + if 'Connection closed by peer' not in str(e): + raise e + + logger.debug("Closing drop_select_handler") + + def drop_select( + self, input_tokens: Optional[torch.Tensor], + roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: + + assert self.request_handling_thread is None, \ + "drop_select should be called by the KV cache consumer "\ + "(e.g. the decode vLLM instance)" + + if isinstance(input_tokens, torch.Tensor): + input_tokens = input_tokens.clone() + if isinstance(roi, torch.Tensor): + roi = roi.clone() + + self.signal_pipe.send_tensor(self.normal_signal) + self.data_pipe.send_tensor(input_tokens) + self.data_pipe.send_tensor(roi) + + input_tokens = self.data_pipe.recv_tensor() + roi = self.data_pipe.recv_tensor() + key = self.data_pipe.recv_tensor() + value = self.data_pipe.recv_tensor() + hidden = self.data_pipe.recv_tensor() + + return [input_tokens, roi, key, value, hidden] + + def full_handler(self): + time.sleep(0.001) + + def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor) -> None: + + if self.buffer_size > self.buffer_size_threshold: + # log outside the while loop to avoid this message being logged + # repeatedly. + logger.debug("KV transfer buffer is full. Handling...") + while self.buffer_size > self.buffer_size_threshold: + self.full_handler() + + self._add_to_buffer(input_tokens, roi, key, value, hidden) + + # when calling the insert, the current process is a sender + # need to launch the request handler and start listening to request. + if self.request_handling_thread is None: + self.request_handling_thread = threading.Thread( + target=self.drop_select_handler) + self.request_handling_thread.start() + + def close(self): + + if hasattr(self, "request_handling_thread" + ) and self.request_handling_thread is not None: + self.request_handling_thread.join() + + else: + # TODO: have a explicit close signal and have a explicit way to + # check if it's requester + self.signal_pipe.send_tensor(self.end_signal) + \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_connector.py new file mode 100644 index 0000000000000..353270b02de86 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_connector.py @@ -0,0 +1,270 @@ +""" + This file implements a simple torch distributed connector by 3 classes: + - `TorchDistributedPipe`: a tensor transmission pipe between vllm instances, + using `torch.distributed` + - `TorchDistributedBuffer`: a buffer to store tensors, implemented on top + of `TorchDistributedPipe` + - `TorchDistributedConnector`: a torch distributed connector between P/D + instance, implemented on top of `TorchDistributedBuffer` +""" +import threading +import time +from collections import deque +from concurrent.futures import ThreadPoolExecutor +from typing import Deque, List, Optional, Union +from copy import deepcopy + +import torch +from torch.distributed import Backend + +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.kv_transfer.kv_connector.pynccl_connector.pynccl_pipe \ + import PyncclPipe +from vllm.distributed.kv_transfer.kv_connector.pynccl_connector.lookup_buffer \ + import LookupBuffer +from vllm.logger import init_logger +from vllm.config import KVTransferConfig + + + +logger = init_logger(__name__) + + + + +class TorchDistributedConnector(KVConnectorBase): + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + config: KVTransferConfig, + ): + + self.lookup_buffer_size = self.kv_buffer_size + + self.send_buffer: Optional[TorchDistributedBuffer] = None + self.recv_buffer: Optional[TorchDistributedBuffer] = None + + device2backend = { + "cpu": "gloo", + "gpu": "nccl", + } + + # In disaggregated prefill, the prefill vLLM only uses send pipe + # and the decode vLLM only uses recv pipe + # In remote KV cache store, vLLM will use both send pipe and recv pipe + # So we build both send pipe and recv pipe for simplicity. + if config.is_kv_producer: + + self.send_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + device2backend[config.kv_device], + self.kv_buffer_size, + ) + self.send_signal_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + "gloo", + self.kv_buffer_size, + ) + self.recv_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + device2backend[config.kv_device], + self.kv_buffer_size, + ) + self.recv_signal_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + "gloo", + self.kv_buffer_size + ) + + else: + + # the current vLLM instance is KV consumer, so it needs to connect + # its recv pipe to the send pipe of KV producder + + self.recv_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + device2backend[config.kv_device], + self.kv_buffer_size, + ) + self.recv_signal_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + "gloo", + self.kv_buffer_size, + ) + self.send_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + device2backend[config.kv_device], + self.kv_buffer_size, + ) + self.send_signal_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + "gloo", + self.kv_buffer_size + ) + + self.send_buffer = TorchDistributedBuffer(self.send_signal_pipe, + self.send_pipe, + self.lookup_buffer_size) + self.recv_buffer = TorchDistributedBuffer(self.recv_signal_pipe, + self.recv_pipe, + self.lookup_buffer_size) + self.tensor_device = config.kv_device + + + def select( + self, input_tokens: Optional[torch.Tensor], + roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: + + return self.send_buffer.drop_select(input, roi) + + def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor) -> None: + + return self.recv_buffer.insert( + input_tokens, + roi, + key, + value, + hidden + ) + + + + def build_partial_prefill_input( + self, + model_input: "ModelInputForGPUWithSamplingMetadata", + input_tokens_list: List[torch.Tensor], + num_computed_tokens_list: List[int], + start_pos_list: List[int], + slot_mapping_flat: torch.Tensor, + device: torch.device, + ) -> "ModelInputForGPUWithSamplingMetadata": + """ + Helper function to rebuild the model input for the current request. + Goal: avoid running redundant prefill on those tokens that already has + KV caches received. + """ + rebuilt_input_tokens = [] + rebuilt_input_positions = [] + rebuilt_query_lens = [] + + rebuilt_num_prefills = 0 + rebuilt_num_prefill_tokens = 0 + rebuilt_slot_mapping = [] + rebuilt_max_query_len = 0 + + rebuilt_block_tables = [] + + rebuilt_query_start_loc = [0] + rebuilt_context_lens_tensor = [] + rebuilt_selected_token_indices = [] + + # recounting query and context lengths + for idx in range(len(input_tokens_list)): + token_tensor = input_tokens_list[idx] + num_token = len(token_tensor) + num_computed_token = num_computed_tokens_list[idx] + # currently attention kernel cannot handle the case where there is 0 + # query token. + if num_computed_token == num_token: + num_computed_token -= 1 + start_pos = start_pos_list[idx] + + rebuilt_input_tokens.append(token_tensor[num_computed_token:]) + # TODO(Jiayi): please check the correctness of next line + rebuilt_input_positions.append( + model_input.input_positions[start_pos + + num_computed_token:start_pos + + num_token]) + q_len = num_token - num_computed_token + rebuilt_query_lens.append(q_len) + + # Attn metadata-related + rebuilt_num_prefills += 1 + rebuilt_num_prefill_tokens += q_len + new_slot_mapping = slot_mapping_flat[start_pos + + num_computed_token:start_pos + + num_token] + rebuilt_slot_mapping.append(new_slot_mapping) + rebuilt_max_query_len = max(q_len, rebuilt_max_query_len) + # TODO(Jiayi): remove hard-code (block_size=16) + blk_size = 16 + temp_block_table = [ + slot_mapping_flat[i] // blk_size + for i in range(start_pos, start_pos + num_token, blk_size) + ] + rebuilt_block_tables.append(temp_block_table) + rebuilt_query_start_loc.append( + rebuilt_num_prefill_tokens) #start with 0 + rebuilt_context_lens_tensor.append(num_computed_token) + + # Sampling metadata related + #seq_groups (use rebuilt query lens) + rebuilt_selected_token_indices.append(rebuilt_num_prefill_tokens - 1) + + # rebuilt attn_metadata + rebuilt_attn_metadata = deepcopy(model_input.attn_metadata) + rebuilt_attn_metadata.num_prefills = rebuilt_num_prefills + rebuilt_attn_metadata.num_prefill_tokens = rebuilt_num_prefill_tokens + rebuilt_attn_metadata.slot_mapping = torch.cat(rebuilt_slot_mapping).to( + device) + rebuilt_attn_metadata.max_query_len = rebuilt_max_query_len + + rebuilt_attn_metadata.block_tables = torch.tensor( + rebuilt_block_tables, + dtype=model_input.attn_metadata.block_tables.dtype).to(device) + + rebuilt_attn_metadata.query_start_loc = torch.tensor( + rebuilt_query_start_loc, + dtype=model_input.attn_metadata.query_start_loc.dtype).to(device) + rebuilt_attn_metadata.context_lens_tensor = torch.tensor( + rebuilt_context_lens_tensor, + dtype=model_input.attn_metadata.context_lens_tensor.dtype, + ).to(device) + + rebuilt_attn_metadata._cached_prefill_metadata = None + + # rebuilt sampling_metadata + rebuilt_sampling_metadata = deepcopy(model_input.sampling_metadata) + for idx, q_len in enumerate(rebuilt_query_lens): + if rebuilt_sampling_metadata.seq_groups is not None: + rebuilt_sampling_metadata.seq_groups[idx].query_len = q_len + + rebuilt_sampling_metadata.selected_token_indices = torch.tensor( + rebuilt_selected_token_indices, + dtype=model_input.sampling_metadata.selected_token_indices.dtype, + ).to(device) + + # import here to avoid circular import. + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + rebuilt_model_input = ModelInputForGPUWithSamplingMetadata( + input_tokens=torch.cat(rebuilt_input_tokens).to(device), + input_positions=torch.cat(rebuilt_input_positions).to(device), + seq_lens=model_input.seq_lens, + query_lens=rebuilt_query_lens, + lora_mapping=model_input.lora_mapping, + lora_requests=model_input.lora_requests, + attn_metadata=rebuilt_attn_metadata, + prompt_adapter_mapping=model_input.prompt_adapter_mapping, + prompt_adapter_requests=model_input.prompt_adapter_requests, + multi_modal_kwargs=model_input.multi_modal_kwargs, + request_ids_to_seq_ids=model_input.request_ids_to_seq_ids, + finished_requests_ids=model_input.finished_requests_ids, + virtual_engine=model_input.virtual_engine, + sampling_metadata=rebuilt_sampling_metadata, + is_prompt=model_input.is_prompt, + ) + + return rebuilt_model_input + diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py new file mode 100644 index 0000000000000..30d9e54cd1edb --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py @@ -0,0 +1,307 @@ +""" + This file implements a simple torch distributed connector by 3 classes: + - `TorchDistributedPipe`: a tensor transmission pipe between vllm instances, + using `torch.distributed` + - `TorchDistributedBuffer`: a buffer to store tensors, implemented on top + of `TorchDistributedPipe` + - `TorchDistributedConnector`: a torch distributed connector between P/D + instance, implemented on top of `TorchDistributedBuffer` +""" +import threading +import time +from collections import deque +from concurrent.futures import ThreadPoolExecutor +from typing import Deque, List, Optional, Union +from copy import deepcopy + +import torch +from torch.distributed import Backend + +from vllm.distributed.utils import StatelessProcessGroup +from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator +from vllm.logger import init_logger +from vllm.config import ParallelConfig + + + +logger = init_logger(__name__) + + + +# if the tensor is only one-element and only contains NONE_INT +# this means that the sended object is None. +NONE_INT = -150886311 + +# Mapping tensor dtype to INT64, used for tensor metadata transmission +FLOAT16_INT = -543205003776624 +INT64_INT = -375623078607432 +BOOL_INT = -28035262008646 +BFLOAT16_INT = -452084912267662 +FLOAT32_INT = -1049557997456592 +FLOAT64_INT = -452201007054137 +FLOAT8_E4M3FN_INT = -1066697177659525 +FLOAT8_E5M2_INT = -618182574682355 + +DTYPE2INT = { + torch.float16: FLOAT16_INT, + torch.int64: INT64_INT, + torch.bool: BOOL_INT, + torch.bfloat16: BFLOAT16_INT, + torch.float32: FLOAT32_INT, + torch.float64: FLOAT64_INT, + torch.float8_e4m3fn: FLOAT8_E4M3FN_INT, + torch.float8_e5m2: FLOAT8_E5M2_INT, +} + +INT2DTYPE = { + FLOAT16_INT: torch.float16, + INT64_INT: torch.int64, + BOOL_INT: torch.bool, + BFLOAT16_INT: torch.bfloat16, + FLOAT32_INT: torch.float32, + FLOAT64_INT: torch.float64, + FLOAT8_E4M3FN_INT: torch.float8_e4m3fn, + FLOAT8_E5M2_INT: torch.float8_e5m2, +} + + +class BrokenPipeException(Exception): + + def __init__(self, message): + self.message = message + super().__init__(self.message) + + +class PyncclPipe: + + METADATA_LENGTH = 16 + MAX_TENSOR_DIMENSIONS = 14 + METADATA_DTYPE = torch.int64 + + def __init__( + self, + config: ParallelConfig + ): + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + self.buffer_size_thresh = buffer_size_thresh + + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend) + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + + assert self.device_group is not None + assert self.rank_in_group <= 1 + + self.device = self._select_device(torch_distributed_backend) + + self.target_rank_for_send = self.ranks[(self.rank_in_group + 1) % + self.world_size] + self.target_rank_for_recv = self.ranks[(self.rank_in_group - 1) % + self.world_size] + + self.transport_thread: Optional[ThreadPoolExecutor] = None + self.buffer_size = 0 + self.buffer_size_lock = threading.Lock() + + self.none_tensor = torch.tensor([NONE_INT], device=self.device) + + # On-device tensors to be reused for recv + self.rcv_metadata_buffer = torch.zeros(self.METADATA_LENGTH, + dtype=self.METADATA_DTYPE, + device=self.device) + + def _select_device(self, backend: Union[str, Backend]): + if torch.cuda.is_available() and backend == Backend.NCCL: + return torch.device(f"cuda:{self.local_rank}") + else: + return "cpu" + + def _make_metadata(self, tensor: torch.Tensor) -> torch.Tensor: + """Create the metadata on based on the input tensor, and move it to GPU. + The metadata's length is `TorchDistributedPipe.METADATA_LENGTH`. + + Currently, the metadata is a int64 tensor and it includes dtype, number + of dimensions, and the shape information of the input tensor. + + + The information follows the layout below: + - metadata[0] -- dtype + - metadata[1] -- number of dimensions + - metadata[2 : 2+ndims] -- the shape of the input tensor + + Parameters: + - tensor: the input tensor + + Returns: + - metadata: the metadata tensor, on self.device + """ + buffer = torch.empty(self.METADATA_LENGTH, + dtype=self.METADATA_DTYPE, + device="cpu") + buffer[0] = DTYPE2INT[tensor.dtype] + ndims = len(tensor.shape) + buffer[1] = len(tensor.shape) + buffer[2:2 + ndims] = torch.tensor(tensor.shape, + dtype=self.METADATA_DTYPE) + return buffer.to(self.device) + + def _prepare_recv_buffer(self, + d_metadata_buffer: torch.Tensor) -> torch.Tensor: + """Create a buffer to receive the tensor based on the metadata. + + Parameters: + - d_metadata_buffer: the metadata tensor on self.device + + Returns: + - buffer: the buffer tensor to receive the tensor, on self.device + """ + h_buffer = d_metadata_buffer.cpu().numpy() + dtype = INT2DTYPE[h_buffer[0]] + ndims = h_buffer[1] + shape = tuple(h_buffer[2:2 + ndims]) + return torch.empty(shape, dtype=dtype, device=self.device) + + def _send_metadata(self, d_metadata_buffer: torch.Tensor): + """Send the metadata buffer to the target rank. + """ + torch.distributed.send( + d_metadata_buffer, + dst=self.target_rank_for_send, + group=self.device_group, + ) + + def _recv_metadata(self) -> torch.Tensor: + """Receive the metadata buffer from the target rank. + + Returns: + - metadata_buffer: the metadata buffer tensor, on self.device + + Note: + The current implementation uses the assumption that there is no + race conditions during sending/receiving. Therefore, the metadata + buffer can be reused + """ + torch.distributed.recv( + self.rcv_metadata_buffer, + src=self.target_rank_for_recv, + group=self.device_group, + ) + + return self.rcv_metadata_buffer + + def _send_impl(self, tensor): + """ + The actual implementation of sending the tensor to the target rank. + This function will first send the metadata, and then send the tensor. + + Parameters: + - tensor: the input tensor to be sent + """ + + metadata = self._make_metadata(tensor) + self._send_metadata(metadata) + torch.distributed.send(tensor.to(self.device), + dst=self.target_rank_for_send, + group=self.device_group) + + def _recv_impl(self) -> torch.Tensor: + """ + The actual implementation of receiving the tensor from the target rank. + This function will first receive the metadata, then receive the tensor. + + This function will block if there is no tensor to receive. + + Returns: + - buffer: the received tensor, on self.device + """ + d_metadata = self._recv_metadata() + buffer = self._prepare_recv_buffer(d_metadata) + + torch.distributed.recv(buffer, + src=self.target_rank_for_recv, + group=self.device_group) + + return buffer + + def send_tensor_wrapper(self, tensor): + try: + """Wrapper for send_tensor_dict""" + tensor_size = tensor.element_size() * tensor.numel() + self._send_impl(tensor) + + with self.buffer_size_lock: + self.buffer_size = self.buffer_size - tensor_size + except Exception as e: + logger.error("[rank%d]: Exception when trying to send %s, msg: %s", + torch.distributed.get_rank(), str(tensor), str(e)) + import traceback + traceback.print_exc() + + def block_if_full(self): + """Block the current thread if the buffer size is larger than 1e9.""" + # TODO: replace this 1e9 with a configurable parameter or a constant + while self.buffer_size > self.buffer_size_thresh: + logger.debug("KV cache transfer pipe is full. Waiting...") + time.sleep(0.05) + + def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + """Sends a tensor to the destination rank in a non-blocking way. + Flow: send tensor dim -- send tensor shape -- send tensor data + """ + + if self.transport_thread is None: + self.transport_thread = ThreadPoolExecutor(max_workers=1) + + if tensor is None: + tensor = self.none_tensor + + tensor_size = tensor.element_size() * tensor.numel() + + assert ( + 0 < len(tensor.shape) < self.MAX_TENSOR_DIMENSIONS + ), f"Only support dimensions within 1-{self.MAX_TENSOR_DIMENSIONS}" + + self.block_if_full() + + with self.buffer_size_lock: + self.buffer_size = self.buffer_size + tensor_size + + self.transport_thread.submit( + self.send_tensor_wrapper, + tensor, + ) + + def recv_tensor(self) -> Optional[torch.Tensor]: + """Receives a tensor from the src rank. Blocking.""" + if self.transport_thread is None: + self.transport_thread = ThreadPoolExecutor(max_workers=1) + + future = self.transport_thread.submit(self._recv_impl) + + try: + tensor = future.result() + except Exception as e: + # the underlying pipe is likely broken + logger.error("Encountering exception in KV receiving thread") + logger.error("%s", e) + # fault tolerance: if the pipe is broken, return None + return None + + if tensor.numel() == 1 and tensor.item() == NONE_INT: + return None + else: + return tensor + + def close(self): + """Close the pipe and release the resources.""" + if (hasattr(self, "transport_thread") + and self.transport_thread is not None): + self.transport_thread.shutdown() + diff --git a/vllm/distributed/kv_transfer/kv_connector/torch_distributed_connector.py b/vllm/distributed/kv_transfer/kv_connector/torch_distributed_connector.py deleted file mode 100644 index fe6787cfe88f4..0000000000000 --- a/vllm/distributed/kv_transfer/kv_connector/torch_distributed_connector.py +++ /dev/null @@ -1,758 +0,0 @@ -""" - This file implements a simple torch distributed connector by 3 classes: - - `TorchDistributedPipe`: a tensor transmission pipe between vllm instances, - using `torch.distributed` - - `TorchDistributedBuffer`: a buffer to store tensors, implemented on top - of `TorchDistributedPipe` - - `TorchDistributedConnector`: a torch distributed connector between P/D - instance, implemented on top of `TorchDistributedBuffer` -""" -import threading -import time -from collections import deque -from concurrent.futures import ThreadPoolExecutor -from typing import Deque, List, Optional, Union -from copy import deepcopy - -import torch -from torch.distributed import Backend - -from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase -from vllm.logger import init_logger -from vllm.config import KVTransferConfig - - - -logger = init_logger(__name__) - - - -# if the tensor is only one-element and only contains NONE_INT -# this means that the sended object is None. -NONE_INT = -150886311 - -# Mapping tensor dtype to INT64, used for tensor metadata transmission -FLOAT16_INT = -543205003776624 -INT64_INT = -375623078607432 -BOOL_INT = -28035262008646 -BFLOAT16_INT = -452084912267662 -FLOAT32_INT = -1049557997456592 -FLOAT64_INT = -452201007054137 -FLOAT8_E4M3FN_INT = -1066697177659525 -FLOAT8_E5M2_INT = -618182574682355 - -DTYPE2INT = { - torch.float16: FLOAT16_INT, - torch.int64: INT64_INT, - torch.bool: BOOL_INT, - torch.bfloat16: BFLOAT16_INT, - torch.float32: FLOAT32_INT, - torch.float64: FLOAT64_INT, - torch.float8_e4m3fn: FLOAT8_E4M3FN_INT, - torch.float8_e5m2: FLOAT8_E5M2_INT, -} - -INT2DTYPE = { - FLOAT16_INT: torch.float16, - INT64_INT: torch.int64, - BOOL_INT: torch.bool, - BFLOAT16_INT: torch.bfloat16, - FLOAT32_INT: torch.float32, - FLOAT64_INT: torch.float64, - FLOAT8_E4M3FN_INT: torch.float8_e4m3fn, - FLOAT8_E5M2_INT: torch.float8_e5m2, -} - - -class BrokenPipeException(Exception): - - def __init__(self, message): - self.message = message - super().__init__(self.message) - - -class TorchDistributedPipe: - - METADATA_LENGTH = 16 - MAX_TENSOR_DIMENSIONS = 14 - METADATA_DTYPE = torch.int64 - - def __init__( - self, - group_ranks: List[List[int]], - local_rank: int, - torch_distributed_backend: Union[str, Backend], - buffer_size_thresh: float, - ): - self.rank = torch.distributed.get_rank() - self.local_rank = local_rank - self.device_group = None - self.buffer_size_thresh = buffer_size_thresh - - for ranks in group_ranks: - device_group = torch.distributed.new_group( - ranks, backend=torch_distributed_backend) - if self.rank in ranks: - self.ranks = ranks - self.world_size = len(ranks) - self.rank_in_group = ranks.index(self.rank) - self.device_group = device_group - - assert self.device_group is not None - assert self.rank_in_group <= 1 - - self.device = self._select_device(torch_distributed_backend) - - self.target_rank_for_send = self.ranks[(self.rank_in_group + 1) % - self.world_size] - self.target_rank_for_recv = self.ranks[(self.rank_in_group - 1) % - self.world_size] - - self.transport_thread: Optional[ThreadPoolExecutor] = None - self.buffer_size = 0 - self.buffer_size_lock = threading.Lock() - - self.none_tensor = torch.tensor([NONE_INT], device=self.device) - - # On-device tensors to be reused for recv - self.rcv_metadata_buffer = torch.zeros(self.METADATA_LENGTH, - dtype=self.METADATA_DTYPE, - device=self.device) - - def _select_device(self, backend: Union[str, Backend]): - if torch.cuda.is_available() and backend == Backend.NCCL: - return torch.device(f"cuda:{self.local_rank}") - else: - return "cpu" - - def _make_metadata(self, tensor: torch.Tensor) -> torch.Tensor: - """Create the metadata on based on the input tensor, and move it to GPU. - The metadata's length is `TorchDistributedPipe.METADATA_LENGTH`. - - Currently, the metadata is a int64 tensor and it includes dtype, number - of dimensions, and the shape information of the input tensor. - - - The information follows the layout below: - - metadata[0] -- dtype - - metadata[1] -- number of dimensions - - metadata[2 : 2+ndims] -- the shape of the input tensor - - Parameters: - - tensor: the input tensor - - Returns: - - metadata: the metadata tensor, on self.device - """ - buffer = torch.empty(self.METADATA_LENGTH, - dtype=self.METADATA_DTYPE, - device="cpu") - buffer[0] = DTYPE2INT[tensor.dtype] - ndims = len(tensor.shape) - buffer[1] = len(tensor.shape) - buffer[2:2 + ndims] = torch.tensor(tensor.shape, - dtype=self.METADATA_DTYPE) - return buffer.to(self.device) - - def _prepare_recv_buffer(self, - d_metadata_buffer: torch.Tensor) -> torch.Tensor: - """Create a buffer to receive the tensor based on the metadata. - - Parameters: - - d_metadata_buffer: the metadata tensor on self.device - - Returns: - - buffer: the buffer tensor to receive the tensor, on self.device - """ - h_buffer = d_metadata_buffer.cpu().numpy() - dtype = INT2DTYPE[h_buffer[0]] - ndims = h_buffer[1] - shape = tuple(h_buffer[2:2 + ndims]) - return torch.empty(shape, dtype=dtype, device=self.device) - - def _send_metadata(self, d_metadata_buffer: torch.Tensor): - """Send the metadata buffer to the target rank. - """ - torch.distributed.send( - d_metadata_buffer, - dst=self.target_rank_for_send, - group=self.device_group, - ) - - def _recv_metadata(self) -> torch.Tensor: - """Receive the metadata buffer from the target rank. - - Returns: - - metadata_buffer: the metadata buffer tensor, on self.device - - Note: - The current implementation uses the assumption that there is no - race conditions during sending/receiving. Therefore, the metadata - buffer can be reused - """ - torch.distributed.recv( - self.rcv_metadata_buffer, - src=self.target_rank_for_recv, - group=self.device_group, - ) - - return self.rcv_metadata_buffer - - def _send_impl(self, tensor): - """ - The actual implementation of sending the tensor to the target rank. - This function will first send the metadata, and then send the tensor. - - Parameters: - - tensor: the input tensor to be sent - """ - - metadata = self._make_metadata(tensor) - self._send_metadata(metadata) - torch.distributed.send(tensor.to(self.device), - dst=self.target_rank_for_send, - group=self.device_group) - - def _recv_impl(self) -> torch.Tensor: - """ - The actual implementation of receiving the tensor from the target rank. - This function will first receive the metadata, then receive the tensor. - - This function will block if there is no tensor to receive. - - Returns: - - buffer: the received tensor, on self.device - """ - d_metadata = self._recv_metadata() - buffer = self._prepare_recv_buffer(d_metadata) - - torch.distributed.recv(buffer, - src=self.target_rank_for_recv, - group=self.device_group) - - return buffer - - def send_tensor_wrapper(self, tensor): - try: - """Wrapper for send_tensor_dict""" - tensor_size = tensor.element_size() * tensor.numel() - self._send_impl(tensor) - - with self.buffer_size_lock: - self.buffer_size = self.buffer_size - tensor_size - except Exception as e: - logger.error("[rank%d]: Exception when trying to send %s, msg: %s", - torch.distributed.get_rank(), str(tensor), str(e)) - import traceback - traceback.print_exc() - - def block_if_full(self): - """Block the current thread if the buffer size is larger than 1e9.""" - # TODO: replace this 1e9 with a configurable parameter or a constant - while self.buffer_size > self.buffer_size_thresh: - logger.debug("KV cache transfer pipe is full. Waiting...") - time.sleep(0.05) - - def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: - """Sends a tensor to the destination rank in a non-blocking way. - Flow: send tensor dim -- send tensor shape -- send tensor data - """ - - if self.transport_thread is None: - self.transport_thread = ThreadPoolExecutor(max_workers=1) - - if tensor is None: - tensor = self.none_tensor - - tensor_size = tensor.element_size() * tensor.numel() - - assert ( - 0 < len(tensor.shape) < self.MAX_TENSOR_DIMENSIONS - ), f"Only support dimensions within 1-{self.MAX_TENSOR_DIMENSIONS}" - - self.block_if_full() - - with self.buffer_size_lock: - self.buffer_size = self.buffer_size + tensor_size - - self.transport_thread.submit( - self.send_tensor_wrapper, - tensor, - ) - - def recv_tensor(self) -> Optional[torch.Tensor]: - """Receives a tensor from the src rank. Blocking.""" - if self.transport_thread is None: - self.transport_thread = ThreadPoolExecutor(max_workers=1) - - future = self.transport_thread.submit(self._recv_impl) - - try: - tensor = future.result() - except Exception as e: - # the underlying pipe is likely broken - logger.error("Encountering exception in KV receiving thread") - logger.error("%s", e) - # fault tolerance: if the pipe is broken, return None - return None - - if tensor.numel() == 1 and tensor.item() == NONE_INT: - return None - else: - return tensor - - def close(self): - """Close the pipe and release the resources.""" - if (hasattr(self, "transport_thread") - and self.transport_thread is not None): - self.transport_thread.shutdown() - - -class TorchDistributedBuffer: - - def __init__(self, - signal_pipe: TorchDistributedPipe, - data_pipe: TorchDistributedPipe, - buffer_size_thresh: float): - """ - signal_pipe: on CPU - - NOTE: on-device recv will block all threads in the process, making the - KV cache producer unable to listen to new request while transmitting - KV cache. Luckily CPU recv only blocks the current thread so we use - CPU recv to listen to new request. - - data_pipe: on device (e.g. GPU) - """ - - self.buffer: Deque[List[torch.Tensor]] = deque() - - self.buffer_size = 0 - self.buffer_size_threshold = buffer_size_thresh - self.buffer_lock = threading.Lock() - self.signal_pipe = signal_pipe - self.data_pipe = data_pipe - self.request_handling_thread: Optional[threading.Thread] = None - - self.normal_signal = torch.tensor([0]) - self.end_signal = None - - def _matches(self, tokens_roi_sender: List[torch.Tensor], - tokens_roi_recver: List[torch.Tensor]): - - # tokens_roi_sender: tokens and roi of the producer (in the buffer) - # tokens_roi_recver: tokens and roi of the consumer (query) - - tokens_sender = tokens_roi_sender[0] - tokens_recver = tokens_roi_recver[0] - roi_sender = tokens_roi_sender[1] - roi_recver = tokens_roi_recver[1] - - if tokens_recver is None: - # consumer sends an empty request - # semantics: DROP SELECT * LIMIT 1 - # so any of the data in the buffer can be drop-selected - return True - - # Assuming that roi is a binary mask on tokens - tokens_sender = tokens_sender[roi_sender] - tokens_recver = tokens_recver[roi_recver] - - # simple common prefix matching - min_length = min(len(tokens_sender), len(tokens_recver)) - if torch.allclose(tokens_sender[:min_length], - tokens_recver[:min_length]): - return min_length - - return 0 - - def _send_tensor_and_dec_size(self, - tensor: Optional[torch.Tensor]) -> None: - - assert tensor is not None, "Use self.data_pipe.send(None) instead" - self.buffer_size -= tensor.element_size() * tensor.numel() - self.data_pipe.send_tensor(tensor) - - def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]): - - if isinstance(data, torch.Tensor): - return data.element_size() * data.numel() - if not data: - # cannot perform `not data` on a tensor - # so this check needs to go after the check above - return 0 - - raise AssertionError("Unknown data type %s" % type(data)) - - def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor, - key: torch.Tensor, value: torch.Tensor, - hidden: torch.Tensor): - - if isinstance(input_tokens, torch.Tensor): - input_tokens = input_tokens.clone() - if isinstance(roi, torch.Tensor): - roi = roi.clone() - if isinstance(key, torch.Tensor): - key = key.clone() - if isinstance(value, torch.Tensor): - value = value.clone() - if isinstance(hidden, torch.Tensor): - hidden = hidden.clone() - - buffer_item = [input_tokens, roi, key, value, hidden] - - with self.buffer_lock: - for data in buffer_item: - self.buffer_size += self._get_element_size(data) - self.buffer.append(buffer_item) - - def _is_end_signal(self, signal): - return signal is None - - def drop_select_handler(self): - - try: - - while True: - signal = self.signal_pipe.recv_tensor() - if self._is_end_signal(signal): - logger.info("Received end signal!") - break - - input_tokens = self.data_pipe.recv_tensor() - - roi = self.data_pipe.recv_tensor() - tokens_roi_recver = [input_tokens, roi] - - matched_length = 0 - - # perform input tokens and roi matching - # FIXME: this matching is O(n), ideally it should be O(1) - # but this buffer size won't (and shouldn't) be too large so - # the fix is not urgent. - with self.buffer_lock: - - for _ in range(len(self.buffer)): - - temp_length = self._matches(self.buffer[0], - tokens_roi_recver) - if temp_length > 0: - matched_length = temp_length - break - # rotate the element we just accessed to the end - self.buffer.rotate(-1) - - if matched_length > 0: - # need to clone the tensor - # in case the tensor is freed before sending finishes - matched_item = self.buffer.popleft() - for tensor in matched_item: - self._send_tensor_and_dec_size(tensor) - - else: - # no match, just send None - for _ in range(5): - self.data_pipe.send_tensor(None) - - except RuntimeError as e: - if 'Connection closed by peer' not in str(e): - raise e - - logger.debug("Closing drop_select_handler") - - def drop_select( - self, input_tokens: Optional[torch.Tensor], - roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: - - assert self.request_handling_thread is None, \ - "drop_select should be called by the KV cache consumer "\ - "(e.g. the decode vLLM instance)" - - if isinstance(input_tokens, torch.Tensor): - input_tokens = input_tokens.clone() - if isinstance(roi, torch.Tensor): - roi = roi.clone() - - self.signal_pipe.send_tensor(self.normal_signal) - self.data_pipe.send_tensor(input_tokens) - self.data_pipe.send_tensor(roi) - - input_tokens = self.data_pipe.recv_tensor() - roi = self.data_pipe.recv_tensor() - key = self.data_pipe.recv_tensor() - value = self.data_pipe.recv_tensor() - hidden = self.data_pipe.recv_tensor() - - return [input_tokens, roi, key, value, hidden] - - def full_handler(self): - time.sleep(0.001) - - def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, - key: torch.Tensor, value: torch.Tensor, - hidden: torch.Tensor) -> None: - - if self.buffer_size > self.buffer_size_threshold: - # log outside the while loop to avoid this message being logged - # repeatedly. - logger.debug("KV transfer buffer is full. Handling...") - while self.buffer_size > self.buffer_size_threshold: - self.full_handler() - - self._add_to_buffer(input_tokens, roi, key, value, hidden) - - # when calling the insert, the current process is a sender - # need to launch the request handler and start listening to request. - if self.request_handling_thread is None: - self.request_handling_thread = threading.Thread( - target=self.drop_select_handler) - self.request_handling_thread.start() - - def close(self): - - if hasattr(self, "request_handling_thread" - ) and self.request_handling_thread is not None: - self.request_handling_thread.join() - - else: - # TODO: have a explicit close signal and have a explicit way to - # check if it's requester - self.signal_pipe.send_tensor(self.end_signal) - - -class TorchDistributedConnector(KVConnectorBase): - - def __init__( - self, - group_ranks: List[List[int]], - local_rank: int, - config: KVTransferConfig, - ): - - self.lookup_buffer_size = self.kv_buffer_size - - self.send_buffer: Optional[TorchDistributedBuffer] = None - self.recv_buffer: Optional[TorchDistributedBuffer] = None - - device2backend = { - "cpu": "gloo", - "gpu": "nccl", - } - - # In disaggregated prefill, the prefill vLLM only uses send pipe - # and the decode vLLM only uses recv pipe - # In remote KV cache store, vLLM will use both send pipe and recv pipe - # So we build both send pipe and recv pipe for simplicity. - if config.is_kv_producer: - - self.send_pipe = TorchDistributedPipe( - group_ranks, - local_rank, - device2backend[config.kv_device], - self.kv_buffer_size, - ) - self.send_signal_pipe = TorchDistributedPipe( - group_ranks, - local_rank, - "gloo", - self.kv_buffer_size, - ) - self.recv_pipe = TorchDistributedPipe( - group_ranks, - local_rank, - device2backend[config.kv_device], - self.kv_buffer_size, - ) - self.recv_signal_pipe = TorchDistributedPipe( - group_ranks, - local_rank, - "gloo", - self.kv_buffer_size - ) - - else: - - # the current vLLM instance is KV consumer, so it needs to connect - # its recv pipe to the send pipe of KV producder - - self.recv_pipe = TorchDistributedPipe( - group_ranks, - local_rank, - device2backend[config.kv_device], - self.kv_buffer_size, - ) - self.recv_signal_pipe = TorchDistributedPipe( - group_ranks, - local_rank, - "gloo", - self.kv_buffer_size, - ) - self.send_pipe = TorchDistributedPipe( - group_ranks, - local_rank, - device2backend[config.kv_device], - self.kv_buffer_size, - ) - self.send_signal_pipe = TorchDistributedPipe( - group_ranks, - local_rank, - "gloo", - self.kv_buffer_size - ) - - self.send_buffer = TorchDistributedBuffer(self.send_signal_pipe, - self.send_pipe, - self.lookup_buffer_size) - self.recv_buffer = TorchDistributedBuffer(self.recv_signal_pipe, - self.recv_pipe, - self.lookup_buffer_size) - self.tensor_device = config.kv_device - - - def select( - self, input_tokens: Optional[torch.Tensor], - roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: - - return self.send_buffer.drop_select(input, roi) - - def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, - key: torch.Tensor, value: torch.Tensor, - hidden: torch.Tensor) -> None: - - return self.recv_buffer.insert( - input_tokens, - roi, - key, - value, - hidden - ) - - - - def build_partial_prefill_input( - self, - model_input: "ModelInputForGPUWithSamplingMetadata", - input_tokens_list: List[torch.Tensor], - num_computed_tokens_list: List[int], - start_pos_list: List[int], - slot_mapping_flat: torch.Tensor, - device: torch.device, - ) -> "ModelInputForGPUWithSamplingMetadata": - """ - Helper function to rebuild the model input for the current request. - Goal: avoid running redundant prefill on those tokens that already has - KV caches received. - """ - rebuilt_input_tokens = [] - rebuilt_input_positions = [] - rebuilt_query_lens = [] - - rebuilt_num_prefills = 0 - rebuilt_num_prefill_tokens = 0 - rebuilt_slot_mapping = [] - rebuilt_max_query_len = 0 - - rebuilt_block_tables = [] - - rebuilt_query_start_loc = [0] - rebuilt_context_lens_tensor = [] - rebuilt_selected_token_indices = [] - - # recounting query and context lengths - for idx in range(len(input_tokens_list)): - token_tensor = input_tokens_list[idx] - num_token = len(token_tensor) - num_computed_token = num_computed_tokens_list[idx] - # currently attention kernel cannot handle the case where there is 0 - # query token. - if num_computed_token == num_token: - num_computed_token -= 1 - start_pos = start_pos_list[idx] - - rebuilt_input_tokens.append(token_tensor[num_computed_token:]) - # TODO(Jiayi): please check the correctness of next line - rebuilt_input_positions.append( - model_input.input_positions[start_pos + - num_computed_token:start_pos + - num_token]) - q_len = num_token - num_computed_token - rebuilt_query_lens.append(q_len) - - # Attn metadata-related - rebuilt_num_prefills += 1 - rebuilt_num_prefill_tokens += q_len - new_slot_mapping = slot_mapping_flat[start_pos + - num_computed_token:start_pos + - num_token] - rebuilt_slot_mapping.append(new_slot_mapping) - rebuilt_max_query_len = max(q_len, rebuilt_max_query_len) - # TODO(Jiayi): remove hard-code (block_size=16) - blk_size = 16 - temp_block_table = [ - slot_mapping_flat[i] // blk_size - for i in range(start_pos, start_pos + num_token, blk_size) - ] - rebuilt_block_tables.append(temp_block_table) - rebuilt_query_start_loc.append( - rebuilt_num_prefill_tokens) #start with 0 - rebuilt_context_lens_tensor.append(num_computed_token) - - # Sampling metadata related - #seq_groups (use rebuilt query lens) - rebuilt_selected_token_indices.append(rebuilt_num_prefill_tokens - 1) - - # rebuilt attn_metadata - rebuilt_attn_metadata = deepcopy(model_input.attn_metadata) - rebuilt_attn_metadata.num_prefills = rebuilt_num_prefills - rebuilt_attn_metadata.num_prefill_tokens = rebuilt_num_prefill_tokens - rebuilt_attn_metadata.slot_mapping = torch.cat(rebuilt_slot_mapping).to( - device) - rebuilt_attn_metadata.max_query_len = rebuilt_max_query_len - - rebuilt_attn_metadata.block_tables = torch.tensor( - rebuilt_block_tables, - dtype=model_input.attn_metadata.block_tables.dtype).to(device) - - rebuilt_attn_metadata.query_start_loc = torch.tensor( - rebuilt_query_start_loc, - dtype=model_input.attn_metadata.query_start_loc.dtype).to(device) - rebuilt_attn_metadata.context_lens_tensor = torch.tensor( - rebuilt_context_lens_tensor, - dtype=model_input.attn_metadata.context_lens_tensor.dtype, - ).to(device) - - rebuilt_attn_metadata._cached_prefill_metadata = None - - # rebuilt sampling_metadata - rebuilt_sampling_metadata = deepcopy(model_input.sampling_metadata) - for idx, q_len in enumerate(rebuilt_query_lens): - if rebuilt_sampling_metadata.seq_groups is not None: - rebuilt_sampling_metadata.seq_groups[idx].query_len = q_len - - rebuilt_sampling_metadata.selected_token_indices = torch.tensor( - rebuilt_selected_token_indices, - dtype=model_input.sampling_metadata.selected_token_indices.dtype, - ).to(device) - - # import here to avoid circular import. - from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata - rebuilt_model_input = ModelInputForGPUWithSamplingMetadata( - input_tokens=torch.cat(rebuilt_input_tokens).to(device), - input_positions=torch.cat(rebuilt_input_positions).to(device), - seq_lens=model_input.seq_lens, - query_lens=rebuilt_query_lens, - lora_mapping=model_input.lora_mapping, - lora_requests=model_input.lora_requests, - attn_metadata=rebuilt_attn_metadata, - prompt_adapter_mapping=model_input.prompt_adapter_mapping, - prompt_adapter_requests=model_input.prompt_adapter_requests, - multi_modal_kwargs=model_input.multi_modal_kwargs, - request_ids_to_seq_ids=model_input.request_ids_to_seq_ids, - finished_requests_ids=model_input.finished_requests_ids, - virtual_engine=model_input.virtual_engine, - sampling_metadata=rebuilt_sampling_metadata, - is_prompt=model_input.is_prompt, - ) - - return rebuilt_model_input - diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 1a82a9e87cb10..7ede63da3027a 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -32,6 +32,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union from unittest.mock import patch +from numpy import product import torch import torch.distributed from torch.distributed import Backend, ProcessGroup @@ -1052,7 +1053,6 @@ def init_distributed_environment( # so this vLLM instance's world size is half of torch's world size torch_dist_world_size = torch_dist_world_size // 2 ranks = [[i for i in range(torch_dist_world_size)]] - ranks = include_decoding_groups_if_disagg_enabled(ranks, world_size) _WORLD = init_world_group(ranks, local_rank, backend) logger.debug("_WORLD initialized for rank %d", @@ -1064,6 +1064,7 @@ def init_distributed_environment( def initialize_model_parallel( + kv_transfer_parallel_size: int = 1, tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, backend: Optional[str] = None, @@ -1113,18 +1114,15 @@ def initialize_model_parallel( world_size: int = torch.distributed.get_world_size() backend = backend or torch.distributed.get_backend( get_world_group().device_group) - if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: - # Disaggregated prefill enabled - # This vLLM instance thinks its word size is tp * pp, but - # torch.distributed contains 2 vLLM instances, - # its world size is 2 * tp * pp - # Adjust the world_size to match. - world_size = world_size // 2 - - if (world_size != - tensor_model_parallel_size * pipeline_model_parallel_size): + + if (world_size != product([ + kv_transfer_parallel_size, + tensor_model_parallel_size, + pipeline_model_parallel_size, + ])): raise RuntimeError( f"world_size ({world_size}) is not equal to " + f"kv_transfer_parallel_size ({kv_transfer_parallel_size}) x " f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " f"pipeline_model_parallel_size ({pipeline_model_parallel_size})") @@ -1139,8 +1137,6 @@ def initialize_model_parallel( range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)) group_ranks.append(ranks) - group_ranks = include_decoding_groups_if_disagg_enabled( - group_ranks, world_size) # message queue broadcaster is only used in tensor model parallel group _TP = init_model_parallel_group(group_ranks, get_world_group().local_rank, @@ -1159,8 +1155,6 @@ def initialize_model_parallel( for i in range(num_pipeline_model_parallel_groups): ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) group_ranks.append(ranks) - group_ranks = include_decoding_groups_if_disagg_enabled( - group_ranks, world_size) # pipeline parallel does not need custom allreduce _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 1f9a1fcebcc47..25d1b8fe8a789 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -68,7 +68,7 @@ def _init_executor(self) -> None: # 127.0.0.1 for communication. distributed_init_method = get_distributed_init_method( "127.0.0.1", - get_open_port(force=dist_kv.IS_DISTRIBUTED_KV_INSTANCE)) + get_open_port(force=self.config.IS_DISTRIBUTED_KV_INSTANCE)) self.workers: List[ProcessWorkerWrapper] = [] # This is the list of workers that are rank 0 of each TP group EXCEPT From 6f3d1b3b20bf6cd63d44eea2dd8b7ae372c58ad5 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Thu, 14 Nov 2024 06:54:14 +0000 Subject: [PATCH 281/303] debugging pynccl pipe --- tests/kv_transfer/test_send_recv.py | 47 +++- vllm/config.py | 252 +++++------------- .../kv_transfer/kv_connector/__init__.py | 3 +- .../kv_transfer/kv_connector/base.py | 3 +- .../pynccl_connector/pynccl_pipe.py | 112 ++++---- vllm/distributed/kv_transfer/vllm_adapter.py | 5 +- vllm/distributed/parallel_state.py | 5 +- vllm/distributed/utils.py | 12 + vllm/utils.py | 1 - 9 files changed, 187 insertions(+), 253 deletions(-) diff --git a/tests/kv_transfer/test_send_recv.py b/tests/kv_transfer/test_send_recv.py index ff771f34c0325..d79bec36b6ceb 100644 --- a/tests/kv_transfer/test_send_recv.py +++ b/tests/kv_transfer/test_send_recv.py @@ -5,7 +5,9 @@ import torch from tqdm import tqdm -import vllm.distributed.kv_transfer.kv_pipe.torch_distributed_pipe as tdp +import vllm.distributed.kv_transfer.kv_connector.pynccl_connector.pynccl_pipe \ + as pnp +from vllm.config import ParallelConfig def test_run(my_rank, pipe): @@ -35,10 +37,15 @@ def test_run(my_rank, pipe): assert torch.allclose(x, x2) assert torch.allclose(y, y2) + +def barrier(my_rank): + torch.distributed.barrier() + def stress_test(my_rank, pipe): - torch.distributed.barrier() + # barrier + barrier(my_rank) tensors: List[torch.Tensor] = [] @@ -59,7 +66,8 @@ def stress_test(my_rank, pipe): tensors.append(x.mean().unsqueeze(0)) tensors.append(x.std().unsqueeze(0)) - torch.distributed.barrier() + barrier(my_rank) + for i in tqdm(range(500)): if my_rank == int((i % 10) > 3): @@ -78,7 +86,6 @@ def stress_test(my_rank, pipe): assert x.mean() == mean[0] assert x.std() == std[0] - torch.distributed.barrier() print("Stress test passed.") @@ -123,13 +130,31 @@ def latency_test(my_rank, pipe, nelement, ntensor): my_rank = int(os.environ['RANK']) - torch.distributed.init_process_group(init_method="tcp://127.0.0.1:23456", - world_size=2, - rank=my_rank) - - print("initialized! My rank is %d" % my_rank) - - pipe = tdp.TorchDistributedPipe([[0, 1]], my_rank, "nccl") + torch.distributed.init_process_group( + backend='gloo', + init_method='tcp://localhost:12567', + rank=my_rank, + world_size=2, + ) + print('done') + + config = ParallelConfig( + 1, + 1, + kv_connector='PyNcclConnector', + kv_buffer_device='cuda', + kv_buffer_size=1e9, + kv_rank=my_rank, + kv_role="kv_both", # this arg doesn't matter in this test + kv_parallel_size=2, + kv_ip="127.0.0.1", + kv_port=12345, + ) + + pipe = pnp.PyNcclPipe( + local_rank=my_rank, + config=config, + ) torch.manual_seed(0) test_run(my_rank, pipe) diff --git a/vllm/config.py b/vllm/config.py index 613db6480f473..97a845ec1fd88 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,8 +1,6 @@ -import copy import enum import json -import warnings -from dataclasses import dataclass, field, replace +from dataclasses import dataclass, field from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal, Mapping, Optional, Set, Tuple, Type, Union) @@ -15,10 +13,9 @@ from vllm.model_executor.models import ModelRegistry from vllm.platforms import current_platform from vllm.tracing import is_otel_available, otel_import_error_traceback -from vllm.transformers_utils.config import ( - ConfigFormat, get_config, get_hf_image_processor_config, - get_hf_text_config, get_pooling_config, - get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope) +from vllm.transformers_utils.config import (ConfigFormat, get_config, + get_hf_image_processor_config, + get_hf_text_config) from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, print_warning_once) @@ -76,6 +73,9 @@ class ModelConfig: code_revision: The specific revision to use for the model code on Hugging Face Hub. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. + rope_scaling: Dictionary containing the scaling configuration for the + RoPE embeddings. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. tokenizer_revision: The specific tokenizer version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. @@ -107,7 +107,7 @@ class ModelConfig: matches the model name exposed via the APIs. If multiple model names provided, the first name will be used. If not specified, the model name will be the same as `model`. - limit_mm_per_prompt: Maximum number of data items per modality + limit_mm_per_prompt: Maximum number of data instances per modality per prompt. Only applicable for multimodal models. override_neuron_config: Initialize non default neuron config or override default neuron config that are specific to Neuron devices, @@ -115,7 +115,6 @@ class ModelConfig: can not be gathered from the vllm arguments. config_format: The config format which shall be loaded. Defaults to 'auto' which defaults to 'hf'. - hf_overrides: Arguments to be forwarded to the HuggingFace config. mm_processor_kwargs: Arguments to be forwarded to the model's processor for multi-modal data, e.g., image processor. pooling_type: Used to configure the pooling method in the embedding @@ -146,7 +145,7 @@ def __init__( allowed_local_media_path: str = "", revision: Optional[str] = None, code_revision: Optional[str] = None, - rope_scaling: Optional[Dict[str, Any]] = None, + rope_scaling: Optional[dict] = None, rope_theta: Optional[float] = None, tokenizer_revision: Optional[str] = None, max_model_len: Optional[int] = None, @@ -164,7 +163,6 @@ def __init__( override_neuron_config: Optional[Dict[str, Any]] = None, config_format: ConfigFormat = ConfigFormat.AUTO, chat_template_text_format: str = "string", - hf_overrides: Optional[Dict[str, Any]] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None, pooling_type: Optional[str] = None, pooling_norm: Optional[bool] = None, @@ -179,22 +177,8 @@ def __init__( self.seed = seed self.revision = revision self.code_revision = code_revision - - if hf_overrides is None: - hf_overrides = {} - if rope_scaling is not None: - hf_override: Dict[str, Any] = {"rope_scaling": rope_scaling} - hf_overrides.update(hf_override) - msg = ("`--rope-scaling` will be removed in a future release. " - f"'Please instead use `--hf-overrides '{hf_override!r}'`") - warnings.warn(DeprecationWarning(msg), stacklevel=2) - if rope_theta is not None: - hf_override = {"rope_theta": rope_theta} - hf_overrides.update(hf_override) - msg = ("`--rope-theta` will be removed in a future release. " - f"'Please instead use `--hf-overrides '{hf_override!r}'`") - warnings.warn(DeprecationWarning(msg), stacklevel=2) - + self.rope_scaling = rope_scaling + self.rope_theta = rope_theta # The tokenizer version is consistent with the model version by default. if tokenizer_revision is None: self.tokenizer_revision = revision @@ -207,11 +191,11 @@ def __init__( self.max_logprobs = max_logprobs self.disable_sliding_window = disable_sliding_window self.skip_tokenizer_init = skip_tokenizer_init + self.hf_config = get_config(self.model, trust_remote_code, revision, - code_revision, config_format, - **hf_overrides) + code_revision, rope_scaling, rope_theta, + config_format) self.hf_text_config = get_hf_text_config(self.hf_config) - self.encoder_config = self._get_encoder_config() self.hf_image_processor_config = get_hf_image_processor_config( self.model, revision) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) @@ -244,8 +228,7 @@ def __init__( max_model_len=max_model_len, disable_sliding_window=self.disable_sliding_window, sliding_window_len=self.get_hf_config_sliding_window(), - spec_target_max_model_len=spec_target_max_model_len, - encoder_config=self.encoder_config) + spec_target_max_model_len=spec_target_max_model_len) self.served_model_name = get_served_model_name(model, served_model_name) self.multimodal_config = self._init_multimodal_config( @@ -289,10 +272,6 @@ def _init_multimodal_config( return None - def _get_encoder_config(self): - return get_sentence_transformer_tokenizer_config( - self.model, self.revision) - def _init_pooler_config( self, pooling_type: Optional[str] = None, @@ -302,14 +281,6 @@ def _init_pooler_config( pooling_returned_token_ids: Optional[List[int]] = None ) -> Optional["PoolerConfig"]: if self.task == "embedding": - pooling_config = get_pooling_config(self.model, self.revision) - if pooling_config is not None: - # override if user does not - # specifies pooling_type and/or pooling_norm - if pooling_type is None: - pooling_type = pooling_config["pooling_type"] - if pooling_norm is None: - pooling_norm = pooling_config["normalize"] return PoolerConfig( pooling_type=pooling_type, pooling_norm=pooling_norm, @@ -696,13 +667,12 @@ def get_multimodal_config(self) -> "MultiModalConfig": return self.multimodal_config @property - def is_encoder_decoder(self) -> bool: + def is_encoder_decoder_model(self) -> bool: """Extract the HF encoder/decoder model flag.""" - return is_encoder_decoder(self.hf_config) - - @property - def uses_mrope(self) -> bool: - return uses_mrope(self.hf_config) + return getattr( + self.hf_config, "is_encoder_decoder", + False) or (hasattr(self.hf_config, "text_config") and getattr( + self.hf_config.text_config, "is_encoder_decoder", False)) @property def is_multimodal_model(self) -> bool: @@ -952,12 +922,9 @@ class ParallelConfig: https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler. placement_group: ray distributed model workers placement group. distributed_executor_backend: Backend to use for distributed model - workers, either "ray" or "mp" (multiprocessing). If the product - of pipeline_parallel_size and tensor_parallel_size is less than - or equal to the number of GPUs available, "mp" will be used to - keep processing on a single host. Otherwise, this will default - to "ray" if Ray is installed and fail otherwise. Note that tpu - and hpu only support Ray for distributed inference. + workers, either "ray" or "mp" (multiprocessing). If either + pipeline_parallel_size or tensor_parallel_size is greater than 1, + will default to "ray" if Ray is installed or "mp" otherwise. kv_connector: The connector to use for kv cache transfer, value can be None, "TorchDistributedConnector" or "LMCacheConnector". kv_buffer_device: The buffer device to use for kv cache transfer. @@ -969,7 +936,6 @@ class ParallelConfig: def __init__( self, - kv_disagg_parallel_size: int, pipeline_parallel_size: int, tensor_parallel_size: int, worker_use_ray: Optional[bool] = None, @@ -983,10 +949,12 @@ def __init__( kv_connector: Optional[str] = None, kv_buffer_device: Optional[str] = None, kv_buffer_size: Optional[float] = None, - kv_disagg_role: Optional[str] = None, - kv_disagg_rank: int = 0, + kv_role: Optional[str] = None, + kv_rank: Optional[int] = None, + kv_parallel_size: Optional[int] = None, + kv_ip: Optional[str] = None, + kv_port: Optional[str] = None, ) -> None: - self.kv_disagg_parallel_size = kv_disagg_parallel_size self.pipeline_parallel_size = pipeline_parallel_size self.tensor_parallel_size = tensor_parallel_size self.distributed_executor_backend = distributed_executor_backend @@ -995,13 +963,15 @@ def __init__( self.tokenizer_pool_config = tokenizer_pool_config self.ray_workers_use_nsight = ray_workers_use_nsight self.placement_group = placement_group - self.world_size = kv_disagg_parallel_size * pipeline_parallel_size *\ - tensor_parallel_size + self.world_size = pipeline_parallel_size * tensor_parallel_size self.kv_connector = kv_connector self.kv_buffer_device = kv_buffer_device self.kv_buffer_size = kv_buffer_size - self.kv_disagg_role = kv_disagg_role - self.kv_disagg_rank = kv_disagg_rank + self.kv_role = kv_role + self.kv_rank = kv_rank + self.kv_parallel_size = kv_parallel_size + self.kv_ip = kv_ip + self.kv_port = kv_port if worker_use_ray: if self.distributed_executor_backend is None: @@ -1111,22 +1081,22 @@ def _verify_args(self) -> None: "variables with prefix `kv_`") if self.kv_connector not in [None, - "TorchDistributedConnector", + "PyNcclConnector", "LMCacheConnector"]: raise ValueError(f"Unsupported kv_connector: {self.kv_connector}. " f"Supported connectors are " - f"`TorchDistributedConnector` and " + f"`PyNcclConnector` and " f"`LMCacheConnector`") - if self.kv_disagg_role not in [None, - "kv_producer", - "kv_consumer", - "kv_both"]: - raise ValueError(f"Unsupported kv_disagg_role: {self.kv_disagg_role}. " + if self.kv_role not in [None, + "kv_producer", + "kv_consumer", + "kv_both"]: + raise ValueError(f"Unsupported kv_role: {self.kv_disagg_role}. " f"Supported roles are `kv_producer`, `kv_consumer`, " f"and `kv_both`") - if self.kv_connector is not None and self.kv_disagg_role is None: + if self.kv_connector is not None and self.kv_role is None: raise ValueError("Please specify kv_disagg_role when kv_connector " "is set, supported roles are `kv_producer`, " "`kv_consumer`, and `kv_both`") @@ -1400,6 +1370,13 @@ def maybe_create_spec_config( "speculative decoding is > 1, but got " f"{speculative_disable_by_batch_size=}") + # Reminder: Please update docs/source/serving/compatibility_matrix.rst + # If the feature combo become valid + if enable_chunked_prefill: + raise ValueError( + "Speculative decoding and chunked prefill are " + f"currently mutually exclusive ({enable_chunked_prefill=}).") + # TODO: The user should be able to specify revision/max model len # for the draft model. It is not currently supported. draft_revision = None @@ -1466,29 +1443,6 @@ def maybe_create_spec_config( f"num_speculative_tokens={n_predict}, but " f"{num_speculative_tokens=} was provided.") - if enable_chunked_prefill and draft_hf_config.model_type in ( - "medusa", "mlp_speculator", "eagle"): - raise ValueError( - "Chunked prefill and hidden-state based draft models are " - "not compatible.") - - speculative_draft_tensor_parallel_size = \ - SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size( - target_parallel_config, - speculative_draft_tensor_parallel_size, - draft_hf_config - ) - - if (enable_chunked_prefill and \ - speculative_draft_tensor_parallel_size != 1): - # TODO - Investigate why the error reported in - # https://github.com/vllm-project/vllm/pull/9291#issuecomment-2463266258 - # is happening and re-enable it. - raise ValueError( - "Chunked prefill and speculative decoding can be enabled " - "simultaneously only for draft models with tensor " - "parallel size 1.") - draft_model_config.max_model_len = ( SpeculativeConfig._maybe_override_draft_max_model_len( speculative_max_model_len, @@ -1567,16 +1521,15 @@ def _maybe_override_draft_max_model_len( ) @staticmethod - def _verify_and_get_draft_model_tensor_parallel_size( - target_parallel_config: ParallelConfig, - speculative_draft_tensor_parallel_size: Optional[int], - draft_hf_config: PretrainedConfig) -> int: - """ - Verifies and adjusts the tensor parallel size for a draft model - specified using speculative_draft_tensor_parallel_size. + def create_draft_parallel_config( + target_parallel_config: ParallelConfig, + speculative_draft_tensor_parallel_size: Optional[int], + draft_hf_config: PretrainedConfig, + ) -> ParallelConfig: + """Create a parallel config for use by the draft worker. + + This is mostly a copy of the target parallel config, except the tp_size. """ - # If speculative_draft_tensor_parallel_size is unset then set it - # appropriately else verify that it is set correctly. if speculative_draft_tensor_parallel_size is None: if draft_hf_config.model_type == "mlp_speculator": speculative_draft_tensor_parallel_size = 1 @@ -1592,18 +1545,7 @@ def _verify_and_get_draft_model_tensor_parallel_size( raise ValueError( f"{speculative_draft_tensor_parallel_size=} cannot be " f"other value than 1 or target model tensor_parallel_size") - return speculative_draft_tensor_parallel_size - - @staticmethod - def create_draft_parallel_config( - target_parallel_config: ParallelConfig, - speculative_draft_tensor_parallel_size: int, - draft_hf_config: PretrainedConfig, - ) -> ParallelConfig: - """Create a parallel config for use by the draft worker. - This is mostly a copy of the target parallel config, except the tp_size. - """ draft_parallel_config = ParallelConfig( pipeline_parallel_size=target_parallel_config. pipeline_parallel_size, @@ -1753,7 +1695,6 @@ class LoRAConfig: # This is a constant. lora_vocab_padding_size: ClassVar[int] = 256 long_lora_scaling_factors: Optional[Tuple[float]] = None - bias_enabled: bool = False def __post_init__(self): # Setting the maximum rank to 256 should be able to satisfy the vast @@ -1921,7 +1862,6 @@ def _get_and_verify_max_len( disable_sliding_window: bool, sliding_window_len: Optional[Union[int, List[Optional[int]]]], spec_target_max_model_len: Optional[int] = None, - encoder_config: Optional[Any] = None, ) -> int: """Get and verify the model's maximum length.""" derived_max_model_len = float("inf") @@ -2004,9 +1944,6 @@ def _get_and_verify_max_len( "original_max_position_embeddings"] derived_max_model_len *= scaling_factor - if encoder_config and "max_seq_length" in encoder_config: - derived_max_model_len = encoder_config["max_seq_length"] - # If the user specified a max length, make sure it is smaller than the # derived length from the HF model config. if max_model_len is None: @@ -2108,15 +2045,12 @@ class VllmConfig: simplifies passing around the distinct configurations in the codebase. """ - model_config: ModelConfig = field(default=None, init=True) # type: ignore - cache_config: CacheConfig = field(default=None, init=True) # type: ignore - parallel_config: ParallelConfig = field(default=None, - init=True) # type: ignore - scheduler_config: SchedulerConfig = field(default=None, - init=True) # type: ignore - device_config: DeviceConfig = field(default=None, - init=True) # type: ignore - load_config: LoadConfig = field(default=None, init=True) # type: ignore + model_config: ModelConfig + cache_config: CacheConfig + parallel_config: ParallelConfig + scheduler_config: SchedulerConfig + device_config: DeviceConfig + load_config: LoadConfig lora_config: Optional[LoRAConfig] = None speculative_config: Optional[SpeculativeConfig] = None decoding_config: Optional[DecodingConfig] = None @@ -2152,23 +2086,14 @@ def _get_quantization_config( return quant_config return None - def with_hf_config(self, hf_config: PretrainedConfig) -> "VllmConfig": - model_config = copy.deepcopy(self.model_config) - model_config.hf_config = hf_config - - return replace(self, model_config=model_config) - def __post_init__(self): """Verify configs are valid & consistent with each other. """ - if self.model_config is not None: - self.model_config.verify_async_output_proc(self.parallel_config, - self.speculative_config, - self.device_config) - self.model_config.verify_with_parallel_config(self.parallel_config) - - if self.cache_config is not None: - self.cache_config.verify_with_parallel_config(self.parallel_config) + self.model_config.verify_async_output_proc(self.parallel_config, + self.speculative_config, + self.device_config) + self.model_config.verify_with_parallel_config(self.parallel_config) + self.cache_config.verify_with_parallel_config(self.parallel_config) if self.lora_config: self.lora_config.verify_with_model_config(self.model_config) @@ -2182,44 +2107,3 @@ def __post_init__(self): self.model_config is not None and self.load_config is not None: self.quant_config = VllmConfig._get_quantization_config( self.model_config, self.load_config) - - def __str__(self): - return ("model=%r, speculative_config=%r, tokenizer=%r, " - "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " - "override_neuron_config=%s, tokenizer_revision=%s, " - "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " - "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " - "pipeline_parallel_size=%d, " - "disable_custom_all_reduce=%s, quantization=%s, " - "enforce_eager=%s, kv_cache_dtype=%s, " - "quantization_param_path=%s, device_config=%s, " - "decoding_config=%r, observability_config=%r, " - "seed=%d, served_model_name=%s, " - "num_scheduler_steps=%d, enable_prefix_caching=%s, " - "use_async_output_proc=%s, mm_processor_kwargs=%s") % \ - (self.model_config.model, self.speculative_config, - self.model_config.tokenizer, - self.model_config.skip_tokenizer_init, - self.model_config.tokenizer_mode, - self.model_config.revision, - self.model_config.override_neuron_config, - self.model_config.tokenizer_revision, - self.model_config.trust_remote_code, - self.model_config.dtype, - self.model_config.max_model_len, - self.load_config.download_dir, - self.load_config.load_format, - self.parallel_config.tensor_parallel_size, - self.parallel_config.pipeline_parallel_size, - self.parallel_config.disable_custom_all_reduce, - self.model_config.quantization, - self.model_config.enforce_eager, - self.cache_config.cache_dtype, - self.model_config.quantization_param_path, - self.device_config.device, self.decoding_config, - self.observability_config, self.model_config.seed, - self.model_config.served_model_name, - self.scheduler_config.num_scheduler_steps, - self.cache_config.enable_prefix_caching, - self.model_config.use_async_output_proc, - self.model_config.mm_processor_kwargs) diff --git a/vllm/distributed/kv_transfer/kv_connector/__init__.py b/vllm/distributed/kv_transfer/kv_connector/__init__.py index 7d0f202f7e1d8..fdbaa76cefeeb 100644 --- a/vllm/distributed/kv_transfer/kv_connector/__init__.py +++ b/vllm/distributed/kv_transfer/kv_connector/__init__.py @@ -1,12 +1,11 @@ from .base import KVConnectorBase -from vllm.config import ParallelConfig class KVConnectorFactory: @staticmethod def create_connector( - config: ParallelConfig + config ) -> KVConnectorBase: if config.kv_connector == 'LMCacheConnector': from .lmcache_connector import LMCacheConnector diff --git a/vllm/distributed/kv_transfer/kv_connector/base.py b/vllm/distributed/kv_transfer/kv_connector/base.py index e0e26c0b1ae8d..c9ac9ccabe557 100644 --- a/vllm/distributed/kv_transfer/kv_connector/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/base.py @@ -10,7 +10,6 @@ from abc import ABC, abstractmethod from typing import List, Optional -from vllm.config import KVTransferConfig import torch @@ -40,7 +39,7 @@ class KVConnectorBase(ABC): """ @abstractmethod - def init(self, config: KVTransferConfig): + def init(self, config): raise NotImplementedError @abstractmethod diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py index 30d9e54cd1edb..3d4c28d88fb9e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py +++ b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py @@ -27,7 +27,6 @@ logger = init_logger(__name__) - # if the tensor is only one-element and only contains NONE_INT # this means that the sended object is None. NONE_INT = -150886311 @@ -72,7 +71,7 @@ def __init__(self, message): super().__init__(self.message) -class PyncclPipe: +class PyNcclPipe: METADATA_LENGTH = 16 MAX_TENSOR_DIMENSIONS = 14 @@ -80,45 +79,69 @@ class PyncclPipe: def __init__( self, + local_rank: int, config: ParallelConfig ): - self.rank = torch.distributed.get_rank() + self.config = config + self.local_rank = local_rank - self.device_group = None - self.buffer_size_thresh = buffer_size_thresh - - for ranks in group_ranks: - device_group = torch.distributed.new_group( - ranks, backend=torch_distributed_backend) - if self.rank in ranks: - self.ranks = ranks - self.world_size = len(ranks) - self.rank_in_group = ranks.index(self.rank) - self.device_group = device_group - - assert self.device_group is not None - assert self.rank_in_group <= 1 - - self.device = self._select_device(torch_distributed_backend) - - self.target_rank_for_send = self.ranks[(self.rank_in_group + 1) % - self.world_size] - self.target_rank_for_recv = self.ranks[(self.rank_in_group - 1) % - self.world_size] - + self.kv_rank = self.config.kv_rank + self.kv_parallel_size = self.config.kv_parallel_size + self.device = self._select_device() + + + # build distributed connection and send/recv implementation + self.group = StatelessProcessGroup.create( + host = self.config.kv_ip, + port = self.config.kv_port, + rank = self.kv_rank, + world_size = self.kv_parallel_size + ) + # add a barrier to make sure all ranks are ready + self.group.barrier() + self.metadata_send_func, self.metadata_recv_func = \ + self._get_metadata_send_recv_impl(self.group) + self.device_send_func, self.device_recv_func = \ + self._get_device_send_recv_impl(self.group) + # set target rank + self.target_rank_for_send = (self.kv_rank+ 1) % self.kv_parallel_size + self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size + + + # transportation-related variables self.transport_thread: Optional[ThreadPoolExecutor] = None self.buffer_size = 0 self.buffer_size_lock = threading.Lock() + self.buffer_size_thresh = self.config.kv_buffer_size self.none_tensor = torch.tensor([NONE_INT], device=self.device) # On-device tensors to be reused for recv self.rcv_metadata_buffer = torch.zeros(self.METADATA_LENGTH, dtype=self.METADATA_DTYPE, - device=self.device) + device="cpu") + + def _get_metadata_send_recv_impl(self, group: StatelessProcessGroup): + return group.send, group.recv + + def _get_device_send_recv_impl(self, group: StatelessProcessGroup): + if self.config.kv_buffer_device == "cuda": + # use PyNCCL for send / recv + comm = PyNcclCommunicator( + group, + device=self.local_rank, + ) + comm.disabled = False + send, recv = comm.send, comm.recv + else: + # use torch c10store for send / recv + send = group.send + recv = group.recv - def _select_device(self, backend: Union[str, Backend]): - if torch.cuda.is_available() and backend == Backend.NCCL: + return send, recv + + def _select_device(self): + if self.config.kv_buffer_device == "cuda": return torch.device(f"cuda:{self.local_rank}") else: return "cpu" @@ -149,7 +172,8 @@ def _make_metadata(self, tensor: torch.Tensor) -> torch.Tensor: ndims = len(tensor.shape) buffer[1] = len(tensor.shape) buffer[2:2 + ndims] = torch.tensor(tensor.shape, - dtype=self.METADATA_DTYPE) + dtype=self.METADATA_DTYPE, + device="cpu") return buffer.to(self.device) def _prepare_recv_buffer(self, @@ -171,11 +195,7 @@ def _prepare_recv_buffer(self, def _send_metadata(self, d_metadata_buffer: torch.Tensor): """Send the metadata buffer to the target rank. """ - torch.distributed.send( - d_metadata_buffer, - dst=self.target_rank_for_send, - group=self.device_group, - ) + self.metadata_send_func(d_metadata_buffer, self.target_rank_for_send) def _recv_metadata(self) -> torch.Tensor: """Receive the metadata buffer from the target rank. @@ -188,10 +208,9 @@ def _recv_metadata(self) -> torch.Tensor: race conditions during sending/receiving. Therefore, the metadata buffer can be reused """ - torch.distributed.recv( - self.rcv_metadata_buffer, - src=self.target_rank_for_recv, - group=self.device_group, + self.metadata_recv_func( + self.rcv_metadata_buffer, + self.target_rank_for_recv ) return self.rcv_metadata_buffer @@ -207,9 +226,7 @@ def _send_impl(self, tensor): metadata = self._make_metadata(tensor) self._send_metadata(metadata) - torch.distributed.send(tensor.to(self.device), - dst=self.target_rank_for_send, - group=self.device_group) + self.device_send_func(tensor.to(self.device), self.target_rank_for_send) def _recv_impl(self) -> torch.Tensor: """ @@ -221,12 +238,12 @@ def _recv_impl(self) -> torch.Tensor: Returns: - buffer: the received tensor, on self.device """ + print('recv_metadata...') d_metadata = self._recv_metadata() + print('recv metadata done, receiving tensor ...') buffer = self._prepare_recv_buffer(d_metadata) - - torch.distributed.recv(buffer, - src=self.target_rank_for_recv, - group=self.device_group) + self.device_recv_func(buffer, self.target_rank_for_recv) + print('recv tensor done.') return buffer @@ -291,8 +308,9 @@ def recv_tensor(self) -> Optional[torch.Tensor]: # the underlying pipe is likely broken logger.error("Encountering exception in KV receiving thread") logger.error("%s", e) - # fault tolerance: if the pipe is broken, return None - return None + import traceback + traceback.print_exc() + raise e if tensor.numel() == 1 and tensor.item() == NONE_INT: return None diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py index e59443e51b3ea..e2d77fe920362 100644 --- a/vllm/distributed/kv_transfer/vllm_adapter.py +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -33,7 +33,6 @@ from vllm.distributed.kv_transfer.kv_connector import KVConnectorFactory from vllm.logger import init_logger from vllm.sequence import IntermediateTensors -from vllm.config import ParallelConfig logger = init_logger(__name__) @@ -44,7 +43,7 @@ IS_KV_CONSUMER: Optional[bool] = None -def set_kv_transfer_attribute(config: ParallelConfig): +def set_kv_transfer_attribute(config): global IS_DISTRIBUTED_KV_INSTANCE, IS_KV_PRODUCER, IS_KV_CONSUMER IS_DISTRIBUTED_KV_INSTANCE = config.is_distributed_kv_instance @@ -65,7 +64,7 @@ def __init__( self, group_ranks: List[List[int]], local_rank: int, - config: ParallelConfig, + config, ): assert self.config.is_distributed_kv_instance, "KV cache transfer "\ diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index d5f1ad689c58c..2268f48534311 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -37,14 +37,13 @@ import torch.distributed from torch.distributed import Backend, ProcessGroup -# Use this import to check if disagg prefill is enabled. -# if enabled, need to adjust distributed group correspondingly. -import vllm.distributed.kv_transfer.vllm_adapter as dist_kv import vllm.envs as envs from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op, supports_custom_op +import vllm.distributed.kv_transfer.vllm_adapter as dist_kv + @dataclass class GraphCaptureContext: diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index dcfcb848cbe06..7cd35d85b8932 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -131,6 +131,9 @@ def send_obj(self, obj: Any, dst: int): self.send_dst_counter[dst] += 1 self.entries.append((key, time.time())) + def send(self, tensor: torch.Tensor, dst: int): + self.send_obj(tensor, dst) + def expire_data(self): """Expire data that is older than `data_expiration_seconds` seconds.""" while self.entries: @@ -150,6 +153,15 @@ def recv_obj(self, src: int) -> Any: self.recv_src_counter[src] += 1 return obj + def recv(self, tensor: torch.Tensor, src: int): + """Receive a tensor from a source rank.""" + recv_tensor = self.recv_obj(src) + assert isinstance(recv_tensor, torch.Tensor), "Received object is"\ + " not a tensor." + assert tensor.size() == recv_tensor.size(), "Received tensor size"\ + " does not match the recv buffer size." + tensor[...] = recv_tensor + def broadcast_obj(self, obj: Optional[Any], src: int) -> Any: """Broadcast an object from a source rank to all other ranks. It does not clean up after all ranks have received the object. diff --git a/vllm/utils.py b/vllm/utils.py index d69461a167780..5ad8a32fa8ab7 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -36,7 +36,6 @@ from torch.library import Library from typing_extensions import ParamSpec, TypeIs, assert_never -import vllm.distributed.kv_transfer.vllm_adapter as dist_kv import vllm.envs as envs from vllm.logger import enable_trace_function_call, init_logger from vllm.platforms import current_platform From 20e045093306b00c146a63aeb6cb50e517b3f0ad Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Thu, 14 Nov 2024 07:24:57 +0000 Subject: [PATCH 282/303] bug found: NONE Tensor did not return none --- tests/kv_transfer/test_send_recv.py | 44 +++++++++++++++++++---------- tests/kv_transfer/test_send_recv.sh | 3 ++ 2 files changed, 32 insertions(+), 15 deletions(-) create mode 100644 tests/kv_transfer/test_send_recv.sh diff --git a/tests/kv_transfer/test_send_recv.py b/tests/kv_transfer/test_send_recv.py index d79bec36b6ceb..7b97252e6da47 100644 --- a/tests/kv_transfer/test_send_recv.py +++ b/tests/kv_transfer/test_send_recv.py @@ -45,13 +45,17 @@ def barrier(my_rank): def stress_test(my_rank, pipe): # barrier - barrier(my_rank) + + torch.distributed.barrier() tensors: List[torch.Tensor] = [] + + torch.manual_seed(0) + for i in tqdm(range(500)): - mean = torch.rand(1).item() - std = torch.rand(1).item() + mean = torch.rand(1).item() * 100 + std = torch.rand(1).item() * 100 size = torch.randint(900, 1000, (2, )) x = torch.normal(mean * 1.0, std * 1.0, size=size.tolist()).to(pipe.device) @@ -66,7 +70,10 @@ def stress_test(my_rank, pipe): tensors.append(x.mean().unsqueeze(0)) tensors.append(x.std().unsqueeze(0)) - barrier(my_rank) + + + torch.distributed.barrier() + for i in tqdm(range(500)): @@ -78,13 +85,20 @@ def stress_test(my_rank, pipe): x = pipe.recv_tensor() mean = pipe.recv_tensor() std = pipe.recv_tensor() - if x is None: - assert mean is None - assert std is None - else: - assert torch.allclose(x, tensors[3 * i]) - assert x.mean() == mean[0] - assert x.std() == std[0] + try: + if x is None: + assert mean is None + assert std is None + else: + assert torch.allclose(x, tensors[3 * i]) + assert x.mean() == mean[0] + assert x.std() == std[0] + except Exception as e: + print("Error at iteration", i, "rank", my_rank) + print(x) + raise e + + torch.distributed.barrier() print("Stress test passed.") @@ -132,11 +146,11 @@ def latency_test(my_rank, pipe, nelement, ntensor): torch.distributed.init_process_group( backend='gloo', - init_method='tcp://localhost:12567', - rank=my_rank, + init_method='tcp://localhost:12398', world_size=2, + rank=my_rank, ) - print('done') + config = ParallelConfig( 1, @@ -156,8 +170,8 @@ def latency_test(my_rank, pipe, nelement, ntensor): config=config, ) - torch.manual_seed(0) test_run(my_rank, pipe) + # torch.distributed.barrier() stress_test(my_rank, pipe) # Use this function if you want to test the latency of pipe impl. diff --git a/tests/kv_transfer/test_send_recv.sh b/tests/kv_transfer/test_send_recv.sh new file mode 100644 index 0000000000000..c9335434473ea --- /dev/null +++ b/tests/kv_transfer/test_send_recv.sh @@ -0,0 +1,3 @@ + +RANK=0 python test_send_recv.py & +RANK=1 python test_send_recv.py & \ No newline at end of file From cc9e8f4eebf5c9a63640b90de6020b3a64601f24 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Thu, 14 Nov 2024 07:56:35 +0000 Subject: [PATCH 283/303] save code for Kaichao to debug --- tests/kv_transfer/test_send_recv.py | 10 ++++++++-- .../kv_connector/pynccl_connector/pynccl_pipe.py | 4 ++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/kv_transfer/test_send_recv.py b/tests/kv_transfer/test_send_recv.py index 7b97252e6da47..0195b028c7cb3 100644 --- a/tests/kv_transfer/test_send_recv.py +++ b/tests/kv_transfer/test_send_recv.py @@ -82,8 +82,11 @@ def stress_test(my_rank, pipe): pipe.send_tensor(tensors[3 * i + 1]) pipe.send_tensor(tensors[3 * i + 2]) else: + print('receiving x') x = pipe.recv_tensor() + print('receiving mean') mean = pipe.recv_tensor() + print('receiving std') std = pipe.recv_tensor() try: if x is None: @@ -96,13 +99,16 @@ def stress_test(my_rank, pipe): except Exception as e: print("Error at iteration", i, "rank", my_rank) print(x) + print(x.numel()) + print(x.item() == -150886311) raise e + + if i == 80: + break torch.distributed.barrier() - print("Stress test passed.") - def latency_test(my_rank, pipe, nelement, ntensor): diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py index 3d4c28d88fb9e..580c34487aade 100644 --- a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py +++ b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py @@ -311,6 +311,10 @@ def recv_tensor(self) -> Optional[torch.Tensor]: import traceback traceback.print_exc() raise e + + if tensor.numel() == 1: + print(tensor.item()) + print(tensor.sum()) if tensor.numel() == 1 and tensor.item() == NONE_INT: return None From 3e7e3416506e7ceecf262ba00a4e8b6826ab6117 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Thu, 14 Nov 2024 08:01:32 +0000 Subject: [PATCH 284/303] adjust --- .../kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py index 580c34487aade..7e555d9d2f547 100644 --- a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py +++ b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py @@ -315,6 +315,7 @@ def recv_tensor(self) -> Optional[torch.Tensor]: if tensor.numel() == 1: print(tensor.item()) print(tensor.sum()) + print(tensor) if tensor.numel() == 1 and tensor.item() == NONE_INT: return None From 49e89a2992252c3fb668bf1cfe5d9b130aa853cf Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Fri, 15 Nov 2024 07:26:48 +0000 Subject: [PATCH 285/303] NCCL pipe bug fix: only transmit metadata when the tensor is None --- tests/kv_transfer/test_send_recv.py | 33 ++---- .../pynccl_connector/pynccl_pipe.py | 101 +++++++----------- 2 files changed, 49 insertions(+), 85 deletions(-) diff --git a/tests/kv_transfer/test_send_recv.py b/tests/kv_transfer/test_send_recv.py index 0195b028c7cb3..61ccc0380a53b 100644 --- a/tests/kv_transfer/test_send_recv.py +++ b/tests/kv_transfer/test_send_recv.py @@ -37,15 +37,10 @@ def test_run(my_rank, pipe): assert torch.allclose(x, x2) assert torch.allclose(y, y2) - -def barrier(my_rank): - torch.distributed.barrier() def stress_test(my_rank, pipe): - # barrier - torch.distributed.barrier() tensors: List[torch.Tensor] = [] @@ -82,29 +77,18 @@ def stress_test(my_rank, pipe): pipe.send_tensor(tensors[3 * i + 1]) pipe.send_tensor(tensors[3 * i + 2]) else: - print('receiving x') x = pipe.recv_tensor() - print('receiving mean') mean = pipe.recv_tensor() - print('receiving std') std = pipe.recv_tensor() - try: - if x is None: - assert mean is None - assert std is None - else: - assert torch.allclose(x, tensors[3 * i]) - assert x.mean() == mean[0] - assert x.std() == std[0] - except Exception as e: - print("Error at iteration", i, "rank", my_rank) - print(x) - print(x.numel()) - print(x.item() == -150886311) - raise e - if i == 80: - break + if x is None: + assert mean is None + assert std is None + else: + assert torch.allclose(x, tensors[3 * i]) + assert x.mean() == mean[0] + assert x.std() == std[0] + torch.distributed.barrier() @@ -177,7 +161,6 @@ def latency_test(my_rank, pipe, nelement, ntensor): ) test_run(my_rank, pipe) - # torch.distributed.barrier() stress_test(my_rank, pipe) # Use this function if you want to test the latency of pipe impl. diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py index 7e555d9d2f547..a8a688d742329 100644 --- a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py +++ b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py @@ -9,9 +9,8 @@ """ import threading import time -from collections import deque from concurrent.futures import ThreadPoolExecutor -from typing import Deque, List, Optional, Union +from typing import Dict, List, Optional, Union from copy import deepcopy import torch @@ -69,6 +68,9 @@ class BrokenPipeException(Exception): def __init__(self, message): self.message = message super().__init__(self.message) + + +Metadata = Dict[str, Optional[torch.Tensor]] class PyNcclPipe: @@ -97,10 +99,8 @@ def __init__( rank = self.kv_rank, world_size = self.kv_parallel_size ) - # add a barrier to make sure all ranks are ready + # add a barrier to make sure the connection is initiated properly self.group.barrier() - self.metadata_send_func, self.metadata_recv_func = \ - self._get_metadata_send_recv_impl(self.group) self.device_send_func, self.device_recv_func = \ self._get_device_send_recv_impl(self.group) # set target rank @@ -116,13 +116,6 @@ def __init__( self.none_tensor = torch.tensor([NONE_INT], device=self.device) - # On-device tensors to be reused for recv - self.rcv_metadata_buffer = torch.zeros(self.METADATA_LENGTH, - dtype=self.METADATA_DTYPE, - device="cpu") - - def _get_metadata_send_recv_impl(self, group: StatelessProcessGroup): - return group.send, group.recv def _get_device_send_recv_impl(self, group: StatelessProcessGroup): if self.config.kv_buffer_device == "cuda": @@ -134,7 +127,7 @@ def _get_device_send_recv_impl(self, group: StatelessProcessGroup): comm.disabled = False send, recv = comm.send, comm.recv else: - # use torch c10store for send / recv + # use cpu communication send = group.send recv = group.recv @@ -146,7 +139,7 @@ def _select_device(self): else: return "cpu" - def _make_metadata(self, tensor: torch.Tensor) -> torch.Tensor: + def _make_metadata(self, tensor: Optional[torch.Tensor]) -> Metadata: """Create the metadata on based on the input tensor, and move it to GPU. The metadata's length is `TorchDistributedPipe.METADATA_LENGTH`. @@ -165,19 +158,16 @@ def _make_metadata(self, tensor: torch.Tensor) -> torch.Tensor: Returns: - metadata: the metadata tensor, on self.device """ - buffer = torch.empty(self.METADATA_LENGTH, - dtype=self.METADATA_DTYPE, - device="cpu") - buffer[0] = DTYPE2INT[tensor.dtype] - ndims = len(tensor.shape) - buffer[1] = len(tensor.shape) - buffer[2:2 + ndims] = torch.tensor(tensor.shape, - dtype=self.METADATA_DTYPE, - device="cpu") - return buffer.to(self.device) + if tensor is None: + return {"dtype": None, "shape": None} + else: + return { + "dtype": tensor.dtype, + "shape": tensor.shape + } def _prepare_recv_buffer(self, - d_metadata_buffer: torch.Tensor) -> torch.Tensor: + metadata: Metadata) -> torch.Tensor: """Create a buffer to receive the tensor based on the metadata. Parameters: @@ -186,16 +176,13 @@ def _prepare_recv_buffer(self, Returns: - buffer: the buffer tensor to receive the tensor, on self.device """ - h_buffer = d_metadata_buffer.cpu().numpy() - dtype = INT2DTYPE[h_buffer[0]] - ndims = h_buffer[1] - shape = tuple(h_buffer[2:2 + ndims]) - return torch.empty(shape, dtype=dtype, device=self.device) + return torch.empty(metadata["shape"], + dtype=metadata["dtype"], device=self.device) - def _send_metadata(self, d_metadata_buffer: torch.Tensor): + def _send_metadata(self, metadata): """Send the metadata buffer to the target rank. """ - self.metadata_send_func(d_metadata_buffer, self.target_rank_for_send) + self.group.send_obj(metadata, self.target_rank_for_send) def _recv_metadata(self) -> torch.Tensor: """Receive the metadata buffer from the target rank. @@ -208,14 +195,11 @@ def _recv_metadata(self) -> torch.Tensor: race conditions during sending/receiving. Therefore, the metadata buffer can be reused """ - self.metadata_recv_func( - self.rcv_metadata_buffer, + return self.group.recv_obj( self.target_rank_for_recv ) - return self.rcv_metadata_buffer - - def _send_impl(self, tensor): + def _send_impl(self, tensor: Optional[torch.Tensor]) -> None: """ The actual implementation of sending the tensor to the target rank. This function will first send the metadata, and then send the tensor. @@ -226,9 +210,11 @@ def _send_impl(self, tensor): metadata = self._make_metadata(tensor) self._send_metadata(metadata) - self.device_send_func(tensor.to(self.device), self.target_rank_for_send) + if tensor is not None: + self.device_send_func(tensor.to(self.device), + self.target_rank_for_send) - def _recv_impl(self) -> torch.Tensor: + def _recv_impl(self) -> Optional[torch.Tensor]: """ The actual implementation of receiving the tensor from the target rank. This function will first receive the metadata, then receive the tensor. @@ -239,18 +225,24 @@ def _recv_impl(self) -> torch.Tensor: - buffer: the received tensor, on self.device """ print('recv_metadata...') - d_metadata = self._recv_metadata() - print('recv metadata done, receiving tensor ...') - buffer = self._prepare_recv_buffer(d_metadata) + metadata = self._recv_metadata() + print('recv metadata done') + if metadata["dtype"] is None: + return None + print('receiving tensor ...') + buffer = self._prepare_recv_buffer(metadata) self.device_recv_func(buffer, self.target_rank_for_recv) print('recv tensor done.') return buffer - def send_tensor_wrapper(self, tensor): + def send_tensor_wrapper( + self, + tensor: Optional[torch.Tensor], + tensor_size: int + ) -> None: try: """Wrapper for send_tensor_dict""" - tensor_size = tensor.element_size() * tensor.numel() self._send_impl(tensor) with self.buffer_size_lock: @@ -277,13 +269,9 @@ def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: self.transport_thread = ThreadPoolExecutor(max_workers=1) if tensor is None: - tensor = self.none_tensor - - tensor_size = tensor.element_size() * tensor.numel() - - assert ( - 0 < len(tensor.shape) < self.MAX_TENSOR_DIMENSIONS - ), f"Only support dimensions within 1-{self.MAX_TENSOR_DIMENSIONS}" + tensor_size = 0 + else: + tensor_size = tensor.element_size() * tensor.numel() self.block_if_full() @@ -293,6 +281,7 @@ def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: self.transport_thread.submit( self.send_tensor_wrapper, tensor, + tensor_size, ) def recv_tensor(self) -> Optional[torch.Tensor]: @@ -312,15 +301,7 @@ def recv_tensor(self) -> Optional[torch.Tensor]: traceback.print_exc() raise e - if tensor.numel() == 1: - print(tensor.item()) - print(tensor.sum()) - print(tensor) - - if tensor.numel() == 1 and tensor.item() == NONE_INT: - return None - else: - return tensor + return tensor def close(self): """Close the pipe and release the resources.""" From b6e83a26aa6cdab9eeb1919b5dbc8c44e4792cdb Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Fri, 15 Nov 2024 07:35:37 +0000 Subject: [PATCH 286/303] Update docstring using GPT, and clean up unnecessary variables --- .../pynccl_connector/pynccl_pipe.py | 240 +++++++----------- 1 file changed, 93 insertions(+), 147 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py index a8a688d742329..21c7c69b2afea 100644 --- a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py +++ b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py @@ -1,12 +1,14 @@ """ - This file implements a simple torch distributed connector by 3 classes: - - `TorchDistributedPipe`: a tensor transmission pipe between vllm instances, - using `torch.distributed` - - `TorchDistributedBuffer`: a buffer to store tensors, implemented on top - of `TorchDistributedPipe` - - `TorchDistributedConnector`: a torch distributed connector between P/D - instance, implemented on top of `TorchDistributedBuffer` + This file implements a simple PyNccl pipe that can send and receive + Optional[torch.Tensor] between two ranks. + + We will first transmit the metadata, and then the tensor. + Metadata format: + Metadata = Dict[str, Optional[torch.Tensor]] + - "dtype": The data type of the tensor (tensor.dtype) or None + - "shape": The shape of the tensor (tensor.shape) or None """ + import threading import time from concurrent.futures import ThreadPoolExecutor @@ -21,109 +23,57 @@ from vllm.logger import init_logger from vllm.config import ParallelConfig - - logger = init_logger(__name__) -# if the tensor is only one-element and only contains NONE_INT -# this means that the sended object is None. -NONE_INT = -150886311 - -# Mapping tensor dtype to INT64, used for tensor metadata transmission -FLOAT16_INT = -543205003776624 -INT64_INT = -375623078607432 -BOOL_INT = -28035262008646 -BFLOAT16_INT = -452084912267662 -FLOAT32_INT = -1049557997456592 -FLOAT64_INT = -452201007054137 -FLOAT8_E4M3FN_INT = -1066697177659525 -FLOAT8_E5M2_INT = -618182574682355 - -DTYPE2INT = { - torch.float16: FLOAT16_INT, - torch.int64: INT64_INT, - torch.bool: BOOL_INT, - torch.bfloat16: BFLOAT16_INT, - torch.float32: FLOAT32_INT, - torch.float64: FLOAT64_INT, - torch.float8_e4m3fn: FLOAT8_E4M3FN_INT, - torch.float8_e5m2: FLOAT8_E5M2_INT, -} - -INT2DTYPE = { - FLOAT16_INT: torch.float16, - INT64_INT: torch.int64, - BOOL_INT: torch.bool, - BFLOAT16_INT: torch.bfloat16, - FLOAT32_INT: torch.float32, - FLOAT64_INT: torch.float64, - FLOAT8_E4M3FN_INT: torch.float8_e4m3fn, - FLOAT8_E5M2_INT: torch.float8_e5m2, -} - - class BrokenPipeException(Exception): - def __init__(self, message): self.message = message super().__init__(self.message) - - + + Metadata = Dict[str, Optional[torch.Tensor]] class PyNcclPipe: - + METADATA_LENGTH = 16 MAX_TENSOR_DIMENSIONS = 14 METADATA_DTYPE = torch.int64 - def __init__( - self, - local_rank: int, - config: ParallelConfig - ): + def __init__(self, local_rank: int, config: ParallelConfig): self.config = config - self.local_rank = local_rank self.kv_rank = self.config.kv_rank self.kv_parallel_size = self.config.kv_parallel_size self.device = self._select_device() - # build distributed connection and send/recv implementation self.group = StatelessProcessGroup.create( - host = self.config.kv_ip, - port = self.config.kv_port, - rank = self.kv_rank, - world_size = self.kv_parallel_size + host=self.config.kv_ip, + port=self.config.kv_port, + rank=self.kv_rank, + world_size=self.kv_parallel_size, ) # add a barrier to make sure the connection is initiated properly self.group.barrier() - self.device_send_func, self.device_recv_func = \ - self._get_device_send_recv_impl(self.group) + impl = self._get_device_send_recv_impl(self.group) + self.device_send_func, self.device_recv_func = impl # set target rank - self.target_rank_for_send = (self.kv_rank+ 1) % self.kv_parallel_size + self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size - # transportation-related variables self.transport_thread: Optional[ThreadPoolExecutor] = None self.buffer_size = 0 self.buffer_size_lock = threading.Lock() self.buffer_size_thresh = self.config.kv_buffer_size - self.none_tensor = torch.tensor([NONE_INT], device=self.device) - def _get_device_send_recv_impl(self, group: StatelessProcessGroup): if self.config.kv_buffer_device == "cuda": # use PyNCCL for send / recv - comm = PyNcclCommunicator( - group, - device=self.local_rank, - ) + comm = PyNcclCommunicator(group, device=self.local_rank) comm.disabled = False send, recv = comm.send, comm.recv else: @@ -140,113 +90,100 @@ def _select_device(self): return "cpu" def _make_metadata(self, tensor: Optional[torch.Tensor]) -> Metadata: - """Create the metadata on based on the input tensor, and move it to GPU. - The metadata's length is `TorchDistributedPipe.METADATA_LENGTH`. - - Currently, the metadata is a int64 tensor and it includes dtype, number - of dimensions, and the shape information of the input tensor. - - - The information follows the layout below: - - metadata[0] -- dtype - - metadata[1] -- number of dimensions - - metadata[2 : 2+ndims] -- the shape of the input tensor + """ + Create the metadata as a dictionary based on the input tensor. Parameters: - - tensor: the input tensor + - tensor: The input tensor or None if no tensor is provided. Returns: - - metadata: the metadata tensor, on self.device + - metadata: A dictionary with the following keys: + - "dtype": The data type of the tensor or None. + - "shape": The shape of the tensor or None. """ if tensor is None: return {"dtype": None, "shape": None} else: - return { - "dtype": tensor.dtype, - "shape": tensor.shape - } + return {"dtype": tensor.dtype, "shape": tensor.shape} - def _prepare_recv_buffer(self, - metadata: Metadata) -> torch.Tensor: - """Create a buffer to receive the tensor based on the metadata. + def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor: + """ + Create a buffer to receive the tensor based on the provided metadata. Parameters: - - d_metadata_buffer: the metadata tensor on self.device + - metadata: A dictionary with keys "dtype" and "shape", describing + the tensor's data type and shape. Returns: - - buffer: the buffer tensor to receive the tensor, on self.device + - buffer: A tensor of the specified type and shape, allocated on + self.device. + """ + return torch.empty( + metadata["shape"], + dtype=metadata["dtype"], + device=self.device) + + def _send_metadata(self, metadata: Metadata): """ - return torch.empty(metadata["shape"], - dtype=metadata["dtype"], device=self.device) + Send the metadata dictionary to the target rank. - def _send_metadata(self, metadata): - """Send the metadata buffer to the target rank. + Parameters: + - metadata: A dictionary with keys "dtype" and "shape". """ self.group.send_obj(metadata, self.target_rank_for_send) - def _recv_metadata(self) -> torch.Tensor: - """Receive the metadata buffer from the target rank. + def _recv_metadata(self) -> Metadata: + """ + Receive the metadata dictionary from the target rank. Returns: - - metadata_buffer: the metadata buffer tensor, on self.device - - Note: - The current implementation uses the assumption that there is no - race conditions during sending/receiving. Therefore, the metadata - buffer can be reused + - metadata: A dictionary with keys "dtype" and "shape" describing + the tensor. """ - return self.group.recv_obj( - self.target_rank_for_recv - ) + return self.group.recv_obj(self.target_rank_for_recv) def _send_impl(self, tensor: Optional[torch.Tensor]) -> None: """ - The actual implementation of sending the tensor to the target rank. - This function will first send the metadata, and then send the tensor. + The actual implementation of sending the tensor and its metadata to the + target rank. Parameters: - - tensor: the input tensor to be sent + - tensor: The input tensor to be sent, or None if no tensor is + being sent. """ - metadata = self._make_metadata(tensor) self._send_metadata(metadata) if tensor is not None: - self.device_send_func(tensor.to(self.device), - self.target_rank_for_send) + self.device_send_func(tensor.to(self.device), self.target_rank_for_send) def _recv_impl(self) -> Optional[torch.Tensor]: """ - The actual implementation of receiving the tensor from the target rank. - This function will first receive the metadata, then receive the tensor. - - This function will block if there is no tensor to receive. + The actual implementation of receiving a tensor and its metadata from + the target rank. Returns: - - buffer: the received tensor, on self.device + - buffer: The received tensor, or None if no tensor is received. """ - print('recv_metadata...') metadata = self._recv_metadata() - print('recv metadata done') if metadata["dtype"] is None: return None - print('receiving tensor ...') buffer = self._prepare_recv_buffer(metadata) self.device_recv_func(buffer, self.target_rank_for_recv) - print('recv tensor done.') return buffer def send_tensor_wrapper( self, - tensor: Optional[torch.Tensor], - tensor_size: int - ) -> None: + tensor: Optional[torch.Tensor], + tensor_size: int) -> None: + """ + Wrapper for _send_impl to handle exceptions and update buffer size. + """ try: - """Wrapper for send_tensor_dict""" self._send_impl(tensor) with self.buffer_size_lock: - self.buffer_size = self.buffer_size - tensor_size + self.buffer_size -= tensor_size except Exception as e: logger.error("[rank%d]: Exception when trying to send %s, msg: %s", torch.distributed.get_rank(), str(tensor), str(e)) @@ -254,38 +191,48 @@ def send_tensor_wrapper( traceback.print_exc() def block_if_full(self): - """Block the current thread if the buffer size is larger than 1e9.""" - # TODO: replace this 1e9 with a configurable parameter or a constant + """ + Block the current thread if the buffer size is larger than the + threshold. + """ while self.buffer_size > self.buffer_size_thresh: logger.debug("KV cache transfer pipe is full. Waiting...") time.sleep(0.05) def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: - """Sends a tensor to the destination rank in a non-blocking way. - Flow: send tensor dim -- send tensor shape -- send tensor data """ + Sends a tensor and its metadata to the destination rank in a + non-blocking way. + Parameters: + - tensor: The tensor to send, or None if no tensor is being sent. + """ if self.transport_thread is None: self.transport_thread = ThreadPoolExecutor(max_workers=1) - - if tensor is None: - tensor_size = 0 - else: + + if tensor is not None: tensor_size = tensor.element_size() * tensor.numel() + else: + tensor_size = 0 self.block_if_full() with self.buffer_size_lock: - self.buffer_size = self.buffer_size + tensor_size + self.buffer_size += tensor_size self.transport_thread.submit( - self.send_tensor_wrapper, - tensor, - tensor_size, + self.send_tensor_wrapper, + tensor, + tensor_size ) def recv_tensor(self) -> Optional[torch.Tensor]: - """Receives a tensor from the src rank. Blocking.""" + """ + Receives a tensor and its metadata from the source rank. Blocking call. + + Returns: + - tensor: The received tensor, or None if no tensor is received. + """ if self.transport_thread is None: self.transport_thread = ThreadPoolExecutor(max_workers=1) @@ -294,18 +241,17 @@ def recv_tensor(self) -> Optional[torch.Tensor]: try: tensor = future.result() except Exception as e: - # the underlying pipe is likely broken logger.error("Encountering exception in KV receiving thread") logger.error("%s", e) import traceback traceback.print_exc() raise e - + return tensor def close(self): - """Close the pipe and release the resources.""" - if (hasattr(self, "transport_thread") - and self.transport_thread is not None): + """ + Close the pipe and release associated resources. + """ + if hasattr(self, "transport_thread") and self.transport_thread is not None: self.transport_thread.shutdown() - From 8d20116be110cca0e2b4960e7d15c039122c3f8a Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Fri, 15 Nov 2024 08:27:46 +0000 Subject: [PATCH 287/303] Bug fix on PyNcclPipe: the device of sending tensor should be inferred by class variable instead of the config --- .../pynccl_connector/pynccl_pipe.py | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py index 21c7c69b2afea..ed30bef955d99 100644 --- a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py +++ b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py @@ -41,17 +41,24 @@ class PyNcclPipe: MAX_TENSOR_DIMENSIONS = 14 METADATA_DTYPE = torch.int64 - def __init__(self, local_rank: int, config: ParallelConfig): + def __init__(self, + local_rank: int, + config: ParallelConfig, + device: Optional[str] = None, + port_offset: int = 0): self.config = config self.local_rank = local_rank self.kv_rank = self.config.kv_rank self.kv_parallel_size = self.config.kv_parallel_size - self.device = self._select_device() + if device is None: + self.device = self._select_device(self.config.kv_buffer_device) + else: + self.device = self._select_device(device) # build distributed connection and send/recv implementation self.group = StatelessProcessGroup.create( host=self.config.kv_ip, - port=self.config.kv_port, + port=self.config.kv_port + port_offset, rank=self.kv_rank, world_size=self.kv_parallel_size, ) @@ -71,7 +78,7 @@ def __init__(self, local_rank: int, config: ParallelConfig): def _get_device_send_recv_impl(self, group: StatelessProcessGroup): - if self.config.kv_buffer_device == "cuda": + if self.device.type == "cuda": # use PyNCCL for send / recv comm = PyNcclCommunicator(group, device=self.local_rank) comm.disabled = False @@ -83,11 +90,11 @@ def _get_device_send_recv_impl(self, group: StatelessProcessGroup): return send, recv - def _select_device(self): - if self.config.kv_buffer_device == "cuda": + def _select_device(self, device: str): + if device == "cuda": return torch.device(f"cuda:{self.local_rank}") else: - return "cpu" + return torch.device("cpu") def _make_metadata(self, tensor: Optional[torch.Tensor]) -> Metadata: """ @@ -243,6 +250,7 @@ def recv_tensor(self) -> Optional[torch.Tensor]: except Exception as e: logger.error("Encountering exception in KV receiving thread") logger.error("%s", e) + logger.error("My device: %s", self.device) import traceback traceback.print_exc() raise e From fdc4aad16343bda73d7f1bbe5e861eadc1a45037 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Fri, 15 Nov 2024 08:27:54 +0000 Subject: [PATCH 288/303] Fix lookup buffer --- tests/kv_transfer/test_lookup_buffer.py | 66 ++++++++++++++----- tests/kv_transfer/test_lookup_buffer.sh | 3 + .../pynccl_connector/lookup_buffer.py | 21 +++--- 3 files changed, 66 insertions(+), 24 deletions(-) create mode 100644 tests/kv_transfer/test_lookup_buffer.sh diff --git a/tests/kv_transfer/test_lookup_buffer.py b/tests/kv_transfer/test_lookup_buffer.py index 0730f091a34b8..32c3d90dfa92d 100644 --- a/tests/kv_transfer/test_lookup_buffer.py +++ b/tests/kv_transfer/test_lookup_buffer.py @@ -4,8 +4,11 @@ import torch from tqdm import tqdm -import vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer as sklb -import vllm.distributed.kv_transfer.kv_pipe.torch_distributed_pipe as tdp +import vllm.distributed.kv_transfer.kv_connector.pynccl_connector.pynccl_pipe \ + as pnp +import vllm.distributed.kv_transfer.kv_connector.pynccl_connector.lookup_buffer\ + as lb +from vllm.config import ParallelConfig # TODO: the test depends on a lot of fields in the current implementation. # We should have standard interface instead direct field access @@ -17,6 +20,8 @@ def test_run(my_rank, buffer, device): if my_rank == 0: assert buffer.buffer_size == 0 assert len(buffer.buffer) == 0 + + print("My rank: %d, device: %s" % (my_rank, device)) # insert tokens = torch.tensor([1, 2, 3]).to(device) @@ -108,24 +113,53 @@ def stress_test(my_rank, buf, device): if __name__ == "__main__": - + + my_rank = int(os.environ['RANK']) - torch.distributed.init_process_group(init_method="tcp://127.0.0.1:23456", - world_size=2, - rank=my_rank) - + torch.distributed.init_process_group( + backend='gloo', + init_method='tcp://localhost:12398', + world_size=2, + rank=my_rank, + ) + print("initialized! My rank is %d" % my_rank) - - pipe = tdp.TorchDistributedPipe([[0, 1]], my_rank, "nccl") - cpu_pipe = tdp.TorchDistributedPipe([[0, 1]], my_rank, "gloo") - buffer = sklb.SimpleKVLookupBuffer(cpu_pipe, pipe, 170000) - - test_run(my_rank, buffer, pipe.device) - - stress_test(my_rank, buffer, pipe.device) + + + config = ParallelConfig( + 1, + 1, + kv_connector='PyNcclConnector', + kv_buffer_device='cuda', + kv_buffer_size=1e9, + kv_rank=my_rank, + kv_role="kv_both", # this arg doesn't matter in this test + kv_parallel_size=2, + kv_ip="127.0.0.1", + kv_port=12345, + ) + + data_pipe = pnp.PyNcclPipe( + local_rank=my_rank, + config=config, + device="cuda", + port_offset=0, + ) + cpu_pipe = pnp.PyNcclPipe( + local_rank=my_rank, + config=config, + device="cpu", + port_offset=1, + ) + + buffer = lb.LookupBuffer(cpu_pipe, data_pipe, 170000) + + test_run(my_rank, buffer, data_pipe.device) + + stress_test(my_rank, buffer, data_pipe.device) buffer.close() - pipe.close() + data_pipe.close() cpu_pipe.close() print('Done') diff --git a/tests/kv_transfer/test_lookup_buffer.sh b/tests/kv_transfer/test_lookup_buffer.sh new file mode 100644 index 0000000000000..eec2a9fb84797 --- /dev/null +++ b/tests/kv_transfer/test_lookup_buffer.sh @@ -0,0 +1,3 @@ + +RANK=0 python test_lookup_buffer.py & +RANK=1 python test_lookup_buffer.py & \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/lookup_buffer.py b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/lookup_buffer.py index 965d6552d722c..2a82132d9d80a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/lookup_buffer.py +++ b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/lookup_buffer.py @@ -17,11 +17,9 @@ import torch from torch.distributed import Backend -from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase -from vllm.distributed.kv_transfer.kv_connector.pynccl_connector.pynccl_pipe \ - import PyncclPipe +import vllm.distributed.kv_transfer.kv_connector.pynccl_connector.pynccl_pipe \ + as pnp from vllm.logger import init_logger -from vllm.config import KVTransferConfig @@ -31,8 +29,8 @@ class LookupBuffer: def __init__(self, - signal_pipe: PyncclPipe, - data_pipe: PyncclPipe, + signal_pipe: pnp.PyNcclPipe, + data_pipe: pnp.PyNcclPipe, buffer_size_thresh: float): """ signal_pipe: on CPU @@ -54,7 +52,7 @@ def __init__(self, self.data_pipe = data_pipe self.request_handling_thread: Optional[threading.Thread] = None - self.normal_signal = torch.tensor([0]) + self.normal_signal = torch.tensor([0], device="cpu") self.end_signal = None def _matches(self, tokens_roi_sender: List[torch.Tensor], @@ -91,6 +89,8 @@ def _send_tensor_and_dec_size(self, assert tensor is not None, "Use self.data_pipe.send(None) instead" self.buffer_size -= tensor.element_size() * tensor.numel() + if tensor.dtype == torch.bool: + tensor = tensor.float() self.data_pipe.send_tensor(tensor) def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]): @@ -142,6 +142,7 @@ def drop_select_handler(self): input_tokens = self.data_pipe.recv_tensor() roi = self.data_pipe.recv_tensor() + roi = (roi > 0.5) tokens_roi_recver = [input_tokens, roi] matched_length = 0 @@ -195,10 +196,14 @@ def drop_select( self.signal_pipe.send_tensor(self.normal_signal) self.data_pipe.send_tensor(input_tokens) - self.data_pipe.send_tensor(roi) + self.data_pipe.send_tensor(roi.float()) input_tokens = self.data_pipe.recv_tensor() roi = self.data_pipe.recv_tensor() + if roi is not None: + # convert from float tensor to bool tensor + # as PyNccl does not support sending bool tensor + roi = (roi > 0.5) key = self.data_pipe.recv_tensor() value = self.data_pipe.recv_tensor() hidden = self.data_pipe.recv_tensor() From e7432e9d966dfedb06325372d6818cc314ffa7a5 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Fri, 15 Nov 2024 21:23:07 +0000 Subject: [PATCH 289/303] Adjust init parameters for connector. --- .../kv_transfer/kv_connector/__init__.py | 13 ++- .../pynccl_connector/pynccl_connector.py | 98 +++++++------------ vllm/distributed/kv_transfer/vllm_adapter.py | 14 --- 3 files changed, 40 insertions(+), 85 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/__init__.py b/vllm/distributed/kv_transfer/kv_connector/__init__.py index fdbaa76cefeeb..b5121f0240af3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/__init__.py +++ b/vllm/distributed/kv_transfer/kv_connector/__init__.py @@ -5,13 +5,12 @@ class KVConnectorFactory: @staticmethod def create_connector( + local_rank: int, config ) -> KVConnectorBase: - if config.kv_connector == 'LMCacheConnector': - from .lmcache_connector import LMCacheConnector - return LMCacheConnector(config) - elif config.kv_connector == 'TorchDistributedConnector': - from .torch_distributed_connector import TorchDistributedConnector - return TorchDistributedConnector(config) + if config.kv_connector == 'PyNcclConnector': + from . import PyNcclConnector + return PyNcclConnector(local_rank, config) else: - raise ValueError(f"Unsupported connector type: {connector_type}") \ No newline at end of file + raise ValueError(f"Unsupported connector type: " + f"{config.kv_connector}") \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_connector.py index 353270b02de86..3a48d3caf7327 100644 --- a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_connector.py @@ -19,7 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase from vllm.distributed.kv_transfer.kv_connector.pynccl_connector.pynccl_pipe \ - import PyncclPipe + import PyNcclPipe from vllm.distributed.kv_transfer.kv_connector.pynccl_connector.lookup_buffer \ import LookupBuffer from vllm.logger import init_logger @@ -32,93 +32,63 @@ -class TorchDistributedConnector(KVConnectorBase): +class PyNcclConnector(KVConnectorBase): def __init__( self, - group_ranks: List[List[int]], local_rank: int, config: KVTransferConfig, ): self.lookup_buffer_size = self.kv_buffer_size - self.send_buffer: Optional[TorchDistributedBuffer] = None - self.recv_buffer: Optional[TorchDistributedBuffer] = None - - device2backend = { - "cpu": "gloo", - "gpu": "nccl", - } + self.send_buffer: Optional[LookupBuffer] = None + self.recv_buffer: Optional[LookupBuffer] = None + + # 2 pipes for every rank in the world + port_offset_base = 2 * config.rank + # In disaggregated prefill, the prefill vLLM only uses send pipe # and the decode vLLM only uses recv pipe - # In remote KV cache store, vLLM will use both send pipe and recv pipe - # So we build both send pipe and recv pipe for simplicity. if config.is_kv_producer: - self.send_pipe = TorchDistributedPipe( - group_ranks, - local_rank, - device2backend[config.kv_device], - self.kv_buffer_size, - ) - self.send_signal_pipe = TorchDistributedPipe( - group_ranks, - local_rank, - "gloo", - self.kv_buffer_size, - ) - self.recv_pipe = TorchDistributedPipe( - group_ranks, - local_rank, - device2backend[config.kv_device], - self.kv_buffer_size, + self.send_data_pipe = PyNcclPipe( + local_rank=local_rank, + config=config, + port_offset=port_offset_base, ) - self.recv_signal_pipe = TorchDistributedPipe( - group_ranks, - local_rank, - "gloo", - self.kv_buffer_size + self.send_signal_pipe = PyNcclPipe( + local_rank=local_rank, + config=config, + port_offset=port_offset_base + 1, + device="cpu", ) + self.send_buffer = LookupBuffer( + self.send_signal_pipe, + self.send_data_pipe, + config.kv_buffer_size) else: # the current vLLM instance is KV consumer, so it needs to connect # its recv pipe to the send pipe of KV producder - - self.recv_pipe = TorchDistributedPipe( - group_ranks, - local_rank, - device2backend[config.kv_device], - self.kv_buffer_size, - ) - self.recv_signal_pipe = TorchDistributedPipe( - group_ranks, - local_rank, - "gloo", - self.kv_buffer_size, + self.recv_data_pipe = PyNcclPipe( + local_rank=local_rank, + config=config, + port_offset=port_offset_base, ) - self.send_pipe = TorchDistributedPipe( - group_ranks, - local_rank, - device2backend[config.kv_device], - self.kv_buffer_size, + self.recv_signal_pipe = PyNcclPipe( + local_rank=local_rank, + config=config, + port_offset=port_offset_base + 1, + device="cpu", ) - self.send_signal_pipe = TorchDistributedPipe( - group_ranks, - local_rank, - "gloo", - self.kv_buffer_size + self.recv_buffer = LookupBuffer( + self.recv_signal_pipe, + self.recv_data_pipe, + config.kv_buffer_size ) - - self.send_buffer = TorchDistributedBuffer(self.send_signal_pipe, - self.send_pipe, - self.lookup_buffer_size) - self.recv_buffer = TorchDistributedBuffer(self.recv_signal_pipe, - self.recv_pipe, - self.lookup_buffer_size) - self.tensor_device = config.kv_device def select( diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py index e2d77fe920362..992033598e095 100644 --- a/vllm/distributed/kv_transfer/vllm_adapter.py +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -37,19 +37,6 @@ logger = init_logger(__name__) -# several flags used for indicating the role of current vLLM worker -IS_DISTRIBUTED_KV_INSTANCE: Optional[bool] = None -IS_KV_PRODUCER: Optional[bool] = None -IS_KV_CONSUMER: Optional[bool] = None - - -def set_kv_transfer_attribute(config): - global IS_DISTRIBUTED_KV_INSTANCE, IS_KV_PRODUCER, IS_KV_CONSUMER - - IS_DISTRIBUTED_KV_INSTANCE = config.is_distributed_kv_instance - IS_KV_PRODUCER = config.is_kv_producer - IS_KV_CONSUMER = config.is_kv_consumer - class KV_transfer_agent: """ @@ -62,7 +49,6 @@ class KV_transfer_agent: def __init__( self, - group_ranks: List[List[int]], local_rank: int, config, ): From d1ce09d3bf48348141551e5d8361e8ac2e2204c0 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 17 Nov 2024 03:45:39 +0000 Subject: [PATCH 290/303] Move KVTransferConfig outside ParallelConfig --- .../disagg_prefill_example.sh | 13 +- tests/kv_transfer/test_lookup_buffer.py | 6 +- tests/kv_transfer/test_send_recv.py | 6 +- vllm/config.py | 132 +++++++----------- .../pynccl_connector/pynccl_pipe.py | 4 +- .../{vllm_adapter.py => kv_transfer_agent.py} | 0 vllm/distributed/parallel_state.py | 130 +++++------------ vllm/engine/arg_utils.py | 75 +++++++--- vllm/executor/gpu_executor.py | 3 +- vllm/worker/worker.py | 6 +- vllm/worker/worker_base.py | 3 +- 11 files changed, 161 insertions(+), 217 deletions(-) rename examples/{distributed_kv => kv_transfer}/disagg_prefill_example.sh (89%) rename vllm/distributed/kv_transfer/{vllm_adapter.py => kv_transfer_agent.py} (100%) diff --git a/examples/distributed_kv/disagg_prefill_example.sh b/examples/kv_transfer/disagg_prefill_example.sh similarity index 89% rename from examples/distributed_kv/disagg_prefill_example.sh rename to examples/kv_transfer/disagg_prefill_example.sh index efec87855dbee..b19f00227d77c 100644 --- a/examples/distributed_kv/disagg_prefill_example.sh +++ b/examples/kv_transfer/disagg_prefill_example.sh @@ -4,7 +4,6 @@ # and then transfer the KV cache between them. export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') -export VLLM_PORT=12345 # install quart first -- required for disagg prefill proxy serve if python3 -c "import quart" &> /dev/null; then @@ -24,20 +23,24 @@ wait_for_server() { } # prefilling instance, which is the KV producer -VLLM_DISTRIBUTED_KV_ROLE=producer CUDA_VISIBLE_DEVICES=0 python3 \ +CUDA_VISIBLE_DEVICES=0 python3 \ -m vllm.entrypoints.openai.api_server \ --model meta-llama/Meta-Llama-3.1-8B-Instruct \ --port 8100 \ --max-model-len 10000 \ - --gpu-memory-utilization 0.8 & + --gpu-memory-utilization 0.8 \ + --kv-connector PyNcclConnector \ + --kv-role kv_producer & # decoding instance, which is the KV consumer -VLLM_DISTRIBUTED_KV_ROLE=consumer CUDA_VISIBLE_DEVICES=1 python3 \ +CUDA_VISIBLE_DEVICES=1 python3 \ -m vllm.entrypoints.openai.api_server \ --model meta-llama/Meta-Llama-3.1-8B-Instruct \ --port 8200 \ --max-model-len 10000 \ - --gpu-memory-utilization 0.8 & + --gpu-memory-utilization 0.8 \ + --kv-connector PyNcclConnector \ + --kv-role kv_consumer & # wait until prefill and decode instances are ready wait_for_server 8100 diff --git a/tests/kv_transfer/test_lookup_buffer.py b/tests/kv_transfer/test_lookup_buffer.py index 32c3d90dfa92d..d3552aca22c4a 100644 --- a/tests/kv_transfer/test_lookup_buffer.py +++ b/tests/kv_transfer/test_lookup_buffer.py @@ -8,7 +8,7 @@ as pnp import vllm.distributed.kv_transfer.kv_connector.pynccl_connector.lookup_buffer\ as lb -from vllm.config import ParallelConfig +from vllm.config import KVTransferConfig # TODO: the test depends on a lot of fields in the current implementation. # We should have standard interface instead direct field access @@ -127,9 +127,7 @@ def stress_test(my_rank, buf, device): print("initialized! My rank is %d" % my_rank) - config = ParallelConfig( - 1, - 1, + config = KVTransferConfig( kv_connector='PyNcclConnector', kv_buffer_device='cuda', kv_buffer_size=1e9, diff --git a/tests/kv_transfer/test_send_recv.py b/tests/kv_transfer/test_send_recv.py index 61ccc0380a53b..239cba19eba51 100644 --- a/tests/kv_transfer/test_send_recv.py +++ b/tests/kv_transfer/test_send_recv.py @@ -7,7 +7,7 @@ import vllm.distributed.kv_transfer.kv_connector.pynccl_connector.pynccl_pipe \ as pnp -from vllm.config import ParallelConfig +from vllm.config import KVTransferConfig def test_run(my_rank, pipe): @@ -142,9 +142,7 @@ def latency_test(my_rank, pipe, nelement, ntensor): ) - config = ParallelConfig( - 1, - 1, + config = KVTransferConfig( kv_connector='PyNcclConnector', kv_buffer_device='cuda', kv_buffer_size=1e9, diff --git a/vllm/config.py b/vllm/config.py index 057a64444f12c..0ec8dfab29fef 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -919,7 +919,6 @@ class ParallelConfig: """Configuration for the distributed execution. Args: - kv_disagg_parallel_size: Number of kv disagg groups. pipeline_parallel_size: Number of pipeline parallel groups. tensor_parallel_size: Number of tensor parallel groups. worker_use_ray: Deprecated, use distributed_executor_backend instead. @@ -937,13 +936,6 @@ class ParallelConfig: workers, either "ray" or "mp" (multiprocessing). If either pipeline_parallel_size or tensor_parallel_size is greater than 1, will default to "ray" if Ray is installed or "mp" otherwise. - kv_connector: The connector to use for kv cache transfer, value can be - None, "TorchDistributedConnector" or "LMCacheConnector". - kv_buffer_device: The buffer device to use for kv cache transfer. - kv_buffer_size: The buffer size to use for kv cache transfer. - kv_disagg_role: The role of the kv disagg worker, can be "kv_producer", - "kv_consumer", "kv_both" or None. - kv_disagg_rank: The rank of the kv disagg worker. """ def __init__( @@ -958,14 +950,6 @@ def __init__( placement_group: Optional["PlacementGroup"] = None, distributed_executor_backend: Optional[Union[ str, Type["ExecutorBase"]]] = None, - kv_connector: Optional[str] = None, - kv_buffer_device: Optional[str] = None, - kv_buffer_size: Optional[float] = None, - kv_role: Optional[str] = None, - kv_rank: Optional[int] = None, - kv_parallel_size: Optional[int] = None, - kv_ip: Optional[str] = None, - kv_port: Optional[str] = None, ) -> None: self.pipeline_parallel_size = pipeline_parallel_size self.tensor_parallel_size = tensor_parallel_size @@ -975,15 +959,7 @@ def __init__( self.tokenizer_pool_config = tokenizer_pool_config self.ray_workers_use_nsight = ray_workers_use_nsight self.placement_group = placement_group - self.world_size = pipeline_parallel_size * tensor_parallel_size - self.kv_connector = kv_connector - self.kv_buffer_device = kv_buffer_device - self.kv_buffer_size = kv_buffer_size - self.kv_role = kv_role - self.kv_rank = kv_rank - self.kv_parallel_size = kv_parallel_size - self.kv_ip = kv_ip - self.kv_port = kv_port + self.world_size = pipeline_parallel_size * self.tensor_parallel_size if worker_use_ray: if self.distributed_executor_backend is None: @@ -1000,13 +976,6 @@ def __init__( raise ValueError( "TPU backend only supports Ray for distributed inference.") - if current_platform.is_hpu() and self.world_size > 1: - if self.distributed_executor_backend is None: - self.distributed_executor_backend = "ray" - if self.distributed_executor_backend != "ray": - raise ValueError( - "HPU backend only supports Ray for distributed inference.") - if self.distributed_executor_backend is None and self.world_size > 1: # We use multiprocessing by default if world_size fits on the # current node and we aren't in a ray placement group. @@ -1044,19 +1013,6 @@ def use_ray(self) -> bool: isinstance(self.distributed_executor_backend, type) and self.distributed_executor_backend.uses_ray) - @property - def is_distributed_kv_instance(self) -> bool: - return self.kv_transfer_role in ["kv_producer", "kv_consumer", "kv_both"] - - @property - def is_kv_producer(self) -> bool: - return self.kv_transfer_role in ["kv_producer", "kv_both"] - - @property - def is_kv_consumer(self) -> bool: - return self.kv_transfer_role in ["kv_consumer", "kv_both"] - - def _verify_args(self) -> None: # Lazy import to avoid circular import from vllm.executor.executor_base import ExecutorBase @@ -1072,7 +1028,7 @@ def _verify_args(self) -> None: if self.use_ray: from vllm.executor import ray_utils ray_utils.assert_ray_available() - if current_platform.is_rocm(): + if is_hip(): self.disable_custom_all_reduce = True logger.info( "Disabled the custom all-reduce kernel because it is not " @@ -1081,38 +1037,6 @@ def _verify_args(self) -> None: raise ValueError("Unable to use nsight profiling unless workers " "run with Ray.") - # A series of checks for P/D disaggregation (and future disaggregation) - if self.kv_connector is None and not all([ - self.kv_disagg_parallel_size == 1, - self.kv_disagg_rank == 0, - self.kv_buffer_size is None, - self.kv_disagg_role is None, - self.kv_buffer_device is None, - ]): - raise ValueError("Please specify kv_connector before configuring " - "variables with prefix `kv_`") - - if self.kv_connector not in [None, - "PyNcclConnector", - "LMCacheConnector"]: - raise ValueError(f"Unsupported kv_connector: {self.kv_connector}. " - f"Supported connectors are " - f"`PyNcclConnector` and " - f"`LMCacheConnector`") - - if self.kv_role not in [None, - "kv_producer", - "kv_consumer", - "kv_both"]: - raise ValueError(f"Unsupported kv_role: {self.kv_disagg_role}. " - f"Supported roles are `kv_producer`, `kv_consumer`, " - f"and `kv_both`") - - if self.kv_connector is not None and self.kv_role is None: - raise ValueError("Please specify kv_disagg_role when kv_connector " - "is set, supported roles are `kv_producer`, " - "`kv_consumer`, and `kv_both`") - class SchedulerConfig: """Scheduler configuration. @@ -2079,6 +2003,57 @@ def __post_init__(self): "OpenTelemetry is not available. Unable to configure " "'otlp_traces_endpoint'. Ensure OpenTelemetry packages are " f"installed. Original error:\n{otel_import_error_traceback}") + + +@dataclass +class KVTransferConfig: + """Configuration for distributed KV cache transfer.""" + + # NOTE: these default values should align with EngineArgs + kv_connector: Optional[str] = None + kv_buffer_device: Optional[str] = None + kv_buffer_size: float = 1e9 + kv_role: Optional[str] = None + kv_rank: Optional[int] = None + kv_parallel_size: int = 1 + kv_ip: str = "127.0.0.1" + kv_port: int = 14579 + + @property + def is_distributed_kv_instance(self) -> bool: + return self.kv_transfer_role in ["kv_producer", "kv_consumer", "kv_both"] + + @property + def is_kv_producer(self) -> bool: + return self.kv_transfer_role in ["kv_producer", "kv_both"] + + @property + def is_kv_consumer(self) -> bool: + return self.kv_transfer_role in ["kv_consumer", "kv_both"] + + + def __post_init__(self): + + if self.kv_connector not in [None, + "PyNcclConnector", + "LMCacheConnector"]: + raise ValueError(f"Unsupported kv_connector: {self.kv_connector}. " + f"Supported connectors are " + f"`PyNcclConnector` and " + f"`LMCacheConnector`") + + if self.kv_role not in [None, + "kv_producer", + "kv_consumer", + "kv_both"]: + raise ValueError(f"Unsupported kv_role: {self.kv_disagg_role}. " + f"Supported roles are `kv_producer`, `kv_consumer`, " + f"and `kv_both`") + + if self.kv_connector is not None and self.kv_role is None: + raise ValueError("Please specify kv_disagg_role when kv_connector " + "is set, supported roles are `kv_producer`, " + "`kv_consumer`, and `kv_both`") @dataclass @@ -2093,6 +2068,7 @@ class VllmConfig: scheduler_config: SchedulerConfig device_config: DeviceConfig load_config: LoadConfig + kv_transfer_config: KVTransferConfig lora_config: Optional[LoRAConfig] = None speculative_config: Optional[SpeculativeConfig] = None decoding_config: Optional[DecodingConfig] = None diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py index ed30bef955d99..fbe86a38bbe77 100644 --- a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py +++ b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py @@ -21,7 +21,7 @@ from vllm.distributed.utils import StatelessProcessGroup from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.logger import init_logger -from vllm.config import ParallelConfig +from vllm.config import KVTransferConfig logger = init_logger(__name__) @@ -43,7 +43,7 @@ class PyNcclPipe: def __init__(self, local_rank: int, - config: ParallelConfig, + config: KVTransferConfig, device: Optional[str] = None, port_offset: int = 0): self.config = config diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/kv_transfer_agent.py similarity index 100% rename from vllm/distributed/kv_transfer/vllm_adapter.py rename to vllm/distributed/kv_transfer/kv_transfer_agent.py diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 2268f48534311..17532f9c7f20b 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -7,9 +7,8 @@ The typical workflow is: - call `init_distributed_environment` to initialize the distributed environment. -- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to - initialize the model parallel groups and disaggregated prefill parallel - groups. +- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to + initialize the model parallel groups. - any code dealing with the distributed stuff @@ -23,7 +22,6 @@ import contextlib import gc import pickle -import time import weakref from collections import namedtuple from contextlib import contextmanager, nullcontext @@ -32,7 +30,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union from unittest.mock import patch -from numpy import product import torch import torch.distributed from torch.distributed import Backend, ProcessGroup @@ -41,8 +38,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op, supports_custom_op - -import vllm.distributed.kv_transfer.vllm_adapter as dist_kv +import vllm.distributed.kv_transfer.kv_transfer_agent as dist_kv @dataclass @@ -885,10 +881,10 @@ def get_world_group() -> GroupCoordinator: return _WORLD -def init_world_group(ranks: List[List[int]], local_rank: int, +def init_world_group(ranks: List[int], local_rank: int, backend: str) -> GroupCoordinator: return GroupCoordinator( - group_ranks=ranks, + group_ranks=[ranks], local_rank=local_rank, torch_distributed_backend=backend, use_pynccl=False, @@ -947,13 +943,14 @@ def get_pp_group() -> GroupCoordinator: # kept for backward compatibility get_pipeline_model_parallel_group = get_pp_group -_DISAGG: Optional[dist_kv.KV_transfer_agent] = None + +_KV_TRANSFER: Optional[dist_kv.KV_transfer_agent] = None -def get_disagg_group() -> dist_kv.KV_transfer_agent: - assert _DISAGG is not None, ( - "disaggregated prefill parallel group is not initialized") - return _DISAGG +def get_kv_transfer_group() -> dist_kv.KV_transfer_agent: + assert _KV_TRANSFER is not None, ( + "disaggregated KV cache transfer parallel group is not initialized") + return _KV_TRANSFER @contextmanager @@ -986,7 +983,6 @@ def set_custom_all_reduce(enable: bool): _ENABLE_CUSTOM_ALL_REDUCE = enable - def init_distributed_environment( world_size: int = -1, rank: int = -1, @@ -1003,30 +999,11 @@ def init_distributed_environment( "distributed_init_method must be provided when initializing " "distributed environment") # this backend is used for WORLD - - # offset world size and rank in disaggregated prefill scenario - maybe_disagg_world_size = world_size - maybe_disagg_rank = rank - if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: - maybe_disagg_world_size = world_size * 2 - logger.debug("Distributed KV transfer enabled.") - if dist_kv.IS_KV_PRODUCER: - # for prefill, the ranks are [0, world_size) - logger.debug("rank %d is KV producer.", rank) - maybe_disagg_rank = rank - else: - # this is decode instance. - # offset global rank by tp * pp (which is world_size) - maybe_disagg_rank = rank + world_size - logger.debug("rank %d is KV consumer, adjust it to %d", rank, - maybe_disagg_rank) - torch.distributed.init_process_group( backend=backend, init_method=distributed_init_method, - world_size=maybe_disagg_world_size, - rank=maybe_disagg_rank) - logger.debug("torch.distributed initialized") + world_size=world_size, + rank=rank) # set the local rank # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816 @@ -1037,29 +1014,16 @@ def init_distributed_environment( local_rank = envs.LOCAL_RANK else: local_rank = rank - global _WORLD if _WORLD is None: - # in single node single process the world size can be -1 - # need to infer the world size from torch.distributed.get_world_size() - torch_dist_world_size = torch.distributed.get_world_size() - if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: - # two vLLM instances in the world - # so this vLLM instance's world size is half of torch's world size - torch_dist_world_size = torch_dist_world_size // 2 - ranks = [[i for i in range(torch_dist_world_size)]] - + ranks = list(range(torch.distributed.get_world_size())) _WORLD = init_world_group(ranks, local_rank, backend) - logger.debug("_WORLD initialized for rank %d", - torch.distributed.get_rank()) - time.sleep(5) else: assert _WORLD.world_size == torch.distributed.get_world_size(), ( "world group already initialized with a different world size") def initialize_model_parallel( - kv_transfer_parallel_size: int = 1, tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, backend: Optional[str] = None, @@ -1085,39 +1049,17 @@ def initialize_model_parallel( are on the same DGX box. For example if we are using 2 DGX-1 boxes with a total of 16 GPUs, rank 0 to 7 belong to the first box and ranks 8 to 15 belong to the second box. - - - Disaggregated prefill will also init its process group using this function. - Changes: - - vLLM world size: unchanged (tp * pp) - - torch.distributed.get_world_size(): - - 2 * tp * pp - - Why: both prefill vLLM and decode vLLM is in the world - - Global rank: - - [0, tp * pp) for prefill - - [tp * pp, 2 * tp * pp) for decode - - Parallel groups - - Extend _WORLD, _TP and _PP using - `include_decoding_groups_if_disagg_enabled` - - Add a new parallel group `_DISAGG` for disaggregated prefill - - [ [0, tp * pp], [1, tp * pp + 1], .. ] - - Local rank: unchanged """ - # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() world_size: int = torch.distributed.get_world_size() backend = backend or torch.distributed.get_backend( get_world_group().device_group) - if (world_size != product([ - kv_transfer_parallel_size, - tensor_model_parallel_size, - pipeline_model_parallel_size, - ])): + if (world_size != + tensor_model_parallel_size * pipeline_model_parallel_size): raise RuntimeError( f"world_size ({world_size}) is not equal to " - f"kv_transfer_parallel_size ({kv_transfer_parallel_size}) x " f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " f"pipeline_model_parallel_size ({pipeline_model_parallel_size})") @@ -1132,13 +1074,13 @@ def initialize_model_parallel( range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)) group_ranks.append(ranks) + # message queue broadcaster is only used in tensor model parallel group _TP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_message_queue_broadcaster=True, group_name="tp") - logger.debug("_TP initialized for rank %d", torch.distributed.get_rank()) # Build the pipeline model-parallel groups. num_pipeline_model_parallel_groups: int = (world_size // @@ -1156,23 +1098,22 @@ def initialize_model_parallel( backend, use_custom_allreduce=False, group_name="pp") - logger.debug("_PP initialized for rank %d", torch.distributed.get_rank()) - - if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: - global _DISAGG - logger.debug("Disaggregated prefill enabled, create _DISAGG group") - group_ranks = [] - for i in range(world_size): - # prefill local rank: i - # decode global rank: i + world_size - group_ranks.append([i, i + world_size]) - logger.debug("Distributed group is %s", str(group_ranks)) - _DISAGG = dist_kv.KV_transfer_agent( - group_ranks=group_ranks, + + +def ensure_kv_transfer_initialized( + config: "KVTransferConfig" +) -> None: + """ + Initialize KV cache transfer parallel group. + """ + + global _KV_TRANSFER + if config.is_distributed_kv_instance and _KV_TRANSFER is None: + _KV_TRANSFER = dist_kv.KV_transfer_agent( local_rank=get_world_group().local_rank, + config=config ) - logger.debug("_DISAGG initialized for rank %d", - torch.distributed.get_rank()) + def ensure_model_parallel_initialized( @@ -1215,7 +1156,7 @@ def model_parallel_is_initialized(): def patch_tensor_parallel_group(tp_group: GroupCoordinator): """Patch the tp group temporarily until this function ends. - This method is for draft workers of speculative decode to run draft model + This method is for draft workers of speculative decoding to run draft model with different tp degree from that of target model workers. Args: @@ -1258,11 +1199,6 @@ def destroy_model_parallel(): _PP.destroy() _PP = None - global _DISAGG - if _DISAGG: - _DISAGG.destroy() - _DISAGG = None - def destroy_distributed_environment(): global _WORLD @@ -1347,4 +1283,4 @@ def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]: shm.unlink() torch.distributed.all_reduce(is_in_the_same_node, group=pg) - return [x == 1 for x in is_in_the_same_node.tolist()] + return [x == 1 for x in is_in_the_same_node.tolist()] \ No newline at end of file diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index e824740c4f87e..2318ae6fb0593 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -13,14 +13,13 @@ LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PoolerConfig, PromptAdapterConfig, SchedulerConfig, SpeculativeConfig, TaskOption, - TokenizerPoolConfig, VllmConfig) + TokenizerPoolConfig, KVTransferConfig, VllmConfig) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.platforms import current_platform from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import FlexibleArgumentParser, StoreBoolean -import vllm.distributed.kv_transfer.vllm_adapter as dist_kv if TYPE_CHECKING: from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup @@ -109,7 +108,6 @@ class EngineArgs: distributed_executor_backend: Optional[Union[str, Type[ExecutorBase]]] = None # number of P/D disaggregation (or other disaggregation) workers - kv_disagg_parapllel_size: int = 1 pipeline_parallel_size: int = 1 tensor_parallel_size: int = 1 max_parallel_loading_workers: Optional[int] = None @@ -196,10 +194,13 @@ class EngineArgs: # P/D disaggregation coonfiguration kv_connector: Optional[str] = None - kv_buffer_size: Optional[int] = None - kv_buffer_device: Optional[str] = None - kv_disagg_role: Optional[str] = None - kv_disagg_device: Optional[str] = None + kv_buffer_size: Optional[int] = 1e9 + kv_buffer_device: Optional[str] = "gpu" + kv_role: Optional[str] = None + kv_rank: Optional[str] = None + kv_parallel_size: int = 1 + kv_ip: str = "127.0.0.1" + kv_port: int = 14579 def __post_init__(self): if not self.tokenizer: @@ -888,30 +889,40 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "e.g. {\"pooling_type\": \"mean\", \"normalize\": false}.'") parser.add_argument( - '--kv-disagg-parallel-size', - '-kdp', + '--kv-parallel-size', type=int, - default=1 + default=EngineArgs.kv_parallel_size, + help="The number of parallel instances for KV cache transfer. " + "For PyNcclConnector, this should be >1." ) parser.add_argument( '--kv-connector', type=str, default=None, - choices=["TorchDistributedConnector", "LMCacheConnector"], + choices=["PyNcclConnector"], help="The KV connector for vLLM to transmit KV caches between vLLM" " instances.") parser.add_argument( '--kv-buffer-size', type=float, - default=None, + default=EngineArgs.kv_buffer_size, help="The buffer size for TorchDistributedConnector. Measured in " "number of bytes. Recommended value: 1e9 (about 1GB)." ) parser.add_argument( - '--kv-disagg-role', + '--kv-buffer-device', + type=str, + default=None, + choices=["CPU", "GPU"], + help="The device used by kv connector to buffer the KV cache. Can " + "be CPU or GPU. Recommended value: CPU." + ) + + parser.add_argument( + '--kv-role', type=str, default=None, choices=["kv_producer", "kv_consumer", "both"], @@ -920,14 +931,30 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: ) parser.add_argument( - '--kv-buffer-device', - type=str, + '--kv-rank', + type=int, default=None, - choices=["CPU", "GPU"], - help="The device used by kv connector to buffer the KV cache. Can " - "be CPU or GPU. Recommended value: CPU." + help="The rank of this vLLM instance in the KV cache transfer." + " Typicall value: 0 for prefill instance, 1 for decode instance." + ) + + parser.add_argument( + '--kv-ip', + type=str, + default=EngineArgs.kv_ip, + help="The IP address of the KV cache producer." ) + + parser.add_argument( + '--kv-port', + type=int, + default=EngineArgs.kv_port, + help="The port of the KV cache producer." + ) + + + return parser @classmethod @@ -1030,7 +1057,6 @@ def create_engine_config(self) -> VllmConfig: cpu_offload_gb=self.cpu_offload_gb, ) parallel_config = ParallelConfig( - kv_disagg_parallel_size=self.kv_disagg_parallel_size, pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, worker_use_ray=self.worker_use_ray, @@ -1043,14 +1069,17 @@ def create_engine_config(self) -> VllmConfig: ), ray_workers_use_nsight=self.ray_workers_use_nsight, distributed_executor_backend=self.distributed_executor_backend, + ) + kv_transfer_config = KVTransferConfig( + kv_parallel_size=self.kv_parallel_size, kv_connector=self.kv_connector, kv_buffer_size=self.kv_buffer_size, kv_buffer_device=self.kv_buffer_device, - kv_disagg_role=self.kv_transfer_role, - kv_disagg_rank=self.kv_disagg_rank, + kv_role=self.kv_role, + kv_rank=self.kv_rank, + kv_ip=self.kv_ip, + kv_port=self.kv_port, ) - # set the kv cache transfer condition check variables - dist_kv.set_kv_transfer_attribute(parallel_config) max_model_len = model_config.max_model_len use_long_context = max_model_len > 32768 diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index a88f9f4bcb36c..808fc8c82742b 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -1,6 +1,5 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union -import vllm.distributed.kv_transfer.vllm_adapter as dist_kv from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -49,7 +48,7 @@ def _get_worker_kwargs( if distributed_init_method is None: distributed_init_method = get_distributed_init_method( get_ip(), - get_open_port(force=dist_kv.IS_DISTRIBUTED_KV_INSTANCE)) + get_open_port()) return dict( vllm_config=self.vllm_config, local_rank=local_rank, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index d3ca6d9d0b17e..814295961adcf 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -7,8 +7,9 @@ import torch.distributed import vllm.envs as envs -from vllm.config import ParallelConfig, VllmConfig +from vllm.config import ParallelConfig, VllmConfig, KVTransferConfig from vllm.distributed import (ensure_model_parallel_initialized, + ensure_kv_transfer_initialized, init_distributed_environment, set_custom_all_reduce) from vllm.logger import init_logger @@ -449,6 +450,7 @@ def get_cache_block_size_bytes(self) -> int: def init_worker_distributed_environment( parallel_config: ParallelConfig, + kv_transfer_config: KVTransferConfig, rank: int, distributed_init_method: Optional[str] = None, local_rank: int = -1, @@ -462,6 +464,8 @@ def init_worker_distributed_environment( ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) + ensure_kv_transfer_initialized(local_rank, kv_transfer_config) + def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): # Check if the GPU supports the dtype. diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index cf8a4946a71c4..e548ff0a83386 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -44,6 +44,7 @@ def __init__( self.speculative_config = vllm_config.speculative_config self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config + self.kv_transfer_config = vllm_config.kv_transfer_config @abstractmethod def init_device(self) -> None: @@ -369,7 +370,7 @@ def execute_model( # output is List[SamplerOutput] return output - def _execute_model_spmd( + def _execute_model_spmdt( self, execute_model_req: ExecuteModelRequest, intermediate_tensors: Optional[IntermediateTensors] = None From 9c4cbc5f890e58fa9cdc21111d6e9e8b94dbdc9e Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Sun, 17 Nov 2024 04:53:58 +0000 Subject: [PATCH 291/303] A series of bug fix: previous merge is buggy and need to manually revert things. --- .../kv_transfer/disagg_prefill_example.sh | 8 +- vllm/config.py | 177 ++++++++++++------ .../kv_transfer/kv_transfer_agent.py | 5 +- vllm/distributed/parallel_state.py | 5 +- vllm/engine/arg_utils.py | 1 + .../tokenizer_group/__init__.py | 7 +- vllm/worker/model_runner.py | 11 +- vllm/worker/worker.py | 8 +- 8 files changed, 147 insertions(+), 75 deletions(-) diff --git a/examples/kv_transfer/disagg_prefill_example.sh b/examples/kv_transfer/disagg_prefill_example.sh index b19f00227d77c..8d274d40ba23f 100644 --- a/examples/kv_transfer/disagg_prefill_example.sh +++ b/examples/kv_transfer/disagg_prefill_example.sh @@ -30,7 +30,9 @@ CUDA_VISIBLE_DEVICES=0 python3 \ --max-model-len 10000 \ --gpu-memory-utilization 0.8 \ --kv-connector PyNcclConnector \ - --kv-role kv_producer & + --kv-role kv_producer \ + --kv-rank 0 \ + --kv-parallel-size 2 & # decoding instance, which is the KV consumer CUDA_VISIBLE_DEVICES=1 python3 \ @@ -40,7 +42,9 @@ CUDA_VISIBLE_DEVICES=1 python3 \ --max-model-len 10000 \ --gpu-memory-utilization 0.8 \ --kv-connector PyNcclConnector \ - --kv-role kv_consumer & + --kv-role kv_consumer \ + --kv-rank 1 \ + --kv-parallel-size 2 & # wait until prefill and decode instances are ready wait_for_server 8100 diff --git a/vllm/config.py b/vllm/config.py index 575705dbf761f..8c65ac30f1cea 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,3 +1,4 @@ +import copy import enum import json import warnings @@ -16,9 +17,10 @@ from vllm.model_executor.models import ModelRegistry from vllm.platforms import current_platform from vllm.tracing import is_otel_available, otel_import_error_traceback -from vllm.transformers_utils.config import (ConfigFormat, get_config, - get_hf_image_processor_config, - get_hf_text_config) +from vllm.transformers_utils.config import ( + ConfigFormat, get_config, get_hf_image_processor_config, + get_hf_text_config, get_pooling_config, + get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope) from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, identity, print_warning_once) @@ -79,9 +81,6 @@ class ModelConfig: code_revision: The specific revision to use for the model code on Hugging Face Hub. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. - rope_scaling: Dictionary containing the scaling configuration for the - RoPE embeddings. When using this flag, don't update - `max_position_embeddings` to the expected new maximum. tokenizer_revision: The specific tokenizer version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. @@ -113,7 +112,7 @@ class ModelConfig: matches the model name exposed via the APIs. If multiple model names provided, the first name will be used. If not specified, the model name will be the same as `model`. - limit_mm_per_prompt: Maximum number of data instances per modality + limit_mm_per_prompt: Maximum number of data items per modality per prompt. Only applicable for multimodal models. config_format: The config format which shall be loaded. Defaults to 'auto' which defaults to 'hf'. @@ -142,7 +141,7 @@ def __init__( allowed_local_media_path: str = "", revision: Optional[str] = None, code_revision: Optional[str] = None, - rope_scaling: Optional[dict] = None, + rope_scaling: Optional[Dict[str, Any]] = None, rope_theta: Optional[float] = None, tokenizer_revision: Optional[str] = None, max_model_len: Optional[int] = None, @@ -213,6 +212,7 @@ def __init__( self.hf_config = hf_config self.hf_text_config = get_hf_text_config(self.hf_config) + self.encoder_config = self._get_encoder_config() self.hf_image_processor_config = get_hf_image_processor_config( self.model, revision) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) @@ -244,7 +244,8 @@ def __init__( max_model_len=max_model_len, disable_sliding_window=self.disable_sliding_window, sliding_window_len=self.get_hf_config_sliding_window(), - spec_target_max_model_len=spec_target_max_model_len) + spec_target_max_model_len=spec_target_max_model_len, + encoder_config=self.encoder_config) self.served_model_name = get_served_model_name(model, served_model_name) self.multimodal_config = self._init_multimodal_config( @@ -282,6 +283,10 @@ def _init_multimodal_config( return None + def _get_encoder_config(self): + return get_sentence_transformer_tokenizer_config( + self.model, self.revision) + def _init_pooler_config( self, override_pooler_config: Optional["PoolerConfig"], @@ -679,12 +684,13 @@ def get_multimodal_config(self) -> "MultiModalConfig": return self.multimodal_config @property - def is_encoder_decoder_model(self) -> bool: + def is_encoder_decoder(self) -> bool: """Extract the HF encoder/decoder model flag.""" - return getattr( - self.hf_config, "is_encoder_decoder", - False) or (hasattr(self.hf_config, "text_config") and getattr( - self.hf_config.text_config, "is_encoder_decoder", False)) + return is_encoder_decoder(self.hf_config) + + @property + def uses_mrope(self) -> bool: + return uses_mrope(self.hf_config) @property def is_multimodal_model(self) -> bool: @@ -933,9 +939,12 @@ class ParallelConfig: https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler. placement_group: ray distributed model workers placement group. distributed_executor_backend: Backend to use for distributed model - workers, either "ray" or "mp" (multiprocessing). If either - pipeline_parallel_size or tensor_parallel_size is greater than 1, - will default to "ray" if Ray is installed or "mp" otherwise. + workers, either "ray" or "mp" (multiprocessing). If the product + of pipeline_parallel_size and tensor_parallel_size is less than + or equal to the number of GPUs available, "mp" will be used to + keep processing on a single host. Otherwise, this will default + to "ray" if Ray is installed and fail otherwise. Note that tpu + and hpu only support Ray for distributed inference. """ def __init__( @@ -976,6 +985,13 @@ def __init__( raise ValueError( "TPU backend only supports Ray for distributed inference.") + if current_platform.is_hpu() and self.world_size > 1: + if self.distributed_executor_backend is None: + self.distributed_executor_backend = "ray" + if self.distributed_executor_backend != "ray": + raise ValueError( + "HPU backend only supports Ray for distributed inference.") + if self.distributed_executor_backend is None and self.world_size > 1: # We use multiprocessing by default if world_size fits on the # current node and we aren't in a ray placement group. @@ -1028,7 +1044,7 @@ def _verify_args(self) -> None: if self.use_ray: from vllm.executor import ray_utils ray_utils.assert_ray_available() - if is_hip(): + if current_platform.is_rocm(): self.disable_custom_all_reduce = True logger.info( "Disabled the custom all-reduce kernel because it is not " @@ -1306,13 +1322,6 @@ def maybe_create_spec_config( "speculative decoding is > 1, but got " f"{speculative_disable_by_batch_size=}") - # Reminder: Please update docs/source/serving/compatibility_matrix.rst - # If the feature combo become valid - if enable_chunked_prefill: - raise ValueError( - "Speculative decoding and chunked prefill are " - f"currently mutually exclusive ({enable_chunked_prefill=}).") - # TODO: The user should be able to specify revision/max model len # for the draft model. It is not currently supported. draft_revision = None @@ -1379,6 +1388,29 @@ def maybe_create_spec_config( f"num_speculative_tokens={n_predict}, but " f"{num_speculative_tokens=} was provided.") + if enable_chunked_prefill and draft_hf_config.model_type in ( + "medusa", "mlp_speculator", "eagle"): + raise ValueError( + "Chunked prefill and hidden-state based draft models are " + "not compatible.") + + speculative_draft_tensor_parallel_size = \ + SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size( + target_parallel_config, + speculative_draft_tensor_parallel_size, + draft_hf_config + ) + + if (enable_chunked_prefill and \ + speculative_draft_tensor_parallel_size != 1): + # TODO - Investigate why the error reported in + # https://github.com/vllm-project/vllm/pull/9291#issuecomment-2463266258 + # is happening and re-enable it. + raise ValueError( + "Chunked prefill and speculative decoding can be enabled " + "simultaneously only for draft models with tensor " + "parallel size 1.") + draft_model_config.max_model_len = ( SpeculativeConfig._maybe_override_draft_max_model_len( speculative_max_model_len, @@ -1457,15 +1489,16 @@ def _maybe_override_draft_max_model_len( ) @staticmethod - def create_draft_parallel_config( - target_parallel_config: ParallelConfig, - speculative_draft_tensor_parallel_size: Optional[int], - draft_hf_config: PretrainedConfig, - ) -> ParallelConfig: - """Create a parallel config for use by the draft worker. - - This is mostly a copy of the target parallel config, except the tp_size. + def _verify_and_get_draft_model_tensor_parallel_size( + target_parallel_config: ParallelConfig, + speculative_draft_tensor_parallel_size: Optional[int], + draft_hf_config: PretrainedConfig) -> int: + """ + Verifies and adjusts the tensor parallel size for a draft model + specified using speculative_draft_tensor_parallel_size. """ + # If speculative_draft_tensor_parallel_size is unset then set it + # appropriately else verify that it is set correctly. if speculative_draft_tensor_parallel_size is None: if draft_hf_config.model_type == "mlp_speculator": speculative_draft_tensor_parallel_size = 1 @@ -1481,7 +1514,18 @@ def create_draft_parallel_config( raise ValueError( f"{speculative_draft_tensor_parallel_size=} cannot be " f"other value than 1 or target model tensor_parallel_size") + return speculative_draft_tensor_parallel_size + + @staticmethod + def create_draft_parallel_config( + target_parallel_config: ParallelConfig, + speculative_draft_tensor_parallel_size: int, + draft_hf_config: PretrainedConfig, + ) -> ParallelConfig: + """Create a parallel config for use by the draft worker. + This is mostly a copy of the target parallel config, except the tp_size. + """ draft_parallel_config = ParallelConfig( pipeline_parallel_size=target_parallel_config. pipeline_parallel_size, @@ -1631,6 +1675,7 @@ class LoRAConfig: # This is a constant. lora_vocab_padding_size: ClassVar[int] = 256 long_lora_scaling_factors: Optional[Tuple[float]] = None + bias_enabled: bool = False def __post_init__(self): # Setting the maximum rank to 256 should be able to satisfy the vast @@ -1828,6 +1873,7 @@ def _get_and_verify_max_len( disable_sliding_window: bool, sliding_window_len: Optional[Union[int, List[Optional[int]]]], spec_target_max_model_len: Optional[int] = None, + encoder_config: Optional[Any] = None, ) -> int: """Get and verify the model's maximum length.""" derived_max_model_len = float("inf") @@ -1910,6 +1956,9 @@ def _get_and_verify_max_len( "original_max_position_embeddings"] derived_max_model_len *= scaling_factor + if encoder_config and "max_seq_length" in encoder_config: + derived_max_model_len = encoder_config["max_seq_length"] + # If the user specified a max length, make sure it is smaller than the # derived length from the HF model config. if max_model_len is None: @@ -1983,7 +2032,6 @@ def __post_init__(self): raise ValueError(f"Invalid guided_decoding_backend '{backend}," f"must be one of {valid_guided_backends}") - @dataclass class ObservabilityConfig: """Configuration for observability.""" @@ -2020,27 +2068,35 @@ class KVTransferConfig: kv_port: int = 14579 @property - def is_distributed_kv_instance(self) -> bool: - return self.kv_transfer_role in ["kv_producer", "kv_consumer", "kv_both"] + def is_kv_transfer_instance(self) -> bool: + return self.kv_connector is not None and \ + self.kv_role in ["kv_producer", "kv_consumer", "kv_both"] + + @property + def need_kv_parallel_group(self) -> bool: + # for those database-based connector, vLLM does not need to create + # parallel group, and in that case the kv parallel size will be 1. + return self.kv_connector is not None and self.kv_parallel_size > 1 @property def is_kv_producer(self) -> bool: - return self.kv_transfer_role in ["kv_producer", "kv_both"] + return self.kv_connector is not None and \ + self.kv_role in ["kv_producer", "kv_both"] @property def is_kv_consumer(self) -> bool: - return self.kv_transfer_role in ["kv_consumer", "kv_both"] + return self.kv_connector is not None and \ + self.kv_role in ["kv_consumer", "kv_both"] def __post_init__(self): if self.kv_connector not in [None, - "PyNcclConnector", - "LMCacheConnector"]: + "PyNcclConnector"]: raise ValueError(f"Unsupported kv_connector: {self.kv_connector}. " f"Supported connectors are " - f"`PyNcclConnector` and " - f"`LMCacheConnector`") + f"`PyNcclConnector`.") + if self.kv_role not in [None, "kv_producer", @@ -2241,13 +2297,15 @@ class VllmConfig: simplifies passing around the distinct configurations in the codebase. """ - model_config: ModelConfig - cache_config: CacheConfig - parallel_config: ParallelConfig - scheduler_config: SchedulerConfig - device_config: DeviceConfig - load_config: LoadConfig - kv_transfer_config: KVTransferConfig + model_config: ModelConfig = field(default=None, init=True) # type: ignore + cache_config: CacheConfig = field(default=None, init=True) # type: ignore + parallel_config: ParallelConfig = field(default=None, + init=True) # type: ignore + scheduler_config: SchedulerConfig = field(default=None, + init=True) # type: ignore + device_config: DeviceConfig = field(default=None, + init=True) # type: ignore + load_config: LoadConfig = field(default=None, init=True) # type: ignore lora_config: Optional[LoRAConfig] = None speculative_config: Optional[SpeculativeConfig] = None decoding_config: Optional[DecodingConfig] = None @@ -2256,6 +2314,8 @@ class VllmConfig: quant_config: Optional[QuantizationConfig] = None compilation_config: CompilationConfig = field(default=None, init=True) # type: ignore + kv_transfer_config: KVTransferConfig = field(default=None, + init=True) # type: ignore @staticmethod def _get_quantization_config( @@ -2285,14 +2345,23 @@ def _get_quantization_config( return quant_config return None + def with_hf_config(self, hf_config: PretrainedConfig) -> "VllmConfig": + model_config = copy.deepcopy(self.model_config) + model_config.hf_config = hf_config + + return replace(self, model_config=model_config) + def __post_init__(self): """Verify configs are valid & consistent with each other. """ - self.model_config.verify_async_output_proc(self.parallel_config, - self.speculative_config, - self.device_config) - self.model_config.verify_with_parallel_config(self.parallel_config) - self.cache_config.verify_with_parallel_config(self.parallel_config) + if self.model_config is not None: + self.model_config.verify_async_output_proc(self.parallel_config, + self.speculative_config, + self.device_config) + self.model_config.verify_with_parallel_config(self.parallel_config) + + if self.cache_config is not None: + self.cache_config.verify_with_parallel_config(self.parallel_config) if self.lora_config: self.lora_config.verify_with_model_config(self.model_config) diff --git a/vllm/distributed/kv_transfer/kv_transfer_agent.py b/vllm/distributed/kv_transfer/kv_transfer_agent.py index 992033598e095..66772a1601776 100644 --- a/vllm/distributed/kv_transfer/kv_transfer_agent.py +++ b/vllm/distributed/kv_transfer/kv_transfer_agent.py @@ -52,8 +52,9 @@ def __init__( local_rank: int, config, ): - - assert self.config.is_distributed_kv_instance, "KV cache transfer "\ + + self.config = config + assert self.config.is_kv_transfer_instance, "KV cache transfer "\ "agent should only be used when kv_connector is set." self.connector = KVConnectorFactory.create_connector(config) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 17532f9c7f20b..bfbacb38051d9 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -999,6 +999,7 @@ def init_distributed_environment( "distributed_init_method must be provided when initializing " "distributed environment") # this backend is used for WORLD + print(distributed_init_method) torch.distributed.init_process_group( backend=backend, init_method=distributed_init_method, @@ -1101,14 +1102,14 @@ def initialize_model_parallel( def ensure_kv_transfer_initialized( - config: "KVTransferConfig" + config: "KVTransferConfig", ) -> None: """ Initialize KV cache transfer parallel group. """ global _KV_TRANSFER - if config.is_distributed_kv_instance and _KV_TRANSFER is None: + if config.need_kv_parallel_group and _KV_TRANSFER is None: _KV_TRANSFER = dist_kv.KV_transfer_agent( local_rank=get_world_group().local_rank, config=config diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 457683f41cfb6..0fdbf123bba89 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1232,6 +1232,7 @@ def create_engine_config(self) -> VllmConfig: decoding_config=decoding_config, observability_config=observability_config, prompt_adapter_config=prompt_adapter_config, + kv_transfer_config=kv_transfer_config, ) diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index 6a114b513f382..dc2acc9d2f5fa 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -25,11 +25,6 @@ def init_tokenizer_from_configs(model_config: ModelConfig, trust_remote_code=model_config.trust_remote_code, revision=model_config.tokenizer_revision) - if (model_config.encoder_config is not None - and "do_lower_case" in model_config.encoder_config): - init_kwargs["do_lower_case"] = model_config.encoder_config[ - "do_lower_case"] - return get_tokenizer_group(parallel_config.tokenizer_pool_config, **init_kwargs) @@ -54,4 +49,4 @@ def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig], return tokenizer_cls.from_config(tokenizer_pool_config, **init_kwargs) -__all__ = ["AnyTokenizer", "get_tokenizer_group", "BaseTokenizerGroup"] +__all__ = ["AnyTokenizer", "get_tokenizer_group", "BaseTokenizerGroup"] \ No newline at end of file diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 8aefdaf504511..d569e2a5e58b3 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -14,7 +14,6 @@ import torch.distributed import torch.nn as nn -import vllm.distributed.kv_transfer.vllm_adapter as dist_kv import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention.backends.abstract import AttentionState @@ -22,7 +21,7 @@ from vllm.compilation.compile_context import set_compile_context from vllm.config import CompilationLevel, VllmConfig from vllm.core.scheduler import SchedulerOutputs -from vllm.distributed import get_disagg_group, get_pp_group +from vllm.distributed import get_kv_transfer_group, get_pp_group from vllm.distributed.parallel_state import graph_capture from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry @@ -1649,7 +1648,7 @@ def execute_model( bypass_model_exec = False if self.need_recv_kv(model_input, kv_caches): hidden_or_intermediate_states, bypass_model_exec, model_input = \ - get_disagg_group().recv_kv_caches_and_hidden_states( + get_kv_transfer_group().recv_kv_caches_and_hidden_states( # model is used to know which layer the current worker # is working on, so that we can receive KV for only those # layers. @@ -1688,7 +1687,7 @@ def execute_model( # Sending KV cache in distributed KV cache transfer setting # NOTE: the send operation is non-blocking if self.need_send_kv(model_input, kv_caches): - get_disagg_group().send_kv_caches_and_hidden_states( + get_kv_transfer_group().send_kv_caches_and_hidden_states( # model_executable is used to know which layer the current # worker is working on, so that we can send KV for only those # layers. @@ -1784,7 +1783,7 @@ def need_recv_kv(self, model_input, kv_caches) -> bool: # check if the current run is prefill is_prefill_run = prefill_meta is not None - return dist_kv.IS_KV_CONSUMER and ( + return self.vllm_config.kv_transfer_config.is_kv_consumer and ( not is_profile_run) and is_prefill_run def need_send_kv(self, model_input, kv_caches) -> bool: @@ -1806,7 +1805,7 @@ def need_send_kv(self, model_input, kv_caches) -> bool: # check if the current run is prefill is_prefill_run = prefill_meta is not None - return dist_kv.IS_KV_PRODUCER and ( + return self.vllm_config.kv_transfer_config.is_kv_producer and ( not is_profile_run) and is_prefill_run diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 814295961adcf..23eeff4a3697d 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -143,7 +143,9 @@ def init_device(self) -> None: raise RuntimeError( f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. - init_worker_distributed_environment(self.parallel_config, self.rank, + init_worker_distributed_environment(self.parallel_config, + self.kv_transfer_config, + self.rank, self.distributed_init_method, self.local_rank) # Set random seed. @@ -460,11 +462,11 @@ def init_worker_distributed_environment( init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank) - + ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) - ensure_kv_transfer_initialized(local_rank, kv_transfer_config) + ensure_kv_transfer_initialized(kv_transfer_config) def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): From 9e8affc70404e642036502f2de3bf421eeaba98b Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Mon, 18 Nov 2024 06:59:07 +0000 Subject: [PATCH 292/303] Fix typo (input_token wrongfully typed as input) and make default kv_buffer_device to cuda. --- .../kv_transfer/disagg_prefill_example.sh | 20 ++++++--- .../kv_transfer/kv_connector/__init__.py | 16 ------- .../kv_transfer/kv_connector/base.py | 10 +++-- .../kv_transfer/kv_connector/factory.py | 17 +++++++ .../kv_connector/lmcache_connector.py | 31 ------------- .../pynccl_connector/pynccl_connector.py | 44 ++++++++++++------- .../kv_transfer/kv_transfer_agent.py | 9 +++- vllm/distributed/parallel_state.py | 2 +- vllm/engine/arg_utils.py | 6 +-- 9 files changed, 76 insertions(+), 79 deletions(-) create mode 100644 vllm/distributed/kv_transfer/kv_connector/factory.py delete mode 100644 vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py diff --git a/examples/kv_transfer/disagg_prefill_example.sh b/examples/kv_transfer/disagg_prefill_example.sh index 8d274d40ba23f..00f290bf1edf5 100644 --- a/examples/kv_transfer/disagg_prefill_example.sh +++ b/examples/kv_transfer/disagg_prefill_example.sh @@ -22,6 +22,9 @@ wait_for_server() { done" && return 0 || return 1 } + +# You can also adjust --kv-ip and --kv-port for distributed inference. + # prefilling instance, which is the KV producer CUDA_VISIBLE_DEVICES=0 python3 \ -m vllm.entrypoints.openai.api_server \ @@ -52,8 +55,10 @@ wait_for_server 8200 # launch a proxy server that opens the service at port 8000 # the workflow of this proxy: -# - send the request to prefill vLLM instance (port 8100), change max_tokens to 1 -# - after the prefill vLLM finishes prefill, send the request to decode vLLM instance +# - send the request to prefill vLLM instance (port 8100), change max_tokens +# to 1 +# - after the prefill vLLM finishes prefill, send the request to decode vLLM +# instance python3 ../../benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py & sleep 1 @@ -76,6 +81,13 @@ output2=$(curl -s http://localhost:8000/v1/completions \ "temperature": 0 }') + +# Cleanup commands, suppressing their output +ps -e | grep pt_main_thread | awk '{print $1}' | xargs kill -9 > /dev/null 2>&1 +pkill -f python3 > /dev/null 2>&1 + +sleep 3 + # Print the outputs of the curl requests echo "" echo "Output of first request: $output1" @@ -83,7 +95,3 @@ echo "Output of second request: $output2" echo "Successfully finished 2 test requests!" echo "" - -# Cleanup commands, suppressing their output -ps -e | grep pt_main_thread | awk '{print $1}' | xargs kill -9 > /dev/null 2>&1 -pkill -f python3 > /dev/null 2>&1 diff --git a/vllm/distributed/kv_transfer/kv_connector/__init__.py b/vllm/distributed/kv_transfer/kv_connector/__init__.py index b5121f0240af3..e69de29bb2d1d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/__init__.py +++ b/vllm/distributed/kv_transfer/kv_connector/__init__.py @@ -1,16 +0,0 @@ - -from .base import KVConnectorBase - -class KVConnectorFactory: - - @staticmethod - def create_connector( - local_rank: int, - config - ) -> KVConnectorBase: - if config.kv_connector == 'PyNcclConnector': - from . import PyNcclConnector - return PyNcclConnector(local_rank, config) - else: - raise ValueError(f"Unsupported connector type: " - f"{config.kv_connector}") \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/kv_connector/base.py b/vllm/distributed/kv_transfer/kv_connector/base.py index c9ac9ccabe557..81783669c4d47 100644 --- a/vllm/distributed/kv_transfer/kv_connector/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/base.py @@ -10,7 +10,6 @@ from abc import ABC, abstractmethod from typing import List, Optional - import torch @@ -39,7 +38,12 @@ class KVConnectorBase(ABC): """ @abstractmethod - def init(self, config): + def __init__( + self, + rank: int, + local_rank: int, + config: "KVTransferConfig", + ): raise NotImplementedError @abstractmethod @@ -113,7 +117,7 @@ def close(self) -> None: @abstractmethod - def rebuild_model_input( + def build_partial_prefill_input( self, model_input: "ModelInputForGPUWithSamplingMetadata", input_tokens_list: List[torch.Tensor], diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py new file mode 100644 index 0000000000000..b2542e0db290a --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -0,0 +1,17 @@ + +from .base import KVConnectorBase + +class KVConnectorFactory: + + @staticmethod + def create_connector( + rank: int, + local_rank: int, + config + ) -> KVConnectorBase: + if config.kv_connector == 'PyNcclConnector': + from .pynccl_connector.pynccl_connector import PyNcclConnector + return PyNcclConnector(rank, local_rank, config) + else: + raise ValueError(f"Unsupported connector type: " + f"{config.kv_connector}") \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py deleted file mode 100644 index 5fa45fb0b337f..0000000000000 --- a/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py +++ /dev/null @@ -1,31 +0,0 @@ -""" - This file implements a simple torch distributed connector by 2 classes: - - `TorchDistributedPipe`: a tensor transmission pipe between P/D instance, - using `torch.distributed` - - `TorchDistributedConnector`: a torch distributed connector between P/D - instance, implemented on top of `TorchDistributedPipe` -""" -import threading -import time -from collections import deque -from concurrent.futures import ThreadPoolExecutor -from typing import Deque, List, Optional, Union - -import torch -from torch.distributed import Backend - -from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase -from vllm.logger import init_logger - -logger = init_logger(__name__) - -try: - import lmcache -except ModuleNotFoundError as e: - logger.error("LMcache not installed, please install LMCache.") - raise e - - -class LMCacheConnector(KVConnectorBase): - - pass \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_connector.py index 3a48d3caf7327..ac5f7b3ff4595 100644 --- a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_connector.py @@ -36,57 +36,60 @@ class PyNcclConnector(KVConnectorBase): def __init__( self, + rank: int, local_rank: int, config: KVTransferConfig, ): - self.lookup_buffer_size = self.kv_buffer_size + self.config = config - self.send_buffer: Optional[LookupBuffer] = None - self.recv_buffer: Optional[LookupBuffer] = None + self.lookup_buffer_size = self.config.kv_buffer_size + + self.producer_buffer: Optional[LookupBuffer] = None + self.consumer_buffer: Optional[LookupBuffer] = None # 2 pipes for every rank in the world - port_offset_base = 2 * config.rank + port_offset_base = 2 * rank # In disaggregated prefill, the prefill vLLM only uses send pipe # and the decode vLLM only uses recv pipe if config.is_kv_producer: - self.send_data_pipe = PyNcclPipe( + self.producer_data_pipe = PyNcclPipe( local_rank=local_rank, config=config, port_offset=port_offset_base, ) - self.send_signal_pipe = PyNcclPipe( + self.producer_signal_pipe = PyNcclPipe( local_rank=local_rank, config=config, port_offset=port_offset_base + 1, device="cpu", ) - self.send_buffer = LookupBuffer( - self.send_signal_pipe, - self.send_data_pipe, + self.producer_buffer = LookupBuffer( + self.producer_signal_pipe, + self.producer_data_pipe, config.kv_buffer_size) else: # the current vLLM instance is KV consumer, so it needs to connect # its recv pipe to the send pipe of KV producder - self.recv_data_pipe = PyNcclPipe( + self.consumer_data_pipe = PyNcclPipe( local_rank=local_rank, config=config, port_offset=port_offset_base, ) - self.recv_signal_pipe = PyNcclPipe( + self.consumer_signal_pipe = PyNcclPipe( local_rank=local_rank, config=config, port_offset=port_offset_base + 1, device="cpu", ) - self.recv_buffer = LookupBuffer( - self.recv_signal_pipe, - self.recv_data_pipe, + self.consumer_buffer = LookupBuffer( + self.consumer_signal_pipe, + self.consumer_data_pipe, config.kv_buffer_size ) @@ -95,13 +98,13 @@ def select( self, input_tokens: Optional[torch.Tensor], roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: - return self.send_buffer.drop_select(input, roi) + return self.consumer_buffer.drop_select(input_tokens, roi) def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, key: torch.Tensor, value: torch.Tensor, hidden: torch.Tensor) -> None: - return self.recv_buffer.insert( + return self.producer_buffer.insert( input_tokens, roi, key, @@ -236,5 +239,12 @@ def build_partial_prefill_input( is_prompt=model_input.is_prompt, ) - return rebuilt_model_input + return rebuilt_model_input + + + def close(self): + self.producer_data_pipe.close() + self.producer_signal_pipe.close() + self.consumer_data_pipe.close() + self.consumer_signal_pipe.close() diff --git a/vllm/distributed/kv_transfer/kv_transfer_agent.py b/vllm/distributed/kv_transfer/kv_transfer_agent.py index 66772a1601776..d473b4972e712 100644 --- a/vllm/distributed/kv_transfer/kv_transfer_agent.py +++ b/vllm/distributed/kv_transfer/kv_transfer_agent.py @@ -30,7 +30,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops -from vllm.distributed.kv_transfer.kv_connector import KVConnectorFactory +from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.logger import init_logger from vllm.sequence import IntermediateTensors @@ -49,6 +49,7 @@ class KV_transfer_agent: def __init__( self, + rank: int, local_rank: int, config, ): @@ -57,7 +58,11 @@ def __init__( assert self.config.is_kv_transfer_instance, "KV cache transfer "\ "agent should only be used when kv_connector is set." - self.connector = KVConnectorFactory.create_connector(config) + self.connector = KVConnectorFactory.create_connector( + rank, + local_rank, + config + ) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index bfbacb38051d9..392f99e90d478 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -999,7 +999,6 @@ def init_distributed_environment( "distributed_init_method must be provided when initializing " "distributed environment") # this backend is used for WORLD - print(distributed_init_method) torch.distributed.init_process_group( backend=backend, init_method=distributed_init_method, @@ -1111,6 +1110,7 @@ def ensure_kv_transfer_initialized( global _KV_TRANSFER if config.need_kv_parallel_group and _KV_TRANSFER is None: _KV_TRANSFER = dist_kv.KV_transfer_agent( + rank=get_world_group().rank, local_rank=get_world_group().local_rank, config=config ) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0fdbf123bba89..57b7623bdd30d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -194,7 +194,7 @@ class EngineArgs: # P/D disaggregation coonfiguration kv_connector: Optional[str] = None kv_buffer_size: Optional[int] = 1e9 - kv_buffer_device: Optional[str] = "gpu" + kv_buffer_device: Optional[str] = "cuda" kv_role: Optional[str] = None kv_rank: Optional[str] = None kv_parallel_size: int = 1 @@ -906,8 +906,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( '--kv-buffer-device', type=str, - default=None, - choices=["CPU", "GPU"], + default=EngineArgs.kv_buffer_device, + choices=["cpu", "cuda"], help="The device used by kv connector to buffer the KV cache. Can " "be CPU or GPU. Recommended value: CPU." ) From d8e79fa332ff677fb46092699ebed0d939a6272a Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Mon, 18 Nov 2024 07:00:35 +0000 Subject: [PATCH 293/303] Make sure the output is shown at the end of the run by sleeping longer -- killing engine really takes time --- examples/kv_transfer/disagg_prefill_example.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/kv_transfer/disagg_prefill_example.sh b/examples/kv_transfer/disagg_prefill_example.sh index 00f290bf1edf5..d4262b68b796a 100644 --- a/examples/kv_transfer/disagg_prefill_example.sh +++ b/examples/kv_transfer/disagg_prefill_example.sh @@ -86,7 +86,7 @@ output2=$(curl -s http://localhost:8000/v1/completions \ ps -e | grep pt_main_thread | awk '{print $1}' | xargs kill -9 > /dev/null 2>&1 pkill -f python3 > /dev/null 2>&1 -sleep 3 +sleep 4 # Print the outputs of the curl requests echo "" From 62f3966d48a027e773eaa6684f0bdaa3c492c045 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Mon, 18 Nov 2024 07:09:50 +0000 Subject: [PATCH 294/303] A series of changes to clean up the diff --- vllm/distributed/parallel_state.py | 8 +++---- vllm/executor/multiproc_gpu_executor.py | 1 - vllm/executor/ray_gpu_executor.py | 3 +-- vllm/utils.py | 30 +------------------------ vllm/worker/worker.py | 1 - vllm/worker/worker_base.py | 2 +- 6 files changed, 7 insertions(+), 38 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 392f99e90d478..f520d3409c2f5 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -38,7 +38,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op, supports_custom_op -import vllm.distributed.kv_transfer.kv_transfer_agent as dist_kv +import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer @dataclass @@ -944,10 +944,10 @@ def get_pp_group() -> GroupCoordinator: get_pipeline_model_parallel_group = get_pp_group -_KV_TRANSFER: Optional[dist_kv.KV_transfer_agent] = None +_KV_TRANSFER: Optional[kv_transfer.KV_transfer_agent] = None -def get_kv_transfer_group() -> dist_kv.KV_transfer_agent: +def get_kv_transfer_group() -> kv_transfer.KV_transfer_agent: assert _KV_TRANSFER is not None, ( "disaggregated KV cache transfer parallel group is not initialized") return _KV_TRANSFER @@ -1109,7 +1109,7 @@ def ensure_kv_transfer_initialized( global _KV_TRANSFER if config.need_kv_parallel_group and _KV_TRANSFER is None: - _KV_TRANSFER = dist_kv.KV_transfer_agent( + _KV_TRANSFER = kv_transfer.KV_transfer_agent( rank=get_world_group().rank, local_rank=get_world_group().local_rank, config=config diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 03552da7f0a23..e39cef088c03c 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -5,7 +5,6 @@ import torch -import vllm.distributed.kv_transfer.vllm_adapter as dist_kv from vllm.executor.distributed_gpu_executor import ( # yapf: disable DistributedGPUExecutor, DistributedGPUExecutorAsync) from vllm.executor.gpu_executor import create_worker diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 1f635de42fe3e..b16fc9201f0ec 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -6,7 +6,6 @@ import msgspec -import vllm.distributed.kv_transfer.vllm_adapter as dist_kv import vllm.envs as envs from vllm.executor.distributed_gpu_executor import ( # yapf: disable DistributedGPUExecutor, DistributedGPUExecutorAsync) @@ -268,7 +267,7 @@ def sort_by_driver_then_worker_ip(worker): # this port will be binded by prefill instance # but the decode instance must use that port to init torch.distributed distributed_init_method = get_distributed_init_method( - driver_ip, get_open_port(force=dist_kv.IS_DISTRIBUTED_KV_INSTANCE)) + driver_ip, get_open_port()) # Initialize the actual workers inside worker wrapper. init_worker_all_kwargs = [ diff --git a/vllm/utils.py b/vllm/utils.py index 55c65607a80d4..111460a29de47 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -516,39 +516,11 @@ def get_open_zmq_ipc_path() -> str: return f"ipc://{base_rpc_path}/{uuid4()}" -def get_open_port(force: bool = False) -> int: +def get_open_port() -> int: port = envs.VLLM_PORT - - if force: - # This flag will only be True in disaggregated prefill scenario - # and VLLM_PORT must be set so that vLLM can connect prefill vLLM - # instance and decode vLLM instance. - assert port is not None, "Please set environment variable VLLM_PORT in" - " order to use disaggregated prefill and distributed KV cache transfer" - - # For prefill vLLM instance (KV producer), `port` must be available. - # For decode vLLM instance `port` can be not available. - if dist_kv.IS_KV_PRODUCER: - # `port` must be available. - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", port)) - return port - except OSError as e: - logger.error( - "Port %d must be empty so that prefill vLLM " - "instance can use this port to initialize " - "distributed KV communication with decode " - "vLLM instance.", port) - raise e - else: - # `port` can be not available - return port - if port is not None: while True: try: - logger.info('Trying port %d', port) with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("", port)) return port diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 23eeff4a3697d..139ede64c7937 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -462,7 +462,6 @@ def init_worker_distributed_environment( init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank) - ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index e548ff0a83386..9ca4667f78cc8 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -370,7 +370,7 @@ def execute_model( # output is List[SamplerOutput] return output - def _execute_model_spmdt( + def _execute_model_spmd( self, execute_model_req: ExecuteModelRequest, intermediate_tensors: Optional[IntermediateTensors] = None From 6478fb2804102bc89f6667ad8d9d7f9e2911cc95 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Mon, 18 Nov 2024 07:13:36 +0000 Subject: [PATCH 295/303] Clean up the diff --- vllm/config.py | 1 + vllm/transformers_utils/tokenizer_group/__init__.py | 7 ++++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 8c65ac30f1cea..8642da68c630f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2032,6 +2032,7 @@ def __post_init__(self): raise ValueError(f"Invalid guided_decoding_backend '{backend}," f"must be one of {valid_guided_backends}") + @dataclass class ObservabilityConfig: """Configuration for observability.""" diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index dc2acc9d2f5fa..f1bf616968773 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -25,6 +25,11 @@ def init_tokenizer_from_configs(model_config: ModelConfig, trust_remote_code=model_config.trust_remote_code, revision=model_config.tokenizer_revision) + if (model_config.encoder_config is not None + and "do_lower_case" in model_config.encoder_config): + init_kwargs["do_lower_case"] = model_config.encoder_config[ + "do_lower_case"] + return get_tokenizer_group(parallel_config.tokenizer_pool_config, **init_kwargs) @@ -49,4 +54,4 @@ def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig], return tokenizer_cls.from_config(tokenizer_pool_config, **init_kwargs) -__all__ = ["AnyTokenizer", "get_tokenizer_group", "BaseTokenizerGroup"] \ No newline at end of file +__all__ = ["AnyTokenizer", "get_tokenizer_group", "BaseTokenizerGroup"] From c2ebcb67faa1c76742a99af506cf843662b8c6b5 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Mon, 18 Nov 2024 07:15:39 +0000 Subject: [PATCH 296/303] Clean up the diff in several executor files --- vllm/executor/gpu_executor.py | 3 +-- vllm/executor/multiproc_gpu_executor.py | 3 +-- vllm/executor/ray_gpu_executor.py | 3 --- 3 files changed, 2 insertions(+), 7 deletions(-) diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 808fc8c82742b..c65d0836e5ff7 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -47,8 +47,7 @@ def _get_worker_kwargs( """Return worker init args for a given rank.""" if distributed_init_method is None: distributed_init_method = get_distributed_init_method( - get_ip(), - get_open_port()) + get_ip(), get_open_port()) return dict( vllm_config=self.vllm_config, local_rank=local_rank, diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index e39cef088c03c..3eb14fb931925 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -69,8 +69,7 @@ def _init_executor(self) -> None: # Since it only works for single node, we can use the loopback address # 127.0.0.1 for communication. distributed_init_method = get_distributed_init_method( - "127.0.0.1", - get_open_port(force=self.config.IS_DISTRIBUTED_KV_INSTANCE)) + "127.0.0.1", get_open_port()) self.workers: List[ProcessWorkerWrapper] = [] # This is the list of workers that are rank 0 of each TP group EXCEPT diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index b16fc9201f0ec..66bab2c686c67 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -263,9 +263,6 @@ def sort_by_driver_then_worker_ip(worker): # solves this issue, as it always works for communication inside # the node. driver_ip = "127.0.0.1" - # force vLLM to use the port specified by envs.VLLM_PORT - # this port will be binded by prefill instance - # but the decode instance must use that port to init torch.distributed distributed_init_method = get_distributed_init_method( driver_ip, get_open_port()) From 9158549c8ec22b447352613fcf6cffd7b6352d94 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Mon, 18 Nov 2024 07:17:52 +0000 Subject: [PATCH 297/303] clean up wierd spaces in tokenizer groups --- vllm/transformers_utils/tokenizer_group/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index f1bf616968773..6a114b513f382 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -26,9 +26,9 @@ def init_tokenizer_from_configs(model_config: ModelConfig, revision=model_config.tokenizer_revision) if (model_config.encoder_config is not None - and "do_lower_case" in model_config.encoder_config): - init_kwargs["do_lower_case"] = model_config.encoder_config[ - "do_lower_case"] + and "do_lower_case" in model_config.encoder_config): + init_kwargs["do_lower_case"] = model_config.encoder_config[ + "do_lower_case"] return get_tokenizer_group(parallel_config.tokenizer_pool_config, **init_kwargs) From 98a44df48d86d5c90875528311f0ee3f71bf0cc6 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Mon, 18 Nov 2024 07:18:10 +0000 Subject: [PATCH 298/303] Remove previous environment variable -- now we initialize distributed communication via CLI args --- vllm/envs.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index cc839850c224f..716e835a555f1 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -388,11 +388,6 @@ def get_default_config_root(): "VLLM_FUSED_MOE_CHUNK_SIZE": lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")), - # Specify the role of current vllm instance - # Value can be "producer", "consumer" or "both". - "VLLM_DISTRIBUTED_KV_ROLE": - lambda: os.getenv("VLLM_DISTRIBUTED_KV_ROLE", None), - # If set, vllm will skip the deprecation warnings. "VLLM_NO_DEPRECATION_WARNING": lambda: bool(int(os.getenv("VLLM_NO_DEPRECATION_WARNING", "0"))), From 744a40f53bc49531d8f09c8ec005f9ba4913378d Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Mon, 18 Nov 2024 07:19:24 +0000 Subject: [PATCH 299/303] add a new line at the end of parallel_state.py to clean up the diff --- vllm/distributed/parallel_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index f520d3409c2f5..8ce9089aabdb8 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1284,4 +1284,4 @@ def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]: shm.unlink() torch.distributed.all_reduce(is_in_the_same_node, group=pg) - return [x == 1 for x in is_in_the_same_node.tolist()] \ No newline at end of file + return [x == 1 for x in is_in_the_same_node.tolist()] From c93c236cd319d16acbf092c2de1c6e8c4ee0247d Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Mon, 18 Nov 2024 22:49:40 +0000 Subject: [PATCH 300/303] fix disagg test --- tests/kv_transfer/disagg_test.py | 23 +++++++++++++++---- .../pynccl_connector/lookup_buffer.py | 8 +++---- .../pynccl_connector/pynccl_connector.py | 8 +++---- .../kv_transfer/kv_transfer_agent.py | 2 +- vllm/distributed/parallel_state.py | 6 ++--- 5 files changed, 30 insertions(+), 17 deletions(-) diff --git a/tests/kv_transfer/disagg_test.py b/tests/kv_transfer/disagg_test.py index 3dfacbdc5fe84..d86be2eaa5d0c 100644 --- a/tests/kv_transfer/disagg_test.py +++ b/tests/kv_transfer/disagg_test.py @@ -19,7 +19,6 @@ def setup_servers(): VLLM_HOST_IP = subprocess.check_output("hostname -I | awk '{print $1}'", shell=True).decode().strip() os.environ["VLLM_HOST_IP"] = VLLM_HOST_IP - os.environ["VLLM_PORT"] = "12345" # Start prefill instance prefill_cmd = [ @@ -33,12 +32,19 @@ def setup_servers(): "--port", "8100", "--gpu-memory-utilization", - "0.8", + "0.5", "--max-model-len", "1000", + "--kv-connector", + "PyNcclConnector", + "--kv-role", + "kv_producer", + "--kv-rank", + "0", + "--kv-parallel-size", + "2", ] prefill_env = os.environ.copy() - prefill_env["VLLM_DISTRIBUTED_KV_ROLE"] = "producer" prefill_env["CUDA_VISIBLE_DEVICES"] = "0,1" prefill_proc = Popen(prefill_cmd, env=prefill_env) @@ -54,12 +60,19 @@ def setup_servers(): "--port", "8200", "--gpu-memory-utilization", - "0.8", + "0.5", "--max-model-len", "1000", + "--kv-connector", + "PyNcclConnector", + "--kv-role", + "kv_consumer", + "--kv-rank", + "1", + "--kv-parallel-size", + "2", ] decode_env = os.environ.copy() - decode_env["VLLM_DISTRIBUTED_KV_ROLE"] = "consumer" decode_env["CUDA_VISIBLE_DEVICES"] = "2,3" decode_proc = Popen(decode_cmd, env=decode_env) diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/lookup_buffer.py b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/lookup_buffer.py index 2a82132d9d80a..f05e3ed8b17e3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/lookup_buffer.py +++ b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/lookup_buffer.py @@ -17,8 +17,8 @@ import torch from torch.distributed import Backend -import vllm.distributed.kv_transfer.kv_connector.pynccl_connector.pynccl_pipe \ - as pnp +from vllm.distributed.kv_transfer.kv_connector.pynccl_connector.pynccl_pipe \ + import PyNcclPipe from vllm.logger import init_logger @@ -29,8 +29,8 @@ class LookupBuffer: def __init__(self, - signal_pipe: pnp.PyNcclPipe, - data_pipe: pnp.PyNcclPipe, + signal_pipe: PyNcclPipe, + data_pipe: PyNcclPipe, buffer_size_thresh: float): """ signal_pipe: on CPU diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_connector.py index ac5f7b3ff4595..a765eae62b517 100644 --- a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_connector.py @@ -90,7 +90,7 @@ def __init__( self.consumer_buffer = LookupBuffer( self.consumer_signal_pipe, self.consumer_data_pipe, - config.kv_buffer_size + config.kv_buffer_size, ) @@ -148,8 +148,8 @@ def build_partial_prefill_input( token_tensor = input_tokens_list[idx] num_token = len(token_tensor) num_computed_token = num_computed_tokens_list[idx] - # currently attention kernel cannot handle the case where there is 0 - # query token. + # currently attention kernel cannot handle the case where there is + # 0 query token. if num_computed_token == num_token: num_computed_token -= 1 start_pos = start_pos_list[idx] @@ -239,7 +239,7 @@ def build_partial_prefill_input( is_prompt=model_input.is_prompt, ) - return rebuilt_model_input + return rebuilt_model_input def close(self): diff --git a/vllm/distributed/kv_transfer/kv_transfer_agent.py b/vllm/distributed/kv_transfer/kv_transfer_agent.py index d473b4972e712..057de698c59d9 100644 --- a/vllm/distributed/kv_transfer/kv_transfer_agent.py +++ b/vllm/distributed/kv_transfer/kv_transfer_agent.py @@ -38,7 +38,7 @@ -class KV_transfer_agent: +class KVTransferAgent: """ A class designated for distributed KV transfer diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 8ce9089aabdb8..21f892ffd5e29 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -944,10 +944,10 @@ def get_pp_group() -> GroupCoordinator: get_pipeline_model_parallel_group = get_pp_group -_KV_TRANSFER: Optional[kv_transfer.KV_transfer_agent] = None +_KV_TRANSFER: Optional[kv_transfer.KVTransferAgent] = None -def get_kv_transfer_group() -> kv_transfer.KV_transfer_agent: +def get_kv_transfer_group() -> kv_transfer.KVTransferAgent: assert _KV_TRANSFER is not None, ( "disaggregated KV cache transfer parallel group is not initialized") return _KV_TRANSFER @@ -1109,7 +1109,7 @@ def ensure_kv_transfer_initialized( global _KV_TRANSFER if config.need_kv_parallel_group and _KV_TRANSFER is None: - _KV_TRANSFER = kv_transfer.KV_transfer_agent( + _KV_TRANSFER = kv_transfer.KVTransferAgent( rank=get_world_group().rank, local_rank=get_world_group().local_rank, config=config From 6b965610a2c6517ea558300a3b2a0744d437c7f8 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Mon, 18 Nov 2024 23:36:40 +0000 Subject: [PATCH 301/303] make format checker happy --- .../disagg_overhead_benchmark.sh | 6 +- .../disagg_performance_benchmark.sh | 8 +- .../visualize_benchmark_results.py | 2 +- .../kv_transfer/disagg_prefill_example.sh | 2 +- tests/kv_transfer/test_lookup_buffer.py | 18 ++--- tests/kv_transfer/test_lookup_buffer.sh | 2 +- tests/kv_transfer/test_send_recv.py | 17 +--- tests/kv_transfer/test_send_recv.sh | 2 +- vllm/config.py | 29 +++---- .../kv_transfer/kv_connector/base.py | 17 ++-- .../kv_transfer/kv_connector/factory.py | 15 ++-- .../{lookup_buffer.py => buffer.py} | 22 ++---- .../{pynccl_connector.py => connector.py} | 79 ++++++++----------- .../{pynccl_pipe.py => pipe.py} | 50 ++++++------ .../kv_transfer/kv_transfer_agent.py | 26 ++---- vllm/distributed/parallel_state.py | 17 ++-- vllm/engine/arg_utils.py | 54 +++++-------- vllm/worker/worker.py | 11 ++- 18 files changed, 158 insertions(+), 219 deletions(-) rename vllm/distributed/kv_transfer/kv_connector/pynccl_connector/{lookup_buffer.py => buffer.py} (95%) rename vllm/distributed/kv_transfer/kv_connector/pynccl_connector/{pynccl_connector.py => connector.py} (85%) rename vllm/distributed/kv_transfer/kv_connector/pynccl_connector/{pynccl_pipe.py => pipe.py} (88%) diff --git a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh index dec00c2c9fe00..90b772a6d9d0a 100644 --- a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh @@ -77,7 +77,7 @@ benchmark() { --dataset-name $dataset_name \ --dataset-path $dataset_path \ --sonnet-input-len $input_len \ - --sonnet-output-len $output_len \ + --sonnet-output-len "$output_len" \ --sonnet-prefix-len $prefix_len \ --num-prompts $num_prompts \ --port 8100 \ @@ -95,14 +95,14 @@ benchmark() { --dataset-name $dataset_name \ --dataset-path $dataset_path \ --sonnet-input-len $input_len \ - --sonnet-output-len $output_len \ + --sonnet-output-len "$output_len" \ --sonnet-prefix-len $prefix_len \ --num-prompts $num_prompts \ --port 8200 \ --save-result \ --result-dir $results_folder \ --result-filename disagg_prefill_2xtp4.json \ - --request-rate $qps + --request-rate "$qps" kill_gpu_processes } diff --git a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh index 0e6875363f4d3..bd9e51257e936 100644 --- a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh @@ -19,7 +19,7 @@ kill_gpu_processes() { # kill all processes on GPU. pkill -f pt_main_thread pkill -f python3 - ps -e | grep pt_main_thread | awk '{print $1}' | xargs kill -9 + pgrep pt_main_thread | xargs kill -9 for port in 8000 8100 8200; do lsof -t -i:$port | xargs -r kill -9; done sleep 1 } @@ -111,14 +111,14 @@ benchmark() { --dataset-name $dataset_name \ --dataset-path $dataset_path \ --sonnet-input-len $input_len \ - --sonnet-output-len $output_len \ + --sonnet-output-len "$output_len" \ --sonnet-prefix-len $prefix_len \ --num-prompts $num_prompts \ --port 8000 \ --save-result \ --result-dir $results_folder \ - --result-filename $tag-qps-$qps.json \ - --request-rate $qps + --result-filename "$tag"-qps-"$qps".json \ + --request-rate "$qps" sleep 2 diff --git a/benchmarks/disagg_benchmarks/visualize_benchmark_results.py b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py index 6c5bf5c791dc9..e59d8bb0e6c8c 100644 --- a/benchmarks/disagg_benchmarks/visualize_benchmark_results.py +++ b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py @@ -8,7 +8,7 @@ data = [] for name in ['disagg_prefill', 'chunked_prefill']: for qps in [2, 4, 6, 8]: - with open(f"results/{name}-qps-{qps}.json", "r") as f: + with open(f"results/{name}-qps-{qps}.json") as f: x = json.load(f) x['name'] = name x['qps'] = qps diff --git a/examples/kv_transfer/disagg_prefill_example.sh b/examples/kv_transfer/disagg_prefill_example.sh index d4262b68b796a..e6c9d17227c76 100644 --- a/examples/kv_transfer/disagg_prefill_example.sh +++ b/examples/kv_transfer/disagg_prefill_example.sh @@ -83,7 +83,7 @@ output2=$(curl -s http://localhost:8000/v1/completions \ # Cleanup commands, suppressing their output -ps -e | grep pt_main_thread | awk '{print $1}' | xargs kill -9 > /dev/null 2>&1 +pgrep pt_main_thread | xargs kill -9 > /dev/null 2>&1 pkill -f python3 > /dev/null 2>&1 sleep 4 diff --git a/tests/kv_transfer/test_lookup_buffer.py b/tests/kv_transfer/test_lookup_buffer.py index d3552aca22c4a..a323ac7319909 100644 --- a/tests/kv_transfer/test_lookup_buffer.py +++ b/tests/kv_transfer/test_lookup_buffer.py @@ -4,11 +4,11 @@ import torch from tqdm import tqdm -import vllm.distributed.kv_transfer.kv_connector.pynccl_connector.pynccl_pipe \ - as pnp -import vllm.distributed.kv_transfer.kv_connector.pynccl_connector.lookup_buffer\ - as lb from vllm.config import KVTransferConfig +from vllm.distributed.kv_transfer.kv_connector.pynccl_connector.buffer import ( + lookup_buffer as lb) +from vllm.distributed.kv_transfer.kv_connector.pynccl_connector.pipe import ( + pynccl_pipe as pnp) # TODO: the test depends on a lot of fields in the current implementation. # We should have standard interface instead direct field access @@ -20,7 +20,7 @@ def test_run(my_rank, buffer, device): if my_rank == 0: assert buffer.buffer_size == 0 assert len(buffer.buffer) == 0 - + print("My rank: %d, device: %s" % (my_rank, device)) # insert @@ -113,8 +113,7 @@ def stress_test(my_rank, buf, device): if __name__ == "__main__": - - + my_rank = int(os.environ['RANK']) torch.distributed.init_process_group( @@ -123,16 +122,15 @@ def stress_test(my_rank, buf, device): world_size=2, rank=my_rank, ) - + print("initialized! My rank is %d" % my_rank) - config = KVTransferConfig( kv_connector='PyNcclConnector', kv_buffer_device='cuda', kv_buffer_size=1e9, kv_rank=my_rank, - kv_role="kv_both", # this arg doesn't matter in this test + kv_role="kv_both", # this arg doesn't matter in this test kv_parallel_size=2, kv_ip="127.0.0.1", kv_port=12345, diff --git a/tests/kv_transfer/test_lookup_buffer.sh b/tests/kv_transfer/test_lookup_buffer.sh index eec2a9fb84797..09d7ee018c3f4 100644 --- a/tests/kv_transfer/test_lookup_buffer.sh +++ b/tests/kv_transfer/test_lookup_buffer.sh @@ -1,3 +1,3 @@ - +#!/bin/bash RANK=0 python test_lookup_buffer.py & RANK=1 python test_lookup_buffer.py & \ No newline at end of file diff --git a/tests/kv_transfer/test_send_recv.py b/tests/kv_transfer/test_send_recv.py index 239cba19eba51..ad791b456c7ba 100644 --- a/tests/kv_transfer/test_send_recv.py +++ b/tests/kv_transfer/test_send_recv.py @@ -5,9 +5,9 @@ import torch from tqdm import tqdm -import vllm.distributed.kv_transfer.kv_connector.pynccl_connector.pynccl_pipe \ - as pnp from vllm.config import KVTransferConfig +from vllm.distributed.kv_transfer.kv_connector.pynccl_connector import ( + pynccl_pipe as pnp) def test_run(my_rank, pipe): @@ -38,14 +38,12 @@ def test_run(my_rank, pipe): assert torch.allclose(y, y2) - def stress_test(my_rank, pipe): torch.distributed.barrier() tensors: List[torch.Tensor] = [] - torch.manual_seed(0) for i in tqdm(range(500)): @@ -65,12 +63,8 @@ def stress_test(my_rank, pipe): tensors.append(x.mean().unsqueeze(0)) tensors.append(x.std().unsqueeze(0)) - - torch.distributed.barrier() - - for i in tqdm(range(500)): if my_rank == int((i % 10) > 3): pipe.send_tensor(tensors[3 * i]) @@ -80,7 +74,7 @@ def stress_test(my_rank, pipe): x = pipe.recv_tensor() mean = pipe.recv_tensor() std = pipe.recv_tensor() - + if x is None: assert mean is None assert std is None @@ -88,12 +82,10 @@ def stress_test(my_rank, pipe): assert torch.allclose(x, tensors[3 * i]) assert x.mean() == mean[0] assert x.std() == std[0] - torch.distributed.barrier() - def latency_test(my_rank, pipe, nelement, ntensor): latencies = [] @@ -140,14 +132,13 @@ def latency_test(my_rank, pipe, nelement, ntensor): world_size=2, rank=my_rank, ) - config = KVTransferConfig( kv_connector='PyNcclConnector', kv_buffer_device='cuda', kv_buffer_size=1e9, kv_rank=my_rank, - kv_role="kv_both", # this arg doesn't matter in this test + kv_role="kv_both", # this arg doesn't matter in this test kv_parallel_size=2, kv_ip="127.0.0.1", kv_port=12345, diff --git a/tests/kv_transfer/test_send_recv.sh b/tests/kv_transfer/test_send_recv.sh index c9335434473ea..935487bd00d6f 100644 --- a/tests/kv_transfer/test_send_recv.sh +++ b/tests/kv_transfer/test_send_recv.sh @@ -1,3 +1,3 @@ - +#!/bin/bash RANK=0 python test_send_recv.py & RANK=1 python test_send_recv.py & \ No newline at end of file diff --git a/vllm/config.py b/vllm/config.py index 8642da68c630f..98f9f798411e8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2053,7 +2053,7 @@ def __post_init__(self): "'otlp_traces_endpoint'. Ensure OpenTelemetry packages are " f"installed. Original error:\n{otel_import_error_traceback}") - + @dataclass class KVTransferConfig: """Configuration for distributed KV cache transfer.""" @@ -2067,7 +2067,7 @@ class KVTransferConfig: kv_parallel_size: int = 1 kv_ip: str = "127.0.0.1" kv_port: int = 14579 - + @property def is_kv_transfer_instance(self) -> bool: return self.kv_connector is not None and \ @@ -2078,40 +2078,35 @@ def need_kv_parallel_group(self) -> bool: # for those database-based connector, vLLM does not need to create # parallel group, and in that case the kv parallel size will be 1. return self.kv_connector is not None and self.kv_parallel_size > 1 - + @property def is_kv_producer(self) -> bool: return self.kv_connector is not None and \ self.kv_role in ["kv_producer", "kv_both"] - + @property def is_kv_consumer(self) -> bool: return self.kv_connector is not None and \ self.kv_role in ["kv_consumer", "kv_both"] - def __post_init__(self): - if self.kv_connector not in [None, - "PyNcclConnector"]: + if self.kv_connector not in [None, "PyNcclConnector"]: raise ValueError(f"Unsupported kv_connector: {self.kv_connector}. " f"Supported connectors are " f"`PyNcclConnector`.") - - if self.kv_role not in [None, - "kv_producer", - "kv_consumer", - "kv_both"]: - raise ValueError(f"Unsupported kv_role: {self.kv_disagg_role}. " - f"Supported roles are `kv_producer`, `kv_consumer`, " - f"and `kv_both`") + if self.kv_role not in [None, "kv_producer", "kv_consumer", "kv_both"]: + raise ValueError( + f"Unsupported kv_role: {self.kv_role}. " + f"Supported roles are `kv_producer`, `kv_consumer`, " + f"and `kv_both`") if self.kv_connector is not None and self.kv_role is None: raise ValueError("Please specify kv_disagg_role when kv_connector " "is set, supported roles are `kv_producer`, " "`kv_consumer`, and `kv_both`") - + class CompilationLevel: # constants for the levels of the compilation process @@ -2316,7 +2311,7 @@ class VllmConfig: compilation_config: CompilationConfig = field(default=None, init=True) # type: ignore kv_transfer_config: KVTransferConfig = field(default=None, - init=True) # type: ignore + init=True) # type: ignore @staticmethod def _get_quantization_config( diff --git a/vllm/distributed/kv_transfer/kv_connector/base.py b/vllm/distributed/kv_transfer/kv_connector/base.py index 81783669c4d47..e92f1cc638969 100644 --- a/vllm/distributed/kv_transfer/kv_connector/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/base.py @@ -8,10 +8,14 @@ """ from abc import ABC, abstractmethod -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional import torch +if TYPE_CHECKING: + from vllm.config import KVTransferConfig + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + class KVConnectorBase(ABC): """ @@ -36,7 +40,7 @@ class KVConnectorBase(ABC): - hidden: the final hidden state generated by model forwarding. This allows vLLM to bypass further model forwarding by transmitting the hidden state. """ - + @abstractmethod def __init__( self, @@ -77,9 +81,8 @@ def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, raise NotImplementedError @abstractmethod - def select( - self, input_tokens: Optional[torch.Tensor], - roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: + def select(self, input_tokens: Optional[torch.Tensor], + roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: """Select KV cache entries from the connector. The functionality is similar to the following python statements @@ -114,8 +117,7 @@ def close(self) -> None: NotImplementedError: This method must be implemented in subclasses. """ raise NotImplementedError - - + @abstractmethod def build_partial_prefill_input( self, @@ -132,4 +134,3 @@ def build_partial_prefill_input( NotImplementedError: This method must be implemented in subclasses. """ raise NotImplementedError - diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index b2542e0db290a..ac8bf788b39ce 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -1,17 +1,14 @@ - from .base import KVConnectorBase + class KVConnectorFactory: - + @staticmethod - def create_connector( - rank: int, - local_rank: int, - config - ) -> KVConnectorBase: + def create_connector(rank: int, local_rank: int, + config) -> KVConnectorBase: if config.kv_connector == 'PyNcclConnector': - from .pynccl_connector.pynccl_connector import PyNcclConnector + from .pynccl_connector.connector import PyNcclConnector return PyNcclConnector(rank, local_rank, config) else: raise ValueError(f"Unsupported connector type: " - f"{config.kv_connector}") \ No newline at end of file + f"{config.kv_connector}") diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/lookup_buffer.py b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/buffer.py similarity index 95% rename from vllm/distributed/kv_transfer/kv_connector/pynccl_connector/lookup_buffer.py rename to vllm/distributed/kv_transfer/kv_connector/pynccl_connector/buffer.py index f05e3ed8b17e3..883ff101d4b29 100644 --- a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/lookup_buffer.py +++ b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/buffer.py @@ -10,27 +10,20 @@ import threading import time from collections import deque -from concurrent.futures import ThreadPoolExecutor from typing import Deque, List, Optional, Union -from copy import deepcopy import torch -from torch.distributed import Backend -from vllm.distributed.kv_transfer.kv_connector.pynccl_connector.pynccl_pipe \ - import PyNcclPipe +from vllm.distributed.kv_transfer.kv_connector.pynccl_connector.pipe import ( + PyNcclPipe) from vllm.logger import init_logger - - logger = init_logger(__name__) class LookupBuffer: - def __init__(self, - signal_pipe: PyNcclPipe, - data_pipe: PyNcclPipe, + def __init__(self, signal_pipe: PyNcclPipe, data_pipe: PyNcclPipe, buffer_size_thresh: float): """ signal_pipe: on CPU @@ -102,7 +95,7 @@ def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]): # so this check needs to go after the check above return 0 - raise AssertionError("Unknown data type %s" % type(data)) + raise AssertionError(f"Unknown data type {type(data)}") def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -142,6 +135,8 @@ def drop_select_handler(self): input_tokens = self.data_pipe.recv_tensor() roi = self.data_pipe.recv_tensor() + assert roi is not None, "Please provide the roi when sending "\ + "drop-select request" roi = (roi > 0.5) tokens_roi_recver = [input_tokens, roi] @@ -192,11 +187,11 @@ def drop_select( if isinstance(input_tokens, torch.Tensor): input_tokens = input_tokens.clone() if isinstance(roi, torch.Tensor): - roi = roi.clone() + roi = roi.clone().float() self.signal_pipe.send_tensor(self.normal_signal) self.data_pipe.send_tensor(input_tokens) - self.data_pipe.send_tensor(roi.float()) + self.data_pipe.send_tensor(roi) input_tokens = self.data_pipe.recv_tensor() roi = self.data_pipe.recv_tensor() @@ -243,4 +238,3 @@ def close(self): # TODO: have a explicit close signal and have a explicit way to # check if it's requester self.signal_pipe.send_tensor(self.end_signal) - \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/connector.py similarity index 85% rename from vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_connector.py rename to vllm/distributed/kv_transfer/kv_connector/pynccl_connector/connector.py index a765eae62b517..26c8a54002102 100644 --- a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/connector.py @@ -7,33 +7,27 @@ - `TorchDistributedConnector`: a torch distributed connector between P/D instance, implemented on top of `TorchDistributedBuffer` """ -import threading -import time -from collections import deque -from concurrent.futures import ThreadPoolExecutor -from typing import Deque, List, Optional, Union from copy import deepcopy +from typing import TYPE_CHECKING, List, Optional import torch -from torch.distributed import Backend +from vllm.config import KVTransferConfig from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase -from vllm.distributed.kv_transfer.kv_connector.pynccl_connector.pynccl_pipe \ - import PyNcclPipe -from vllm.distributed.kv_transfer.kv_connector.pynccl_connector.lookup_buffer \ - import LookupBuffer +from vllm.distributed.kv_transfer.kv_connector.pynccl_connector.buffer import ( + LookupBuffer) +from vllm.distributed.kv_transfer.kv_connector.pynccl_connector.pipe import ( + PyNcclPipe) from vllm.logger import init_logger -from vllm.config import KVTransferConfig - +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata logger = init_logger(__name__) - - class PyNcclConnector(KVConnectorBase): - + def __init__( self, rank: int, @@ -51,7 +45,6 @@ def __init__( # 2 pipes for every rank in the world port_offset_base = 2 * rank - # In disaggregated prefill, the prefill vLLM only uses send pipe # and the decode vLLM only uses recv pipe if config.is_kv_producer: @@ -67,11 +60,10 @@ def __init__( port_offset=port_offset_base + 1, device="cpu", ) - self.producer_buffer = LookupBuffer( - self.producer_signal_pipe, - self.producer_data_pipe, - config.kv_buffer_size) - + self.producer_buffer = LookupBuffer(self.producer_signal_pipe, + self.producer_data_pipe, + config.kv_buffer_size) + else: # the current vLLM instance is KV consumer, so it needs to connect @@ -88,32 +80,27 @@ def __init__( device="cpu", ) self.consumer_buffer = LookupBuffer( - self.consumer_signal_pipe, + self.consumer_signal_pipe, self.consumer_data_pipe, config.kv_buffer_size, ) - - - def select( - self, input_tokens: Optional[torch.Tensor], - roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: + def select(self, input_tokens: Optional[torch.Tensor], + roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: + + assert self.consumer_buffer is not None, "Please initialize the "\ + "consumer buffer before calling select." return self.consumer_buffer.drop_select(input_tokens, roi) - + def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, key: torch.Tensor, value: torch.Tensor, hidden: torch.Tensor) -> None: - return self.producer_buffer.insert( - input_tokens, - roi, - key, - value, - hidden - ) + assert self.producer_buffer is not None, "Please initialize the "\ + "producer buffer before calling insert." + + self.producer_buffer.insert(input_tokens, roi, key, value, hidden) - - def build_partial_prefill_input( self, model_input: "ModelInputForGPUWithSamplingMetadata", @@ -148,7 +135,7 @@ def build_partial_prefill_input( token_tensor = input_tokens_list[idx] num_token = len(token_tensor) num_computed_token = num_computed_tokens_list[idx] - # currently attention kernel cannot handle the case where there is + # currently attention kernel cannot handle the case where there is # 0 query token. if num_computed_token == num_token: num_computed_token -= 1 @@ -167,8 +154,8 @@ def build_partial_prefill_input( rebuilt_num_prefills += 1 rebuilt_num_prefill_tokens += q_len new_slot_mapping = slot_mapping_flat[start_pos + - num_computed_token:start_pos + - num_token] + num_computed_token:start_pos + + num_token] rebuilt_slot_mapping.append(new_slot_mapping) rebuilt_max_query_len = max(q_len, rebuilt_max_query_len) # TODO(Jiayi): remove hard-code (block_size=16) @@ -184,14 +171,15 @@ def build_partial_prefill_input( # Sampling metadata related #seq_groups (use rebuilt query lens) - rebuilt_selected_token_indices.append(rebuilt_num_prefill_tokens - 1) + rebuilt_selected_token_indices.append(rebuilt_num_prefill_tokens - + 1) # rebuilt attn_metadata rebuilt_attn_metadata = deepcopy(model_input.attn_metadata) rebuilt_attn_metadata.num_prefills = rebuilt_num_prefills rebuilt_attn_metadata.num_prefill_tokens = rebuilt_num_prefill_tokens - rebuilt_attn_metadata.slot_mapping = torch.cat(rebuilt_slot_mapping).to( - device) + rebuilt_attn_metadata.slot_mapping = torch.cat( + rebuilt_slot_mapping).to(device) rebuilt_attn_metadata.max_query_len = rebuilt_max_query_len rebuilt_attn_metadata.block_tables = torch.tensor( @@ -220,7 +208,8 @@ def build_partial_prefill_input( ).to(device) # import here to avoid circular import. - from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + from vllm.worker.model_runner import ( + ModelInputForGPUWithSamplingMetadata) rebuilt_model_input = ModelInputForGPUWithSamplingMetadata( input_tokens=torch.cat(rebuilt_input_tokens).to(device), input_positions=torch.cat(rebuilt_input_positions).to(device), @@ -240,11 +229,9 @@ def build_partial_prefill_input( ) return rebuilt_model_input - def close(self): self.producer_data_pipe.close() self.producer_signal_pipe.close() self.consumer_data_pipe.close() self.consumer_signal_pipe.close() - diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pipe.py similarity index 88% rename from vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py rename to vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pipe.py index fbe86a38bbe77..d74f0967e4731 100644 --- a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pynccl_pipe.py +++ b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/pipe.py @@ -12,21 +12,20 @@ import threading import time from concurrent.futures import ThreadPoolExecutor -from typing import Dict, List, Optional, Union -from copy import deepcopy +from typing import Callable, Dict, Optional, Tuple import torch -from torch.distributed import Backend -from vllm.distributed.utils import StatelessProcessGroup +from vllm.config import KVTransferConfig from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator +from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger -from vllm.config import KVTransferConfig logger = init_logger(__name__) class BrokenPipeException(Exception): + def __init__(self, message): self.message = message super().__init__(self.message) @@ -41,8 +40,8 @@ class PyNcclPipe: MAX_TENSOR_DIMENSIONS = 14 METADATA_DTYPE = torch.int64 - def __init__(self, - local_rank: int, + def __init__(self, + local_rank: int, config: KVTransferConfig, device: Optional[str] = None, port_offset: int = 0): @@ -76,13 +75,18 @@ def __init__(self, self.buffer_size_lock = threading.Lock() self.buffer_size_thresh = self.config.kv_buffer_size + def _get_device_send_recv_impl( + self, group: StatelessProcessGroup + ) -> Tuple[Callable[[torch.Tensor, int], None], Callable[ + [torch.Tensor, int], None]]: - def _get_device_send_recv_impl(self, group: StatelessProcessGroup): + send: Callable[[torch.Tensor, int], None] + recv: Callable[[torch.Tensor, int], None] if self.device.type == "cuda": # use PyNCCL for send / recv comm = PyNcclCommunicator(group, device=self.local_rank) comm.disabled = False - send, recv = comm.send, comm.recv + send, recv = comm.send, comm.recv # type: ignore else: # use cpu communication send = group.send @@ -125,10 +129,9 @@ def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor: - buffer: A tensor of the specified type and shape, allocated on self.device. """ - return torch.empty( - metadata["shape"], - dtype=metadata["dtype"], - device=self.device) + return torch.empty(metadata["shape"], + dtype=metadata["dtype"], + device=self.device) def _send_metadata(self, metadata: Metadata): """ @@ -161,7 +164,8 @@ def _send_impl(self, tensor: Optional[torch.Tensor]) -> None: metadata = self._make_metadata(tensor) self._send_metadata(metadata) if tensor is not None: - self.device_send_func(tensor.to(self.device), self.target_rank_for_send) + self.device_send_func(tensor.to(self.device), + self.target_rank_for_send) def _recv_impl(self) -> Optional[torch.Tensor]: """ @@ -179,10 +183,8 @@ def _recv_impl(self) -> Optional[torch.Tensor]: return buffer - def send_tensor_wrapper( - self, - tensor: Optional[torch.Tensor], - tensor_size: int) -> None: + def send_tensor_wrapper(self, tensor: Optional[torch.Tensor], + tensor_size: int) -> None: """ Wrapper for _send_impl to handle exceptions and update buffer size. """ @@ -216,7 +218,7 @@ def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: """ if self.transport_thread is None: self.transport_thread = ThreadPoolExecutor(max_workers=1) - + if tensor is not None: tensor_size = tensor.element_size() * tensor.numel() else: @@ -227,11 +229,8 @@ def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: with self.buffer_size_lock: self.buffer_size += tensor_size - self.transport_thread.submit( - self.send_tensor_wrapper, - tensor, - tensor_size - ) + self.transport_thread.submit(self.send_tensor_wrapper, tensor, + tensor_size) def recv_tensor(self) -> Optional[torch.Tensor]: """ @@ -261,5 +260,6 @@ def close(self): """ Close the pipe and release associated resources. """ - if hasattr(self, "transport_thread") and self.transport_thread is not None: + if hasattr(self, + "transport_thread") and self.transport_thread is not None: self.transport_thread.shutdown() diff --git a/vllm/distributed/kv_transfer/kv_transfer_agent.py b/vllm/distributed/kv_transfer/kv_transfer_agent.py index 057de698c59d9..61169753de619 100644 --- a/vllm/distributed/kv_transfer/kv_transfer_agent.py +++ b/vllm/distributed/kv_transfer/kv_transfer_agent.py @@ -18,26 +18,22 @@ - Delete the matched item in the lookup buffer to free up GPU memory. - The decode vLLM then store the KV cache into paged memory. """ -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Tuple, Union if TYPE_CHECKING: from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata -from copy import deepcopy - import torch -from torch.distributed import Backend -import vllm.envs as envs from vllm import _custom_ops as ops -from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) from vllm.logger import init_logger from vllm.sequence import IntermediateTensors logger = init_logger(__name__) - class KVTransferAgent: """ A class designated for distributed KV transfer @@ -53,19 +49,14 @@ def __init__( local_rank: int, config, ): - + self.config = config assert self.config.is_kv_transfer_instance, "KV cache transfer "\ "agent should only be used when kv_connector is set." self.connector = KVConnectorFactory.create_connector( - rank, - local_rank, - config - ) - + rank, local_rank, config) - def send_kv_caches_and_hidden_states( self, model_executable: torch.nn.Module, @@ -106,11 +97,10 @@ def send_kv_caches_and_hidden_states( keys = torch.cat(keys, dim=0) values = torch.cat(values, dim=0) - + self.connector.insert( - current_tokens, torch.ones_like(current_tokens, - dtype=bool), keys, values, - hidden_or_intermediate_states[start_pos:end_pos]) + current_tokens, torch.ones_like(current_tokens, dtype=bool), + keys, values, hidden_or_intermediate_states[start_pos:end_pos]) logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 21f892ffd5e29..b308281494f45 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -27,18 +27,22 @@ from contextlib import contextmanager, nullcontext from dataclasses import dataclass from multiprocessing import shared_memory -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, + Union) from unittest.mock import patch import torch import torch.distributed from torch.distributed import Backend, ProcessGroup +import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer import vllm.envs as envs from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op, supports_custom_op -import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer + +if TYPE_CHECKING: + from vllm.config import KVTransferConfig @dataclass @@ -943,7 +947,6 @@ def get_pp_group() -> GroupCoordinator: # kept for backward compatibility get_pipeline_model_parallel_group = get_pp_group - _KV_TRANSFER: Optional[kv_transfer.KVTransferAgent] = None @@ -1100,9 +1103,7 @@ def initialize_model_parallel( group_name="pp") -def ensure_kv_transfer_initialized( - config: "KVTransferConfig", -) -> None: +def ensure_kv_transfer_initialized(config: "KVTransferConfig") -> None: """ Initialize KV cache transfer parallel group. """ @@ -1112,9 +1113,7 @@ def ensure_kv_transfer_initialized( _KV_TRANSFER = kv_transfer.KVTransferAgent( rank=get_world_group().rank, local_rank=get_world_group().local_rank, - config=config - ) - + config=config) def ensure_model_parallel_initialized( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 57b7623bdd30d..5a48bfef01cf6 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -9,11 +9,12 @@ import vllm.envs as envs from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig, - DeviceConfig, HfOverrides, LoadConfig, LoadFormat, - LoRAConfig, ModelConfig, ObservabilityConfig, - ParallelConfig, PoolerConfig, PromptAdapterConfig, - SchedulerConfig, SpeculativeConfig, TaskOption, - TokenizerPoolConfig, KVTransferConfig, VllmConfig) + DeviceConfig, HfOverrides, KVTransferConfig, + LoadConfig, LoadFormat, LoRAConfig, ModelConfig, + ObservabilityConfig, ParallelConfig, PoolerConfig, + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig, TaskOption, TokenizerPoolConfig, + VllmConfig) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS @@ -193,7 +194,7 @@ class EngineArgs: # P/D disaggregation coonfiguration kv_connector: Optional[str] = None - kv_buffer_size: Optional[int] = 1e9 + kv_buffer_size: Optional[float] = 1e9 kv_buffer_device: Optional[str] = "cuda" kv_role: Optional[str] = None kv_rank: Optional[str] = None @@ -884,8 +885,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=int, default=EngineArgs.kv_parallel_size, help="The number of parallel instances for KV cache transfer. " - "For PyNcclConnector, this should be >1." - ) + "For PyNcclConnector, this should be >1.") parser.add_argument( '--kv-connector', @@ -900,8 +900,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=float, default=EngineArgs.kv_buffer_size, help="The buffer size for TorchDistributedConnector. Measured in " - "number of bytes. Recommended value: 1e9 (about 1GB)." - ) + "number of bytes. Recommended value: 1e9 (about 1GB).") parser.add_argument( '--kv-buffer-device', @@ -909,42 +908,32 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.kv_buffer_device, choices=["cpu", "cuda"], help="The device used by kv connector to buffer the KV cache. Can " - "be CPU or GPU. Recommended value: CPU." - ) + "be CPU or GPU. Recommended value: CPU.") parser.add_argument( '--kv-role', type=str, - default=None, + default=None, choices=["kv_producer", "kv_consumer", "both"], help="Whether this vLLM instance produces, consumes KV cache, or " - "both. Choices are 'kv_producer', 'kv_consumer', and 'both'." - ) + "both. Choices are 'kv_producer', 'kv_consumer', and 'both'.") parser.add_argument( '--kv-rank', type=int, default=None, help="The rank of this vLLM instance in the KV cache transfer." - " Typicall value: 0 for prefill instance, 1 for decode instance." - ) - - parser.add_argument( - '--kv-ip', - type=str, - default=EngineArgs.kv_ip, - help="The IP address of the KV cache producer." - ) + " Typical value: 0 for prefill instance, 1 for decode instance.") - - parser.add_argument( - '--kv-port', - type=int, - default=EngineArgs.kv_port, - help="The port of the KV cache producer." - ) + parser.add_argument('--kv-ip', + type=str, + default=EngineArgs.kv_ip, + help="The IP address of the KV cache producer.") - + parser.add_argument('--kv-port', + type=int, + default=EngineArgs.kv_port, + help="The port of the KV cache producer.") return parser @@ -1219,7 +1208,6 @@ def create_engine_config(self) -> VllmConfig: or "all" in detailed_trace_modules, ) - return VllmConfig( model_config=model_config, cache_config=cache_config, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 139ede64c7937..d58eb61d049d6 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -7,9 +7,9 @@ import torch.distributed import vllm.envs as envs -from vllm.config import ParallelConfig, VllmConfig, KVTransferConfig -from vllm.distributed import (ensure_model_parallel_initialized, - ensure_kv_transfer_initialized, +from vllm.config import KVTransferConfig, ParallelConfig, VllmConfig +from vllm.distributed import (ensure_kv_transfer_initialized, + ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) from vllm.logger import init_logger @@ -143,9 +143,8 @@ def init_device(self) -> None: raise RuntimeError( f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. - init_worker_distributed_environment(self.parallel_config, - self.kv_transfer_config, - self.rank, + init_worker_distributed_environment(self.parallel_config, + self.kv_transfer_config, self.rank, self.distributed_init_method, self.local_rank) # Set random seed. From 04ed58f07a7538cfcfefc9d077f95bfd42ec2078 Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Tue, 19 Nov 2024 06:05:24 +0000 Subject: [PATCH 302/303] adjust benchmarking files to incorporate with new CLI arg changes --- .../disagg_overhead_benchmark.sh | 24 +++++-- .../disagg_performance_benchmark.sh | 65 +++++++++---------- 2 files changed, 48 insertions(+), 41 deletions(-) diff --git a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh index 90b772a6d9d0a..e7d30001e850c 100644 --- a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh @@ -10,7 +10,7 @@ set -ex kill_gpu_processes() { # kill all processes on GPU. - pkill pt_main_thread + pkill -f pt_main_thread sleep 10 # remove vllm config file @@ -38,7 +38,6 @@ benchmark() { export VLLM_LOGGING_LEVEL=DEBUG export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') - export VLLM_PORT=12345 # compare chunked prefill with disaggregated prefill @@ -53,19 +52,30 @@ benchmark() { output_len=$2 - VLLM_DISTRIBUTED_KV_ROLE=producer CUDA_VISIBLE_DEVICES=0 python3 \ + CUDA_VISIBLE_DEVICES=0 python3 \ -m vllm.entrypoints.openai.api_server \ --model meta-llama/Meta-Llama-3.1-8B-Instruct \ --port 8100 \ --max-model-len 10000 \ - --gpu-memory-utilization 0.8 & - - VLLM_DISTRIBUTED_KV_ROLE=consumer CUDA_VISIBLE_DEVICES=1 python3 \ + --gpu-memory-utilization 0.6 \ + --kv-connector PyNcclConnector \ + --kv-role kv_producer \ + --kv-rank 0 \ + --kv-parallel-size 2 \ + --kv-buffer-size 1e10 & + + + CUDA_VISIBLE_DEVICES=1 python3 \ -m vllm.entrypoints.openai.api_server \ --model meta-llama/Meta-Llama-3.1-8B-Instruct \ --port 8200 \ --max-model-len 10000 \ - --gpu-memory-utilization 0.8 & + --gpu-memory-utilization 0.6 \ + --kv-connector PyNcclConnector \ + --kv-role kv_consumer \ + --kv-rank 1 \ + --kv-parallel-size 2 \ + --kv-buffer-size 1e10 & wait_for_server 8100 wait_for_server 8200 diff --git a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh index bd9e51257e936..4837173741ebe 100644 --- a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh @@ -36,28 +36,22 @@ wait_for_server() { launch_chunked_prefill() { - model="meta-llama/Meta-Llama-3.1-70B-Instruct" + model="meta-llama/Meta-Llama-3.1-8B-Instruct" # disagg prefill - CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ - -m vllm.entrypoints.openai.api_server \ - --model $model \ - --port 8100 \ - -tp 4 \ - --max-model-len 10000 \ - --disable-log-stats \ - --disable-log-requests \ - --enable-chunked-prefill \ - --gpu-memory-utilization 0.8 & - CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ + CUDA_VISIBLE_DEVICES=0 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8100 \ + --max-model-len 10000 \ + --enable-chunked-prefill \ + --gpu-memory-utilization 0.6 & + CUDA_VISIBLE_DEVICES=1 python3 \ -m vllm.entrypoints.openai.api_server \ --model $model \ --port 8200 \ - -tp 4 \ --max-model-len 10000 \ - --disable-log-stats \ - --disable-log-requests \ --enable-chunked-prefill \ - --gpu-memory-utilization 0.8 & + --gpu-memory-utilization 0.6 & wait_for_server 8100 wait_for_server 8200 python3 round_robin_proxy.py & @@ -66,26 +60,30 @@ launch_chunked_prefill() { launch_disagg_prefill() { - model="meta-llama/Meta-Llama-3.1-70B-Instruct" + model="meta-llama/Meta-Llama-3.1-8B-Instruct" # disagg prefill - VLLM_PORT=12345 VLLM_DISTRIBUTED_KV_ROLE=producer CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ - -m vllm.entrypoints.openai.api_server \ - --model $model \ - --port 8100 \ - -tp 4 \ - --max-model-len 10000 \ - --disable-log-stats \ - --disable-log-requests \ - --gpu-memory-utilization 0.8 & - VLLM_PORT=12345 VLLM_DISTRIBUTED_KV_ROLE=consumer CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ + CUDA_VISIBLE_DEVICES=0 python3 \ -m vllm.entrypoints.openai.api_server \ - --model $model \ + --model meta-llama/Meta-Llama-3.1-8B-Instruct \ + --port 8100 \ + --max-model-len 10000 \ + --gpu-memory-utilization 0.6 \ + --kv-connector PyNcclConnector \ + --kv-role kv_producer \ + --kv-rank 0 \ + --kv-parallel-size 2 \ + --kv-buffer-size 5e9 & + CUDA_VISIBLE_DEVICES=1 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model meta-llama/Meta-Llama-3.1-8B-Instruct \ --port 8200 \ - -tp 4 \ --max-model-len 10000 \ - --disable-log-stats \ - --disable-log-requests \ - --gpu-memory-utilization 0.8 & + --gpu-memory-utilization 0.6 \ + --kv-connector PyNcclConnector \ + --kv-role kv_consumer \ + --kv-rank 1 \ + --kv-parallel-size 2 \ + --kv-buffer-size 5e9 & wait_for_server 8100 wait_for_server 8200 python3 disagg_prefill_proxy_server.py & @@ -98,7 +96,7 @@ benchmark() { model="meta-llama/Meta-Llama-3.1-70B-Instruct" dataset_name="sonnet" dataset_path="../sonnet_4x.txt" - num_prompts=200 + num_prompts=100 qps=$1 prefix_len=50 input_len=1024 @@ -149,7 +147,6 @@ main() { default_output_len=6 - export VLLM_LOGGING_LEVEL=DEBUG export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') launch_chunked_prefill From 529c42591294ed7ac196ffe4061f298b769d47dc Mon Sep 17 00:00:00 2001 From: KuntaiDu Date: Wed, 20 Nov 2024 07:07:21 +0000 Subject: [PATCH 303/303] fall back to full prefill when any of the KV cache receive fails. --- .../disagg_performance_benchmark.sh | 11 +- .../kv_transfer/kv_connector/base.py | 80 ++++- .../pynccl_connector/connector.py | 277 ++++++++++-------- .../kv_transfer/kv_transfer_agent.py | 148 +--------- 4 files changed, 228 insertions(+), 288 deletions(-) diff --git a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh index 4837173741ebe..2bd7aa14de5f0 100644 --- a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh @@ -17,9 +17,8 @@ set -ex kill_gpu_processes() { # kill all processes on GPU. - pkill -f pt_main_thread - pkill -f python3 - pgrep pt_main_thread | xargs kill -9 + pgrep pt_main_thread | xargs -r kill -9 + pgrep python3 | xargs -r kill -9 for port in 8000 8100 8200; do lsof -t -i:$port | xargs -r kill -9; done sleep 1 } @@ -64,7 +63,7 @@ launch_disagg_prefill() { # disagg prefill CUDA_VISIBLE_DEVICES=0 python3 \ -m vllm.entrypoints.openai.api_server \ - --model meta-llama/Meta-Llama-3.1-8B-Instruct \ + --model $model \ --port 8100 \ --max-model-len 10000 \ --gpu-memory-utilization 0.6 \ @@ -75,7 +74,7 @@ launch_disagg_prefill() { --kv-buffer-size 5e9 & CUDA_VISIBLE_DEVICES=1 python3 \ -m vllm.entrypoints.openai.api_server \ - --model meta-llama/Meta-Llama-3.1-8B-Instruct \ + --model $model \ --port 8200 \ --max-model-len 10000 \ --gpu-memory-utilization 0.6 \ @@ -93,7 +92,7 @@ launch_disagg_prefill() { benchmark() { results_folder="./results" - model="meta-llama/Meta-Llama-3.1-70B-Instruct" + model="meta-llama/Meta-Llama-3.1-8B-Instruct" dataset_name="sonnet" dataset_path="../sonnet_4x.txt" num_prompts=100 diff --git a/vllm/distributed/kv_transfer/kv_connector/base.py b/vllm/distributed/kv_transfer/kv_connector/base.py index e92f1cc638969..8254d3fdd5338 100644 --- a/vllm/distributed/kv_transfer/kv_connector/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/base.py @@ -8,10 +8,12 @@ """ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch +from vllm.sequence import IntermediateTensors + if TYPE_CHECKING: from vllm.config import KVTransferConfig from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata @@ -119,18 +121,74 @@ def close(self) -> None: raise NotImplementedError @abstractmethod - def build_partial_prefill_input( + def send_kv_caches_and_hidden_states( self, + model_executable: torch.nn.Module, model_input: "ModelInputForGPUWithSamplingMetadata", - input_tokens_list: List[torch.Tensor], - num_computed_tokens_list: List[int], - start_pos_list: List[int], - slot_mapping_flat: torch.Tensor, - device: torch.device, - ) -> "ModelInputForGPUWithSamplingMetadata": - """Rebuild the model input based on how many KV caches are received + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + """ + Send KV caches and hidden states to the connector. + + This method processes the input tokens, KV caches, and + hidden/intermediate states for a given model and sends the data to the + decode instance. + + Args: + model_executable (torch.nn.Module): The model executable containing + start and end layer information. + model_input (ModelInputForGPUWithSamplingMetadata): The input + metadata from vLLM. + kv_caches (List[torch.Tensor]): List of KV caches (keys and values) + for each layer. + hidden_or_intermediate_states (Union[torch.Tensor, + IntermediateTensors]): + The hidden or intermediate states associated with the tokens. + + Returns: + None - Raises: - NotImplementedError: This method must be implemented in subclasses. """ + + raise NotImplementedError + + @abstractmethod + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor] + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + """ + Receive KV caches and hidden states from the connector. + + This method attempts to retrieve KV caches and hidden states for input + tokens. If all required KV caches and hidden states are received, it + will bypass model input, else it will fall back to normal vLLM model + forwarding. + + Args: + model_executable (torch.nn.Module): + The model executable from vLLM modelrunner. + model_input (ModelInputForGPUWithSamplingMetadata): + The model input from vLLM modelrunner. + kv_caches (List[torch.Tensor]): + List of KV caches for each layer. + + Returns: + - hidden_or_intermediate_states (torch.Tensor or + IntermediateTensors): + Concatenated hidden states if all required data is retrieved, + otherwise `None`. + - bypass_model_exec (bool): + Indicates whether the model execution can be skipped (True) or + needs to be redone (False). + - model_input (ModelInputForGPUWithSamplingMetadata): + Optionally adjusted input metadata for re-execution when + `bypass_model_exec=False`. + + """ + raise NotImplementedError diff --git a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/connector.py b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/connector.py index 26c8a54002102..fe45adba0605e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/pynccl_connector/connector.py @@ -7,11 +7,11 @@ - `TorchDistributedConnector`: a torch distributed connector between P/D instance, implemented on top of `TorchDistributedBuffer` """ -from copy import deepcopy -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch +from vllm import _custom_ops as ops from vllm.config import KVTransferConfig from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase from vllm.distributed.kv_transfer.kv_connector.pynccl_connector.buffer import ( @@ -19,6 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.pynccl_connector.pipe import ( PyNcclPipe) from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors if TYPE_CHECKING: from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata @@ -101,134 +102,154 @@ def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, self.producer_buffer.insert(input_tokens, roi, key, value, hidden) - def build_partial_prefill_input( + def send_kv_caches_and_hidden_states( self, + model_executable: torch.nn.Module, model_input: "ModelInputForGPUWithSamplingMetadata", - input_tokens_list: List[torch.Tensor], - num_computed_tokens_list: List[int], - start_pos_list: List[int], - slot_mapping_flat: torch.Tensor, - device: torch.device, - ) -> "ModelInputForGPUWithSamplingMetadata": - """ - Helper function to rebuild the model input for the current request. - Goal: avoid running redundant prefill on those tokens that already has - KV caches received. - """ - rebuilt_input_tokens = [] - rebuilt_input_positions = [] - rebuilt_query_lens = [] - - rebuilt_num_prefills = 0 - rebuilt_num_prefill_tokens = 0 - rebuilt_slot_mapping = [] - rebuilt_max_query_len = 0 - - rebuilt_block_tables = [] - - rebuilt_query_start_loc = [0] - rebuilt_context_lens_tensor = [] - rebuilt_selected_token_indices = [] - - # recounting query and context lengths - for idx in range(len(input_tokens_list)): - token_tensor = input_tokens_list[idx] - num_token = len(token_tensor) - num_computed_token = num_computed_tokens_list[idx] - # currently attention kernel cannot handle the case where there is - # 0 query token. - if num_computed_token == num_token: - num_computed_token -= 1 - start_pos = start_pos_list[idx] - - rebuilt_input_tokens.append(token_tensor[num_computed_token:]) - # TODO(Jiayi): please check the correctness of next line - rebuilt_input_positions.append( - model_input.input_positions[start_pos + - num_computed_token:start_pos + - num_token]) - q_len = num_token - num_computed_token - rebuilt_query_lens.append(q_len) - - # Attn metadata-related - rebuilt_num_prefills += 1 - rebuilt_num_prefill_tokens += q_len - new_slot_mapping = slot_mapping_flat[start_pos + - num_computed_token:start_pos + - num_token] - rebuilt_slot_mapping.append(new_slot_mapping) - rebuilt_max_query_len = max(q_len, rebuilt_max_query_len) - # TODO(Jiayi): remove hard-code (block_size=16) - blk_size = 16 - temp_block_table = [ - slot_mapping_flat[i] // blk_size - for i in range(start_pos, start_pos + num_token, blk_size) - ] - rebuilt_block_tables.append(temp_block_table) - rebuilt_query_start_loc.append( - rebuilt_num_prefill_tokens) #start with 0 - rebuilt_context_lens_tensor.append(num_computed_token) - - # Sampling metadata related - #seq_groups (use rebuilt query lens) - rebuilt_selected_token_indices.append(rebuilt_num_prefill_tokens - - 1) - - # rebuilt attn_metadata - rebuilt_attn_metadata = deepcopy(model_input.attn_metadata) - rebuilt_attn_metadata.num_prefills = rebuilt_num_prefills - rebuilt_attn_metadata.num_prefill_tokens = rebuilt_num_prefill_tokens - rebuilt_attn_metadata.slot_mapping = torch.cat( - rebuilt_slot_mapping).to(device) - rebuilt_attn_metadata.max_query_len = rebuilt_max_query_len - - rebuilt_attn_metadata.block_tables = torch.tensor( - rebuilt_block_tables, - dtype=model_input.attn_metadata.block_tables.dtype).to(device) - - rebuilt_attn_metadata.query_start_loc = torch.tensor( - rebuilt_query_start_loc, - dtype=model_input.attn_metadata.query_start_loc.dtype).to(device) - rebuilt_attn_metadata.context_lens_tensor = torch.tensor( - rebuilt_context_lens_tensor, - dtype=model_input.attn_metadata.context_lens_tensor.dtype, - ).to(device) - - rebuilt_attn_metadata._cached_prefill_metadata = None - - # rebuilt sampling_metadata - rebuilt_sampling_metadata = deepcopy(model_input.sampling_metadata) - for idx, q_len in enumerate(rebuilt_query_lens): - if rebuilt_sampling_metadata.seq_groups is not None: - rebuilt_sampling_metadata.seq_groups[idx].query_len = q_len - - rebuilt_sampling_metadata.selected_token_indices = torch.tensor( - rebuilt_selected_token_indices, - dtype=model_input.sampling_metadata.selected_token_indices.dtype, - ).to(device) - - # import here to avoid circular import. - from vllm.worker.model_runner import ( - ModelInputForGPUWithSamplingMetadata) - rebuilt_model_input = ModelInputForGPUWithSamplingMetadata( - input_tokens=torch.cat(rebuilt_input_tokens).to(device), - input_positions=torch.cat(rebuilt_input_positions).to(device), - seq_lens=model_input.seq_lens, - query_lens=rebuilt_query_lens, - lora_mapping=model_input.lora_mapping, - lora_requests=model_input.lora_requests, - attn_metadata=rebuilt_attn_metadata, - prompt_adapter_mapping=model_input.prompt_adapter_mapping, - prompt_adapter_requests=model_input.prompt_adapter_requests, - multi_modal_kwargs=model_input.multi_modal_kwargs, - request_ids_to_seq_ids=model_input.request_ids_to_seq_ids, - finished_requests_ids=model_input.finished_requests_ids, - virtual_engine=model_input.virtual_engine, - sampling_metadata=rebuilt_sampling_metadata, - is_prompt=model_input.is_prompt, - ) - - return rebuilt_model_input + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() + start_layer = model_executable.model.start_layer + end_layer = model_executable.model.end_layer + + # query_lens contains new KV caches that are added to vLLM. + # so we will send them to decode instance + # FIXME(Kuntai): This assume that all requests are prefill. + for idx, slen in enumerate(seq_lens): + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + current_tokens = input_tokens_tensor[start_pos:end_pos] + + keys, values = [], [] + + for layer_id in range(start_layer, end_layer): + kv_cache = kv_caches[layer_id - start_layer] + + _, _, num_heads, head_size = kv_cache[0].shape + + key_cache = kv_cache[0].reshape(-1, num_heads, head_size) + value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + + current_slot_mapping = slot_mapping_flat[start_pos:end_pos] + + keys.append(key_cache[current_slot_mapping].unsqueeze(0)) + values.append(value_cache[current_slot_mapping].unsqueeze(0)) + + keys = torch.cat(keys, dim=0) + values = torch.cat(values, dim=0) + + self.insert(current_tokens, + torch.ones_like(current_tokens, + dtype=bool), keys, values, + hidden_or_intermediate_states[start_pos:end_pos]) + + logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) + + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor] + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + + # When bypass_model_exec is set to False, it means that at least for one + # request its corresponding KV cache or hidden state is missing. + # In this case we need to do prefilling to recompute missing KV cache + # and hidden states. + bypass_model_exec = True + + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + slot_mapping = model_input.attn_metadata.slot_mapping.flatten() + + hidden_or_intermediate_states_for_one_req = [] + + input_tokens_list = [] + num_computed_tokens_list = [] + start_pos_list = [] + + # enumerate different requests + # FIXME(Kuntai): This impl assumes that all requests are prefill. + for idx, slen in enumerate(seq_lens): + + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + current_tokens = input_tokens_tensor[start_pos:end_pos] + num_tokens = slen + + # collecting data for rebuilding the input + input_tokens_list.append(current_tokens) + start_pos_list.append(start_pos) + + ret = self.select(current_tokens, + torch.ones_like(current_tokens, dtype=bool)) + if ret[0] is None: + # didn't find any match. + bypass_model_exec = False + num_computed_tokens_list.append(0) + continue + + roi: torch.Tensor = ret[1] + keys: torch.Tensor = ret[2] + values: torch.Tensor = ret[3] + hidden: torch.Tensor = ret[4] + + num_computed_tokens = roi.shape[0] + num_computed_tokens_list.append(num_computed_tokens) + + # check if both KV cache and the hidden states are received + # If not, need to redo the forwarding to compute missing states + if not all([(num_computed_tokens == num_tokens), hidden is not None + ]): + bypass_model_exec = False + + # update the end position based on how many tokens are cached. + end_pos = start_pos + num_computed_tokens + + # put received KV caches into paged memory + for i in range(model_executable.model.start_layer, + model_executable.model.end_layer): + + kv_cache = kv_caches[i - model_executable.model.start_layer] + layer = model_executable.model.layers[i] + + key_cache, value_cache = kv_cache[0], kv_cache[1] + ops.reshape_and_cache_flash( + keys[i - model_executable.model.start_layer].to( + key_cache.device), + values[i - model_executable.model.start_layer].to( + value_cache.device), + key_cache, + value_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + layer.self_attn.attn._v_scale, + ) + + hidden_or_intermediate_states_for_one_req.append(hidden) + + if not bypass_model_exec: + # Some of the KV cache is not retrieved + # so we need to adjust model_input and redo the forwarding. + logger.debug( + "[rank%d]: Failed to receive all KVs and hidden " + "states, redo model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = None + + else: + logger.debug( + "[rank%d]: Successfully received all KVs and hidden " + "states, skip model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = torch.cat( + hidden_or_intermediate_states_for_one_req, dim=0) + + return hidden_or_intermediate_states, bypass_model_exec, model_input def close(self): self.producer_data_pipe.close() diff --git a/vllm/distributed/kv_transfer/kv_transfer_agent.py b/vllm/distributed/kv_transfer/kv_transfer_agent.py index 61169753de619..98ca06138ebd3 100644 --- a/vllm/distributed/kv_transfer/kv_transfer_agent.py +++ b/vllm/distributed/kv_transfer/kv_transfer_agent.py @@ -25,7 +25,6 @@ import torch -from vllm import _custom_ops as ops from vllm.distributed.kv_transfer.kv_connector.factory import ( KVConnectorFactory) from vllm.logger import init_logger @@ -66,43 +65,9 @@ def send_kv_caches_and_hidden_states( IntermediateTensors], ) -> None: - input_tokens_tensor = model_input.input_tokens - seq_lens = model_input.attn_metadata.seq_lens - slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() - start_layer = model_executable.model.start_layer - end_layer = model_executable.model.end_layer - - # query_lens contains new KV caches that are added to vLLM. - # so we will send them to decode instance - # FIXME(Kuntai): This assume that all requests are prefill. - for idx, slen in enumerate(seq_lens): - start_pos = sum(seq_lens[:idx]) - end_pos = start_pos + slen - current_tokens = input_tokens_tensor[start_pos:end_pos] - - keys, values = [], [] - - for layer_id in range(start_layer, end_layer): - kv_cache = kv_caches[layer_id - start_layer] - - _, _, num_heads, head_size = kv_cache[0].shape - - key_cache = kv_cache[0].reshape(-1, num_heads, head_size) - value_cache = kv_cache[1].reshape(-1, num_heads, head_size) - - current_slot_mapping = slot_mapping_flat[start_pos:end_pos] - - keys.append(key_cache[current_slot_mapping].unsqueeze(0)) - values.append(value_cache[current_slot_mapping].unsqueeze(0)) - - keys = torch.cat(keys, dim=0) - values = torch.cat(values, dim=0) - - self.connector.insert( - current_tokens, torch.ones_like(current_tokens, dtype=bool), - keys, values, hidden_or_intermediate_states[start_pos:end_pos]) - - logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) + self.connector.send_kv_caches_and_hidden_states( + model_executable, model_input, kv_caches, + hidden_or_intermediate_states) def close(self) -> None: self.connector.close() @@ -114,108 +79,5 @@ def recv_kv_caches_and_hidden_states( ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, "ModelInputForGPUWithSamplingMetadata"]: - # When this flag is set to False, it means that at least for one - # request its corresponding KV cache or hidden state is missing. - # In this case we need to do prefilling to recompute missing KV cache - # and hidden states. - bypass_model_exec = True - - input_tokens_tensor = model_input.input_tokens - seq_lens = model_input.attn_metadata.seq_lens - slot_mapping = model_input.attn_metadata.slot_mapping.flatten() - - hidden_or_intermediate_states_for_one_req = [] - - input_tokens_list = [] - num_computed_tokens_list = [] - start_pos_list = [] - - # enumerate different requests - # FIXME(Kuntai): This impl assumes that all requests are prefill. - for idx, slen in enumerate(seq_lens): - - start_pos = sum(seq_lens[:idx]) - end_pos = start_pos + slen - current_tokens = input_tokens_tensor[start_pos:end_pos] - num_tokens = slen - - # collecting data for rebuilding the input - input_tokens_list.append(current_tokens) - start_pos_list.append(start_pos) - - ret = self.connector.select( - current_tokens, torch.ones_like(current_tokens, dtype=bool)) - if ret[0] is None: - # didn't find any match. - bypass_model_exec = False - num_computed_tokens_list.append(0) - continue - - roi: torch.Tensor = ret[1] - keys: torch.Tensor = ret[2] - values: torch.Tensor = ret[3] - hidden: torch.Tensor = ret[4] - - num_computed_tokens = roi.shape[0] - num_computed_tokens_list.append(num_computed_tokens) - - # check if both KV cache and the hidden states are received - # If not, need to redo the forwarding to compute missing states - if not all([(num_computed_tokens == num_tokens), hidden is not None - ]): - bypass_model_exec = False - - # update the end position based on how many tokens are cached. - end_pos = start_pos + num_computed_tokens - - # put received KV caches into paged memory - for i in range(model_executable.model.start_layer, - model_executable.model.end_layer): - - kv_cache = kv_caches[i - model_executable.model.start_layer] - layer = model_executable.model.layers[i] - - key_cache, value_cache = kv_cache[0], kv_cache[1] - ops.reshape_and_cache_flash( - keys[i - model_executable.model.start_layer].to( - key_cache.device), - values[i - model_executable.model.start_layer].to( - value_cache.device), - key_cache, - value_cache, - slot_mapping[start_pos:end_pos], - layer.self_attn.attn.kv_cache_dtype, - layer.self_attn.attn._k_scale, - layer.self_attn.attn._v_scale, - ) - - hidden_or_intermediate_states_for_one_req.append(hidden) - - if not bypass_model_exec: - # Some of the KV cache is not retrieved - # so we need to adjust model_input and redo the forwarding. - logger.debug( - "[rank%d]: Failed to receive all KVs and hidden " - "states, redo model forwarding.", torch.distributed.get_rank()) - - # allow the connector to mutate the model input - # useful for injecting memory movement / computation requests - rebuilt_model_input = self.connector.build_partial_prefill_input( - model_input, - input_tokens_list, - num_computed_tokens_list, - start_pos_list, - slot_mapping, - device=input_tokens_tensor.device, - ) - model_input = rebuilt_model_input - hidden_or_intermediate_states = None - - else: - logger.debug( - "[rank%d]: Successfully received all KVs and hidden " - "states, skip model forwarding.", torch.distributed.get_rank()) - hidden_or_intermediate_states = torch.cat( - hidden_or_intermediate_states_for_one_req, dim=0) - - return hidden_or_intermediate_states, bypass_model_exec, model_input + return self.connector.recv_kv_caches_and_hidden_states( + model_executable, model_input, kv_caches)