diff --git a/package.json b/package.json index d7d37ef..4816377 100644 --- a/package.json +++ b/package.json @@ -60,9 +60,10 @@ "@jupyterlab/notebook": "^4.2.0", "@jupyterlab/rendermime": "^4.2.0", "@jupyterlab/settingregistry": "^4.2.0", + "@langchain/core": "^0.3.13", + "@langchain/mistralai": "^0.1.1", "@lumino/coreutils": "^2.1.2", - "@lumino/polling": "^2.1.2", - "@mistralai/mistralai": "^0.5.0" + "@lumino/polling": "^2.1.2" }, "devDependencies": { "@jupyterlab/builder": "^4.0.0", diff --git a/src/handler.ts b/src/handler.ts index a839c20..5e9d8d3 100644 --- a/src/handler.ts +++ b/src/handler.ts @@ -10,7 +10,12 @@ import { INewMessage } from '@jupyter/chat'; import { UUID } from '@lumino/coreutils'; -import MistralClient from '@mistralai/mistralai'; +import type { ChatMistralAI } from '@langchain/mistralai'; +import { + AIMessage, + HumanMessage, + mergeMessageRuns +} from '@langchain/core/messages'; export type ConnectionMessage = { type: 'connection'; @@ -34,27 +39,27 @@ export class CodestralHandler extends ChatModel { }; this.messageAdded(msg); this._history.messages.push(msg); - const response = await this._mistralClient.chat({ - model: 'codestral-latest', - messages: this._history.messages.map(msg => { - return { - role: msg.sender.username === 'User' ? 'user' : 'assistant', - content: msg.body - }; + + const messages = mergeMessageRuns( + this._history.messages.map(msg => { + if (msg.sender.username === 'User') { + return new HumanMessage(msg.body); + } + return new AIMessage(msg.body); }) - }); - if (response.choices.length === 0) { - return false; - } - const botMessage = response.choices[0].message; + ); + const response = await this._mistralClient.invoke(messages); + // TODO: fix deprecated response.text + const content = response.text; const botMsg: IChatMessage = { id: UUID.uuid4(), - body: botMessage.content as string, + body: content, sender: { username: 'Codestral' }, time: Date.now(), type: 'msg' }; this.messageAdded(botMsg); + this._history.messages.push(botMsg); return true; } @@ -70,12 +75,12 @@ export class CodestralHandler extends ChatModel { super.messageAdded(message); } - private _mistralClient: MistralClient; + private _mistralClient: ChatMistralAI; private _history: IChatHistory = { messages: [] }; } export namespace CodestralHandler { export interface IOptions extends ChatModel.IOptions { - mistralClient: MistralClient; + mistralClient: ChatMistralAI; } } diff --git a/src/index.ts b/src/index.ts index 8666dae..1e9ec03 100644 --- a/src/index.ts +++ b/src/index.ts @@ -13,12 +13,10 @@ import { ICompletionProviderManager } from '@jupyterlab/completer'; import { INotebookTracker } from '@jupyterlab/notebook'; import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; import { ISettingRegistry } from '@jupyterlab/settingregistry'; -import { CodestralProvider } from './provider'; -import MistralClient from '@mistralai/mistralai'; +import { ChatMistralAI, MistralAI } from '@langchain/mistralai'; import { CodestralHandler } from './handler'; - -const mistralClient = new MistralClient(); +import { CodestralProvider } from './provider'; const inlineProviderPlugin: JupyterFrontEndPlugin = { id: 'jupyterlab-codestral:inline-provider', @@ -29,6 +27,10 @@ const inlineProviderPlugin: JupyterFrontEndPlugin = { manager: ICompletionProviderManager, settingRegistry: ISettingRegistry ): void => { + const mistralClient = new MistralAI({ + model: 'codestral-latest', + apiKey: 'TMP' + }); const provider = new CodestralProvider({ mistralClient }); manager.registerInlineProvider(provider); @@ -73,6 +75,10 @@ const chatPlugin: JupyterFrontEndPlugin = { }); } + const mistralClient = new ChatMistralAI({ + model: 'codestral-latest', + apiKey: 'TMP' + }); const chatHandler = new CodestralHandler({ mistralClient, activeCellManager: activeCellManager @@ -88,6 +94,25 @@ const chatPlugin: JupyterFrontEndPlugin = { chatHandler.config = { sendWithShiftEnter, enableCodeToolbar }; } + // TODO: handle the apiKey better + settingsRegistry + ?.load(inlineProviderPlugin.id) + .then(settings => { + const updateKey = () => { + const apiKey = settings.get('apiKey').composite as string; + mistralClient.apiKey = apiKey; + }; + + settings.changed.connect(() => updateKey()); + updateKey(); + }) + .catch(reason => { + console.error( + `Failed to load settings for ${inlineProviderPlugin.id}`, + reason + ); + }); + Promise.all([app.restored, settingsRegistry?.load(chatPlugin.id)]) .then(([, settings]) => { if (!settings) { diff --git a/src/provider.ts b/src/provider.ts index 213a88f..7c4c1e5 100644 --- a/src/provider.ts +++ b/src/provider.ts @@ -6,7 +6,9 @@ import { import { Throttler } from '@lumino/polling'; -import MistralClient, { CompletionRequest } from '@mistralai/mistralai'; +import { CompletionRequest } from '@mistralai/mistralai'; + +import type { MistralAI } from '@langchain/mistralai'; /* * The Mistral API has a rate limit of 1 request per second @@ -20,7 +22,11 @@ export class CodestralProvider implements IInlineCompletionProvider { constructor(options: CodestralProvider.IOptions) { this._mistralClient = options.mistralClient; this._throttler = new Throttler(async (data: CompletionRequest) => { - const response = await this._mistralClient.completion(data); + const response = await this._mistralClient.completionWithRetry( + data, + {}, + false + ); const items = response.choices.map((choice: any) => { return { insertText: choice.message.content as string }; }); @@ -61,11 +67,11 @@ export class CodestralProvider implements IInlineCompletionProvider { } private _throttler: Throttler; - private _mistralClient: MistralClient; + private _mistralClient: MistralAI; } export namespace CodestralProvider { export interface IOptions { - mistralClient: MistralClient; + mistralClient: MistralAI; } } diff --git a/yarn.lock b/yarn.lock index 355781c..7bf63b1 100644 --- a/yarn.lock +++ b/yarn.lock @@ -1501,6 +1501,39 @@ __metadata: languageName: node linkType: hard +"@langchain/core@npm:^0.3.13": + version: 0.3.13 + resolution: "@langchain/core@npm:0.3.13" + dependencies: + ansi-styles: ^5.0.0 + camelcase: 6 + decamelize: 1.2.0 + js-tiktoken: ^1.0.12 + langsmith: ^0.1.65 + mustache: ^4.2.0 + p-queue: ^6.6.2 + p-retry: 4 + uuid: ^10.0.0 + zod: ^3.22.4 + zod-to-json-schema: ^3.22.3 + checksum: 9d6c25f9c276f8107c9b7f824d4b5688f23e4b047fc90af8c1bc4149f61ec2f82f5234c7d23a2029a1c2edc95c9a6caef6122017704be35f91c9a0f0498a2f75 + languageName: node + linkType: hard + +"@langchain/mistralai@npm:^0.1.1": + version: 0.1.1 + resolution: "@langchain/mistralai@npm:0.1.1" + dependencies: + "@mistralai/mistralai": ^0.4.0 + uuid: ^10.0.0 + zod: ^3.22.4 + zod-to-json-schema: ^3.22.4 + peerDependencies: + "@langchain/core": ">=0.2.21 <0.4.0" + checksum: f377fb4130adf435071da2fef36de3bf6850b4e7a41ec88b2096933364c103c38daf8a8d4becbe0a9c3194b2110a0094c2c017d04164af7861e973e3d088d466 + languageName: node + linkType: hard + "@lezer/common@npm:^1.0.0, @lezer/common@npm:^1.0.2, @lezer/common@npm:^1.1.0, @lezer/common@npm:^1.2.0, @lezer/common@npm:^1.2.1": version: 1.2.1 resolution: "@lezer/common@npm:1.2.1" @@ -1926,12 +1959,12 @@ __metadata: languageName: node linkType: hard -"@mistralai/mistralai@npm:^0.5.0": - version: 0.5.0 - resolution: "@mistralai/mistralai@npm:0.5.0" +"@mistralai/mistralai@npm:^0.4.0": + version: 0.4.0 + resolution: "@mistralai/mistralai@npm:0.4.0" dependencies: node-fetch: ^2.6.7 - checksum: b421f314bf22c4db883fe981517fed43e5177d83e9a5fb0fb6b1ff6915fc0e5503c0ec2e70e5302113d70f22477477fa6a9499300980e4ee9b71ca60b9ce8599 + checksum: 1b03fc0b55164c02e5fb29fb2d09ebe4ad44346fc313f7fb3ab09e48f73f975763d1ac9654098d433ea17d7caa20654b2b15510822276acc9fa46db461a254a6 languageName: node linkType: hard @@ -2298,6 +2331,13 @@ __metadata: languageName: node linkType: hard +"@types/retry@npm:0.12.0": + version: 0.12.0 + resolution: "@types/retry@npm:0.12.0" + checksum: 61a072c7639f6e8126588bf1eb1ce8835f2cb9c2aba795c4491cf6310e013267b0c8488039857c261c387e9728c1b43205099223f160bb6a76b4374f741b5603 + languageName: node + linkType: hard + "@types/semver@npm:^7.5.0": version: 7.5.8 resolution: "@types/semver@npm:7.5.8" @@ -2312,6 +2352,13 @@ __metadata: languageName: node linkType: hard +"@types/uuid@npm:^10.0.0": + version: 10.0.0 + resolution: "@types/uuid@npm:10.0.0" + checksum: e3958f8b0fe551c86c14431f5940c3470127293280830684154b91dc7eb3514aeb79fe3216968833cf79d4d1c67f580f054b5be2cd562bebf4f728913e73e944 + languageName: node + linkType: hard + "@types/webpack-sources@npm:^0.1.5": version: 0.1.12 resolution: "@types/webpack-sources@npm:0.1.12" @@ -2775,6 +2822,13 @@ __metadata: languageName: node linkType: hard +"ansi-styles@npm:^5.0.0": + version: 5.2.0 + resolution: "ansi-styles@npm:5.2.0" + checksum: d7f4e97ce0623aea6bc0d90dcd28881ee04cba06c570b97fd3391bd7a268eedfd9d5e2dd4fdcbdd82b8105df5faf6f24aaedc08eaf3da898e702db5948f63469 + languageName: node + linkType: hard + "ansi-styles@npm:^6.1.0": version: 6.2.1 resolution: "ansi-styles@npm:6.2.1" @@ -2870,6 +2924,13 @@ __metadata: languageName: node linkType: hard +"base64-js@npm:^1.5.1": + version: 1.5.1 + resolution: "base64-js@npm:1.5.1" + checksum: 669632eb3745404c2f822a18fc3a0122d2f9a7a13f7fb8b5823ee19d1d2ff9ee5b52c53367176ea4ad093c332fd5ab4bd0ebae5a8e27917a4105a4cfc86b1005 + languageName: node + linkType: hard + "big.js@npm:^5.2.2": version: 5.2.2 resolution: "big.js@npm:5.2.2" @@ -2958,7 +3019,7 @@ __metadata: languageName: node linkType: hard -"camelcase@npm:^6.3.0": +"camelcase@npm:6, camelcase@npm:^6.3.0": version: 6.3.0 resolution: "camelcase@npm:6.3.0" checksum: 8c96818a9076434998511251dcb2761a94817ea17dbdc37f47ac080bd088fc62c7369429a19e2178b993497132c8cbcf5cc1f44ba963e76782ba469c0474938d @@ -3313,7 +3374,7 @@ __metadata: languageName: node linkType: hard -"decamelize@npm:^1.1.0": +"decamelize@npm:1.2.0, decamelize@npm:^1.1.0": version: 1.2.0 resolution: "decamelize@npm:1.2.0" checksum: ad8c51a7e7e0720c70ec2eeb1163b66da03e7616d7b98c9ef43cce2416395e84c1e9548dd94f5f6ffecfee9f8b94251fc57121a8b021f2ff2469b2bae247b8aa @@ -3796,6 +3857,13 @@ __metadata: languageName: node linkType: hard +"eventemitter3@npm:^4.0.4": + version: 4.0.7 + resolution: "eventemitter3@npm:4.0.7" + checksum: 1875311c42fcfe9c707b2712c32664a245629b42bb0a5a84439762dd0fd637fc54d078155ea83c2af9e0323c9ac13687e03cfba79b03af9f40c89b4960099374 + languageName: node + linkType: hard + "events@npm:^3.2.0": version: 3.3.0 resolution: "events@npm:3.3.0" @@ -4676,6 +4744,15 @@ __metadata: languageName: node linkType: hard +"js-tiktoken@npm:^1.0.12": + version: 1.0.15 + resolution: "js-tiktoken@npm:1.0.15" + dependencies: + base64-js: ^1.5.1 + checksum: fb37641fcbec0386276e99459a4c94c9e790b3fe59143191e06e20a8069695999afe7da00fddeb591a731a68afcf068803a75d77825b2d697541bb165e4795eb + languageName: node + linkType: hard + "js-tokens@npm:^3.0.0 || ^4.0.0, js-tokens@npm:^4.0.0": version: 4.0.0 resolution: "js-tokens@npm:4.0.0" @@ -4806,9 +4883,10 @@ __metadata: "@jupyterlab/notebook": ^4.2.0 "@jupyterlab/rendermime": ^4.2.0 "@jupyterlab/settingregistry": ^4.2.0 + "@langchain/core": ^0.3.13 + "@langchain/mistralai": ^0.1.1 "@lumino/coreutils": ^2.1.2 "@lumino/polling": ^2.1.2 - "@mistralai/mistralai": ^0.5.0 "@types/json-schema": ^7.0.11 "@types/react": ^18.0.26 "@types/react-addons-linked-state-mixin": ^0.14.22 @@ -4856,6 +4934,25 @@ __metadata: languageName: node linkType: hard +"langsmith@npm:^0.1.65": + version: 0.1.66 + resolution: "langsmith@npm:0.1.66" + dependencies: + "@types/uuid": ^10.0.0 + commander: ^10.0.1 + p-queue: ^6.6.2 + p-retry: 4 + semver: ^7.6.3 + uuid: ^10.0.0 + peerDependencies: + openai: "*" + peerDependenciesMeta: + openai: + optional: true + checksum: 9c0cb365e2278a255dc0d045402fbec67cf3a6314e7858b398a267b1e97d3f24e5b4ba5286c80a3fe6c13443ac1fca8fd1555f3a0990b74b7c6be8c608581394 + languageName: node + linkType: hard + "levn@npm:^0.4.1": version: 0.4.1 resolution: "levn@npm:0.4.1" @@ -5207,6 +5304,15 @@ __metadata: languageName: node linkType: hard +"mustache@npm:^4.2.0": + version: 4.2.0 + resolution: "mustache@npm:4.2.0" + bin: + mustache: bin/mustache + checksum: 928fcb63e3aa44a562bfe9b59ba202cccbe40a46da50be6f0dd831b495be1dd7e38ca4657f0ecab2c1a89dc7bccba0885eab7ee7c1b215830da765758c7e0506 + languageName: node + linkType: hard + "nanoid@npm:^3.3.7": version: 3.3.7 resolution: "nanoid@npm:3.3.7" @@ -5366,6 +5472,13 @@ __metadata: languageName: node linkType: hard +"p-finally@npm:^1.0.0": + version: 1.0.0 + resolution: "p-finally@npm:1.0.0" + checksum: 93a654c53dc805dd5b5891bab16eb0ea46db8f66c4bfd99336ae929323b1af2b70a8b0654f8f1eae924b2b73d037031366d645f1fd18b3d30cbd15950cc4b1d4 + languageName: node + linkType: hard + "p-limit@npm:^2.2.0": version: 2.3.0 resolution: "p-limit@npm:2.3.0" @@ -5402,6 +5515,35 @@ __metadata: languageName: node linkType: hard +"p-queue@npm:^6.6.2": + version: 6.6.2 + resolution: "p-queue@npm:6.6.2" + dependencies: + eventemitter3: ^4.0.4 + p-timeout: ^3.2.0 + checksum: 832642fcc4ab6477b43e6d7c30209ab10952969ed211c6d6f2931be8a4f9935e3578c72e8cce053dc34f2eb6941a408a2c516a54904e989851a1a209cf19761c + languageName: node + linkType: hard + +"p-retry@npm:4": + version: 4.6.2 + resolution: "p-retry@npm:4.6.2" + dependencies: + "@types/retry": 0.12.0 + retry: ^0.13.1 + checksum: 45c270bfddaffb4a895cea16cb760dcc72bdecb6cb45fef1971fa6ea2e91ddeafddefe01e444ac73e33b1b3d5d29fb0dd18a7effb294262437221ddc03ce0f2e + languageName: node + linkType: hard + +"p-timeout@npm:^3.2.0": + version: 3.2.0 + resolution: "p-timeout@npm:3.2.0" + dependencies: + p-finally: ^1.0.0 + checksum: 3dd0eaa048780a6f23e5855df3dd45c7beacff1f820476c1d0d1bcd6648e3298752ba2c877aa1c92f6453c7dd23faaf13d9f5149fc14c0598a142e2c5e8d649c + languageName: node + linkType: hard + "p-try@npm:^2.0.0": version: 2.2.0 resolution: "p-try@npm:2.2.0" @@ -5932,6 +6074,13 @@ __metadata: languageName: node linkType: hard +"retry@npm:^0.13.1": + version: 0.13.1 + resolution: "retry@npm:0.13.1" + checksum: 47c4d5be674f7c13eee4cfe927345023972197dbbdfba5d3af7e461d13b44de1bfd663bfc80d2f601f8ef3fc8164c16dd99655a221921954a65d044a2fc1233b + languageName: node + linkType: hard + "reusify@npm:^1.0.4": version: 1.0.4 resolution: "reusify@npm:1.0.4" @@ -6082,6 +6231,15 @@ __metadata: languageName: node linkType: hard +"semver@npm:^7.6.3": + version: 7.6.3 + resolution: "semver@npm:7.6.3" + bin: + semver: bin/semver.js + checksum: 4110ec5d015c9438f322257b1c51fe30276e5f766a3f64c09edd1d7ea7118ecbc3f379f3b69032bacf13116dc7abc4ad8ce0d7e2bd642e26b0d271b56b61a7d8 + languageName: node + linkType: hard + "serialize-javascript@npm:^6.0.1": version: 6.0.2 resolution: "serialize-javascript@npm:6.0.2" @@ -6901,6 +7059,15 @@ __metadata: languageName: node linkType: hard +"uuid@npm:^10.0.0": + version: 10.0.0 + resolution: "uuid@npm:10.0.0" + bin: + uuid: dist/bin/uuid + checksum: 4b81611ade2885d2313ddd8dc865d93d8dccc13ddf901745edca8f86d99bc46d7a330d678e7532e7ebf93ce616679fb19b2e3568873ac0c14c999032acb25869 + languageName: node + linkType: hard + "validate-npm-package-license@npm:^3.0.1": version: 3.0.4 resolution: "validate-npm-package-license@npm:3.0.4" @@ -7329,3 +7496,19 @@ __metadata: checksum: f77b3d8d00310def622123df93d4ee654fc6a0096182af8bd60679ddcdfb3474c56c6c7190817c84a2785648cdee9d721c0154eb45698c62176c322fb46fc700 languageName: node linkType: hard + +"zod-to-json-schema@npm:^3.22.3, zod-to-json-schema@npm:^3.22.4": + version: 3.23.3 + resolution: "zod-to-json-schema@npm:3.23.3" + peerDependencies: + zod: ^3.23.3 + checksum: 0d51cf64b54fd39e86434cd5d2239c2981808e6461d022e4c68a1dec67fff28ef2b7bb5733dfd40eb50d6ce6d252288f3989d67134fa81401c36469bb26f13ec + languageName: node + linkType: hard + +"zod@npm:^3.22.4": + version: 3.23.8 + resolution: "zod@npm:3.23.8" + checksum: 15949ff82118f59c893dacd9d3c766d02b6fa2e71cf474d5aa888570c469dbf5446ac5ad562bb035bf7ac9650da94f290655c194f4a6de3e766f43febd432c5c + languageName: node + linkType: hard