Skip to content

Commit

Permalink
[V1] Make v1 more testable (vllm-project#9888)
Browse files Browse the repository at this point in the history
Signed-off-by: Joe Runde <[email protected]>
  • Loading branch information
joerunde authored Nov 6, 2024
1 parent 87bd7e0 commit d58268c
Show file tree
Hide file tree
Showing 75 changed files with 243 additions and 165 deletions.
3 changes: 3 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ ADD . /vllm-workspace/
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -r requirements-dev.txt

# Copy in the v1 package for testing (it isn't distributed yet)
COPY vllm/v1 /usr/local/lib/python3.12/dist-packages/vllm/v1

# doc requires source code
# we hide them inside `test_docs/` , so that this source code
# will not be imported by other tests
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,5 @@ markers = [
"skip_global_cleanup",
"core_model: run this model test in each PR instead of just daily",
"distributed_2_gpus: run this test only in distributed tests for 2 GPUs",
"skip_v1: do not run this test with v1",
]
18 changes: 18 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from enum import Enum
from typing import (Any, Callable, Dict, List, Optional, Tuple, Type,
TypedDict, TypeVar, Union)
from unittest.mock import patch

import numpy as np
import pytest
Expand Down Expand Up @@ -108,6 +109,23 @@ def prompts(self, prompts: _VideoAssetPrompts) -> List[str]:
"""Singleton instance of :class:`_VideoAssets`."""


@pytest.fixture(params=[True, False])
def run_with_both_engines(request):
# Automatically runs tests twice, once with V1 and once without
use_v1 = request.param
# Tests decorated with `@skip_v1` are only run without v1
skip_v1 = request.node.get_closest_marker("skip_v1")

if use_v1:
if skip_v1:
pytest.skip("Skipping test on vllm V1")
with patch('vllm.envs.VLLM_USE_V1', True):
yield
else:
with patch('vllm.envs.VLLM_USE_V1', False):
yield


@pytest.fixture(autouse=True)
def init_test_http_connection():
# pytest_asyncio may use a different event loop per test
Expand Down
9 changes: 9 additions & 0 deletions tests/entrypoints/llm/test_prompt_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,21 @@
from vllm import LLM


@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


def test_empty_prompt():
llm = LLM(model="gpt2", enforce_eager=True)
with pytest.raises(ValueError, match='Prompt cannot be empty'):
llm.generate([""])


@pytest.mark.skip_v1
def test_out_of_vocab_token():
llm = LLM(model="gpt2", enforce_eager=True)
with pytest.raises(ValueError, match='out of vocabulary'):
Expand Down
2 changes: 2 additions & 0 deletions tests/kernels/test_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def test_env(name: str, device: str, monkeypatch):

def test_flash_attn(monkeypatch):
"""Test FlashAttn validation."""
# TODO: When testing for v1, pipe in `use_v1` as an argument to
# which_attn_to_use

override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL)

Expand Down
4 changes: 2 additions & 2 deletions tests/kernels/test_encoder_decoder_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
AttentionType)
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from vllm.attention.selector import (_Backend, get_attn_backend,
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
global_force_attn_backend_context_manager)
from vllm.forward_context import set_forward_context
from vllm.platforms import current_platform
Expand Down Expand Up @@ -774,7 +774,7 @@ def set_reset_environment(attn_backend):
default_dtype = torch.get_default_dtype()
if attn_backend.name == 'FLASH_ATTN':
torch.set_default_dtype(torch.bfloat16)
get_attn_backend.cache_clear()
_cached_get_attn_backend.cache_clear()
yield
# Reset the torch datatype to what it was before the test
# so as not to impact the remaining tests.
Expand Down
43 changes: 33 additions & 10 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
return forced_attn_backend


