From bd32dd9e06b5b27318b3420ed232f517390c48e3 Mon Sep 17 00:00:00 2001 From: seayon Date: Sun, 25 Aug 2024 12:55:03 +0800 Subject: [PATCH] Improve MIME type detection for image URLs Implement fallback to content-based detection using Pillow when URL-based guess fails, enhancing accuracy while maintaining efficiency. --- .../model_runtime/model_providers/anthropic/llm/llm.py | 6 ++++-- api/core/model_runtime/model_providers/bedrock/llm/llm.py | 8 ++++---- api/core/model_runtime/model_providers/google/llm/llm.py | 6 ++++-- .../model_runtime/model_providers/vertex_ai/llm/llm.py | 5 ++++- 4 files changed, 16 insertions(+), 9 deletions(-) diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index 19ce401999c50f..81be1a06a7cd0e 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -1,6 +1,6 @@ import base64 +import io import json -import mimetypes from collections.abc import Generator from typing import Optional, Union, cast @@ -18,6 +18,7 @@ ) from anthropic.types.beta.tools import ToolsBetaMessage from httpx import Timeout +from PIL import Image from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta @@ -462,7 +463,8 @@ def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tupl # fetch image data from url try: image_content = requests.get(message_content.data).content - mime_type, _ = mimetypes.guess_type(message_content.data) + with Image.open(io.BytesIO(image_content)) as img: + mime_type = f"image/{img.format.lower()}" base64_data = base64.b64encode(image_content).decode('utf-8') except Exception as ex: raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index 335fa493cded9f..3f7266f6002025 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -1,8 +1,8 @@ # standard import import base64 +import io import json import logging -import mimetypes from collections.abc import Generator from typing import Optional, Union, cast @@ -17,6 +17,7 @@ ServiceNotInRegionError, UnknownServiceError, ) +from PIL.Image import Image # local import from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta @@ -381,9 +382,8 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: try: url = message_content.data image_content = requests.get(url).content - if '?' in url: - url = url.split('?')[0] - mime_type, _ = mimetypes.guess_type(url) + with Image.open(io.BytesIO(image_content)) as img: + mime_type = f"image/{img.format.lower()}" base64_data = base64.b64encode(image_content).decode('utf-8') except Exception as ex: raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index ebcd0af35b2138..84241fb6c877a4 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -1,7 +1,7 @@ import base64 +import io import json import logging -import mimetypes from collections.abc import Generator from typing import Optional, Union, cast @@ -12,6 +12,7 @@ import requests from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory from google.generativeai.types.content_types import to_part +from PIL import Image from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( @@ -371,7 +372,8 @@ def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType: # fetch image data from url try: image_content = requests.get(message_content.data).content - mime_type, _ = mimetypes.guess_type(message_content.data) + with Image.open(io.BytesIO(image_content)) as img: + mime_type = f"image/{img.format.lower()}" base64_data = base64.b64encode(image_content).decode('utf-8') except Exception as ex: raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py index 8901549110ee07..1a7368a2cf843d 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py @@ -1,4 +1,5 @@ import base64 +import io import json import logging from collections.abc import Generator @@ -18,6 +19,7 @@ ) from google.cloud import aiplatform from google.oauth2 import service_account +from PIL import Image from vertexai.generative_models import HarmBlockThreshold, HarmCategory from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage @@ -332,7 +334,8 @@ def _convert_claude_prompt_message_to_dict(self, message: PromptMessage) -> dict # fetch image data from url try: image_content = requests.get(message_content.data).content - mime_type, _ = mimetypes.guess_type(message_content.data) + with Image.open(io.BytesIO(image_content)) as img: + mime_type = f"image/{img.format.lower()}" base64_data = base64.b64encode(image_content).decode('utf-8') except Exception as ex: raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")