Skip to content

Commit

Permalink
Fix ChatMessage serialization with janky openai types (#16410)
Browse files Browse the repository at this point in the history
  • Loading branch information
logan-markewich authored Oct 7, 2024
1 parent 4655797 commit 62849af
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 18 deletions.
11 changes: 9 additions & 2 deletions .github/workflows/coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,14 @@ jobs:
CHANGED_ROOTS=""
FILTER_PATTERNS="["
for file in $CHANGED_FILES; do
root=$(echo "$file" | cut -d'/' -f1,2,3)
# Start with the full path
root="$file"
# Keep going up the directory tree until we find a directory containing a marker file
# (e.g., 'pyproject.toml' for python projects)
while [[ ! -f "$root/pyproject.toml" && "$root" != "." && "$root" != "/" ]]; do
root=$(dirname "$root")
done
if [[ ! "$FILTER_PATTERNS" =~ "$root" ]]; then
FILTER_PATTERNS="${FILTER_PATTERNS}'${root}',"
CHANGED_ROOTS="${CHANGED_ROOTS} ${root}/::"
Expand All @@ -68,7 +75,7 @@ jobs:
echo "Coverage filter patterns: $FILTER_PATTERNS"
echo "Changed roots: $CHANGED_ROOTS"
pants --no-local-cache test \
pants --level=error --no-local-cache test \
--test-use-coverage \
--coverage-py-filter="${FILTER_PATTERNS}" \
${CHANGED_ROOTS}
28 changes: 12 additions & 16 deletions llama-index-core/llama_index/core/base/llms/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
Any,
)

from llama_index.core.bridge.pydantic import BaseModel, Field, ConfigDict
from llama_index.core.bridge.pydantic import (
BaseModel,
Field,
ConfigDict,
field_serializer,
)
from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS
from llama_index.core.schema import ImageType

Expand Down Expand Up @@ -95,7 +100,8 @@ def from_str(
return cls(role=role, content=content, **kwargs)

def _recursive_serialization(self, value: Any) -> Any:
if isinstance(value, (V1BaseModel, V2BaseModel)):
if isinstance(value, V2BaseModel):
value.model_rebuild() # ensures all fields are initialized and serializable
return value.model_dump() # type: ignore
if isinstance(value, dict):
return {
Expand All @@ -106,23 +112,13 @@ def _recursive_serialization(self, value: Any) -> Any:
return [self._recursive_serialization(item) for item in value]
return value

@field_serializer("additional_kwargs", check_fields=False)
def serialize_additional_kwargs(self, value: Any, _info: Any) -> Any:
return self._recursive_serialization(value)

def dict(self, **kwargs: Any) -> Dict[str, Any]:
return self.model_dump(**kwargs)

def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
# ensure all additional_kwargs are serializable
msg = super().model_dump(**kwargs)

for key, value in msg.get("additional_kwargs", {}).items():
value = self._recursive_serialization(value)
if not isinstance(value, (str, int, float, bool, dict, list, type(None))):
raise ValueError(
f"Failed to serialize additional_kwargs value: {value}"
)
msg["additional_kwargs"][key] = value

return msg


class LogProb(BaseModel):
"""LogProb of a token."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -441,3 +441,36 @@ def test_completion_model_with_retry(MockSyncOpenAI: MagicMock) -> None:
# The actual retry count is max_retries - 1
# see https://github.com/jd/tenacity/issues/459
assert mock_instance.completions.create.call_count == 3


@patch("llama_index.llms.openai.base.SyncOpenAI")
def test_ensure_chat_message_is_serializable(MockSyncOpenAI: MagicMock) -> None:
with CachedOpenAIApiKeys(set_fake_key=True):
mock_instance = MockSyncOpenAI.return_value
mock_instance.chat.completions.create.return_value = mock_chat_completion_v1()

llm = OpenAI(model="gpt-3.5-turbo")
message = ChatMessage(role="user", content="test message")

response = llm.chat([message])
response.message.additional_kwargs["test"] = ChatCompletionChunk(
id="chatcmpl-6ptKyqKOGXZT6iQnqiXAH8adNLUzD",
object="chat.completion.chunk",
created=1677825464,
model="gpt-3.5-turbo-0301",
choices=[
ChunkChoice(
delta=ChoiceDelta(role="assistant", content="test"),
finish_reason=None,
index=0,
)
],
)
data = response.message.dict()
assert isinstance(data, dict)
assert isinstance(data["additional_kwargs"], dict)
assert isinstance(data["additional_kwargs"]["test"]["choices"], list)
assert (
data["additional_kwargs"]["test"]["choices"][0]["delta"]["content"]
== "test"
)

0 comments on commit 62849af

Please sign in to comment.