Skip to content

Commit

Permalink
Merge branch 'main' into omer/run-loader
Browse files Browse the repository at this point in the history
  • Loading branch information
omer-dayan authored Nov 17, 2024
2 parents efa4705 + 643ecf7 commit 11a2fe8
Show file tree
Hide file tree
Showing 74 changed files with 994 additions and 388 deletions.
17 changes: 10 additions & 7 deletions .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@

# This lists cover the "core" components of vLLM that require careful review
/vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
/vllm/core @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
/vllm/engine/llm_engine.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
/vllm/executor/executor_base.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
/vllm/worker/worker_base.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
/vllm/worker/worker.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
/vllm/model_executor/layers/sampler.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
CMakeLists.txt @tlrmchlsmth @WoosukKwon
/vllm/core @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
/vllm/engine/llm_engine.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
/vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
CMakeLists.txt @tlrmchlsmth

# vLLM V1
/vllm/v1 @WoosukKwon @robertgshaw2-neuralmagic @njhill @ywang96 @comaniac @alexm-neuralmagic

# Test ownership
/tests/async_engine @njhill @robertgshaw2-neuralmagic @simon-mo
Expand Down
8 changes: 5 additions & 3 deletions tests/compile/piecewise/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel
from vllm.config import VllmConfig
from vllm.config import CompilationLevel, VllmConfig
from vllm.plugins import set_current_vllm_config
from vllm.utils import direct_register_custom_op

global_counter = 0
Expand Down Expand Up @@ -82,7 +82,9 @@ def test_simple_piecewise_compile():
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)

model = SillyModel(vllm_config=VllmConfig(), prefix='')
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
model = SillyModel(vllm_config=vllm_config, prefix='')

inputs = torch.randn(100).cuda()

Expand Down
22 changes: 12 additions & 10 deletions tests/compile/piecewise/test_toy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@
from torch.library import Library

from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.config import CompilationConfig
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel
from vllm.config import VllmConfig
from vllm.plugins import set_compilation_config
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.plugins import set_compilation_config, set_current_vllm_config
from vllm.utils import direct_register_custom_op

# create a library to hold the custom op
Expand Down Expand Up @@ -272,9 +270,11 @@ def run_model(llama_config,
CompilationLevel.NO_COMPILATION)
set_compilation_config(None)

model = LlamaModel(config=llama_config,
vllm_config=VllmConfig(),
prefix="").eval().cuda()
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
model = LlamaModel(config=llama_config,
vllm_config=vllm_config,
prefix="").eval().cuda()

B = 16 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
Expand Down Expand Up @@ -395,9 +395,11 @@ def benchmark():
else:
set_compilation_config(None)

model = LlamaModel(config=llama_config,
vllm_config=VllmConfig(),
prefix="").eval().cuda().to(torch.bfloat16)
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
model = LlamaModel(config=llama_config,
vllm_config=vllm_config,
prefix="").eval().cuda().to(torch.bfloat16)

B = 256 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
Expand Down
2 changes: 1 addition & 1 deletion tests/compile/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from vllm.compilation.levels import CompilationLevel
from vllm.config import CompilationLevel
from vllm.utils import cuda_device_count_stateless

from ..utils import compare_all_settings
Expand Down
2 changes: 1 addition & 1 deletion tests/compile/test_full_graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from vllm.compilation.levels import CompilationLevel
from vllm.config import CompilationLevel

from ..utils import fork_new_process_for_each_test
from .utils import TEST_MODELS, check_full_graph_support
Expand Down
2 changes: 1 addition & 1 deletion tests/compile/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from compressed_tensors.quantization import FP8_DTYPE

import vllm.envs as envs
from vllm.compilation.config import CompilationConfig
from vllm.compilation.fusion import (FusionPass, find_auto_fn,
find_auto_fn_maybe)
from vllm.compilation.reshapes import RedundantReshapesPass
from vllm.config import CompilationConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear)
Expand Down
4 changes: 3 additions & 1 deletion tests/compile/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import CompilationLevel


