Skip to content

Commit

Permalink
Making the LLM providers more generics (#10)
Browse files Browse the repository at this point in the history
* WIP on making the LLM providers more generics

* Update changes in settings to the chat and completion LLM

* Cleaning and lint

* Provides only one completion provider

* Rename 'client' to 'provider' and LlmProvider to AIProvider for better readability
  • Loading branch information
brichet authored Nov 4, 2024
1 parent f68be53 commit f0637eb
Show file tree
Hide file tree
Showing 14 changed files with 416 additions and 153 deletions.
3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@
"@langchain/core": "^0.3.13",
"@langchain/mistralai": "^0.1.1",
"@lumino/coreutils": "^2.1.2",
"@lumino/polling": "^2.1.2"
"@lumino/polling": "^2.1.2",
"@lumino/signaling": "^2.1.2"
},
"devDependencies": {
"@jupyterlab/builder": "^4.0.0",
Expand Down
21 changes: 21 additions & 0 deletions schema/ai-provider.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
{
"title": "AI provider",
"description": "Provider settings",
"type": "object",
"properties": {
"provider": {
"type": "string",
"title": "The AI provider",
"description": "The AI provider to use for chat and completion",
"default": "None",
"enum": ["None", "MistralAI"]
},
"apiKey": {
"type": "string",
"title": "The Codestral API key",
"description": "The API key to use for Codestral",
"default": ""
}
},
"additionalProperties": false
}
14 changes: 0 additions & 14 deletions schema/inline-provider.json

This file was deleted.

41 changes: 31 additions & 10 deletions src/handler.ts → src/chat-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,30 @@ import {
IChatMessage,
INewMessage
} from '@jupyter/chat';
import { UUID } from '@lumino/coreutils';
import type { ChatMistralAI } from '@langchain/mistralai';
import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
import {
AIMessage,
HumanMessage,
mergeMessageRuns
} from '@langchain/core/messages';
import { UUID } from '@lumino/coreutils';

export type ConnectionMessage = {
type: 'connection';
client_id: string;
};

export class CodestralHandler extends ChatModel {
constructor(options: CodestralHandler.IOptions) {
export class ChatHandler extends ChatModel {
constructor(options: ChatHandler.IOptions) {
super(options);
this._mistralClient = options.mistralClient;
this._provider = options.provider;
}

get provider(): BaseChatModel | null {
return this._provider;
}
set provider(provider: BaseChatModel | null) {
this._provider = provider;
}

async sendMessage(message: INewMessage): Promise<boolean> {
Expand All @@ -38,6 +45,19 @@ export class CodestralHandler extends ChatModel {
type: 'msg'
};
this.messageAdded(msg);

if (this._provider === null) {
const botMsg: IChatMessage = {
id: UUID.uuid4(),
body: '**AI provider not configured for the chat**',
sender: { username: 'ERROR' },
time: Date.now(),
type: 'msg'
};
this.messageAdded(botMsg);
return false;
}

this._history.messages.push(msg);

const messages = mergeMessageRuns(
Expand All @@ -48,13 +68,14 @@ export class CodestralHandler extends ChatModel {
return new AIMessage(msg.body);
})
);
const response = await this._mistralClient.invoke(messages);

const response = await this._provider.invoke(messages);
// TODO: fix deprecated response.text
const content = response.text;
const botMsg: IChatMessage = {
id: UUID.uuid4(),
body: content,
sender: { username: 'Codestral' },
sender: { username: 'Bot' },
time: Date.now(),
type: 'msg'
};
Expand All @@ -75,12 +96,12 @@ export class CodestralHandler extends ChatModel {
super.messageAdded(message);
}

private _mistralClient: ChatMistralAI;
private _provider: BaseChatModel | null;
private _history: IChatHistory = { messages: [] };
}

export namespace CodestralHandler {
export namespace ChatHandler {
export interface IOptions extends ChatModel.IOptions {
mistralClient: ChatMistralAI;
provider: BaseChatModel | null;
}
}
61 changes: 61 additions & 0 deletions src/completion-provider.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import {
CompletionHandler,
IInlineCompletionContext,
IInlineCompletionProvider
} from '@jupyterlab/completer';
import { LLM } from '@langchain/core/language_models/llms';

import { getCompleter, IBaseCompleter } from './llm-models';

/**
* The generic completion provider to register to the completion provider manager.
*/
export class CompletionProvider implements IInlineCompletionProvider {
readonly identifier = '@jupyterlite/ai';

constructor(options: CompletionProvider.IOptions) {
this.name = options.name;
}

/**
* Getter and setter of the name.
* The setter will create the appropriate completer, accordingly to the name.
*/
get name(): string {
return this._name;
}
set name(name: string) {
this._name = name;
this._completer = getCompleter(name);
}

/**
* get the current completer.
*/
get completer(): IBaseCompleter | null {
return this._completer;
}

/**
* Get the LLM completer.
*/
get llmCompleter(): LLM | null {
return this._completer?.provider || null;
}

async fetch(
request: CompletionHandler.IRequest,
context: IInlineCompletionContext
) {
return this._completer?.fetch(request, context);
}

private _name: string = 'None';
private _completer: IBaseCompleter | null = null;
}

export namespace CompletionProvider {
export interface IOptions {
name: string;
}
}
114 changes: 47 additions & 67 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,55 +13,20 @@ import { ICompletionProviderManager } from '@jupyterlab/completer';
import { INotebookTracker } from '@jupyterlab/notebook';
import { IRenderMimeRegistry } from '@jupyterlab/rendermime';
import { ISettingRegistry } from '@jupyterlab/settingregistry';
import { ChatMistralAI, MistralAI } from '@langchain/mistralai';

import { CodestralHandler } from './handler';
import { CodestralProvider } from './provider';

const inlineProviderPlugin: JupyterFrontEndPlugin<void> = {
id: 'jupyterlab-codestral:inline-provider',
autoStart: true,
requires: [ICompletionProviderManager, ISettingRegistry],
activate: (
app: JupyterFrontEnd,
manager: ICompletionProviderManager,
settingRegistry: ISettingRegistry
): void => {
const mistralClient = new MistralAI({
model: 'codestral-latest',
apiKey: 'TMP'
});
const provider = new CodestralProvider({ mistralClient });
manager.registerInlineProvider(provider);

settingRegistry
.load(inlineProviderPlugin.id)
.then(settings => {
const updateKey = () => {
const apiKey = settings.get('apiKey').composite as string;
mistralClient.apiKey = apiKey;
};

settings.changed.connect(() => updateKey());
updateKey();
})
.catch(reason => {
console.error(
`Failed to load settings for ${inlineProviderPlugin.id}`,
reason
);
});
}
};
import { ChatHandler } from './chat-handler';
import { AIProvider } from './provider';
import { IAIProvider } from './token';

const chatPlugin: JupyterFrontEndPlugin<void> = {
id: 'jupyterlab-codestral:chat',
description: 'Codestral chat extension',
description: 'LLM chat extension',
autoStart: true,
optional: [INotebookTracker, ISettingRegistry, IThemeManager],
requires: [IRenderMimeRegistry],
requires: [IAIProvider, IRenderMimeRegistry],
activate: async (
app: JupyterFrontEnd,
aiProvider: IAIProvider,
rmRegistry: IRenderMimeRegistry,
notebookTracker: INotebookTracker | null,
settingsRegistry: ISettingRegistry | null,
Expand All @@ -75,15 +40,15 @@ const chatPlugin: JupyterFrontEndPlugin<void> = {
});
}

const mistralClient = new ChatMistralAI({
model: 'codestral-latest',
apiKey: 'TMP'
});
const chatHandler = new CodestralHandler({
mistralClient,
const chatHandler = new ChatHandler({
provider: aiProvider.chatModel,
activeCellManager: activeCellManager
});

aiProvider.modelChange.connect(() => {
chatHandler.provider = aiProvider.chatModel;
});

let sendWithShiftEnter = false;
let enableCodeToolbar = true;

Expand All @@ -94,25 +59,6 @@ const chatPlugin: JupyterFrontEndPlugin<void> = {
chatHandler.config = { sendWithShiftEnter, enableCodeToolbar };
}

// TODO: handle the apiKey better
settingsRegistry
?.load(inlineProviderPlugin.id)
.then(settings => {
const updateKey = () => {
const apiKey = settings.get('apiKey').composite as string;
mistralClient.apiKey = apiKey;
};

settings.changed.connect(() => updateKey());
updateKey();
})
.catch(reason => {
console.error(
`Failed to load settings for ${inlineProviderPlugin.id}`,
reason
);
});

Promise.all([app.restored, settingsRegistry?.load(chatPlugin.id)])
.then(([, settings]) => {
if (!settings) {
Expand Down Expand Up @@ -148,4 +94,38 @@ const chatPlugin: JupyterFrontEndPlugin<void> = {
}
};

export default [inlineProviderPlugin, chatPlugin];
const aiProviderPlugin: JupyterFrontEndPlugin<IAIProvider> = {
id: 'jupyterlab-codestral:ai-provider',
autoStart: true,
requires: [ICompletionProviderManager, ISettingRegistry],
provides: IAIProvider,
activate: (
app: JupyterFrontEnd,
manager: ICompletionProviderManager,
settingRegistry: ISettingRegistry
): IAIProvider => {
const aiProvider = new AIProvider({ completionProviderManager: manager });

settingRegistry
.load(aiProviderPlugin.id)
.then(settings => {
const updateProvider = () => {
const provider = settings.get('provider').composite as string;
aiProvider.setModels(provider, settings.composite);
};

settings.changed.connect(() => updateProvider());
updateProvider();
})
.catch(reason => {
console.error(
`Failed to load settings for ${aiProviderPlugin.id}`,
reason
);
});

return aiProvider;
}
};

export default [chatPlugin, aiProviderPlugin];
20 changes: 20 additions & 0 deletions src/llm-models/base-completer.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import {
CompletionHandler,
IInlineCompletionContext
} from '@jupyterlab/completer';
import { LLM } from '@langchain/core/language_models/llms';

export interface IBaseCompleter {
/**
* The LLM completer.
*/
provider: LLM;

/**
* The fetch request for the LLM completer.
*/
fetch(
request: CompletionHandler.IRequest,
context: IInlineCompletionContext
): Promise<any>;
}
Loading

0 comments on commit f0637eb

Please sign in to comment.