diff --git a/src/components/Chat/ChatMessages.tsx b/src/components/Chat/ChatMessages.tsx index e32cea4c..894cf62a 100644 --- a/src/components/Chat/ChatMessages.tsx +++ b/src/components/Chat/ChatMessages.tsx @@ -20,21 +20,18 @@ interface MessageProps { index: number currentChat: Chat setCurrentChat: React.Dispatch> - handleNewChatMessage: (userTextFieldInput?: string, chatFilters?: AgentConfig) => void } -const Message: React.FC = ({ message, index, currentChat, setCurrentChat, handleNewChatMessage }) => { +const Message: React.FC = ({ message, index, currentChat, setCurrentChat }) => { return ( <> {message.role === 'user' && } {message.role === 'assistant' && ( )} {message.role === 'system' && } @@ -65,7 +62,6 @@ const ChatMessages: React.FC = ({ index={index} currentChat={currentChat} setCurrentChat={setCurrentChat} - handleNewChatMessage={handleNewChatMessage} /> ))} diff --git a/src/components/Chat/MessageComponents/AssistantMessage.tsx b/src/components/Chat/MessageComponents/AssistantMessage.tsx index f8b8eb29..ff679c6e 100644 --- a/src/components/Chat/MessageComponents/AssistantMessage.tsx +++ b/src/components/Chat/MessageComponents/AssistantMessage.tsx @@ -1,11 +1,11 @@ -import React, { useCallback, useEffect, useMemo } from 'react' +import React, { useCallback, useMemo } from 'react' import { HiOutlinePencilAlt } from 'react-icons/hi' import { toast } from 'react-toastify' import { ToolCallPart } from 'ai' import { FaRegCopy } from 'react-icons/fa' -import { Chat, AgentConfig, ReorChatMessage } from '../types' +import { Chat, ReorChatMessage } from '../types' import { - addToolResultToMessages, + makeAndAddToolResultToMessages, extractMessagePartsFromAssistantMessage, findToolResultMatchingToolCall, getClassNameBasedOnMessageRole, @@ -19,17 +19,9 @@ interface AssistantMessageProps { message: ReorChatMessage setCurrentChat: React.Dispatch> currentChat: Chat - messageIndex: number - handleNewChatMessage: (userTextFieldInput?: string, chatFilters?: AgentConfig) => void } -const AssistantMessage: React.FC = ({ - message, - setCurrentChat, - currentChat, - messageIndex, - handleNewChatMessage, -}) => { +const AssistantMessage: React.FC = ({ message, setCurrentChat, currentChat }) => { if (message.role !== 'assistant') { throw new Error('Message is not an assistant message') } @@ -60,7 +52,7 @@ const AssistantMessage: React.FC = ({ return } - const updatedMessages = await addToolResultToMessages(currentChat.messages, toolCallPart, message) + const updatedMessages = await makeAndAddToolResultToMessages(currentChat.messages, toolCallPart, message) setCurrentChat((prevChat) => { if (!prevChat) return prevChat @@ -75,37 +67,6 @@ const AssistantMessage: React.FC = ({ [currentChat, setCurrentChat, saveChat, message], ) - const isLatestAssistantMessage = (index: number, messages: ReorChatMessage[]) => { - return messages.slice(index + 1).every((msg) => msg.role !== 'assistant') - } - - useEffect(() => { - if (!isLatestAssistantMessage(messageIndex, currentChat.messages)) return - toolCalls.forEach((toolCall) => { - const existingToolCall = findToolResultMatchingToolCall(toolCall.toolCallId, currentChat.messages) - const toolDefinition = currentChat.toolDefinitions.find((definition) => definition.name === toolCall.toolName) - if (toolDefinition && toolDefinition.autoExecute && !existingToolCall) { - executeToolCall(toolCall) - } - }) - }, [currentChat, currentChat.toolDefinitions, executeToolCall, toolCalls, messageIndex]) - - useEffect(() => { - if (!isLatestAssistantMessage(messageIndex, currentChat.messages)) return - - const shouldLLMRespondToToolResults = - toolCalls.length > 0 && - toolCalls.every((toolCall) => { - const existingToolResult = findToolResultMatchingToolCall(toolCall.toolCallId, currentChat.messages) - const toolDefinition = currentChat.toolDefinitions.find((definition) => definition.name === toolCall.toolName) - return existingToolResult && toolDefinition?.autoExecute - }) - - if (shouldLLMRespondToToolResults) { - handleNewChatMessage() - } - }, [currentChat, currentChat.toolDefinitions, executeToolCall, toolCalls, messageIndex, handleNewChatMessage]) - const renderContent = () => { return ( <> diff --git a/src/components/Chat/index.tsx b/src/components/Chat/index.tsx index 8380d0e9..fdcf1a35 100644 --- a/src/components/Chat/index.tsx +++ b/src/components/Chat/index.tsx @@ -1,7 +1,12 @@ -import React, { useCallback, useEffect, useState } from 'react' +import React, { useCallback, useEffect, useState, useRef } from 'react' import { streamText } from 'ai' -import { appendToOrCreateChat, appendTextContentToMessages, removeUncalledToolsFromMessages } from './utils' +import { + appendToolCallsAndAutoExecuteTools, + appendStringContentToMessages, + appendToOrCreateChat, + removeUncalledToolsFromMessages, +} from './utils' import '../../styles/chat.css' import ChatMessages from './ChatMessages' @@ -16,6 +21,7 @@ const ChatComponent: React.FC = () => { const [defaultModelName, setDefaultLLMName] = useState('') const [currentChat, setCurrentChat] = useState(undefined) const { saveChat, currentOpenChatID, setCurrentOpenChatID } = useChatContext() + const abortControllerRef = useRef(null) useEffect(() => { const fetchDefaultLLM = async () => { @@ -27,6 +33,10 @@ const ChatComponent: React.FC = () => { useEffect(() => { const fetchChat = async () => { + if (abortControllerRef.current) { + abortControllerRef.current.abort() + } + const chat = await window.electronStore.getChat(currentOpenChatID) setCurrentChat((oldChat) => { if (oldChat) { @@ -62,28 +72,45 @@ const ChatComponent: React.FC = () => { const llmClient = await resolveLLMClient(defaultLLMName) + abortControllerRef.current = new AbortController() + const { textStream, toolCalls } = await streamText({ model: llmClient, messages: removeUncalledToolsFromMessages(outputChat.messages), tools: Object.assign({}, ...outputChat.toolDefinitions.map(convertToolConfigToZodSchema)), + abortSignal: abortControllerRef.current.signal, }) + // eslint-disable-next-line no-restricted-syntax for await (const text of textStream) { + if (abortControllerRef.current.signal.aborted) { + return + } + outputChat = { ...outputChat, - messages: appendTextContentToMessages(outputChat.messages || [], text), + messages: appendStringContentToMessages(outputChat.messages || [], text), } setCurrentChat(outputChat) setLoadingState('generating') } - outputChat.messages = appendTextContentToMessages(outputChat.messages, await toolCalls) - setCurrentChat(outputChat) - await saveChat(outputChat) + + if (!abortControllerRef.current.signal.aborted) { + outputChat.messages = await appendToolCallsAndAutoExecuteTools( + outputChat.messages, + outputChat.toolDefinitions, + await toolCalls, + ) + setCurrentChat(outputChat) + await saveChat(outputChat) + } setLoadingState('idle') } catch (error) { setLoadingState('idle') throw error + } finally { + abortControllerRef.current = null } }, [setCurrentOpenChatID, saveChat, currentChat], diff --git a/src/components/Chat/utils.ts b/src/components/Chat/utils.ts index 392f7dd8..bdf7cb38 100644 --- a/src/components/Chat/utils.ts +++ b/src/components/Chat/utils.ts @@ -3,35 +3,61 @@ import { FileInfoWithContent } from 'electron/main/filesystem/types' import generateChatName from '@shared/utils' import { AssistantContent, CoreAssistantMessage, CoreToolMessage, ToolCallPart } from 'ai' import posthog from 'posthog-js' -import { AnonymizedAgentConfig, Chat, AgentConfig, PromptTemplate, ReorChatMessage } from './types' +import { AnonymizedAgentConfig, Chat, AgentConfig, PromptTemplate, ReorChatMessage, ToolDefinition } from './types' import { retreiveFromVectorDB } from '@/utils/db' import { createToolResult } from './tools' -export const appendTextContentToMessages = ( - messages: ReorChatMessage[], - content: string | ToolCallPart[], -): ReorChatMessage[] => { - if (content === '' || (Array.isArray(content) && content.length === 0)) { +export const appendStringContentToMessages = (messages: ReorChatMessage[], content: string): ReorChatMessage[] => { + if (content === '') { return messages } - const appendContent = (existingContent: AssistantContent, newContent: string | ToolCallPart[]): AssistantContent => { - if (typeof existingContent === 'string') { - return typeof newContent === 'string' - ? existingContent + newContent - : [{ type: 'text' as const, text: existingContent }, ...newContent] - } + if (messages.length === 0) { return [ - ...existingContent, - ...(typeof newContent === 'string' ? [{ type: 'text' as const, text: newContent }] : newContent), + { + role: 'assistant', + content, + }, ] } + const lastMessage = messages[messages.length - 1] + + if (lastMessage.role === 'assistant') { + return [ + ...messages.slice(0, -1), + { + ...lastMessage, + content: + typeof lastMessage.content === 'string' + ? lastMessage.content + content + : [...lastMessage.content, { type: 'text' as const, text: content }], + }, + ] + } + + return [ + ...messages, + { + role: 'assistant', + content, + }, + ] +} + +export const appendToolCallPartsToMessages = ( + messages: ReorChatMessage[], + toolCalls: ToolCallPart[], +): ReorChatMessage[] => { + if (toolCalls.length === 0) { + return messages + } + if (messages.length === 0) { return [ { role: 'assistant', - content: typeof content === 'string' ? content : content, + content: toolCalls, }, ] } @@ -43,7 +69,9 @@ export const appendTextContentToMessages = ( ...messages.slice(0, -1), { ...lastMessage, - content: appendContent(lastMessage.content, content), + content: Array.isArray(lastMessage.content) + ? [...lastMessage.content, ...toolCalls] + : [{ type: 'text' as const, text: lastMessage.content }, ...toolCalls], }, ] } @@ -52,11 +80,64 @@ export const appendTextContentToMessages = ( ...messages, { role: 'assistant', - content: typeof content === 'string' ? content : content, + content: toolCalls, }, ] } +export const makeAndAddToolResultToMessages = async ( + messages: ReorChatMessage[], + toolCallPart: ToolCallPart, + assistantMessage: ReorChatMessage, +): Promise => { + const toolResult = await createToolResult(toolCallPart.toolName, toolCallPart.args as any, toolCallPart.toolCallId) + + const toolMessage: CoreToolMessage = { + role: 'tool', + content: [toolResult], + } + + const assistantIndex = messages.findIndex((msg) => msg === assistantMessage) + if (assistantIndex === -1) { + throw new Error('Assistant message not found') + } + + return [...messages.slice(0, assistantIndex + 1), toolMessage, ...messages.slice(assistantIndex + 1)] +} + +const autoExecuteTools = async ( + messages: ReorChatMessage[], + toolDefinitions: ToolDefinition[], + toolCalls: ToolCallPart[], +) => { + const toolsThatNeedExecuting = toolCalls.filter((toolCall) => { + const toolDefinition = toolDefinitions.find((definition) => definition.name === toolCall.toolName) + return toolDefinition?.autoExecute + }) + let outputMessages = messages + const lastMessage = messages[messages.length - 1] + + if (lastMessage.role !== 'assistant') { + throw new Error('Last message is not an assistant message') + } + // eslint-disable-next-line no-restricted-syntax + for (const toolCall of toolsThatNeedExecuting) { + // eslint-disable-next-line no-await-in-loop + outputMessages = await makeAndAddToolResultToMessages(outputMessages, toolCall, lastMessage) + } + return outputMessages +} + +export const appendToolCallsAndAutoExecuteTools = async ( + messages: ReorChatMessage[], + toolDefinitions: ToolDefinition[], + toolCalls: ToolCallPart[], +): Promise => { + const messagesWithToolCalls = appendToolCallPartsToMessages(messages, toolCalls) + const messagesWithToolResults = await autoExecuteTools(messagesWithToolCalls, toolDefinitions, toolCalls) + return messagesWithToolResults +} + export const convertMessageToString = (message: ReorChatMessage | undefined): string => { if (!message) { return '' @@ -248,23 +329,3 @@ export const removeUncalledToolsFromMessages = (messages: ReorChatMessage[]): Re return message }) } - -export const addToolResultToMessages = async ( - messages: ReorChatMessage[], - toolCallPart: ToolCallPart, - assistantMessage: ReorChatMessage, -): Promise => { - const toolResult = await createToolResult(toolCallPart.toolName, toolCallPart.args as any, toolCallPart.toolCallId) - - const toolMessage: CoreToolMessage = { - role: 'tool', - content: [toolResult], - } - - const assistantIndex = messages.findIndex((msg) => msg === assistantMessage) - if (assistantIndex === -1) { - throw new Error('Assistant message not found') - } - - return [...messages.slice(0, assistantIndex + 1), toolMessage, ...messages.slice(assistantIndex + 1)] -} diff --git a/src/components/WritingAssistant/WritingAssistant.tsx b/src/components/WritingAssistant/WritingAssistant.tsx index 83b93e89..8d8e2ae3 100644 --- a/src/components/WritingAssistant/WritingAssistant.tsx +++ b/src/components/WritingAssistant/WritingAssistant.tsx @@ -6,7 +6,7 @@ import TextField from '@mui/material/TextField' import Button from '@mui/material/Button' import posthog from 'posthog-js' import { streamText } from 'ai' -import { appendTextContentToMessages, convertMessageToString } from '../Chat/utils' +import { appendStringContentToMessages, convertMessageToString } from '../Chat/utils' import useOutsideClick from './hooks/use-outside-click' import getClassNames, { generatePromptString, getLastMessage } from './utils' import { ReorChatMessage } from '../Chat/types' @@ -232,7 +232,7 @@ const WritingAssistant: React.FC = () => { let updatedMessages = messages // eslint-disable-next-line no-restricted-syntax for await (const textPart of textStream) { - updatedMessages = appendTextContentToMessages(updatedMessages, textPart) + updatedMessages = appendStringContentToMessages(updatedMessages, textPart) setMessages(updatedMessages) }