Skip to content

Commit

Permalink
Merge branch 'main' into omer/run-loader
Browse files Browse the repository at this point in the history
  • Loading branch information
omer-dayan committed Dec 19, 2024
2 parents 4ea40d1 + a985f7a commit e2fd6bf
Show file tree
Hide file tree
Showing 17 changed files with 354 additions and 692 deletions.
15 changes: 15 additions & 0 deletions .buildkite/release-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,18 @@ steps:
password-env: DOCKERHUB_TOKEN
env:
DOCKER_BUILDKIT: "1"

- block: "Build CPU release image"
key: block-cpu-release-image-build
depends_on: ~

- label: "Build and publish CPU release image"
depends_on: block-cpu-release-image-build
agents:
queue: cpu_queue_postmerge
commands:
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$RELEASE_VERSION --progress plain -f Dockerfile.cpu ."
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$RELEASE_VERSION"
env:
DOCKER_BUILDKIT: "1"
8 changes: 6 additions & 2 deletions examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,6 @@ def run_qwen_vl(question: str, modality: str):

# Qwen2-VL
def run_qwen2_vl(question: str, modality: str):
assert modality == "image"

model_name = "Qwen/Qwen2-VL-7B-Instruct"

Expand All @@ -463,8 +462,13 @@ def run_qwen2_vl(question: str, modality: str):
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)

if modality == "image":
placeholder = "<|image_pad|>"
elif modality == "video":
placeholder = "<|video_pad|>"

prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n")
stop_token_ids = None
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from typing import Any, Dict, Tuple

import pytest
import torch
from PIL.Image import Image
from transformers import AutoTokenizer

from vllm.inputs import InputContext, token_inputs
from vllm.multimodal import MultiModalRegistry
from vllm.inputs import InputContext, InputProcessingContext

from .....conftest import _ImageAssets
from ....utils import build_model_context
Expand All @@ -20,22 +17,9 @@
# NOTE: Qwen2VL supports multiple input modalities, so it registers multiple
# input mappers.
@pytest.fixture()
def image_input_mapper_for_qwen2_vl():
from vllm.model_executor.models.qwen2_vl import (
image_input_mapper_for_qwen2_vl)
return image_input_mapper_for_qwen2_vl


@pytest.fixture()
def input_processor_for_qwen2_vl():
from vllm.model_executor.models.qwen2_vl import (
input_processor_for_qwen2_vl)
return input_processor_for_qwen2_vl


@pytest.fixture()
def qwen2_vl_context() -> InputContext:
return build_model_context(model_name=MODEL)
def processor_for_qwen2_vl():
from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalProcessor
return Qwen2VLMultiModalProcessor


@pytest.fixture()
Expand All @@ -45,123 +29,77 @@ def get_max_qwen2_vl_image_tokens():
return get_max_qwen2_vl_image_tokens


@pytest.fixture()
def dummy_data_for_qwen2_vl():
from vllm.model_executor.models.qwen2_vl import dummy_data_for_qwen2_vl
return dummy_data_for_qwen2_vl


@pytest.mark.parametrize("mm_processor_kwargs,expected_max_tokens", [
({}, 1225),
({
MIN_PIXELS: 64**2,
MAX_PIXELS: 512**2
}, 324),
])
def test_qwen2_vl_max_image_tokens(get_max_qwen2_vl_image_tokens,
qwen2_vl_context: InputContext,
mm_processor_kwargs: Dict[str, Any],
expected_max_tokens: int):
@pytest.mark.parametrize("model", [MODEL])
def test_qwen2_vl_max_image_tokens(
get_max_qwen2_vl_image_tokens,
model: str,
mm_processor_kwargs: Dict[str, Any],
expected_max_tokens: int,
):
"""Ensure that the max token calc handles min/max pixels properly."""
actual_max_tokens = get_max_qwen2_vl_image_tokens(qwen2_vl_context,
**mm_processor_kwargs)
assert actual_max_tokens == expected_max_tokens


@pytest.mark.parametrize("mm_processor_kwargs,token_count,img_size", [
[{}, 1225, (980, 980)],
[{
MIN_PIXELS: 64**2,
MAX_PIXELS: 512**2
}, 324, (504, 504)],
])
def test_qwen2_vl_dummy_data(dummy_data_for_qwen2_vl,
qwen2_vl_context: InputContext,
mm_processor_kwargs: Dict[str, Any],
token_count: int, img_size: Tuple[int, int]):
"""Ensure that the dummy data handles min/max pixels properly."""
seq_len = 3000
hf_config = qwen2_vl_context.get_hf_config()
image_token_id = hf_config.image_token_id

# NOTE: video value is required, but isn't actually used
# when making the dummy data except for error handling currently
dummy_data = dummy_data_for_qwen2_vl(
ctx=qwen2_vl_context,
seq_len=seq_len,
mm_counts={
"image": 1,
"video": 0
},
**mm_processor_kwargs,
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
mm_processor_kwargs=None,
)
seq_data = dummy_data.seq_data
mm_data = dummy_data.multi_modal_data

# Ensure we have the right number of placeholders for min/max pixel values
assert seq_data.get_token_ids().count(image_token_id) == token_count

# Ensure the images were resized correctly
image = mm_data["image"]
assert isinstance(image, Image)
assert image.size == img_size
actual_max_tokens = get_max_qwen2_vl_image_tokens(
InputContext(ctx.model_config), **mm_processor_kwargs)
assert actual_max_tokens == expected_max_tokens