@lru_cache(maxsize=None)
def get_attn_backend(
head_size: int,
dtype: torch.dtype,
Expand All @@ -99,14 +98,39 @@ def get_attn_backend(
is_blocksparse: bool = False,
) -> Type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
# value to be returned from the cache if the value changes between calls.
# To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
# private function.
return _cached_get_attn_backend(
head_size=head_size,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
block_size=block_size,
is_attention_free=is_attention_free,
is_blocksparse=is_blocksparse,
use_v1=envs.VLLM_USE_V1,
)


@lru_cache(maxsize=None)
def _cached_get_attn_backend(
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool,
is_blocksparse: bool = False,
use_v1: bool = False,
) -> Type[AttentionBackend]:
if is_blocksparse:
logger.info("Using BlocksparseFlashAttention backend.")
from vllm.attention.backends.blocksparse_attn import (
BlocksparseFlashAttentionBackend)
return BlocksparseFlashAttentionBackend

backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size,
is_attention_free)
is_attention_free, use_v1)
if backend == _Backend.FLASH_ATTN:
logger.info("Using Flash Attention backend.")
from vllm.attention.backends.flash_attn import ( # noqa: F401
Expand Down Expand Up @@ -162,13 +186,12 @@ def get_attn_backend(
raise ValueError("Invalid attention backend.")


def which_attn_to_use(
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool,
) -> _Backend:
def which_attn_to_use(head_size: int,
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool,
use_v1: bool = False) -> _Backend:
"""Returns which flash attention backend to use."""
# Default case.
selected_backend = _Backend.FLASH_ATTN
Expand Down Expand Up @@ -228,7 +251,7 @@ def which_attn_to_use(
if current_platform.is_hpu():
return _Backend.HPU_ATTN

if envs.VLLM_USE_V1:
if use_v1:
return _Backend.FLASH_ATTN_VLLM_V1

# FlashAttn in NVIDIA GPUs.
Expand Down
18 changes: 10 additions & 8 deletions vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import cloudpickle
import zmq

import vllm.envs
from vllm import AsyncEngineArgs, SamplingParams
from vllm.engine.llm_engine import LLMEngine
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
Expand All @@ -17,17 +19,11 @@
RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest)
# yapf: enable
from vllm.envs import VLLM_USE_V1
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.usage.usage_lib import UsageContext

if VLLM_USE_V1:
from vllm.v1.engine.llm_engine import LLMEngine
else:
from vllm.engine.llm_engine import LLMEngine

logger = init_logger(__name__)

POLLING_TIMEOUT_MS = 10000
Expand Down Expand Up @@ -117,11 +113,17 @@ def from_engine_args(cls, engine_args: AsyncEngineArgs,
load_general_plugins()

engine_config = engine_args.create_engine_config()
if vllm.envs.VLLM_USE_V1:
# Lazy import: the v1 package isn't distributed
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
engine_class = V1LLMEngine
else:
engine_class = LLMEngine

executor_class = LLMEngine._get_executor_cls(engine_config)
executor_class = engine_class._get_executor_cls(engine_config)

use_async_sockets = (engine_config.model_config.use_async_output_proc
and not VLLM_USE_V1)
and not vllm.envs.VLLM_USE_V1)

return cls(ipc_path=ipc_path,
use_async_sockets=use_async_sockets,
Expand Down
26 changes: 17 additions & 9 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import itertools
import warnings
from contextlib import contextmanager
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type,
Union, cast, overload)

from tqdm import tqdm
Expand All @@ -10,6 +10,7 @@
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, get_beam_search_score)
from vllm.engine.arg_utils import EngineArgs, TaskOption
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_hf_chat_template,
apply_mistral_chat_template,
Expand All @@ -31,11 +32,6 @@
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of

if envs.VLLM_USE_V1:
from vllm.v1.engine.llm_engine import LLMEngine # type: ignore
else:
from vllm.engine.llm_engine import LLMEngine # type: ignore

logger = init_logger(__name__)


Expand Down Expand Up @@ -206,10 +202,21 @@ def __init__(
pooling_returned_token_ids=pooling_returned_token_ids,
**kwargs,
)
self.llm_engine = LLMEngine.from_engine_args(
# Logic to switch between engines is done at runtime instead of import
# to avoid import order issues
self.engine_class = self.get_engine_class()
self.llm_engine = self.engine_class.from_engine_args(
engine_args, usage_context=UsageContext.LLM_CLASS)
self.request_counter = Counter()

@staticmethod
def get_engine_class() -> Type[LLMEngine]:
if envs.VLLM_USE_V1:
# Lazy import: the v1 package isn't distributed
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
return V1LLMEngine # type: ignore
return LLMEngine

def get_tokenizer(self) -> AnyTokenizer:
return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer

Expand Down Expand Up @@ -394,7 +401,7 @@ def generate(
priority=priority)

outputs = self._run_engine(use_tqdm=use_tqdm)
return LLMEngine.validate_outputs(outputs, RequestOutput)
return self.engine_class.validate_outputs(outputs, RequestOutput)

def beam_search(
self,
Expand Down Expand Up @@ -769,7 +776,8 @@ def encode(
)

outputs = self._run_engine(use_tqdm=use_tqdm)
return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput)
return self.engine_class.validate_outputs(outputs,
EmbeddingRequestOutput)

def start_profile(self) -> None:
self.llm_engine.start_profile()
Expand Down
9 changes: 9 additions & 0 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@
else:
flashinfer_top_k_top_p_sampling = None


def get_sampler() -> torch.nn.Module:
if envs.VLLM_USE_V1:
# Lazy import: the v1 package isn't distributed
from vllm.v1.sample.sampler import Sampler as V1Sampler
return V1Sampler()
return Sampler()


# (num_token_ids, num_parent_ids) per sequence group.
SampleResultType = List[Tuple[List[int], List[int]]]

Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/arctic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from vllm.model_executor.layers.quantization.deepspeedfp import (
DeepSpeedFPConfig, DeepSpeedFPParameter)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Expand Down Expand Up @@ -436,7 +436,7 @@ def __init__(self,
self.unpadded_vocab_size = config.vocab_size
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Expand Down Expand Up @@ -352,7 +352,7 @@ def __init__(
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Expand Down Expand Up @@ -838,7 +838,7 @@ def __init__(self,

self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()
self.sampler = get_sampler()

def forward(
self,
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import consecutive_placeholder_ranges
Expand Down Expand Up @@ -525,7 +525,7 @@ def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler

return Sampler()
return get_sampler()

def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Expand Down Expand Up @@ -298,7 +298,7 @@ def __init__(
self.config.hidden_size)

self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)

Expand Down
Loading

0 comments on commit d58268c

Please sign in to comment.