Skip to content

Commit

Permalink
Bug tool (#406)
Browse files Browse the repository at this point in the history
* Abort.

* CallingNice
  • Loading branch information
joseplayero authored Sep 20, 2024
1 parent 644db14 commit bf6b92c
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 94 deletions.
6 changes: 1 addition & 5 deletions src/components/Chat/ChatMessages.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,18 @@ interface MessageProps {
index: number
currentChat: Chat
setCurrentChat: React.Dispatch<React.SetStateAction<Chat | undefined>>
handleNewChatMessage: (userTextFieldInput?: string, chatFilters?: AgentConfig) => void
}

const Message: React.FC<MessageProps> = ({ message, index, currentChat, setCurrentChat, handleNewChatMessage }) => {
const Message: React.FC<MessageProps> = ({ message, index, currentChat, setCurrentChat }) => {
return (
<>
{message.role === 'user' && <UserMessage key={`user-${index}`} message={message} />}
{message.role === 'assistant' && (
<AssistantMessage
key={`assistant-${index}`}
messageIndex={index}
message={message}
setCurrentChat={setCurrentChat}
currentChat={currentChat}
handleNewChatMessage={handleNewChatMessage}
/>
)}
{message.role === 'system' && <SystemMessage key={`system-${index}`} message={message} />}
Expand Down Expand Up @@ -65,7 +62,6 @@ const ChatMessages: React.FC<ChatMessagesProps> = ({
index={index}
currentChat={currentChat}
setCurrentChat={setCurrentChat}
handleNewChatMessage={handleNewChatMessage}
/>
))}
</div>
Expand Down
49 changes: 5 additions & 44 deletions src/components/Chat/MessageComponents/AssistantMessage.tsx
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import React, { useCallback, useEffect, useMemo } from 'react'
import React, { useCallback, useMemo } from 'react'
import { HiOutlinePencilAlt } from 'react-icons/hi'
import { toast } from 'react-toastify'
import { ToolCallPart } from 'ai'
import { FaRegCopy } from 'react-icons/fa'
import { Chat, AgentConfig, ReorChatMessage } from '../types'
import { Chat, ReorChatMessage } from '../types'
import {
addToolResultToMessages,
makeAndAddToolResultToMessages,
extractMessagePartsFromAssistantMessage,
findToolResultMatchingToolCall,
getClassNameBasedOnMessageRole,
Expand All @@ -19,17 +19,9 @@ interface AssistantMessageProps {
message: ReorChatMessage
setCurrentChat: React.Dispatch<React.SetStateAction<Chat | undefined>>
currentChat: Chat
messageIndex: number
handleNewChatMessage: (userTextFieldInput?: string, chatFilters?: AgentConfig) => void
}

const AssistantMessage: React.FC<AssistantMessageProps> = ({
message,
setCurrentChat,
currentChat,
messageIndex,
handleNewChatMessage,
}) => {
const AssistantMessage: React.FC<AssistantMessageProps> = ({ message, setCurrentChat, currentChat }) => {
if (message.role !== 'assistant') {
throw new Error('Message is not an assistant message')
}
Expand Down Expand Up @@ -60,7 +52,7 @@ const AssistantMessage: React.FC<AssistantMessageProps> = ({
return
}

const updatedMessages = await addToolResultToMessages(currentChat.messages, toolCallPart, message)
const updatedMessages = await makeAndAddToolResultToMessages(currentChat.messages, toolCallPart, message)

setCurrentChat((prevChat) => {
if (!prevChat) return prevChat
Expand All @@ -75,37 +67,6 @@ const AssistantMessage: React.FC<AssistantMessageProps> = ({
[currentChat, setCurrentChat, saveChat, message],
)

const isLatestAssistantMessage = (index: number, messages: ReorChatMessage[]) => {
return messages.slice(index + 1).every((msg) => msg.role !== 'assistant')
}

useEffect(() => {
if (!isLatestAssistantMessage(messageIndex, currentChat.messages)) return
toolCalls.forEach((toolCall) => {
const existingToolCall = findToolResultMatchingToolCall(toolCall.toolCallId, currentChat.messages)
const toolDefinition = currentChat.toolDefinitions.find((definition) => definition.name === toolCall.toolName)
if (toolDefinition && toolDefinition.autoExecute && !existingToolCall) {
executeToolCall(toolCall)
}
})
}, [currentChat, currentChat.toolDefinitions, executeToolCall, toolCalls, messageIndex])

useEffect(() => {
if (!isLatestAssistantMessage(messageIndex, currentChat.messages)) return

const shouldLLMRespondToToolResults =
toolCalls.length > 0 &&
toolCalls.every((toolCall) => {
const existingToolResult = findToolResultMatchingToolCall(toolCall.toolCallId, currentChat.messages)
const toolDefinition = currentChat.toolDefinitions.find((definition) => definition.name === toolCall.toolName)
return existingToolResult && toolDefinition?.autoExecute
})

if (shouldLLMRespondToToolResults) {
handleNewChatMessage()
}
}, [currentChat, currentChat.toolDefinitions, executeToolCall, toolCalls, messageIndex, handleNewChatMessage])

const renderContent = () => {
return (
<>
Expand Down
39 changes: 33 additions & 6 deletions src/components/Chat/index.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import React, { useCallback, useEffect, useState } from 'react'
import React, { useCallback, useEffect, useState, useRef } from 'react'

import { streamText } from 'ai'
import { appendToOrCreateChat, appendTextContentToMessages, removeUncalledToolsFromMessages } from './utils'
import {
appendToolCallsAndAutoExecuteTools,
appendStringContentToMessages,
appendToOrCreateChat,
removeUncalledToolsFromMessages,
} from './utils'

import '../../styles/chat.css'
import ChatMessages from './ChatMessages'
Expand All @@ -16,6 +21,7 @@ const ChatComponent: React.FC = () => {
const [defaultModelName, setDefaultLLMName] = useState<string>('')
const [currentChat, setCurrentChat] = useState<Chat | undefined>(undefined)
const { saveChat, currentOpenChatID, setCurrentOpenChatID } = useChatContext()
const abortControllerRef = useRef<AbortController | null>(null)

useEffect(() => {
const fetchDefaultLLM = async () => {
Expand All @@ -27,6 +33,10 @@ const ChatComponent: React.FC = () => {

useEffect(() => {
const fetchChat = async () => {
if (abortControllerRef.current) {
abortControllerRef.current.abort()
}

const chat = await window.electronStore.getChat(currentOpenChatID)
setCurrentChat((oldChat) => {
if (oldChat) {
Expand Down Expand Up @@ -62,28 +72,45 @@ const ChatComponent: React.FC = () => {

const llmClient = await resolveLLMClient(defaultLLMName)

abortControllerRef.current = new AbortController()

const { textStream, toolCalls } = await streamText({
model: llmClient,
messages: removeUncalledToolsFromMessages(outputChat.messages),
tools: Object.assign({}, ...outputChat.toolDefinitions.map(convertToolConfigToZodSchema)),
abortSignal: abortControllerRef.current.signal,
})

// eslint-disable-next-line no-restricted-syntax
for await (const text of textStream) {
if (abortControllerRef.current.signal.aborted) {
return
}

outputChat = {
...outputChat,
messages: appendTextContentToMessages(outputChat.messages || [], text),
messages: appendStringContentToMessages(outputChat.messages || [], text),
}
setCurrentChat(outputChat)
setLoadingState('generating')
}
outputChat.messages = appendTextContentToMessages(outputChat.messages, await toolCalls)
setCurrentChat(outputChat)
await saveChat(outputChat)

if (!abortControllerRef.current.signal.aborted) {
outputChat.messages = await appendToolCallsAndAutoExecuteTools(
outputChat.messages,
outputChat.toolDefinitions,
await toolCalls,
)
setCurrentChat(outputChat)
await saveChat(outputChat)
}

setLoadingState('idle')
} catch (error) {
setLoadingState('idle')
throw error
} finally {
abortControllerRef.current = null
}
},
[setCurrentOpenChatID, saveChat, currentChat],
Expand Down
135 changes: 98 additions & 37 deletions src/components/Chat/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,61 @@ import { FileInfoWithContent } from 'electron/main/filesystem/types'
import generateChatName from '@shared/utils'
import { AssistantContent, CoreAssistantMessage, CoreToolMessage, ToolCallPart } from 'ai'
import posthog from 'posthog-js'
import { AnonymizedAgentConfig, Chat, AgentConfig, PromptTemplate, ReorChatMessage } from './types'
import { AnonymizedAgentConfig, Chat, AgentConfig, PromptTemplate, ReorChatMessage, ToolDefinition } from './types'
import { retreiveFromVectorDB } from '@/utils/db'
import { createToolResult } from './tools'

export const appendTextContentToMessages = (
messages: ReorChatMessage[],
content: string | ToolCallPart[],
): ReorChatMessage[] => {
if (content === '' || (Array.isArray(content) && content.length === 0)) {
export const appendStringContentToMessages = (messages: ReorChatMessage[], content: string): ReorChatMessage[] => {
if (content === '') {
return messages
}

const appendContent = (existingContent: AssistantContent, newContent: string | ToolCallPart[]): AssistantContent => {
if (typeof existingContent === 'string') {
return typeof newContent === 'string'
? existingContent + newContent
: [{ type: 'text' as const, text: existingContent }, ...newContent]
}
if (messages.length === 0) {
return [
...existingContent,
...(typeof newContent === 'string' ? [{ type: 'text' as const, text: newContent }] : newContent),
{
role: 'assistant',
content,
},
]
}

const lastMessage = messages[messages.length - 1]

if (lastMessage.role === 'assistant') {
return [
...messages.slice(0, -1),
{
...lastMessage,
content:
typeof lastMessage.content === 'string'
? lastMessage.content + content
: [...lastMessage.content, { type: 'text' as const, text: content }],
},
]
}

return [
...messages,
{
role: 'assistant',
content,
},
]
}

export const appendToolCallPartsToMessages = (
messages: ReorChatMessage[],
toolCalls: ToolCallPart[],
): ReorChatMessage[] => {
if (toolCalls.length === 0) {
return messages
}

if (messages.length === 0) {
return [
{
role: 'assistant',
content: typeof content === 'string' ? content : content,
content: toolCalls,
},
]
}
Expand All @@ -43,7 +69,9 @@ export const appendTextContentToMessages = (
...messages.slice(0, -1),
{
...lastMessage,
content: appendContent(lastMessage.content, content),
content: Array.isArray(lastMessage.content)
? [...lastMessage.content, ...toolCalls]
: [{ type: 'text' as const, text: lastMessage.content }, ...toolCalls],
},
]
}
Expand All @@ -52,11 +80,64 @@ export const appendTextContentToMessages = (
...messages,
{
role: 'assistant',
content: typeof content === 'string' ? content : content,
content: toolCalls,
},
]
}

export const makeAndAddToolResultToMessages = async (
messages: ReorChatMessage[],
toolCallPart: ToolCallPart,
assistantMessage: ReorChatMessage,
): Promise<ReorChatMessage[]> => {
const toolResult = await createToolResult(toolCallPart.toolName, toolCallPart.args as any, toolCallPart.toolCallId)

const toolMessage: CoreToolMessage = {
role: 'tool',
content: [toolResult],
}

const assistantIndex = messages.findIndex((msg) => msg === assistantMessage)
if (assistantIndex === -1) {
throw new Error('Assistant message not found')
}

return [...messages.slice(0, assistantIndex + 1), toolMessage, ...messages.slice(assistantIndex + 1)]
}

const autoExecuteTools = async (
messages: ReorChatMessage[],
toolDefinitions: ToolDefinition[],
toolCalls: ToolCallPart[],
) => {
const toolsThatNeedExecuting = toolCalls.filter((toolCall) => {
const toolDefinition = toolDefinitions.find((definition) => definition.name === toolCall.toolName)
return toolDefinition?.autoExecute
})
let outputMessages = messages
const lastMessage = messages[messages.length - 1]

if (lastMessage.role !== 'assistant') {
throw new Error('Last message is not an assistant message')
}
// eslint-disable-next-line no-restricted-syntax
for (const toolCall of toolsThatNeedExecuting) {
// eslint-disable-next-line no-await-in-loop
outputMessages = await makeAndAddToolResultToMessages(outputMessages, toolCall, lastMessage)
}
return outputMessages
}

export const appendToolCallsAndAutoExecuteTools = async (
messages: ReorChatMessage[],
toolDefinitions: ToolDefinition[],
toolCalls: ToolCallPart[],
): Promise<ReorChatMessage[]> => {
const messagesWithToolCalls = appendToolCallPartsToMessages(messages, toolCalls)
const messagesWithToolResults = await autoExecuteTools(messagesWithToolCalls, toolDefinitions, toolCalls)
return messagesWithToolResults
}

export const convertMessageToString = (message: ReorChatMessage | undefined): string => {
if (!message) {
return ''
Expand Down Expand Up @@ -248,23 +329,3 @@ export const removeUncalledToolsFromMessages = (messages: ReorChatMessage[]): Re
return message
})
}

export const addToolResultToMessages = async (
messages: ReorChatMessage[],
toolCallPart: ToolCallPart,
assistantMessage: ReorChatMessage,
): Promise<ReorChatMessage[]> => {
const toolResult = await createToolResult(toolCallPart.toolName, toolCallPart.args as any, toolCallPart.toolCallId)

const toolMessage: CoreToolMessage = {
role: 'tool',
content: [toolResult],
}

const assistantIndex = messages.findIndex((msg) => msg === assistantMessage)
if (assistantIndex === -1) {
throw new Error('Assistant message not found')
}

return [...messages.slice(0, assistantIndex + 1), toolMessage, ...messages.slice(assistantIndex + 1)]
}
4 changes: 2 additions & 2 deletions src/components/WritingAssistant/WritingAssistant.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import TextField from '@mui/material/TextField'
import Button from '@mui/material/Button'
import posthog from 'posthog-js'
import { streamText } from 'ai'
import { appendTextContentToMessages, convertMessageToString } from '../Chat/utils'
import { appendStringContentToMessages, convertMessageToString } from '../Chat/utils'
import useOutsideClick from './hooks/use-outside-click'
import getClassNames, { generatePromptString, getLastMessage } from './utils'
import { ReorChatMessage } from '../Chat/types'
Expand Down Expand Up @@ -232,7 +232,7 @@ const WritingAssistant: React.FC = () => {
let updatedMessages = messages
// eslint-disable-next-line no-restricted-syntax
for await (const textPart of textStream) {
updatedMessages = appendTextContentToMessages(updatedMessages, textPart)
updatedMessages = appendStringContentToMessages(updatedMessages, textPart)
setMessages(updatedMessages)
}

Expand Down

0 comments on commit bf6b92c

Please sign in to comment.