@pytest.mark.parametrize("mm_processor_kwargs,num_placeholders", [
({}, 1426),
({
MIN_PIXELS: 64**2,
MAX_PIXELS: 512**2
}, 330),
])
def test_input_processor(input_processor_for_qwen2_vl,
qwen2_vl_context: InputContext,
image_assets: _ImageAssets, num_placeholders: int,
mm_processor_kwargs: Dict[str, Any]):
"""Ensure that the image processor handles min/max pixels properly."""
tokenizer = AutoTokenizer.from_pretrained(MODEL)
prompt = "<|vision_start|><|image_pad|><|vision_end|>"

image = image_assets[0].pil_image
hf_config = qwen2_vl_context.get_hf_config()
image_token_id = hf_config.image_token_id

inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt),
prompt=prompt,
multi_modal_data={"image": [image]})

processed_inputs = input_processor_for_qwen2_vl(qwen2_vl_context, inputs,
**mm_processor_kwargs)
assert processed_inputs["prompt_token_ids"].count(
image_token_id) == num_placeholders
assert len(processed_inputs["multi_modal_data"]["image"]) == 1


@pytest.mark.parametrize("mm_processor_kwargs,pixels_shape", [
({}, [5704, 1176]),
({
MIN_PIXELS: 64**2,
MAX_PIXELS: 512**2
}, [1320, 1176]),
])
def test_image_mapper_override(qwen2_vl_context: InputContext,
image_assets: _ImageAssets,
mm_processor_kwargs: Dict[str, Any],
pixels_shape: Tuple[int, int]):
"""Ensure that the image mapper handles min/max pixels properly."""
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(qwen2_vl_context.model_config)

image = image_assets[0].pil_image

mapped_output = mm_registry.map_input(
qwen2_vl_context.model_config,
{"image": image},
mm_processor_kwargs=mm_processor_kwargs,
@pytest.mark.parametrize(
"mm_processor_kwargs, expected_toks_per_img, expected_pixels_shape", [
({}, 1426, (5704, 1176)),
({
MIN_PIXELS: 64**2,
MAX_PIXELS: 512**2
}, 330, (1320, 1176)),
])
@pytest.mark.parametrize("model", [MODEL])
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_override(
processor_for_qwen2_vl,
image_assets: _ImageAssets,
model: str,
mm_processor_kwargs: Dict[str, Any],
expected_toks_per_img: int,
expected_pixels_shape: Tuple[int, int],
num_imgs: int,
):
"""Ensure Qwen2VLMultiModalProcessor handles min/max pixels properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
# the partial when calling the custom input processor.
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
mm_processor_kwargs=None,
)

# Dimension 0 of pixel values should match the product of image_grid_thw
actual_pixels_shape = mapped_output["pixel_values"].shape
assert list(actual_pixels_shape) == pixels_shape
assert actual_pixels_shape[0] == torch.prod(
mapped_output["image_grid_thw"])
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)
# Build the image str / prompt based on the number of images we pass
prompt = "<|vision_start|><|image_pad|><|vision_end|>" * num_imgs
images = [image_assets[0].pil_image] * num_imgs

mm_data = {"image": images}

processor = processor_for_qwen2_vl(ctx)
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)

# Ensure we have the right number of placeholders per num_crops size
hf_processor = processor._get_hf_processor(**mm_processor_kwargs)
image_token_id = tokenizer.convert_tokens_to_ids(hf_processor.image_token)
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values"].shape

assert img_tok_count == expected_toks_per_img * num_imgs
assert pixel_shape[0] == expected_pixels_shape[0] * num_imgs
assert pixel_shape[1] == expected_pixels_shape[1]
9 changes: 4 additions & 5 deletions vllm/adapter_commons/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Hashable, Optional, TypeVar
from typing import Any, Callable, Dict, Optional, TypeVar

from torch import nn

Expand All @@ -24,14 +24,13 @@ def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs):
T = TypeVar('T')


class AdapterLRUCache(LRUCache[T]):
class AdapterLRUCache(LRUCache[int, T]):

def __init__(self, capacity: int, deactivate_fn: Callable[[Hashable],
None]):
def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
super().__init__(capacity)
self.deactivate_fn = deactivate_fn

def _on_remove(self, key: Hashable, value: Optional[T]):
def _on_remove(self, key: int, value: Optional[T]):
logger.debug("Removing adapter int id: %d", key)
self.deactivate_fn(key)
return super()._on_remove(key, value)
Expand Down
88 changes: 0 additions & 88 deletions vllm/block.py

This file was deleted.

4 changes: 2 additions & 2 deletions vllm/core/evictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class EvictionPolicy(enum.Enum):

class Evictor(ABC):
"""The Evictor subclasses should be used by the BlockAllocator class to
handle eviction of freed PhysicalTokenBlocks.
handle eviction of freed Blocks.
"""

@abstractmethod
Expand Down Expand Up @@ -70,7 +70,7 @@ def __init__(self, content_hash: int, num_hashed_tokens: int,

class LRUEvictor(Evictor):
"""Evicts in a least-recently-used order using the last_accessed timestamp
that's recorded in the PhysicalTokenBlock. If there are multiple blocks with
that's recorded in the Block. If there are multiple blocks with
the same last_accessed time, then the one with the largest num_hashed_tokens
will be evicted. If two blocks each have the lowest last_accessed time and
highest num_hashed_tokens value, then one will be chose arbitrarily
Expand Down
Loading

0 comments on commit e2fd6bf

Please sign in to comment.