Skip to content

Commit

Permalink
Improve MIME type detection for image URLs (langgenius#6531)
Browse files Browse the repository at this point in the history
Co-authored-by: seayon <[email protected]>
  • Loading branch information
2 people authored and cuiks committed Sep 2, 2024
1 parent ef8b48f commit 83eb43e
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 9 deletions.
6 changes: 4 additions & 2 deletions api/core/model_runtime/model_providers/anthropic/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import base64
import io
import json
import mimetypes
from collections.abc import Generator
from typing import Optional, Union, cast

Expand All @@ -18,6 +18,7 @@
)
from anthropic.types.beta.tools import ToolsBetaMessage
from httpx import Timeout
from PIL import Image

from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
Expand Down Expand Up @@ -462,7 +463,8 @@ def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tupl
# fetch image data from url
try:
image_content = requests.get(message_content.data).content
mime_type, _ = mimetypes.guess_type(message_content.data)
with Image.open(io.BytesIO(image_content)) as img:
mime_type = f"image/{img.format.lower()}"
base64_data = base64.b64encode(image_content).decode('utf-8')
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
Expand Down
8 changes: 4 additions & 4 deletions api/core/model_runtime/model_providers/bedrock/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# standard import
import base64
import io
import json
import logging
import mimetypes
from collections.abc import Generator
from typing import Optional, Union, cast

Expand All @@ -17,6 +17,7 @@
ServiceNotInRegionError,
UnknownServiceError,
)
from PIL.Image import Image

# local import
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
Expand Down Expand Up @@ -381,9 +382,8 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
try:
url = message_content.data
image_content = requests.get(url).content
if '?' in url:
url = url.split('?')[0]
mime_type, _ = mimetypes.guess_type(url)
with Image.open(io.BytesIO(image_content)) as img:
mime_type = f"image/{img.format.lower()}"
base64_data = base64.b64encode(image_content).decode('utf-8')
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
Expand Down
6 changes: 4 additions & 2 deletions api/core/model_runtime/model_providers/google/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import base64
import io
import json
import logging
import mimetypes
from collections.abc import Generator
from typing import Optional, Union, cast

Expand All @@ -12,6 +12,7 @@
import requests
from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory
from google.generativeai.types.content_types import to_part
from PIL import Image

from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
Expand Down Expand Up @@ -371,7 +372,8 @@ def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType:
# fetch image data from url
try:
image_content = requests.get(message_content.data).content
mime_type, _ = mimetypes.guess_type(message_content.data)
with Image.open(io.BytesIO(image_content)) as img:
mime_type = f"image/{img.format.lower()}"
base64_data = base64.b64encode(image_content).decode('utf-8')
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
Expand Down
5 changes: 4 additions & 1 deletion api/core/model_runtime/model_providers/vertex_ai/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64
import io
import json
import logging
from collections.abc import Generator
Expand All @@ -18,6 +19,7 @@
)
from google.cloud import aiplatform
from google.oauth2 import service_account
from PIL import Image
from vertexai.generative_models import HarmBlockThreshold, HarmCategory

from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
Expand Down Expand Up @@ -332,7 +334,8 @@ def _convert_claude_prompt_message_to_dict(self, message: PromptMessage) -> dict
# fetch image data from url
try:
image_content = requests.get(message_content.data).content
mime_type, _ = mimetypes.guess_type(message_content.data)
with Image.open(io.BytesIO(image_content)) as img:
mime_type = f"image/{img.format.lower()}"
base64_data = base64.b64encode(image_content).decode('utf-8')
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
Expand Down

0 comments on commit 83eb43e

Please sign in to comment.