From 6ffc423e2b757adf47f562599512a3bd4b094025 Mon Sep 17 00:00:00 2001 From: Luis Otavio Martins Date: Wed, 18 Sep 2024 05:08:35 -0300 Subject: [PATCH] Initial implementation of core memory (#307) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Initial implementation of core memory * stop reseting the db in every build * stop reseting the db in every build * delete unused code * changed prompt * readded commented verbose * changed promptĀ² --- package.json | 4 +- .../migration.sql | 13 +++ prisma/schema.prisma | 1 + src/clients/open-router.ts | 25 ++++- src/clients/tools/tool-core-memory.ts | 98 +++++++++++++++++++ src/clients/tools/tools-openrouter.ts | 8 ++ src/crud/conversation.ts | 17 ++++ src/handlers/context/index.ts | 1 + 8 files changed, 159 insertions(+), 8 deletions(-) create mode 100644 prisma/migrations/20240918062447_add_core_memory_to_open_router_conversation/migration.sql create mode 100644 src/clients/tools/tool-core-memory.ts diff --git a/package.json b/package.json index 4de2b99..ece2754 100644 --- a/package.json +++ b/package.json @@ -7,14 +7,12 @@ "scripts": { "start": "ts-node src/index.ts", "dev": "nodemon src/index.ts", - "generate": "prisma generate", - "migrate": "prisma migrate reset", "studio": "prisma studio", "docker:build": "docker build -t sydney-whatsapp-chatbot .", "docker:run": "docker run -it -d --name sydney-whatsapp-chatbot sydney-whatsapp-chatbot", "docker:build:run": "pnpm docker:build && pnpm docker:run", "docker:stop": "docker stop sydney-whatsapp-chatbot", - "build": "npm run generate && npm run migrate" + "build": "prisma generate && prisma migrate" }, "repository": "https://github.com/WAppAI/assistant.git", "author": "Luis Otavio and Matheus Veiga ", diff --git a/prisma/migrations/20240918062447_add_core_memory_to_open_router_conversation/migration.sql b/prisma/migrations/20240918062447_add_core_memory_to_open_router_conversation/migration.sql new file mode 100644 index 0000000..3cfe0b2 --- /dev/null +++ b/prisma/migrations/20240918062447_add_core_memory_to_open_router_conversation/migration.sql @@ -0,0 +1,13 @@ +-- RedefineTables +PRAGMA foreign_keys=OFF; +CREATE TABLE "new_OpenRouterConversation" ( + "waChatId" TEXT NOT NULL PRIMARY KEY, + "memory" TEXT NOT NULL, + "coreMemory" TEXT NOT NULL DEFAULT '', + CONSTRAINT "OpenRouterConversation_waChatId_fkey" FOREIGN KEY ("waChatId") REFERENCES "WAChat" ("id") ON DELETE RESTRICT ON UPDATE CASCADE +); +INSERT INTO "new_OpenRouterConversation" ("memory", "waChatId") SELECT "memory", "waChatId" FROM "OpenRouterConversation"; +DROP TABLE "OpenRouterConversation"; +ALTER TABLE "new_OpenRouterConversation" RENAME TO "OpenRouterConversation"; +PRAGMA foreign_key_check; +PRAGMA foreign_keys=ON; diff --git a/prisma/schema.prisma b/prisma/schema.prisma index 72f3c9f..6e11efb 100644 --- a/prisma/schema.prisma +++ b/prisma/schema.prisma @@ -34,6 +34,7 @@ model BingConversation { model OpenRouterConversation { waChatId String @id memory String + coreMemory String @default("") waChat WAChat @relation(fields: [waChatId], references: [id]) } diff --git a/src/clients/open-router.ts b/src/clients/open-router.ts index c8596c2..6dc8219 100644 --- a/src/clients/open-router.ts +++ b/src/clients/open-router.ts @@ -32,6 +32,7 @@ import { getLLMModel, getOpenRouterConversationFor, getOpenRouterMemoryFor, + getCoreMemoryFor, } from "../crud/conversation"; import { anthropicToolCallingModels, @@ -43,9 +44,10 @@ import { import { tools } from "./tools/tools-openrouter"; function parseMessageHistory( - rawHistory: { [key: string]: string }[] + rawHistory: { [key: string]: string }[], + coreMemory: string ): (HumanMessage | AIMessage)[] { - return rawHistory.map((messageObj) => { + const parsedMessages = rawHistory.map((messageObj) => { const messageType = Object.keys(messageObj)[0]; const messageContent = messageObj[messageType]; @@ -55,6 +57,14 @@ function parseMessageHistory( return new AIMessage(messageContent); } }); + + // Prepend the core memory as an AIMessage + const coreMemoryMessage = new AIMessage(`Your current core memory: + + + ${coreMemory} +`); + return [coreMemoryMessage, ...parsedMessages]; } async function createMemoryForOpenRouter(chat: string) { @@ -91,6 +101,7 @@ async function createMemoryForOpenRouter(chat: string) { } if (conversation) { + const coreMemory = (await getCoreMemoryFor(chat)) || ""; if (memory instanceof ConversationSummaryMemory) { let memoryString = await getOpenRouterMemoryFor(chat); if (memoryString === undefined) return; @@ -99,11 +110,15 @@ async function createMemoryForOpenRouter(chat: string) { let memoryString = await getOpenRouterMemoryFor(chat); if (memoryString === undefined) return; - const pastMessages = parseMessageHistory(JSON.parse(memoryString)); + const pastMessages = parseMessageHistory( + JSON.parse(memoryString), + coreMemory + ); memory.chatHistory = new ChatMessageHistory(pastMessages); } } else { - let memoryString: BaseMessage[] = []; + const coreMemory = (await getCoreMemoryFor(chat)) || ""; + let memoryString: BaseMessage[] = [new AIMessage(coreMemory)]; memory.chatHistory = new ChatMessageHistory(memoryString); } @@ -122,7 +137,7 @@ export async function createExecutorForOpenRouter( const memory = await createMemoryForOpenRouter(chat); const toolCallingPrompt = await pull( - "luisotee/wa-assistant-tool-calling" + "luisotee/wa-assistant-tool-calling-core-memory" ); const defaultPrompt = await pull("luisotee/wa-assistant"); diff --git a/src/clients/tools/tool-core-memory.ts b/src/clients/tools/tool-core-memory.ts new file mode 100644 index 0000000..2973dc0 --- /dev/null +++ b/src/clients/tools/tool-core-memory.ts @@ -0,0 +1,98 @@ +// src/clients/tools/tool-core-memory.ts + +import { StructuredTool } from "langchain/tools"; +import { z } from "zod"; +import { getCoreMemoryFor, updateCoreMemory } from "../../crud/conversation"; + +const AddToCoreMemorySchema = z.object({ + chat: z.string().describe("The chat ID to which the message will be added."), + message: z.string().describe("The message to add to the core memory."), +}); + +export class AddToCoreMemoryTool extends StructuredTool { + name = "AddToCoreMemoryTool"; + description = "Adds a message to the core memory for a given chat."; + schema = AddToCoreMemorySchema; + + async _call({ + chat, + message, + }: z.infer): Promise { + try { + let coreMemory = await getCoreMemoryFor(chat); + if (!coreMemory) { + coreMemory = ""; + } + coreMemory += ` ${message}`; + await updateCoreMemory(chat, coreMemory.trim()); + return `Message added to core memory for chat: ${chat}`; + } catch (error) { + console.error("Error adding to core memory:", error); + throw error; + } + } +} + +const DeleteFromCoreMemorySchema = z.object({ + chat: z.string().describe("The chat ID from which the part will be deleted."), + part: z.string().describe("The specific part of the core memory to delete."), +}); + +export class DeleteFromCoreMemoryTool extends StructuredTool { + name = "DeleteFromCoreMemoryTool"; + description = "Deletes a specific part of the core memory for a given chat."; + schema = DeleteFromCoreMemorySchema; + + async _call({ + chat, + part, + }: z.infer): Promise { + try { + let coreMemory = await getCoreMemoryFor(chat); + if (!coreMemory) { + return `No core memory found for chat: ${chat}`; + } + coreMemory = coreMemory.replace(part, "").trim(); + await updateCoreMemory(chat, coreMemory); + return `Part deleted from core memory for chat: ${chat}`; + } catch (error) { + console.error("Error deleting from core memory:", error); + throw error; + } + } +} + +const ReplaceInCoreMemorySchema = z.object({ + chat: z + .string() + .describe("The chat ID for which the core memory will be replaced."), + oldPart: z + .string() + .describe("The specific part of the core memory to replace."), + newPart: z.string().describe("The new part to replace the old part with."), +}); + +export class ReplaceInCoreMemoryTool extends StructuredTool { + name = "ReplaceInCoreMemoryTool"; + description = "Replaces a specific part of the core memory for a given chat."; + schema = ReplaceInCoreMemorySchema; + + async _call({ + chat, + oldPart, + newPart, + }: z.infer): Promise { + try { + let coreMemory = await getCoreMemoryFor(chat); + if (!coreMemory) { + return `No core memory found for chat: ${chat}`; + } + coreMemory = coreMemory.replace(oldPart, newPart).trim(); + await updateCoreMemory(chat, coreMemory); + return `Part replaced in core memory for chat: ${chat}`; + } catch (error) { + console.error("Error replacing in core memory:", error); + throw error; + } + } +} diff --git a/src/clients/tools/tools-openrouter.ts b/src/clients/tools/tools-openrouter.ts index c727d0f..06d55b1 100644 --- a/src/clients/tools/tools-openrouter.ts +++ b/src/clients/tools/tools-openrouter.ts @@ -27,6 +27,11 @@ import { } from "../../constants"; import { GoogleRoutesAPI } from "@langchain/community/tools/google_routes"; import { WeatherTool } from "./tool-weather"; +import { + AddToCoreMemoryTool, + DeleteFromCoreMemoryTool, + ReplaceInCoreMemoryTool, +} from "./tool-core-memory"; let googleCalendarCreateTool = null; let googleCalendarViewTool = null; @@ -107,6 +112,9 @@ export const tools = [ ...(googleCalendarViewTool ? [googleCalendarViewTool] : []), ...(dalleTool ? [dalleTool] : []), ...(googleRoutesTool ? [googleRoutesTool] : []), + new AddToCoreMemoryTool(), + new DeleteFromCoreMemoryTool(), + new ReplaceInCoreMemoryTool(), weatherTool, wikipediaTool, calculatorTool, diff --git a/src/crud/conversation.ts b/src/crud/conversation.ts index 4a241d4..cf0ff49 100644 --- a/src/crud/conversation.ts +++ b/src/crud/conversation.ts @@ -118,3 +118,20 @@ export async function updateOpenRouterConversation( }, }); } + +export async function getCoreMemoryFor(chatId: string): Promise { + const conversation = await prisma.openRouterConversation.findFirst({ + where: { waChatId: chatId }, + }); + return conversation?.coreMemory || null; +} + +export async function updateCoreMemory( + chatId: string, + coreMemory: string +): Promise { + await prisma.openRouterConversation.update({ + data: { coreMemory }, + where: { waChatId: chatId }, + }); +} diff --git a/src/handlers/context/index.ts b/src/handlers/context/index.ts index 9e96b25..94169fb 100644 --- a/src/handlers/context/index.ts +++ b/src/handlers/context/index.ts @@ -20,6 +20,7 @@ export async function createContextFromMessage(message: Message) { const chatContext = await getChatContext(message); const context = stripIndents`[system](#context) + - The chat ID is '${chat.id._serialized}' ${chatContext} - The user's timezone is '${timezone}' - The user's local date and time is: ${timestampLocal}