Skip to content

Commit

Permalink
[Frontend] Multimodal support in offline chat (vllm-project#8098)
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored Sep 4, 2024
1 parent 2be8ec6 commit 855c262
Show file tree
Hide file tree
Showing 8 changed files with 356 additions and 112 deletions.
34 changes: 34 additions & 0 deletions tests/entrypoints/llm/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from vllm import LLM, RequestOutput, SamplingParams

from ...conftest import cleanup
from ..openai.test_vision import TEST_IMAGE_URLS

MODEL_NAME = "facebook/opt-125m"

Expand Down Expand Up @@ -159,3 +160,36 @@ def test_chat():
]
outputs = llm.chat(messages)
assert len(outputs) == 1


@pytest.mark.parametrize("image_urls",
[[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]])
def test_chat_multi_image(image_urls: List[str]):
llm = LLM(
model="microsoft/Phi-3.5-vision-instruct",
dtype="bfloat16",
max_model_len=4096,
max_num_seqs=5,
enforce_eager=True,
trust_remote_code=True,
limit_mm_per_prompt={"image": 2},
)

messages = [{
"role":
"user",
"content": [
*({
"type": "image_url",
"image_url": {
"url": image_url
}
} for image_url in image_urls),
{
"type": "text",
"text": "What's in this image?"
},
],
}]
outputs = llm.chat(messages)
assert len(outputs) >= 0
164 changes: 124 additions & 40 deletions tests/entrypoints/test_chat_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import warnings
from typing import Optional

import pytest
from PIL import Image

from vllm.assets.image import ImageAsset
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import parse_chat_messages
from vllm.entrypoints.chat_utils import (parse_chat_messages,
parse_chat_messages_futures)
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import encode_image_base64
from vllm.transformers_utils.tokenizer_group import TokenizerGroup

Expand Down Expand Up @@ -42,10 +45,28 @@ def image_url():
return f"data:image/jpeg;base64,{base64}"


@pytest.mark.asyncio
async def test_parse_chat_messages_with_image_url(phi3v_model_config,
phi3v_tokenizer, image_url):
conversation, mm_future = parse_chat_messages([{
def _assert_mm_data_is_image_input(
mm_data: Optional[MultiModalDataDict],
image_count: int,
) -> None:
assert mm_data is not None
assert set(mm_data.keys()) == {"image"}

image_data = mm_data.get("image")
assert image_data is not None

if image_count == 1:
assert isinstance(image_data, Image.Image)
else:
assert isinstance(image_data, list) and len(image_data) == image_count


def test_parse_chat_messages_single_image(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
conversation, mm_data = parse_chat_messages([{
"role":
"user",
"content": [{
Expand All @@ -63,15 +84,42 @@ async def test_parse_chat_messages_with_image_url(phi3v_model_config,
"role": "user",
"content": "<|image_1|>\nWhat's in the image?"
}]
mm_data = await mm_future
assert set(mm_data.keys()) == {"image"}
assert isinstance(mm_data["image"], Image.Image)
_assert_mm_data_is_image_input(mm_data, 1)


@pytest.mark.asyncio
async def test_parse_chat_messages_multiple_images(phi3v_model_config,
phi3v_tokenizer, image_url):
conversation, mm_future = parse_chat_messages([{
async def test_parse_chat_messages_single_image_async(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
conversation, mm_future = parse_chat_messages_futures([{
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "text",
"text": "What's in the image?"
}]
}], phi3v_model_config, phi3v_tokenizer)

assert conversation == [{
"role": "user",
"content": "<|image_1|>\nWhat's in the image?"
}]
_assert_mm_data_is_image_input(await mm_future, 1)


def test_parse_chat_messages_multiple_images(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
conversation, mm_data = parse_chat_messages([{
"role":
"user",
"content": [{
Expand All @@ -96,15 +144,49 @@ async def test_parse_chat_messages_multiple_images(phi3v_model_config,
"content":
"<|image_1|>\n<|image_2|>\nWhat's in these images?"
}]
mm_data = await mm_future
assert set(mm_data.keys()) == {"image"}
assert len(mm_data["image"]) == 2
_assert_mm_data_is_image_input(mm_data, 2)


@pytest.mark.asyncio
async def test_parse_chat_messages_placeholder_already_in_prompt(
phi3v_model_config, phi3v_tokenizer, image_url):
conversation, mm_future = parse_chat_messages([{
async def test_parse_chat_messages_multiple_images_async(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
conversation, mm_future = parse_chat_messages_futures([{
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "text",
"text": "What's in these images?"
}]
}], phi3v_model_config, phi3v_tokenizer)

assert conversation == [{
"role":
"user",
"content":
"<|image_1|>\n<|image_2|>\nWhat's in these images?"
}]
_assert_mm_data_is_image_input(await mm_future, 2)


def test_parse_chat_messages_placeholder_already_in_prompt(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
conversation, mm_data = parse_chat_messages([{
"role":
"user",
"content": [{
Expand All @@ -131,15 +213,15 @@ async def test_parse_chat_messages_placeholder_already_in_prompt(
"content":
"What's in <|image_1|> and how does it compare to <|image_2|>?"
}]
mm_data = await mm_future
assert set(mm_data.keys()) == {"image"}
assert len(mm_data["image"]) == 2
_assert_mm_data_is_image_input(mm_data, 2)


@pytest.mark.asyncio
async def test_parse_chat_messages_placeholder_one_already_in_prompt(
phi3v_model_config, phi3v_tokenizer, image_url):
conversation, mm_future = parse_chat_messages([{
def test_parse_chat_messages_placeholder_one_already_in_prompt(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
conversation, mm_data = parse_chat_messages([{
"role":
"user",
"content": [{
Expand Down Expand Up @@ -167,15 +249,15 @@ async def test_parse_chat_messages_placeholder_one_already_in_prompt(
"<|image_2|>\nWhat's in <|image_1|> and how does it compare to the "
"other one?"
}]
mm_data = await mm_future
assert set(mm_data.keys()) == {"image"}
assert len(mm_data["image"]) == 2
_assert_mm_data_is_image_input(mm_data, 2)


@pytest.mark.asyncio
async def test_parse_chat_messages_multiple_images_across_messages(
phi3v_model_config, phi3v_tokenizer, image_url):
conversation, mm_future = parse_chat_messages([{
def test_parse_chat_messages_multiple_images_across_messages(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
conversation, mm_data = parse_chat_messages([{
"role":
"user",
"content": [{
Expand Down Expand Up @@ -218,14 +300,14 @@ async def test_parse_chat_messages_multiple_images_across_messages(
"content": "<|image_2|>\nWhat about this one?"
},
]
mm_data = await mm_future
assert set(mm_data.keys()) == {"image"}
assert len(mm_data["image"]) == 2
_assert_mm_data_is_image_input(mm_data, 2)


@pytest.mark.asyncio
async def test_parse_chat_messages_rejects_too_many_images_in_one_message(
phi3v_model_config, phi3v_tokenizer, image_url):
def test_parse_chat_messages_rejects_too_many_images_in_one_message(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
Expand Down Expand Up @@ -259,9 +341,11 @@ async def test_parse_chat_messages_rejects_too_many_images_in_one_message(
}], phi3v_model_config, phi3v_tokenizer)


@pytest.mark.asyncio
async def test_parse_chat_messages_rejects_too_many_images_across_messages(
phi3v_model_config, phi3v_tokenizer, image_url):
def test_parse_chat_messages_rejects_too_many_images_across_messages(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
Expand Down
Loading

0 comments on commit 855c262

Please sign in to comment.