From 450766a44bf6868a716de7c012c7e42347c62898 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Thu, 3 Oct 2024 20:28:15 +0800 Subject: [PATCH] google gemini support function call --- app/client/platforms/google.ts | 193 ++++++++++++++------------------- app/utils.ts | 3 + app/utils/chat.ts | 1 + 3 files changed, 86 insertions(+), 111 deletions(-) diff --git a/app/client/platforms/google.ts b/app/client/platforms/google.ts index 3c2607271bf..3b82a569bc6 100644 --- a/app/client/platforms/google.ts +++ b/app/client/platforms/google.ts @@ -7,21 +7,25 @@ import { LLMUsage, SpeechOptions, } from "../api"; -import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; +import { + useAccessStore, + useAppConfig, + useChatStore, + usePluginStore, + ChatMessageTool, +} from "@/app/store"; +import { stream } from "@/app/utils/chat"; import { getClientConfig } from "@/app/config/client"; import { DEFAULT_API_HOST } from "@/app/constant"; -import Locale from "../../locales"; -import { - EventStreamContentType, - fetchEventSource, -} from "@fortaine/fetch-event-source"; -import { prettyObject } from "@/app/utils/format"; + import { getMessageTextContent, getMessageImages, isVisionModel, } from "@/app/utils"; import { preProcessImageContent } from "@/app/utils/chat"; +import { nanoid } from "nanoid"; +import { RequestPayload } from "./openai"; export class GeminiProApi implements LLMApi { path(path: string): string { @@ -177,114 +181,81 @@ export class GeminiProApi implements LLMApi { ); if (shouldStream) { - let responseText = ""; - let remainText = ""; - let finished = false; - - const finish = () => { - if (!finished) { - finished = true; - options.onFinish(responseText + remainText); - } - }; - - // animate response to make it looks smooth - function animateResponseText() { - if (finished || controller.signal.aborted) { - responseText += remainText; - finish(); - return; - } - - if (remainText.length > 0) { - const fetchCount = Math.max(1, Math.round(remainText.length / 60)); - const fetchText = remainText.slice(0, fetchCount); - responseText += fetchText; - remainText = remainText.slice(fetchCount); - options.onUpdate?.(responseText, fetchText); - } - - requestAnimationFrame(animateResponseText); - } - - // start animaion - animateResponseText(); - - controller.signal.onabort = finish; - - fetchEventSource(chatPath, { - ...chatPayload, - async onopen(res) { - clearTimeout(requestTimeoutId); - const contentType = res.headers.get("content-type"); - console.log( - "[Gemini] request response content type: ", - contentType, - ); - - if (contentType?.startsWith("text/plain")) { - responseText = await res.clone().text(); - return finish(); - } - - if ( - !res.ok || - !res.headers - .get("content-type") - ?.startsWith(EventStreamContentType) || - res.status !== 200 - ) { - const responseTexts = [responseText]; - let extraInfo = await res.clone().text(); - try { - const resJson = await res.clone().json(); - extraInfo = prettyObject(resJson); - } catch {} - - if (res.status === 401) { - responseTexts.push(Locale.Error.Unauthorized); - } - - if (extraInfo) { - responseTexts.push(extraInfo); - } - - responseText = responseTexts.join("\n\n"); - - return finish(); - } - }, - onmessage(msg) { - if (msg.data === "[DONE]" || finished) { - return finish(); - } - const text = msg.data; - try { - const json = JSON.parse(text); - const delta = apiClient.extractMessage(json); - - if (delta) { - remainText += delta; - } + const [tools, funcs] = usePluginStore + .getState() + .getAsTools( + useChatStore.getState().currentSession().mask?.plugin || [], + ); + return stream( + chatPath, + requestPayload, + getHeaders(), + // @ts-ignore + [{ functionDeclarations: tools.map((tool) => tool.function) }], + funcs, + controller, + // parseSSE + (text: string, runTools: ChatMessageTool[]) => { + // console.log("parseSSE", text, runTools); + const chunkJson = JSON.parse(text); - const blockReason = json?.promptFeedback?.blockReason; - if (blockReason) { - // being blocked - console.log(`[Google] [Safety Ratings] result:`, blockReason); - } - } catch (e) { - console.error("[Request] parse error", text, msg); + const functionCall = chunkJson?.candidates + ?.at(0) + ?.content.parts.at(0)?.functionCall; + if (functionCall) { + const { name, args } = functionCall; + runTools.push({ + id: nanoid(), + type: "function", + function: { + name, + arguments: JSON.stringify(args), // utils.chat call function, using JSON.parse + }, + }); } + return chunkJson?.candidates?.at(0)?.content.parts.at(0)?.text; }, - onclose() { - finish(); - }, - onerror(e) { - options.onError?.(e); - throw e; + // processToolMessage, include tool_calls message and tool call results + ( + requestPayload: RequestPayload, + toolCallMessage: any, + toolCallResult: any[], + ) => { + // @ts-ignore + requestPayload?.contents?.splice( + // @ts-ignore + requestPayload?.contents?.length, + 0, + { + role: "model", + parts: toolCallMessage.tool_calls.map( + (tool: ChatMessageTool) => ({ + functionCall: { + name: tool?.function?.name, + args: JSON.parse(tool?.function?.arguments as string), + }, + }), + ), + }, + // @ts-ignore + ...toolCallResult.map((result) => ({ + role: "function", + parts: [ + { + functionResponse: { + name: result.name, + response: { + name: result.name, + content: result.content, // TODO just text content... + }, + }, + }, + ], + })), + ); }, - openWhenHidden: true, - }); + options, + ); } else { const res = await fetch(chatPath, chatPayload); clearTimeout(requestTimeoutId); diff --git a/app/utils.ts b/app/utils.ts index 6b2f65952c7..a7c817767e0 100644 --- a/app/utils.ts +++ b/app/utils.ts @@ -284,6 +284,9 @@ export function showPlugins(provider: ServiceProvider, model: string) { if (provider == ServiceProvider.Anthropic && !model.includes("claude-2")) { return true; } + if (provider == ServiceProvider.Google && !model.includes("vision")) { + return true; + } return false; } diff --git a/app/utils/chat.ts b/app/utils/chat.ts index 7f3bb23c58e..c93334bb1af 100644 --- a/app/utils/chat.ts +++ b/app/utils/chat.ts @@ -240,6 +240,7 @@ export function stream( return e.toString(); }) .then((content) => ({ + name: tool.function.name, role: "tool", content, tool_call_id: tool.id,