class MyMod(torch.nn.Module):
Expand All @@ -18,7 +19,8 @@ class MyWrapper(TorchCompileWrapperWithCustomDispatcher):
def __init__(self, model):
self.model = model
compiled_callable = torch.compile(self.forward, backend="eager")
super().__init__(compiled_callable)
super().__init__(compiled_callable,
compilation_level=CompilationLevel.DYNAMO_ONCE)

def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
# this is the function to be compiled
Expand Down
2 changes: 1 addition & 1 deletion tests/compile/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
from vllm.compilation.levels import CompilationLevel
from vllm.config import CompilationLevel
from vllm.platforms import current_platform

TEST_MODELS = [
Expand Down
52 changes: 26 additions & 26 deletions tests/model_executor/test_enabled_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

import pytest

from vllm.config import CompilationConfig, VllmConfig
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import (GeluAndMul,
ReLUSquaredActivation,
SiluAndMul)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.plugins import set_current_vllm_config


# Registered subclass for test
Expand Down Expand Up @@ -51,42 +53,40 @@ class Relu3(ReLUSquaredActivation):
])
def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int],
default_on: bool):
os.environ["VLLM_CUSTOM_OPS"] = env
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level)
vllm_config = VllmConfig(compilation_config=CompilationConfig(
custom_ops=env.split(",")))
with set_current_vllm_config(vllm_config):
assert CustomOp.default_on() == default_on

# Reset default_on (computed once):
CustomOp.default_on.cache_clear()
ops_enabled = [bool(x) for x in ops_enabled]

assert CustomOp.default_on() == default_on
assert RMSNorm(1024).enabled() == ops_enabled[0]
assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0]

ops_enabled = [bool(x) for x in ops_enabled]
assert SiluAndMul().enabled() == ops_enabled[1]
assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1]

assert RMSNorm(1024).enabled() == ops_enabled[0]
assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0]
assert GeluAndMul().enabled() == ops_enabled[2]
assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2]

assert SiluAndMul().enabled() == ops_enabled[1]
assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1]
# If registered, subclasses should follow their own name
assert Relu3().enabled() == ops_enabled[3]
assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3]

assert GeluAndMul().enabled() == ops_enabled[2]
assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2]
# Unregistered subclass
class SiluAndMul2(SiluAndMul):
pass

# If registered, subclasses should follow their own name
assert Relu3().enabled() == ops_enabled[3]
assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3]

# Unregistered subclass
class SiluAndMul2(SiluAndMul):
pass

# Subclasses should not require registration
assert SiluAndMul2().enabled() == SiluAndMul().enabled()
# Subclasses should not require registration
assert SiluAndMul2().enabled() == SiluAndMul().enabled()


@pytest.mark.parametrize(
"env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"])
def test_enabled_ops_invalid(env: str):
os.environ["VLLM_CUSTOM_OPS"] = env
CustomOp.default_on.cache_clear()

with pytest.raises(AssertionError):
RMSNorm(1024).enabled()
with pytest.raises(Exception): # noqa
vllm_config = VllmConfig(compilation_config=CompilationConfig(
custom_ops=env.split(",")))
with set_current_vllm_config(vllm_config):
RMSNorm(1024).enabled()
136 changes: 132 additions & 4 deletions tests/models/decoder_only/vision_language/test_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

IMAGE_PLACEHOLDER = "<|vision_start|><|image_pad|><|vision_end|>"
VIDEO_PLACEHOLDER = "<|vision_start|><|video_pad|><|vision_end|>"
MODEL_HIDDEN_SIZE = 1536


