Skip to content

Commit

Permalink
fix: make OpenAIMultiModal work with new ChatMessage (#17138)
Browse files Browse the repository at this point in the history
* support multiple text blocks

* fix typing

* add more features to ImageBlock

* cosmetics

* fix type checking

* make OpenAIMultiModal work with new ChatMessage

* remove leftovers

* fix mypy

* catch unsupported block types and raise an error

* fix tests

* fix tests

* fix more tests

* fix possible typo
  • Loading branch information
masci authored Dec 4, 2024
1 parent 0a94757 commit 024637b
Show file tree
Hide file tree
Showing 12 changed files with 313 additions and 235 deletions.
74 changes: 33 additions & 41 deletions docs/docs/examples/multi_modal/openai_multi_modal.ipynb

Large diffs are not rendered by default.

61 changes: 46 additions & 15 deletions llama-index-core/llama_index/core/base/llms/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Union,
)

import filetype
import requests
from typing_extensions import Self

Expand All @@ -25,9 +26,9 @@
FilePath,
field_serializer,
field_validator,
model_validator,
)
from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS
from llama_index.core.schema import ImageType


class MessageRole(str, Enum):
Expand All @@ -53,12 +54,7 @@ class ImageBlock(BaseModel):
path: FilePath | None = None
url: AnyUrl | str | None = None
image_mimetype: str | None = None

@field_validator("image", mode="after")
@classmethod
def image_to_base64(cls, image: bytes) -> bytes:
"""Store the image as base64."""
return base64.b64encode(image)
detail: str | None = None

@field_validator("url", mode="after")
@classmethod
Expand All @@ -68,16 +64,47 @@ def urlstr_to_anyurl(cls, url: str | AnyUrl) -> AnyUrl:
return url
return AnyUrl(url=url)

def resolve_image(self) -> ImageType:
"""Resolve an image such that PIL can read it."""
@model_validator(mode="after")
def image_to_base64(self) -> Self:
"""Store the image as base64 and guess the mimetype when possible.
In case the model was built passing image data but without a mimetype,
we try to guess it using the filetype library. To avoid resource-intense
operations, we won't load the path or the URL to guess the mimetype.
"""
if not self.image:
return self

if not self.image_mimetype:
guess = filetype.guess(self.image)
self.image_mimetype = guess.mime if guess else None

self.image = base64.b64encode(self.image)

return self

def resolve_image(self, as_base64: bool = False) -> BytesIO:
"""Resolve an image such that PIL can read it.
Args:
as_base64 (bool): whether the resolved image should be returned as base64-encoded bytes
"""
if self.image is not None:
if as_base64:
return BytesIO(self.image)
return BytesIO(base64.b64decode(self.image))
elif self.path is not None:
return BytesIO(self.path.read_bytes())
img_bytes = self.path.read_bytes()
if as_base64:
return BytesIO(base64.b64encode(img_bytes))
return BytesIO(img_bytes)
elif self.url is not None:
# load image from URL
response = requests.get(str(self.url))
return BytesIO(response.content)
img_bytes = response.content
if as_base64:
return BytesIO(base64.b64encode(img_bytes))
return BytesIO(img_bytes)
else:
raise ValueError("No image found in the chat message!")

Expand Down Expand Up @@ -113,11 +140,15 @@ def content(self) -> str | None:
"""Keeps backward compatibility with the old `content` field.
Returns:
The block content if there's a single TextBlock, an empty string otherwise.
The cumulative content of the blocks if they're all of type TextBlock, None otherwise.
"""
if len(self.blocks) == 1 and isinstance(self.blocks[0], TextBlock):
return self.blocks[0].text
return None
content = ""
for block in self.blocks:
if not isinstance(block, TextBlock):
return None
content += block.text

return content

@content.setter
def content(self, content: str) -> None:
Expand Down
19 changes: 9 additions & 10 deletions llama-index-core/llama_index/core/llms/function_calling.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
from typing import Any, Dict, List, Optional, Sequence, Union, TYPE_CHECKING
from abc import abstractmethod
import asyncio

from llama_index.core.base.llms.types import (
ChatMessage,
)
from llama_index.core.llms.llm import LLM
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union

from llama_index.core.base.llms.types import (
ChatMessage,
ChatResponse,
ChatResponseGen,
ChatResponseAsyncGen,
ChatResponseGen,
)
from llama_index.core.llms.llm import ToolSelection
from llama_index.core.llms.llm import LLM, ToolSelection

if TYPE_CHECKING:
from llama_index.core.chat_engine.types import AgentChatResponse
Expand All @@ -27,6 +22,10 @@ class FunctionCallingLLM(LLM):
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
# Help static checkers understand this class hierarchy
super().__init__(*args, **kwargs)

def chat_with_tools(
self,
tools: Sequence["BaseTool"],
Expand Down Expand Up @@ -222,10 +221,10 @@ async def apredict_and_call(
**kwargs: Any,
) -> "AgentChatResponse":
"""Predict and call the tool."""
from llama_index.core.chat_engine.types import AgentChatResponse
from llama_index.core.tools.calling import (
acall_tool_with_selection,
)
from llama_index.core.chat_engine.types import AgentChatResponse

if not self.metadata.is_function_calling_model:
return await super().apredict_and_call(
Expand Down
8 changes: 6 additions & 2 deletions llama-index-core/llama_index/core/multi_modal_llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
)
from llama_index.core.bridge.pydantic import (
BaseModel,
Field,
ConfigDict,
Field,
)
from llama_index.core.callbacks import CallbackManager
from llama_index.core.constants import (
Expand All @@ -29,7 +29,7 @@
DEFAULT_NUM_OUTPUTS,
)
from llama_index.core.instrumentation import DispatcherSpanMixin
from llama_index.core.llms.callbacks import llm_completion_callback, llm_chat_callback
from llama_index.core.llms.callbacks import llm_chat_callback, llm_completion_callback
from llama_index.core.schema import BaseComponent, ImageNode


Expand Down Expand Up @@ -86,6 +86,10 @@ class MultiModalLLM(ChainableMixin, BaseComponent, DispatcherSpanMixin):
default_factory=CallbackManager, exclude=True
)

