Skip to content

Commit

Permalink
fix(ui): prevent consecutive messages from same source (#1782)
Browse files Browse the repository at this point in the history
* fix(ui): prevent consecutive messages from same role

* update
  • Loading branch information
liangfung authored Apr 9, 2024
1 parent f2702fe commit e111769
Showing 1 changed file with 50 additions and 6 deletions.
56 changes: 50 additions & 6 deletions ee/tabby-ui/lib/hooks/use-patch-fetch.ts
Original file line number Diff line number Diff line change
@@ -1,32 +1,40 @@
import { useEffect } from 'react'
import { OpenAIStream, OpenAIStreamCallbacks, StreamingTextResponse } from 'ai'
import {
Message,
OpenAIStream,
OpenAIStreamCallbacks,
StreamingTextResponse
} from 'ai'

import fetcher from '../tabby/fetcher'

export function usePatchFetch(callbacks?: OpenAIStreamCallbacks) {
interface PatchFetchOptions extends OpenAIStreamCallbacks {}

export function usePatchFetch(options?: PatchFetchOptions) {
useEffect(() => {
if (!window._originFetch) {
window._originFetch = window.fetch
}

const fetch = window._originFetch

window.fetch = async function (url, options) {
window.fetch = async function (url, requestInit) {
if (url !== '/api/chat') {
return fetch(url, options)
return fetch(url, requestInit)
}

const headers: HeadersInit = {
'Content-Type': 'application/json'
}

const res = await fetcher(`/v1beta/chat/completions`, {
...options,
...requestInit,
body: mergeMessagesByRole(requestInit?.body),
method: 'POST',
headers,
customFetch: fetch,
responseFormatter(response) {
const stream = OpenAIStream(response, callbacks)
const stream = OpenAIStream(response, options)
return new StreamingTextResponse(stream)
}
})
Expand All @@ -42,3 +50,39 @@ export function usePatchFetch(callbacks?: OpenAIStreamCallbacks) {
}
}, [])
}

function mergeMessagesByRole(body: BodyInit | null | undefined) {
if (typeof body !== 'string') return body
try {
const bodyObject = JSON.parse(body)
let messages: Message[] = bodyObject.messages?.slice()
if (Array.isArray(messages) && messages.length > 1) {
let previewCursor = 0
let curCursor = 1
while (curCursor < messages.length) {
let prevMessage = messages[previewCursor]
let curMessage = messages[curCursor]
if (curMessage.role === prevMessage.role) {
messages = [
...messages.slice(0, previewCursor),
{
...prevMessage,
content: [prevMessage.content, curMessage.content].join('\n')
},
...messages.slice(curCursor + 1)
]
} else {
previewCursor = curCursor++
}
}
return JSON.stringify({
...bodyObject,
messages
})
} else {
return body
}
} catch (e) {
return body
}
}

0 comments on commit e111769

Please sign in to comment.