Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel][LoRA]Punica prefill kernels fusion #11234

Merged
merged 70 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
ec3590d
Init
jeejeelee Dec 10, 2024
9474fb0
Sync main
jeejeelee Dec 10, 2024
8c2ac4c
Fix bug
jeejeelee Dec 10, 2024
2897d05
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 11, 2024
35aebea
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 11, 2024
628a567
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 11, 2024
d04121c
Back up
jeejeelee Dec 11, 2024
a306f42
shrink_sgmv Done
jeejeelee Dec 11, 2024
f6bccc7
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 11, 2024
e5cb72e
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 12, 2024
b6013db
Optimize ptr compute
jeejeelee Dec 12, 2024
7f088ec
Merge commit 'b6013db4' into punica-kernel-fusion
jeejeelee Dec 13, 2024
32c5279
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 13, 2024
8d3742b
Increase the tile size
jeejeelee Dec 13, 2024
9564b33
Clean up triton interface
jeejeelee Dec 13, 2024
3eb3ac3
Sync main
jeejeelee Dec 16, 2024
4012466
Backup
jeejeelee Dec 16, 2024
18bbadf
Optimize one sclice kernel
jeejeelee Dec 16, 2024
43aae70
Delete unused code
jeejeelee Dec 16, 2024
482de15
Refactor expand
jeejeelee Dec 16, 2024
259d382
format
jeejeelee Dec 16, 2024
00f1904
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 17, 2024
a0197e3
Optimize logic
jeejeelee Dec 17, 2024
38ba4f1
Add comments
jeejeelee Dec 17, 2024
3c37226
Fix bug
jeejeelee Dec 17, 2024
45180c1
Fix expand bug
jeejeelee Dec 17, 2024
2e52d2c
Backup
jeejeelee Dec 17, 2024
2146141
revert expand tile size
jeejeelee Dec 17, 2024
d724891
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 18, 2024
9719617
Clean up code
jeejeelee Dec 18, 2024
5d2c557
Optimize expand tile size
jeejeelee Dec 18, 2024
958500d
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 19, 2024
5c88ec4
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 19, 2024
3460308
improve expand (#3)
Abatom Dec 19, 2024
24e893c
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 20, 2024
c9747c6
Lora expand (#4)
Abatom Dec 20, 2024
f3ecfc6
Lora expand (#5)
Abatom Dec 20, 2024
5859da7
Fix K size
jeejeelee Dec 20, 2024
b3ea6fc
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 21, 2024
eb01089
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 23, 2024
ebc9519
revert (#6)
Abatom Dec 24, 2024
a4f46b6
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 24, 2024
2cdf459
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 24, 2024
ba2c444
Add unit test
jeejeelee Dec 24, 2024
394886d
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 24, 2024
36fbeac
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 25, 2024
0f7897b
Optimize unit test
jeejeelee Dec 25, 2024
3edb696
Optimize unit test
jeejeelee Dec 25, 2024
49c6c21
Fix comment
jeejeelee Dec 25, 2024
bf3b9ca
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 26, 2024
fe24a41
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 26, 2024
9d89f47
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 27, 2024
489eca1
Optimize code
jeejeelee Dec 28, 2024
04ae0dd
Add lock for unit test
jeejeelee Dec 28, 2024
fa489f2
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 30, 2024
ea19a7d
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 30, 2024
65d0f2f
Optimize arg
jeejeelee Dec 30, 2024
797ae77
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 30, 2024
2b9f928
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 31, 2024
09fb9a9
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 31, 2024
f446454
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Jan 1, 2025
767b233
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Jan 2, 2025
421382e
Fix expand bug
jeejeelee Jan 2, 2025
90a9117
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Jan 3, 2025
2c79295
Reduce memory
jeejeelee Jan 3, 2025
7e8d3bd
Modify minicpmv test
jeejeelee Jan 4, 2025
02b1d80
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Jan 4, 2025
bd8cc45
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Jan 5, 2025
7ffd15e
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Jan 6, 2025
c1c5b4b
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Jan 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ steps:
source_file_dependencies:
- vllm/lora
- tests/lora
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_minicpmv_tp.py
parallelism: 4

- label: "PyTorch Fullgraph Smoke Test" # 9min
Expand Down Expand Up @@ -535,6 +535,7 @@ steps:
# requires multi-GPU testing for validation.
- pytest -v -s -x lora/test_chatglm3_tp.py
- pytest -v -s -x lora/test_llama_tp.py
- pytest -v -s -x lora/test_minicpmv_tp.py


- label: Weight Loading Multiple GPU Test # 33min
Expand Down
77 changes: 0 additions & 77 deletions tests/lora/test_minicpmv.py

This file was deleted.

63 changes: 44 additions & 19 deletions tests/lora/test_minicpmv_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import pytest

import vllm
from tests.utils import fork_new_process_for_each_test
from vllm.assets.image import ImageAsset
from vllm.lora.request import LoRARequest

from ..utils import multi_gpu_test
from vllm.platforms import current_platform

MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5"

Expand All @@ -17,13 +17,11 @@

IMAGE_ASSETS = [
ImageAsset("stop_sign"),
ImageAsset("cherry_blossom"),
]

# After fine-tuning with LoRA, all generated content should start begin `A`.
EXPECTED_OUTPUT = [
"A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501
"A pink cherry blossom tree with a blue sky in the background.",
]


Expand All @@ -50,48 +48,75 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
# Print the outputs.
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
print(f"Generated text: {generated_text!r}")
return generated_texts


@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_minicpmv_tp2(minicpmv_lora_files, fully_sharded):
@pytest.mark.xfail(
current_platform.is_rocm(),
reason="MiniCPM-V dependency xformers incompatible with ROCm")
@fork_new_process_for_each_test
def test_minicpmv_lora(minicpmv_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_num_seqs=2,
enable_lora=True,
max_loras=2,
max_lora_rank=8,
enforce_eager=True,
trust_remote_code=True,
enable_chunked_prefill=True,
)
output1 = do_sample(llm, minicpmv_lora_files, lora_id=1)
for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output1[i])
output2 = do_sample(llm, minicpmv_lora_files, lora_id=2)
for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output2[i])


@pytest.mark.xfail(
current_platform.is_rocm(),
reason="MiniCPM-V dependency xformers incompatible with ROCm")
@fork_new_process_for_each_test
def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files):
llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_num_seqs=2,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=2,
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=fully_sharded,
enforce_eager=True,
enable_chunked_prefill=True,
)

output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)

for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output_tp[i])


@multi_gpu_test(num_gpus=4)
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_minicpmv_tp4(minicpmv_lora_files, fully_sharded):
@pytest.mark.xfail(
current_platform.is_rocm(),
reason="MiniCPM-V dependency xformers incompatible with ROCm")
@fork_new_process_for_each_test
def test_minicpmv_tp4_fully_sharded_loras(minicpmv_lora_files):
llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_num_seqs=2,
max_loras=4,
max_lora_rank=64,
max_loras=2,
max_lora_rank=8,
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=fully_sharded,
fully_sharded_loras=True,
enable_chunked_prefill=True,
)
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output_tp[i])
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=2)
for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output_tp[i])
Loading
Loading