diff --git a/docs/docs/modules/memory/examples/dynamodb.mdx b/docs/docs/modules/memory/examples/dynamodb.mdx index a2bed03573f5..d84cc98f36a0 100644 --- a/docs/docs/modules/memory/examples/dynamodb.mdx +++ b/docs/docs/modules/memory/examples/dynamodb.mdx @@ -16,7 +16,7 @@ First, install the AWS DynamoDB client in your project: npm install @aws-sdk/client-dynamodb ``` -Next, sign into your AWS account and create a DynamoDB table. Name the table `langchain`, and name your partition key `id` and make sure it's a string. You can leave sort key and the other settings alone. +Next, sign into your AWS account and create a DynamoDB table. Name the table `langchain`, and name your partition key `id`. Make sure your partition key is a string. You can leave sort key and the other settings alone. You'll also need to retrieve an AWS access key and secret key for a role or user that has access to the table and add them to your environment variables. diff --git a/docs/docs/modules/memory/examples/redis.mdx b/docs/docs/modules/memory/examples/redis.mdx new file mode 100644 index 000000000000..bb3ce4ebbb6f --- /dev/null +++ b/docs/docs/modules/memory/examples/redis.mdx @@ -0,0 +1,36 @@ +--- +hide_table_of_contents: true +--- + +import CodeBlock from "@theme/CodeBlock"; + +# Redis-Backed Chat Memory + +For longer-term persistence across chat sessions, you can swap out the default in-memory `chatHistory` that backs chat memory classes like `BufferMemory` for a [Redis](https://redis.io/) instance. + +## Setup + +You will need to install [node-redis](https://github.com/redis/node-redis) in your project: + +```bash npm2yarn +npm install redis +``` + +You will also need a Redis instance to connect to. See instructions on [the official Redis website](https://redis.io/docs/getting-started/) for running the server locally. + +## Usage + +Each chat history session stored in Redis must have a unique id. You can provide an optional `sessionTTL` to make sessions expire after a give number of seconds. +The `config` parameter is passed directly into the `createClient` method of [node-redis](https://github.com/redis/node-redis), and takes all the same arguments. + +import Example from "@examples/memory/redis.ts"; + +{Example} + +## Advanced Usage + +You can also directly pass in a previously created [node-redis](https://github.com/redis/node-redis) client instance: + +import AdvancedExample from "@examples/memory/redis-advanced.ts"; + +{AdvancedExample} diff --git a/docs/docs/modules/memory/examples/vector_store_memory.mdx b/docs/docs/modules/memory/examples/vector_store_memory.mdx index 7bef42155f3f..79449441807c 100644 --- a/docs/docs/modules/memory/examples/vector_store_memory.mdx +++ b/docs/docs/modules/memory/examples/vector_store_memory.mdx @@ -5,7 +5,7 @@ hide_table_of_contents: true import CodeBlock from "@theme/CodeBlock"; import Example from "@examples/memory/vector_store.ts"; -# VectorStore-backed Memory +# VectorStore-Backed Memory `VectorStoreRetrieverMemory` stores memories in a VectorDB and queries the top-K most "salient" docs every time it is called. diff --git a/examples/package.json b/examples/package.json index 762d78377469..7440b08bb71e 100644 --- a/examples/package.json +++ b/examples/package.json @@ -38,6 +38,7 @@ "ml-distance": "^4.0.0", "mongodb": "^5.2.0", "prisma": "^4.11.0", + "redis": "^4.6.6", "sqlite3": "^5.1.4", "typeorm": "^0.3.12", "weaviate-ts-client": "^1.0.0", diff --git a/examples/src/memory/redis-advanced.ts b/examples/src/memory/redis-advanced.ts new file mode 100644 index 000000000000..0c123f3dbdac --- /dev/null +++ b/examples/src/memory/redis-advanced.ts @@ -0,0 +1,45 @@ +import { createClient } from "redis"; +import { BufferMemory } from "langchain/memory"; +import { RedisChatMessageHistory } from "langchain/stores/message/redis"; +import { ChatOpenAI } from "langchain/chat_models/openai"; +import { ConversationChain } from "langchain/chains"; + +const client = createClient({ + url: "redis://localhost:6379", +}); + +const memory = new BufferMemory({ + chatHistory: new RedisChatMessageHistory({ + sessionId: new Date().toISOString(), + sessionTTL: 300, + client, + }), +}); + +const model = new ChatOpenAI({ + modelName: "gpt-3.5-turbo", + temperature: 0, +}); + +const chain = new ConversationChain({ llm: model, memory }); + +const res1 = await chain.call({ input: "Hi! I'm Jim." }); +console.log({ res1 }); +/* +{ + res1: { + text: "Hello Jim! It's nice to meet you. My name is AI. How may I assist you today?" + } +} +*/ + +const res2 = await chain.call({ input: "What did I just say my name was?" }); +console.log({ res2 }); + +/* +{ + res1: { + text: "You said your name was Jim." + } +} +*/ diff --git a/examples/src/memory/redis.ts b/examples/src/memory/redis.ts new file mode 100644 index 000000000000..750679dcace1 --- /dev/null +++ b/examples/src/memory/redis.ts @@ -0,0 +1,42 @@ +import { BufferMemory } from "langchain/memory"; +import { RedisChatMessageHistory } from "langchain/stores/message/redis"; +import { ChatOpenAI } from "langchain/chat_models/openai"; +import { ConversationChain } from "langchain/chains"; + +const memory = new BufferMemory({ + chatHistory: new RedisChatMessageHistory({ + sessionId: new Date().toISOString(), // Or some other unique identifier for the conversation + sessionTTL: 300, // 5 minutes, omit this parameter to make sessions never expire + config: { + url: "redis://localhost:6379", // Default value, override with your own instance's URL + }, + }), +}); + +const model = new ChatOpenAI({ + modelName: "gpt-3.5-turbo", + temperature: 0, +}); + +const chain = new ConversationChain({ llm: model, memory }); + +const res1 = await chain.call({ input: "Hi! I'm Jim." }); +console.log({ res1 }); +/* +{ + res1: { + text: "Hello Jim! It's nice to meet you. My name is AI. How may I assist you today?" + } +} +*/ + +const res2 = await chain.call({ input: "What did I just say my name was?" }); +console.log({ res2 }); + +/* +{ + res1: { + text: "You said your name was Jim." + } +} +*/ diff --git a/langchain/.gitignore b/langchain/.gitignore index 515bd328750f..2fd6c7ef2b8c 100644 --- a/langchain/.gitignore +++ b/langchain/.gitignore @@ -265,6 +265,9 @@ stores/file/node.d.ts stores/message/dynamodb.cjs stores/message/dynamodb.js stores/message/dynamodb.d.ts +stores/message/redis.cjs +stores/message/redis.js +stores/message/redis.d.ts experimental/autogpt.cjs experimental/autogpt.js experimental/autogpt.d.ts diff --git a/langchain/package.json b/langchain/package.json index af99e166d9f0..df783169dd7a 100644 --- a/langchain/package.json +++ b/langchain/package.json @@ -277,6 +277,9 @@ "stores/message/dynamodb.cjs", "stores/message/dynamodb.js", "stores/message/dynamodb.d.ts", + "stores/message/redis.cjs", + "stores/message/redis.js", + "stores/message/redis.d.ts", "experimental/autogpt.cjs", "experimental/autogpt.js", "experimental/autogpt.d.ts", @@ -1007,6 +1010,11 @@ "import": "./stores/message/dynamodb.js", "require": "./stores/message/dynamodb.cjs" }, + "./stores/message/redis": { + "types": "./stores/message/redis.d.ts", + "import": "./stores/message/redis.js", + "require": "./stores/message/redis.cjs" + }, "./experimental/autogpt": { "types": "./experimental/autogpt.d.ts", "import": "./experimental/autogpt.js", diff --git a/langchain/scripts/create-entrypoints.js b/langchain/scripts/create-entrypoints.js index 10c57f4140fe..327396fb3c58 100644 --- a/langchain/scripts/create-entrypoints.js +++ b/langchain/scripts/create-entrypoints.js @@ -119,6 +119,7 @@ const entrypoints = { "stores/file/in_memory": "stores/file/in_memory", "stores/file/node": "stores/file/node", "stores/message/dynamodb": "stores/message/dynamodb", + "stores/message/redis": "stores/message/redis", // experimental "experimental/autogpt": "experimental/autogpt/index", "experimental/babyagi": "experimental/babyagi/index", @@ -191,6 +192,7 @@ const requiresOptionalDependency = [ "cache/redis", "stores/file/node", "stores/message/dynamodb", + "stores/message/redis", ]; // List of test-exports-* packages which we use to test that the exports field diff --git a/langchain/src/stores/message/redis.ts b/langchain/src/stores/message/redis.ts new file mode 100644 index 000000000000..8fbcc9dfbf4a --- /dev/null +++ b/langchain/src/stores/message/redis.ts @@ -0,0 +1,85 @@ +import { + createClient, + RedisClientOptions, + RedisClientType, + RedisModules, + RedisFunctions, + RedisScripts, +} from "redis"; +import { + BaseChatMessage, + BaseListChatMessageHistory, +} from "../../schema/index.js"; +import { + StoredMessage, + mapChatMessagesToStoredMessages, + mapStoredMessagesToChatMessages, +} from "./utils.js"; + +export type RedisChatMessageHistoryInput = { + sessionId: string; + sessionTTL?: number; + config?: RedisClientOptions; + // Typing issues with createClient output: https://github.com/redis/node-redis/issues/1865 + // eslint-disable-next-line @typescript-eslint/no-explicit-any + client?: any; +}; + +export class RedisChatMessageHistory extends BaseListChatMessageHistory { + public client: RedisClientType; + + private sessionId: string; + + private sessionTTL?: number; + + constructor(fields: RedisChatMessageHistoryInput) { + const { sessionId, sessionTTL, config, client } = fields; + super(); + this.client = (client ?? createClient(config ?? {})) as RedisClientType< + RedisModules, + RedisFunctions, + RedisScripts + >; + this.sessionId = sessionId; + this.sessionTTL = sessionTTL; + } + + async ensureReadiness() { + if (!this.client.isReady) { + await this.client.connect(); + } + return true; + } + + async getMessages(): Promise { + await this.ensureReadiness(); + const rawStoredMessages = await this.client.lRange(this.sessionId, 0, -1); + const orderedMessages = rawStoredMessages + .reverse() + .map((message) => JSON.parse(message)); + const previousMessages = orderedMessages + .map((item) => ({ + type: item.type, + role: item.role, + text: item.text, + })) + .filter( + (x): x is StoredMessage => x.type !== undefined && x.text !== undefined + ); + return mapStoredMessagesToChatMessages(previousMessages); + } + + async addMessage(message: BaseChatMessage): Promise { + await this.ensureReadiness(); + const messageToAdd = mapChatMessagesToStoredMessages([message]); + await this.client.lPush(this.sessionId, JSON.stringify(messageToAdd[0])); + if (this.sessionTTL) { + await this.client.expire(this.sessionId, this.sessionTTL); + } + } + + async clear(): Promise { + await this.ensureReadiness(); + await this.client.del(this.sessionId); + } +} diff --git a/langchain/src/stores/tests/redis.int.test.ts b/langchain/src/stores/tests/redis.int.test.ts new file mode 100644 index 000000000000..00e19a132f6d --- /dev/null +++ b/langchain/src/stores/tests/redis.int.test.ts @@ -0,0 +1,126 @@ +/* eslint-disable no-promise-executor-return */ + +import { test, expect } from "@jest/globals"; +import { createClient } from "redis"; +import { RedisChatMessageHistory } from "../message/redis.js"; +import { HumanChatMessage, AIChatMessage } from "../../schema/index.js"; +import { ChatOpenAI } from "../../chat_models/openai.js"; +import { ConversationChain } from "../../chains/conversation.js"; +import { BufferMemory } from "../../memory/buffer_memory.js"; + +afterAll(async () => { + const client = createClient(); + await client.connect(); + await client.flushDb(); + await client.disconnect(); +}); + +test("Test Redis history store", async () => { + const chatHistory = new RedisChatMessageHistory({ + sessionId: new Date().toISOString(), + }); + + const blankResult = await chatHistory.getMessages(); + expect(blankResult).toStrictEqual([]); + + await chatHistory.addUserMessage("Who is the best vocalist?"); + await chatHistory.addAIChatMessage("Ozzy Osbourne"); + + const expectedMessages = [ + new HumanChatMessage("Who is the best vocalist?"), + new AIChatMessage("Ozzy Osbourne"), + ]; + + const resultWithHistory = await chatHistory.getMessages(); + expect(resultWithHistory).toEqual(expectedMessages); +}); + +test("Test clear Redis history store", async () => { + const chatHistory = new RedisChatMessageHistory({ + sessionId: new Date().toISOString(), + }); + + await chatHistory.addUserMessage("Who is the best vocalist?"); + await chatHistory.addAIChatMessage("Ozzy Osbourne"); + + const expectedMessages = [ + new HumanChatMessage("Who is the best vocalist?"), + new AIChatMessage("Ozzy Osbourne"), + ]; + + const resultWithHistory = await chatHistory.getMessages(); + expect(resultWithHistory).toEqual(expectedMessages); + + await chatHistory.clear(); + + const blankResult = await chatHistory.getMessages(); + expect(blankResult).toStrictEqual([]); +}); + +test("Test Redis history with a TTL", async () => { + const chatHistory = new RedisChatMessageHistory({ + sessionId: new Date().toISOString(), + sessionTTL: 5, + }); + + const blankResult = await chatHistory.getMessages(); + expect(blankResult).toStrictEqual([]); + + await chatHistory.addUserMessage("Who is the best vocalist?"); + await chatHistory.addAIChatMessage("Ozzy Osbourne"); + + const expectedMessages = [ + new HumanChatMessage("Who is the best vocalist?"), + new AIChatMessage("Ozzy Osbourne"), + ]; + + const resultWithHistory = await chatHistory.getMessages(); + expect(resultWithHistory).toEqual(expectedMessages); + + await new Promise((resolve) => setTimeout(resolve, 5000)); + + const expiredResult = await chatHistory.getMessages(); + expect(expiredResult).toStrictEqual([]); +}); + +test("Test Redis memory with Buffer Memory", async () => { + const memory = new BufferMemory({ + returnMessages: true, + chatHistory: new RedisChatMessageHistory({ + sessionId: new Date().toISOString(), + }), + }); + + await memory.saveContext( + { input: "Who is the best vocalist?" }, + { response: "Ozzy Osbourne" } + ); + + const expectedHistory = [ + new HumanChatMessage("Who is the best vocalist?"), + new AIChatMessage("Ozzy Osbourne"), + ]; + + const result2 = await memory.loadMemoryVariables({}); + expect(result2).toStrictEqual({ history: expectedHistory }); +}); + +test("Test Redis memory with LLM Chain", async () => { + const memory = new BufferMemory({ + chatHistory: new RedisChatMessageHistory({ + sessionId: new Date().toISOString(), + }), + }); + + const model = new ChatOpenAI({ + modelName: "gpt-3.5-turbo", + temperature: 0, + }); + const chain = new ConversationChain({ llm: model, memory }); + + const res1 = await chain.call({ input: "Hi! I'm Jim." }); + console.log({ res1 }); + + const res2 = await chain.call({ input: "What did I just say my name was?" }); + console.log({ res2 }); +}); diff --git a/langchain/tsconfig.json b/langchain/tsconfig.json index f7fb4494e931..13158af27cf9 100644 --- a/langchain/tsconfig.json +++ b/langchain/tsconfig.json @@ -115,6 +115,7 @@ "src/stores/file/in_memory.ts", "src/stores/file/node.ts", "src/stores/message/dynamodb.ts", + "src/stores/message/redis.ts", "src/experimental/autogpt/index.ts", "src/experimental/babyagi/index.ts", "src/experimental/plan_and_execute/index.ts", diff --git a/yarn.lock b/yarn.lock index eefed52eee6a..aaa6fb36625c 100644 --- a/yarn.lock +++ b/yarn.lock @@ -6706,6 +6706,17 @@ __metadata: languageName: node linkType: hard +"@redis/client@npm:1.5.7": + version: 1.5.7 + resolution: "@redis/client@npm:1.5.7" + dependencies: + cluster-key-slot: 1.1.2 + generic-pool: 3.9.0 + yallist: 4.0.0 + checksum: 3ded14d947b1aba9121fb2608c99960d89716b1ef9baf48ef68384a30fc8e2fe2f1be7a959247d81e7ce15a38ea9e3fce512f3fdce5a400da88dc96496b09e89 + languageName: node + linkType: hard + "@redis/graph@npm:1.1.0": version: 1.1.0 resolution: "@redis/graph@npm:1.1.0" @@ -6733,6 +6744,15 @@ __metadata: languageName: node linkType: hard +"@redis/search@npm:1.1.2": + version: 1.1.2 + resolution: "@redis/search@npm:1.1.2" + peerDependencies: + "@redis/client": ^1.0.0 + checksum: fc3c0bd62c150ea7f8b3f08b0e67893b4e8df71b4820d750de6ba00ccff3720fdc5d4f50618e385c9e183c784635185e2e98a3e6c3d20ac30f2c60996f38b992 + languageName: node + linkType: hard + "@redis/time-series@npm:1.0.4": version: 1.0.4 resolution: "@redis/time-series@npm:1.0.4" @@ -13870,6 +13890,7 @@ __metadata: mongodb: ^5.2.0 prettier: ^2.8.3 prisma: ^4.11.0 + redis: ^4.6.6 sqlite3: ^5.1.4 tsx: ^3.12.3 typeorm: ^0.3.12 @@ -22522,6 +22543,20 @@ __metadata: languageName: node linkType: hard +"redis@npm:^4.6.6": + version: 4.6.6 + resolution: "redis@npm:4.6.6" + dependencies: + "@redis/bloom": 1.2.0 + "@redis/client": 1.5.7 + "@redis/graph": 1.1.0 + "@redis/json": 1.0.4 + "@redis/search": 1.1.2 + "@redis/time-series": 1.0.4 + checksum: fb2e667d91406105c229964ddaf120a8eaae02c1f583f384f68962e166542199300fd75af9572285af988870636b8e75b04562e6d9c250ab95e43c3d87a1cb72 + languageName: node + linkType: hard + "reflect-metadata@npm:^0.1.13": version: 0.1.13 resolution: "reflect-metadata@npm:0.1.13"