def qwen2_vl_chat_template(*query):
Expand Down Expand Up @@ -230,7 +231,7 @@ def batch_make_video_embeddings(
return result


def run_test(
def run_embedding_input_test(
vllm_runner: Type[VllmRunner],
inputs: List[Tuple[List[str], PromptImageInput, PromptVideoInput]],
model: str,
Expand Down Expand Up @@ -326,7 +327,7 @@ def test_qwen2_vl_image_embeddings_input(vllm_runner, image_assets, model,
[],
) for image, prompt in zip(images, IMAGE_PROMPTS)]

run_test(
run_embedding_input_test(
vllm_runner,
inputs_per_case,
model,
Expand Down Expand Up @@ -371,7 +372,7 @@ def test_qwen2_vl_multiple_image_embeddings_input(vllm_runner, image_assets,
[],
)]

run_test(
run_embedding_input_test(
vllm_runner,
inputs_per_case,
model,
Expand Down Expand Up @@ -416,7 +417,134 @@ def test_qwen2_vl_video_embeddings_input(vllm_runner, video_assets, model,
[rescale_video_size(video, factor) for factor in size_factors],
) for video, prompt in zip(sampled_vids, VIDEO_PROMPTS)]

run_test(
run_embedding_input_test(
vllm_runner,
inputs_per_case,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
mm_limit=1,
tensor_parallel_size=1,
)


def run_chunked_prefill_test(
vllm_runner: Type[VllmRunner],
inputs: List[Tuple[List[str], PromptImageInput, PromptVideoInput]],
model: str,
*,
dtype: str,
max_tokens: int,
num_logprobs: int,
mm_limit: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Compare inference result between
chunked prefill disabled and chunked prefill enabled
"""

# NOTE:
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
task="generate",
max_model_len=4000,
max_num_seqs=4,
dtype=dtype,
limit_mm_per_prompt={
"image": mm_limit,
"video": mm_limit
},
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend
) as vllm_model:

outputs_per_case = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images or None,
videos=videos or None)
for prompts, images, videos in inputs
]

with vllm_runner(
model,
task="generate",
max_model_len=4000,
max_num_seqs=4,
dtype=dtype,
limit_mm_per_prompt={
"image": mm_limit,
"video": mm_limit
},
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enable_chunked_prefill=True,
# should be small enough to ensure prefilling is chunked
max_num_batched_tokens=32,
mm_processor_kwargs={
"max_pixels": 16 * 28 * 28,
}) as vllm_model_chunked:
outputs_per_case_chunked = [
vllm_model_chunked.generate_greedy_logprobs(
prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images or None,
videos=videos or None) for prompts, images, videos in inputs
]

for outputs, \
outputs_chunked \
in zip(outputs_per_case,
outputs_per_case_chunked):
check_logprobs_close(
outputs_0_lst=outputs,
outputs_1_lst=outputs_chunked,
name_0="non_chunked",
name_1="chunked",
)


@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [1])
@pytest.mark.parametrize("num_logprobs", [10])
def test_qwen2_vl_mrope_chunked_prefill(vllm_runner, example_prompts,
model: str, dtype: str,
max_tokens: int,
num_logprobs: int) -> None:
"""
Test Qwen2-VL's chunked prefill with M-RoPE
"""
prompts = [
qwen2_vl_chat_template(IMAGE_PLACEHOLDER, prompt)
for prompt in example_prompts[:1]
]

# 1. Qwen2-VL's M-RoPE works only when there are some multi-modal inputs,
# so an image is included in the inputs
# 2. however, Qwen2-VL currently won't work properly
# when chunked prefill is enabled and there are some multi-modal inputs,
# here use a hacky way: provide a **zero-length** image to make it happy
#
# and finally we achieved:
# (1) chunked_prefill enabled; (2) M-RoPE works; to continue our tests
zero_len_image = {
"image_embeds": torch.empty((0, MODEL_HIDDEN_SIZE)),
"image_grid_thw": torch.tensor([[0, 0, 0]])
}
images = [zero_len_image] * len(prompts)

inputs_per_case: List[Tuple[List[str], PromptImageInput,
PromptVideoInput]] = [
(prompts, images, []),
]

run_chunked_prefill_test(
vllm_runner,
inputs_per_case,
model,
Expand Down
Loading

0 comments on commit 11a2fe8

Please sign in to comment.