Skip to content

Commit

Permalink
Switch to using langchain.js
Browse files Browse the repository at this point in the history
  • Loading branch information
jtpio authored and brichet committed Oct 23, 2024
1 parent 6b2cae1 commit 2642e3f
Show file tree
Hide file tree
Showing 5 changed files with 260 additions and 33 deletions.
5 changes: 3 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@
"@jupyterlab/notebook": "^4.2.0",
"@jupyterlab/rendermime": "^4.2.0",
"@jupyterlab/settingregistry": "^4.2.0",
"@langchain/core": "^0.3.13",
"@langchain/mistralai": "^0.1.1",
"@lumino/coreutils": "^2.1.2",
"@lumino/polling": "^2.1.2",
"@mistralai/mistralai": "^0.5.0"
"@lumino/polling": "^2.1.2"
},
"devDependencies": {
"@jupyterlab/builder": "^4.0.0",
Expand Down
44 changes: 28 additions & 16 deletions src/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@ import {
INewMessage
} from '@jupyter/chat';
import { UUID } from '@lumino/coreutils';
import MistralClient from '@mistralai/mistralai';
import type { ChatMistralAI } from '@langchain/mistralai';
import {
HumanMessage,
mergeMessageRuns,
SystemMessage
} from '@langchain/core/messages';

export type ConnectionMessage = {
type: 'connection';
Expand All @@ -34,22 +39,29 @@ export class CodestralHandler extends ChatModel {
};
this.messageAdded(msg);
this._history.messages.push(msg);
const response = await this._mistralClient.chat({
model: 'codestral-latest',
messages: this._history.messages.map(msg => {
return {
role: msg.sender.username === 'User' ? 'user' : 'assistant',
content: msg.body
};
// const response = await this._mistralClient.chat({
// model: 'codestral-latest',
// messages: this._history.messages.map(msg => {
// return {
// role: msg.sender === 'User' ? 'user' : 'assistant',
// content: msg.body
// };
// })
// });
const messages = mergeMessageRuns(
this._history.messages.map(msg => {
if (msg.sender.username === 'User') {
return new HumanMessage(msg.body);
}
return new SystemMessage(msg.body);
})
});
if (response.choices.length === 0) {
return false;
}
const botMessage = response.choices[0].message;
);
const response = await this._mistralClient.invoke(messages);
// TODO: fix deprecated response.text
const content = response.text;
const botMsg: IChatMessage = {
id: UUID.uuid4(),
body: botMessage.content as string,
body: content,
sender: { username: 'Codestral' },
time: Date.now(),
type: 'msg'
Expand All @@ -70,12 +82,12 @@ export class CodestralHandler extends ChatModel {
super.messageAdded(message);
}

private _mistralClient: MistralClient;
private _mistralClient: ChatMistralAI;
private _history: IChatHistory = { messages: [] };
}

export namespace CodestralHandler {
export interface IOptions extends ChatModel.IOptions {
mistralClient: MistralClient;
mistralClient: ChatMistralAI;
}
}
33 changes: 29 additions & 4 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@ import { ICompletionProviderManager } from '@jupyterlab/completer';
import { INotebookTracker } from '@jupyterlab/notebook';
import { IRenderMimeRegistry } from '@jupyterlab/rendermime';
import { ISettingRegistry } from '@jupyterlab/settingregistry';
import { CodestralProvider } from './provider';
import MistralClient from '@mistralai/mistralai';
import { ChatMistralAI, MistralAI } from '@langchain/mistralai';

import { CodestralHandler } from './handler';

const mistralClient = new MistralClient();
import { CodestralProvider } from './provider';

const inlineProviderPlugin: JupyterFrontEndPlugin<void> = {
id: 'jupyterlab-codestral:inline-provider',
Expand All @@ -29,6 +27,10 @@ const inlineProviderPlugin: JupyterFrontEndPlugin<void> = {
manager: ICompletionProviderManager,
settingRegistry: ISettingRegistry
): void => {
const mistralClient = new MistralAI({
model: 'codestral-latest',
apiKey: 'TMP'
});
const provider = new CodestralProvider({ mistralClient });
manager.registerInlineProvider(provider);

Expand Down Expand Up @@ -73,6 +75,10 @@ const chatPlugin: JupyterFrontEndPlugin<void> = {
});
}

const mistralClient = new ChatMistralAI({
model: 'codestral-latest',
apiKey: 'TMP'
});
const chatHandler = new CodestralHandler({
mistralClient,
activeCellManager: activeCellManager
Expand All @@ -88,6 +94,25 @@ const chatPlugin: JupyterFrontEndPlugin<void> = {
chatHandler.config = { sendWithShiftEnter, enableCodeToolbar };
}

// TODO: handle the apiKey better
settingsRegistry
?.load(inlineProviderPlugin.id)
.then(settings => {
const updateKey = () => {
const apiKey = settings.get('apiKey').composite as string;
mistralClient.apiKey = apiKey;
};

settings.changed.connect(() => updateKey());
updateKey();
})
.catch(reason => {
console.error(
`Failed to load settings for ${inlineProviderPlugin.id}`,
reason
);
});

Promise.all([app.restored, settingsRegistry?.load(chatPlugin.id)])
.then(([, settings]) => {
if (!settings) {
Expand Down
14 changes: 10 additions & 4 deletions src/provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import {

import { Throttler } from '@lumino/polling';

import MistralClient, { CompletionRequest } from '@mistralai/mistralai';
import { CompletionRequest } from '@mistralai/mistralai';

import type { MistralAI } from '@langchain/mistralai';

/*
* The Mistral API has a rate limit of 1 request per second
Expand All @@ -20,7 +22,11 @@ export class CodestralProvider implements IInlineCompletionProvider {
constructor(options: CodestralProvider.IOptions) {
this._mistralClient = options.mistralClient;
this._throttler = new Throttler(async (data: CompletionRequest) => {
const response = await this._mistralClient.completion(data);
const response = await this._mistralClient.completionWithRetry(
data,
{},
false
);
const items = response.choices.map((choice: any) => {
return { insertText: choice.message.content as string };
});
Expand Down Expand Up @@ -61,11 +67,11 @@ export class CodestralProvider implements IInlineCompletionProvider {
}

private _throttler: Throttler;
private _mistralClient: MistralClient;
private _mistralClient: MistralAI;
}

export namespace CodestralProvider {
export interface IOptions {
mistralClient: MistralClient;
mistralClient: MistralAI;
}
}
Loading

0 comments on commit 2642e3f

Please sign in to comment.