def __init__(self, *args: Any, **kwargs: Any) -> None:
# Help static checkers understand this class hierarchy
super().__init__(*args, **kwargs)

@property
@abstractmethod
def metadata(self) -> MultiModalLLMMetadata:
Expand Down
56 changes: 40 additions & 16 deletions llama-index-core/tests/base/llms/test_types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import base64
from io import BytesIO
from pathlib import Path
from unittest import mock

import pytest
from llama_index.core.base.llms.types import (
ChatMessage,
Expand All @@ -7,16 +11,19 @@
TextBlock,
)
from llama_index.core.bridge.pydantic import BaseModel
from pathlib import Path
from unittest import mock
from pydantic import AnyUrl


@pytest.fixture()
def png_1px() -> bytes:
def png_1px_b64() -> bytes:
return b"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="


@pytest.fixture()
def png_1px(png_1px_b64) -> bytes:
return base64.b64decode(png_1px_b64)


def test_chat_message_from_str():
m = ChatMessage.from_str(content="test content")
assert m.content == "test content"
Expand All @@ -39,11 +46,12 @@ def test_chat_message_content_legacy_get():
assert type(m.blocks[0]) is TextBlock
assert m.blocks[0].text == "test content"

m = ChatMessage(content=[TextBlock(text="test content")])
assert m.content == "test content"
assert len(m.blocks) == 1
assert type(m.blocks[0]) is TextBlock
assert m.blocks[0].text == "test content"
m = ChatMessage(
content=[TextBlock(text="test content 1 "), TextBlock(text="test content 2")]
)
assert m.content == "test content 1 test content 2"
assert len(m.blocks) == 2
assert all(type(block) is TextBlock for block in m.blocks)


def test_chat_message_content_legacy_set():
Expand Down Expand Up @@ -105,14 +113,21 @@ def test_chat_message_legacy_roundtrip():
}


def test_image_block_resolve_image(png_1px: bytes):
def test_image_block_resolve_image(png_1px: bytes, png_1px_b64: bytes):
b = ImageBlock(image=png_1px)

img = b.resolve_image()
assert isinstance(img, BytesIO)
assert img.read() == png_1px

img = b.resolve_image(as_base64=True)
assert isinstance(img, BytesIO)
assert img.read() == png_1px_b64


def test_image_block_resolve_image_path(tmp_path: Path, png_1px: bytes):
def test_image_block_resolve_image_path(
tmp_path: Path, png_1px_b64: bytes, png_1px: bytes
):
png_path = tmp_path / "test.png"
png_path.write_bytes(png_1px)

Expand All @@ -121,23 +136,32 @@ def test_image_block_resolve_image_path(tmp_path: Path, png_1px: bytes):
assert isinstance(img, BytesIO)
assert img.read() == png_1px

img = b.resolve_image(as_base64=True)
assert isinstance(img, BytesIO)
assert img.read() == png_1px_b64


def test_image_block_resolve_image_url(png_1px):
def test_image_block_resolve_image_url(png_1px_b64: bytes, png_1px: bytes):
with mock.patch("llama_index.core.base.llms.types.requests") as mocked_req:
url_str = "http://example.com"
mocked_req.get.return_value = mock.MagicMock(content=png_1px)
b1 = ImageBlock(url=url_str)
img = b1.resolve_image()
b = ImageBlock(url=AnyUrl(url=url_str))
img = b.resolve_image()
assert isinstance(img, BytesIO)
assert img.read() == png_1px

b2 = ImageBlock(url=AnyUrl(url=url_str))
img = b1.resolve_image()
img = b.resolve_image(as_base64=True)
assert isinstance(img, BytesIO)
assert img.read() == png_1px
assert img.read() == png_1px_b64


def test_image_block_resolve_error():
with pytest.raises(ValueError, match="No image found in the chat message!"):
b = ImageBlock()
b.resolve_image()


def test_image_block_store_as_anyurl():
url_str = "http://example.com"
b = ImageBlock(url=url_str)
assert b.url == AnyUrl(url=url_str)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import List, Dict, Any
from types import MappingProxyType
from typing import Any, Dict, List
from unittest.mock import MagicMock, call, patch

from llama_index.core.base.llms.types import ChatMessage, MessageRole
Expand Down Expand Up @@ -133,7 +133,9 @@ def test_chat(MockSyncOpenAI: MagicMock) -> None:
response = llm.chat([ChatMessage(role=MessageRole.USER, content="test message")])
assert response.message.content == content
mock_instance.chat.completions.create.assert_called_once_with(
messages=[{"role": MessageRole.USER, "content": "test message"}],
messages=[
{"role": "user", "content": [{"type": "text", "text": "test message"}]}
],
stream=False,
model=STUB_MODEL_NAME,
temperature=0.1,
Expand Down
Loading

0 comments on commit 024637b

Please sign in to comment.