From e111769632cff81682a63ca0e9987eded8029184 Mon Sep 17 00:00:00 2001 From: aliang <1098486429@qq.com> Date: Tue, 9 Apr 2024 20:11:30 +0800 Subject: [PATCH] fix(ui): prevent consecutive messages from same source (#1782) * fix(ui): prevent consecutive messages from same role * update --- ee/tabby-ui/lib/hooks/use-patch-fetch.ts | 56 +++++++++++++++++++++--- 1 file changed, 50 insertions(+), 6 deletions(-) diff --git a/ee/tabby-ui/lib/hooks/use-patch-fetch.ts b/ee/tabby-ui/lib/hooks/use-patch-fetch.ts index 43766a689709..8dccd8555d70 100644 --- a/ee/tabby-ui/lib/hooks/use-patch-fetch.ts +++ b/ee/tabby-ui/lib/hooks/use-patch-fetch.ts @@ -1,9 +1,16 @@ import { useEffect } from 'react' -import { OpenAIStream, OpenAIStreamCallbacks, StreamingTextResponse } from 'ai' +import { + Message, + OpenAIStream, + OpenAIStreamCallbacks, + StreamingTextResponse +} from 'ai' import fetcher from '../tabby/fetcher' -export function usePatchFetch(callbacks?: OpenAIStreamCallbacks) { +interface PatchFetchOptions extends OpenAIStreamCallbacks {} + +export function usePatchFetch(options?: PatchFetchOptions) { useEffect(() => { if (!window._originFetch) { window._originFetch = window.fetch @@ -11,9 +18,9 @@ export function usePatchFetch(callbacks?: OpenAIStreamCallbacks) { const fetch = window._originFetch - window.fetch = async function (url, options) { + window.fetch = async function (url, requestInit) { if (url !== '/api/chat') { - return fetch(url, options) + return fetch(url, requestInit) } const headers: HeadersInit = { @@ -21,12 +28,13 @@ export function usePatchFetch(callbacks?: OpenAIStreamCallbacks) { } const res = await fetcher(`/v1beta/chat/completions`, { - ...options, + ...requestInit, + body: mergeMessagesByRole(requestInit?.body), method: 'POST', headers, customFetch: fetch, responseFormatter(response) { - const stream = OpenAIStream(response, callbacks) + const stream = OpenAIStream(response, options) return new StreamingTextResponse(stream) } }) @@ -42,3 +50,39 @@ export function usePatchFetch(callbacks?: OpenAIStreamCallbacks) { } }, []) } + +function mergeMessagesByRole(body: BodyInit | null | undefined) { + if (typeof body !== 'string') return body + try { + const bodyObject = JSON.parse(body) + let messages: Message[] = bodyObject.messages?.slice() + if (Array.isArray(messages) && messages.length > 1) { + let previewCursor = 0 + let curCursor = 1 + while (curCursor < messages.length) { + let prevMessage = messages[previewCursor] + let curMessage = messages[curCursor] + if (curMessage.role === prevMessage.role) { + messages = [ + ...messages.slice(0, previewCursor), + { + ...prevMessage, + content: [prevMessage.content, curMessage.content].join('\n') + }, + ...messages.slice(curCursor + 1) + ] + } else { + previewCursor = curCursor++ + } + } + return JSON.stringify({ + ...bodyObject, + messages + }) + } else { + return body + } + } catch (e) { + return body + } +}