From 2a21ad2cf50913e9289542d7c3f2c92a4146826f Mon Sep 17 00:00:00 2001 From: rerorero Date: Wed, 12 Jun 2024 14:14:53 +0900 Subject: [PATCH] fix: remote_url doesn't work for gemini (#5090) --- .../model_providers/google/llm/llm.py | 27 ++++++++++++++----- 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index 5a674fdeeeb77e..2dfde70816eeee 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -1,18 +1,22 @@ +import base64 import json import logging +import mimetypes from collections.abc import Generator -from typing import Optional, Union +from typing import Optional, Union, cast import google.ai.generativelanguage as glm import google.api_core.exceptions as exceptions import google.generativeai as genai import google.generativeai.client as client +import requests from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory from google.generativeai.types.content_types import to_part from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, + ImagePromptMessageContent, PromptMessage, PromptMessageContentType, PromptMessageTool, @@ -361,11 +365,22 @@ def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType: for c in message.content: if c.type == PromptMessageContentType.TEXT: glm_content['parts'].append(to_part(c.data)) - else: - metadata, data = c.data.split(',', 1) - mime_type = metadata.split(';', 1)[0].split(':')[1] - blob = {"inline_data":{"mime_type":mime_type,"data":data}} + elif c.type == PromptMessageContentType.IMAGE: + message_content = cast(ImagePromptMessageContent, c) + if message_content.data.startswith("data:"): + metadata, base64_data = c.data.split(',', 1) + mime_type = metadata.split(';', 1)[0].split(':')[1] + else: + # fetch image data from url + try: + image_content = requests.get(message_content.data).content + mime_type, _ = mimetypes.guess_type(message_content.data) + base64_data = base64.b64encode(image_content).decode('utf-8') + except Exception as ex: + raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") + blob = {"inline_data":{"mime_type":mime_type,"data":base64_data}} glm_content['parts'].append(blob) + return glm_content elif isinstance(message, AssistantPromptMessage): glm_content = { @@ -444,4 +459,4 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] exceptions.RequestRangeNotSatisfiable, exceptions.Cancelled, ] - } \ No newline at end of file + }