Skip to content

Commit

Permalink
Initial implementation of core memory (#307)
Browse files Browse the repository at this point in the history
* Initial implementation of core memory

* stop reseting the db in every build

* stop reseting the db in every build

* delete unused code

* changed prompt

* readded commented verbose

* changed prompt²
  • Loading branch information
Luisotee authored Sep 18, 2024
1 parent 5876207 commit 6ffc423
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 8 deletions.
4 changes: 1 addition & 3 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@
"scripts": {
"start": "ts-node src/index.ts",
"dev": "nodemon src/index.ts",
"generate": "prisma generate",
"migrate": "prisma migrate reset",
"studio": "prisma studio",
"docker:build": "docker build -t sydney-whatsapp-chatbot .",
"docker:run": "docker run -it -d --name sydney-whatsapp-chatbot sydney-whatsapp-chatbot",
"docker:build:run": "pnpm docker:build && pnpm docker:run",
"docker:stop": "docker stop sydney-whatsapp-chatbot",
"build": "npm run generate && npm run migrate"
"build": "prisma generate && prisma migrate"
},
"repository": "https://github.com/WAppAI/assistant.git",
"author": "Luis Otavio <[email protected]> and Matheus Veiga <[email protected]>",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
-- RedefineTables
PRAGMA foreign_keys=OFF;
CREATE TABLE "new_OpenRouterConversation" (
"waChatId" TEXT NOT NULL PRIMARY KEY,
"memory" TEXT NOT NULL,
"coreMemory" TEXT NOT NULL DEFAULT '',
CONSTRAINT "OpenRouterConversation_waChatId_fkey" FOREIGN KEY ("waChatId") REFERENCES "WAChat" ("id") ON DELETE RESTRICT ON UPDATE CASCADE
);
INSERT INTO "new_OpenRouterConversation" ("memory", "waChatId") SELECT "memory", "waChatId" FROM "OpenRouterConversation";
DROP TABLE "OpenRouterConversation";
ALTER TABLE "new_OpenRouterConversation" RENAME TO "OpenRouterConversation";
PRAGMA foreign_key_check;
PRAGMA foreign_keys=ON;
1 change: 1 addition & 0 deletions prisma/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ model BingConversation {
model OpenRouterConversation {
waChatId String @id
memory String
coreMemory String @default("")
waChat WAChat @relation(fields: [waChatId], references: [id])
}

Expand Down
25 changes: 20 additions & 5 deletions src/clients/open-router.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import {
getLLMModel,
getOpenRouterConversationFor,
getOpenRouterMemoryFor,
getCoreMemoryFor,
} from "../crud/conversation";
import {
anthropicToolCallingModels,
Expand All @@ -43,9 +44,10 @@ import {
import { tools } from "./tools/tools-openrouter";

function parseMessageHistory(
rawHistory: { [key: string]: string }[]
rawHistory: { [key: string]: string }[],
coreMemory: string
): (HumanMessage | AIMessage)[] {
return rawHistory.map((messageObj) => {
const parsedMessages = rawHistory.map((messageObj) => {
const messageType = Object.keys(messageObj)[0];
const messageContent = messageObj[messageType];

Expand All @@ -55,6 +57,14 @@ function parseMessageHistory(
return new AIMessage(messageContent);
}
});

// Prepend the core memory as an AIMessage
const coreMemoryMessage = new AIMessage(`Your current core memory:
<core_memory>
${coreMemory}
</core_memory>`);
return [coreMemoryMessage, ...parsedMessages];
}

async function createMemoryForOpenRouter(chat: string) {
Expand Down Expand Up @@ -91,6 +101,7 @@ async function createMemoryForOpenRouter(chat: string) {
}

if (conversation) {
const coreMemory = (await getCoreMemoryFor(chat)) || "";
if (memory instanceof ConversationSummaryMemory) {
let memoryString = await getOpenRouterMemoryFor(chat);
if (memoryString === undefined) return;
Expand All @@ -99,11 +110,15 @@ async function createMemoryForOpenRouter(chat: string) {
let memoryString = await getOpenRouterMemoryFor(chat);
if (memoryString === undefined) return;

const pastMessages = parseMessageHistory(JSON.parse(memoryString));
const pastMessages = parseMessageHistory(
JSON.parse(memoryString),
coreMemory
);
memory.chatHistory = new ChatMessageHistory(pastMessages);
}
} else {
let memoryString: BaseMessage[] = [];
const coreMemory = (await getCoreMemoryFor(chat)) || "";
let memoryString: BaseMessage[] = [new AIMessage(coreMemory)];
memory.chatHistory = new ChatMessageHistory(memoryString);
}

Expand All @@ -122,7 +137,7 @@ export async function createExecutorForOpenRouter(
const memory = await createMemoryForOpenRouter(chat);

const toolCallingPrompt = await pull<ChatPromptTemplate>(
"luisotee/wa-assistant-tool-calling"
"luisotee/wa-assistant-tool-calling-core-memory"
);
const defaultPrompt = await pull<ChatPromptTemplate>("luisotee/wa-assistant");

Expand Down
98 changes: 98 additions & 0 deletions src/clients/tools/tool-core-memory.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// src/clients/tools/tool-core-memory.ts

import { StructuredTool } from "langchain/tools";
import { z } from "zod";
import { getCoreMemoryFor, updateCoreMemory } from "../../crud/conversation";

const AddToCoreMemorySchema = z.object({
chat: z.string().describe("The chat ID to which the message will be added."),
message: z.string().describe("The message to add to the core memory."),
});

export class AddToCoreMemoryTool extends StructuredTool {
name = "AddToCoreMemoryTool";
description = "Adds a message to the core memory for a given chat.";
schema = AddToCoreMemorySchema;

async _call({
chat,
message,
}: z.infer<typeof AddToCoreMemorySchema>): Promise<string> {
try {
let coreMemory = await getCoreMemoryFor(chat);
if (!coreMemory) {
coreMemory = "";
}
coreMemory += ` ${message}`;
await updateCoreMemory(chat, coreMemory.trim());
return `Message added to core memory for chat: ${chat}`;
} catch (error) {
console.error("Error adding to core memory:", error);
throw error;
}
}
}

const DeleteFromCoreMemorySchema = z.object({
chat: z.string().describe("The chat ID from which the part will be deleted."),
part: z.string().describe("The specific part of the core memory to delete."),
});

export class DeleteFromCoreMemoryTool extends StructuredTool {
name = "DeleteFromCoreMemoryTool";
description = "Deletes a specific part of the core memory for a given chat.";
schema = DeleteFromCoreMemorySchema;

async _call({
chat,
part,
}: z.infer<typeof DeleteFromCoreMemorySchema>): Promise<string> {
try {
let coreMemory = await getCoreMemoryFor(chat);
if (!coreMemory) {
return `No core memory found for chat: ${chat}`;
}
coreMemory = coreMemory.replace(part, "").trim();
await updateCoreMemory(chat, coreMemory);
return `Part deleted from core memory for chat: ${chat}`;
} catch (error) {
console.error("Error deleting from core memory:", error);
throw error;
}
}
}

const ReplaceInCoreMemorySchema = z.object({
chat: z
.string()
.describe("The chat ID for which the core memory will be replaced."),
oldPart: z
.string()
.describe("The specific part of the core memory to replace."),
newPart: z.string().describe("The new part to replace the old part with."),
});

export class ReplaceInCoreMemoryTool extends StructuredTool {
name = "ReplaceInCoreMemoryTool";
description = "Replaces a specific part of the core memory for a given chat.";
schema = ReplaceInCoreMemorySchema;

async _call({
chat,
oldPart,
newPart,
}: z.infer<typeof ReplaceInCoreMemorySchema>): Promise<string> {
try {
let coreMemory = await getCoreMemoryFor(chat);
if (!coreMemory) {
return `No core memory found for chat: ${chat}`;
}
coreMemory = coreMemory.replace(oldPart, newPart).trim();
await updateCoreMemory(chat, coreMemory);
return `Part replaced in core memory for chat: ${chat}`;
} catch (error) {
console.error("Error replacing in core memory:", error);
throw error;
}
}
}
8 changes: 8 additions & 0 deletions src/clients/tools/tools-openrouter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ import {
} from "../../constants";
import { GoogleRoutesAPI } from "@langchain/community/tools/google_routes";
import { WeatherTool } from "./tool-weather";
import {
AddToCoreMemoryTool,
DeleteFromCoreMemoryTool,
ReplaceInCoreMemoryTool,
} from "./tool-core-memory";

let googleCalendarCreateTool = null;
let googleCalendarViewTool = null;
Expand Down Expand Up @@ -107,6 +112,9 @@ export const tools = [
...(googleCalendarViewTool ? [googleCalendarViewTool] : []),
...(dalleTool ? [dalleTool] : []),
...(googleRoutesTool ? [googleRoutesTool] : []),
new AddToCoreMemoryTool(),
new DeleteFromCoreMemoryTool(),
new ReplaceInCoreMemoryTool(),
weatherTool,
wikipediaTool,
calculatorTool,
Expand Down
17 changes: 17 additions & 0 deletions src/crud/conversation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,20 @@ export async function updateOpenRouterConversation(
},
});
}

export async function getCoreMemoryFor(chatId: string): Promise<string | null> {
const conversation = await prisma.openRouterConversation.findFirst({
where: { waChatId: chatId },
});
return conversation?.coreMemory || null;
}

export async function updateCoreMemory(
chatId: string,
coreMemory: string
): Promise<void> {
await prisma.openRouterConversation.update({
data: { coreMemory },
where: { waChatId: chatId },
});
}
1 change: 1 addition & 0 deletions src/handlers/context/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ export async function createContextFromMessage(message: Message) {
const chatContext = await getChatContext(message);

const context = stripIndents`[system](#context)
- The chat ID is '${chat.id._serialized}'
${chatContext}
- The user's timezone is '${timezone}'
- The user's local date and time is: ${timestampLocal}
Expand Down

0 comments on commit 6ffc423

Please sign in to comment.