From f0637eb68d31b7668f1289548004a5168538456f Mon Sep 17 00:00:00 2001 From: Nicolas Brichet <32258950+brichet@users.noreply.github.com> Date: Mon, 4 Nov 2024 15:02:31 +0100 Subject: [PATCH] Making the LLM providers more generics (#10) * WIP on making the LLM providers more generics * Update changes in settings to the chat and completion LLM * Cleaning and lint * Provides only one completion provider * Rename 'client' to 'provider' and LlmProvider to AIProvider for better readability --- package.json | 3 +- schema/ai-provider.json | 21 ++++ schema/inline-provider.json | 14 --- src/{handler.ts => chat-handler.ts} | 41 ++++-- src/completion-provider.ts | 61 +++++++++ src/index.ts | 114 +++++++---------- src/llm-models/base-completer.ts | 20 +++ src/llm-models/codestral-completer.ts | 74 +++++++++++ src/llm-models/index.ts | 3 + src/llm-models/utils.ts | 24 ++++ src/provider.ts | 174 +++++++++++++++++--------- src/token.ts | 17 +++ tsconfig.json | 2 +- yarn.lock | 1 + 14 files changed, 416 insertions(+), 153 deletions(-) create mode 100644 schema/ai-provider.json delete mode 100644 schema/inline-provider.json rename src/{handler.ts => chat-handler.ts} (64%) create mode 100644 src/completion-provider.ts create mode 100644 src/llm-models/base-completer.ts create mode 100644 src/llm-models/codestral-completer.ts create mode 100644 src/llm-models/index.ts create mode 100644 src/llm-models/utils.ts create mode 100644 src/token.ts diff --git a/package.json b/package.json index 4816377..539736a 100644 --- a/package.json +++ b/package.json @@ -63,7 +63,8 @@ "@langchain/core": "^0.3.13", "@langchain/mistralai": "^0.1.1", "@lumino/coreutils": "^2.1.2", - "@lumino/polling": "^2.1.2" + "@lumino/polling": "^2.1.2", + "@lumino/signaling": "^2.1.2" }, "devDependencies": { "@jupyterlab/builder": "^4.0.0", diff --git a/schema/ai-provider.json b/schema/ai-provider.json new file mode 100644 index 0000000..d4b9a04 --- /dev/null +++ b/schema/ai-provider.json @@ -0,0 +1,21 @@ +{ + "title": "AI provider", + "description": "Provider settings", + "type": "object", + "properties": { + "provider": { + "type": "string", + "title": "The AI provider", + "description": "The AI provider to use for chat and completion", + "default": "None", + "enum": ["None", "MistralAI"] + }, + "apiKey": { + "type": "string", + "title": "The Codestral API key", + "description": "The API key to use for Codestral", + "default": "" + } + }, + "additionalProperties": false +} diff --git a/schema/inline-provider.json b/schema/inline-provider.json deleted file mode 100644 index 12a7219..0000000 --- a/schema/inline-provider.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "title": "Codestral", - "description": "Codestral settings", - "type": "object", - "properties": { - "apiKey": { - "type": "string", - "title": "The Codestral API key", - "description": "The API key to use for Codestral", - "default": "" - } - }, - "additionalProperties": false -} diff --git a/src/handler.ts b/src/chat-handler.ts similarity index 64% rename from src/handler.ts rename to src/chat-handler.ts index 5e9d8d3..18417f6 100644 --- a/src/handler.ts +++ b/src/chat-handler.ts @@ -9,23 +9,30 @@ import { IChatMessage, INewMessage } from '@jupyter/chat'; -import { UUID } from '@lumino/coreutils'; -import type { ChatMistralAI } from '@langchain/mistralai'; +import type { BaseChatModel } from '@langchain/core/language_models/chat_models'; import { AIMessage, HumanMessage, mergeMessageRuns } from '@langchain/core/messages'; +import { UUID } from '@lumino/coreutils'; export type ConnectionMessage = { type: 'connection'; client_id: string; }; -export class CodestralHandler extends ChatModel { - constructor(options: CodestralHandler.IOptions) { +export class ChatHandler extends ChatModel { + constructor(options: ChatHandler.IOptions) { super(options); - this._mistralClient = options.mistralClient; + this._provider = options.provider; + } + + get provider(): BaseChatModel | null { + return this._provider; + } + set provider(provider: BaseChatModel | null) { + this._provider = provider; } async sendMessage(message: INewMessage): Promise { @@ -38,6 +45,19 @@ export class CodestralHandler extends ChatModel { type: 'msg' }; this.messageAdded(msg); + + if (this._provider === null) { + const botMsg: IChatMessage = { + id: UUID.uuid4(), + body: '**AI provider not configured for the chat**', + sender: { username: 'ERROR' }, + time: Date.now(), + type: 'msg' + }; + this.messageAdded(botMsg); + return false; + } + this._history.messages.push(msg); const messages = mergeMessageRuns( @@ -48,13 +68,14 @@ export class CodestralHandler extends ChatModel { return new AIMessage(msg.body); }) ); - const response = await this._mistralClient.invoke(messages); + + const response = await this._provider.invoke(messages); // TODO: fix deprecated response.text const content = response.text; const botMsg: IChatMessage = { id: UUID.uuid4(), body: content, - sender: { username: 'Codestral' }, + sender: { username: 'Bot' }, time: Date.now(), type: 'msg' }; @@ -75,12 +96,12 @@ export class CodestralHandler extends ChatModel { super.messageAdded(message); } - private _mistralClient: ChatMistralAI; + private _provider: BaseChatModel | null; private _history: IChatHistory = { messages: [] }; } -export namespace CodestralHandler { +export namespace ChatHandler { export interface IOptions extends ChatModel.IOptions { - mistralClient: ChatMistralAI; + provider: BaseChatModel | null; } } diff --git a/src/completion-provider.ts b/src/completion-provider.ts new file mode 100644 index 0000000..b2ac0b1 --- /dev/null +++ b/src/completion-provider.ts @@ -0,0 +1,61 @@ +import { + CompletionHandler, + IInlineCompletionContext, + IInlineCompletionProvider +} from '@jupyterlab/completer'; +import { LLM } from '@langchain/core/language_models/llms'; + +import { getCompleter, IBaseCompleter } from './llm-models'; + +/** + * The generic completion provider to register to the completion provider manager. + */ +export class CompletionProvider implements IInlineCompletionProvider { + readonly identifier = '@jupyterlite/ai'; + + constructor(options: CompletionProvider.IOptions) { + this.name = options.name; + } + + /** + * Getter and setter of the name. + * The setter will create the appropriate completer, accordingly to the name. + */ + get name(): string { + return this._name; + } + set name(name: string) { + this._name = name; + this._completer = getCompleter(name); + } + + /** + * get the current completer. + */ + get completer(): IBaseCompleter | null { + return this._completer; + } + + /** + * Get the LLM completer. + */ + get llmCompleter(): LLM | null { + return this._completer?.provider || null; + } + + async fetch( + request: CompletionHandler.IRequest, + context: IInlineCompletionContext + ) { + return this._completer?.fetch(request, context); + } + + private _name: string = 'None'; + private _completer: IBaseCompleter | null = null; +} + +export namespace CompletionProvider { + export interface IOptions { + name: string; + } +} diff --git a/src/index.ts b/src/index.ts index 1e9ec03..2cc8bdc 100644 --- a/src/index.ts +++ b/src/index.ts @@ -13,55 +13,20 @@ import { ICompletionProviderManager } from '@jupyterlab/completer'; import { INotebookTracker } from '@jupyterlab/notebook'; import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; import { ISettingRegistry } from '@jupyterlab/settingregistry'; -import { ChatMistralAI, MistralAI } from '@langchain/mistralai'; -import { CodestralHandler } from './handler'; -import { CodestralProvider } from './provider'; - -const inlineProviderPlugin: JupyterFrontEndPlugin = { - id: 'jupyterlab-codestral:inline-provider', - autoStart: true, - requires: [ICompletionProviderManager, ISettingRegistry], - activate: ( - app: JupyterFrontEnd, - manager: ICompletionProviderManager, - settingRegistry: ISettingRegistry - ): void => { - const mistralClient = new MistralAI({ - model: 'codestral-latest', - apiKey: 'TMP' - }); - const provider = new CodestralProvider({ mistralClient }); - manager.registerInlineProvider(provider); - - settingRegistry - .load(inlineProviderPlugin.id) - .then(settings => { - const updateKey = () => { - const apiKey = settings.get('apiKey').composite as string; - mistralClient.apiKey = apiKey; - }; - - settings.changed.connect(() => updateKey()); - updateKey(); - }) - .catch(reason => { - console.error( - `Failed to load settings for ${inlineProviderPlugin.id}`, - reason - ); - }); - } -}; +import { ChatHandler } from './chat-handler'; +import { AIProvider } from './provider'; +import { IAIProvider } from './token'; const chatPlugin: JupyterFrontEndPlugin = { id: 'jupyterlab-codestral:chat', - description: 'Codestral chat extension', + description: 'LLM chat extension', autoStart: true, optional: [INotebookTracker, ISettingRegistry, IThemeManager], - requires: [IRenderMimeRegistry], + requires: [IAIProvider, IRenderMimeRegistry], activate: async ( app: JupyterFrontEnd, + aiProvider: IAIProvider, rmRegistry: IRenderMimeRegistry, notebookTracker: INotebookTracker | null, settingsRegistry: ISettingRegistry | null, @@ -75,15 +40,15 @@ const chatPlugin: JupyterFrontEndPlugin = { }); } - const mistralClient = new ChatMistralAI({ - model: 'codestral-latest', - apiKey: 'TMP' - }); - const chatHandler = new CodestralHandler({ - mistralClient, + const chatHandler = new ChatHandler({ + provider: aiProvider.chatModel, activeCellManager: activeCellManager }); + aiProvider.modelChange.connect(() => { + chatHandler.provider = aiProvider.chatModel; + }); + let sendWithShiftEnter = false; let enableCodeToolbar = true; @@ -94,25 +59,6 @@ const chatPlugin: JupyterFrontEndPlugin = { chatHandler.config = { sendWithShiftEnter, enableCodeToolbar }; } - // TODO: handle the apiKey better - settingsRegistry - ?.load(inlineProviderPlugin.id) - .then(settings => { - const updateKey = () => { - const apiKey = settings.get('apiKey').composite as string; - mistralClient.apiKey = apiKey; - }; - - settings.changed.connect(() => updateKey()); - updateKey(); - }) - .catch(reason => { - console.error( - `Failed to load settings for ${inlineProviderPlugin.id}`, - reason - ); - }); - Promise.all([app.restored, settingsRegistry?.load(chatPlugin.id)]) .then(([, settings]) => { if (!settings) { @@ -148,4 +94,38 @@ const chatPlugin: JupyterFrontEndPlugin = { } }; -export default [inlineProviderPlugin, chatPlugin]; +const aiProviderPlugin: JupyterFrontEndPlugin = { + id: 'jupyterlab-codestral:ai-provider', + autoStart: true, + requires: [ICompletionProviderManager, ISettingRegistry], + provides: IAIProvider, + activate: ( + app: JupyterFrontEnd, + manager: ICompletionProviderManager, + settingRegistry: ISettingRegistry + ): IAIProvider => { + const aiProvider = new AIProvider({ completionProviderManager: manager }); + + settingRegistry + .load(aiProviderPlugin.id) + .then(settings => { + const updateProvider = () => { + const provider = settings.get('provider').composite as string; + aiProvider.setModels(provider, settings.composite); + }; + + settings.changed.connect(() => updateProvider()); + updateProvider(); + }) + .catch(reason => { + console.error( + `Failed to load settings for ${aiProviderPlugin.id}`, + reason + ); + }); + + return aiProvider; + } +}; + +export default [chatPlugin, aiProviderPlugin]; diff --git a/src/llm-models/base-completer.ts b/src/llm-models/base-completer.ts new file mode 100644 index 0000000..498abf6 --- /dev/null +++ b/src/llm-models/base-completer.ts @@ -0,0 +1,20 @@ +import { + CompletionHandler, + IInlineCompletionContext +} from '@jupyterlab/completer'; +import { LLM } from '@langchain/core/language_models/llms'; + +export interface IBaseCompleter { + /** + * The LLM completer. + */ + provider: LLM; + + /** + * The fetch request for the LLM completer. + */ + fetch( + request: CompletionHandler.IRequest, + context: IInlineCompletionContext + ): Promise; +} diff --git a/src/llm-models/codestral-completer.ts b/src/llm-models/codestral-completer.ts new file mode 100644 index 0000000..f1168c8 --- /dev/null +++ b/src/llm-models/codestral-completer.ts @@ -0,0 +1,74 @@ +import { + CompletionHandler, + IInlineCompletionContext +} from '@jupyterlab/completer'; +import { LLM } from '@langchain/core/language_models/llms'; +import { MistralAI } from '@langchain/mistralai'; +import { Throttler } from '@lumino/polling'; +import { CompletionRequest } from '@mistralai/mistralai'; + +import { IBaseCompleter } from './base-completer'; + +/* + * The Mistral API has a rate limit of 1 request per second + */ +const INTERVAL = 1000; + +export class CodestralCompleter implements IBaseCompleter { + constructor() { + this._mistralProvider = new MistralAI({ + apiKey: 'TMP', + model: 'codestral-latest' + }); + this._throttler = new Throttler(async (data: CompletionRequest) => { + const response = await this._mistralProvider.completionWithRetry( + data, + {}, + false + ); + const items = response.choices.map((choice: any) => { + return { insertText: choice.message.content as string }; + }); + + return { + items + }; + }, INTERVAL); + } + + get provider(): LLM { + return this._mistralProvider; + } + + async fetch( + request: CompletionHandler.IRequest, + context: IInlineCompletionContext + ) { + const { text, offset: cursorOffset } = request; + const prompt = text.slice(0, cursorOffset); + const suffix = text.slice(cursorOffset); + + const data = { + prompt, + suffix, + model: this._mistralProvider.model, + // temperature: 0, + // top_p: 1, + // max_tokens: 1024, + // min_tokens: 0, + stream: false, + // random_seed: 1337, + stop: [] + }; + + try { + return this._throttler.invoke(data); + } catch (error) { + console.error('Error fetching completions', error); + return { items: [] }; + } + } + + private _throttler: Throttler; + private _mistralProvider: MistralAI; +} diff --git a/src/llm-models/index.ts b/src/llm-models/index.ts new file mode 100644 index 0000000..ae6b725 --- /dev/null +++ b/src/llm-models/index.ts @@ -0,0 +1,3 @@ +export * from './base-completer'; +export * from './codestral-completer'; +export * from './utils'; diff --git a/src/llm-models/utils.ts b/src/llm-models/utils.ts new file mode 100644 index 0000000..6d9b9f4 --- /dev/null +++ b/src/llm-models/utils.ts @@ -0,0 +1,24 @@ +import { BaseChatModel } from '@langchain/core/language_models/chat_models'; +import { ChatMistralAI } from '@langchain/mistralai'; +import { IBaseCompleter } from './base-completer'; +import { CodestralCompleter } from './codestral-completer'; + +/** + * Get an LLM completer from the name. + */ +export function getCompleter(name: string): IBaseCompleter | null { + if (name === 'MistralAI') { + return new CodestralCompleter(); + } + return null; +} + +/** + * Get an LLM chat model from the name. + */ +export function getChatModel(name: string): BaseChatModel | null { + if (name === 'MistralAI') { + return new ChatMistralAI({ apiKey: 'TMP' }); + } + return null; +} diff --git a/src/provider.ts b/src/provider.ts index 7c4c1e5..de88ba3 100644 --- a/src/provider.ts +++ b/src/provider.ts @@ -1,77 +1,131 @@ -import { - CompletionHandler, - IInlineCompletionContext, - IInlineCompletionProvider -} from '@jupyterlab/completer'; +import { ICompletionProviderManager } from '@jupyterlab/completer'; +import { BaseLanguageModel } from '@langchain/core/language_models/base'; +import { BaseChatModel } from '@langchain/core/language_models/chat_models'; +import { ISignal, Signal } from '@lumino/signaling'; +import { ReadonlyPartialJSONObject } from '@lumino/coreutils'; -import { Throttler } from '@lumino/polling'; +import { CompletionProvider } from './completion-provider'; +import { getChatModel, IBaseCompleter } from './llm-models'; +import { IAIProvider } from './token'; -import { CompletionRequest } from '@mistralai/mistralai'; - -import type { MistralAI } from '@langchain/mistralai'; - -/* - * The Mistral API has a rate limit of 1 request per second - */ -const INTERVAL = 1000; +export class AIProvider implements IAIProvider { + constructor(options: AIProvider.IOptions) { + this._completionProvider = new CompletionProvider({ name: 'None' }); + options.completionProviderManager.registerInlineProvider( + this._completionProvider + ); + } -export class CodestralProvider implements IInlineCompletionProvider { - readonly identifier = 'Codestral'; - readonly name = 'Codestral'; + get name(): string { + return this._name; + } - constructor(options: CodestralProvider.IOptions) { - this._mistralClient = options.mistralClient; - this._throttler = new Throttler(async (data: CompletionRequest) => { - const response = await this._mistralClient.completionWithRetry( - data, - {}, - false - ); - const items = response.choices.map((choice: any) => { - return { insertText: choice.message.content as string }; - }); + /** + * get the current completer of the completion provider. + */ + get completer(): IBaseCompleter | null { + if (this._name === null) { + return null; + } + return this._completionProvider.completer; + } - return { - items - }; - }, INTERVAL); + /** + * get the current llm chat model. + */ + get chatModel(): BaseChatModel | null { + if (this._name === null) { + return null; + } + return this._llmChatModel; } - async fetch( - request: CompletionHandler.IRequest, - context: IInlineCompletionContext - ) { - const { text, offset: cursorOffset } = request; - const prompt = text.slice(0, cursorOffset); - const suffix = text.slice(cursorOffset); + /** + * Set the models (chat model and completer). + * Creates the models if the name has changed, otherwise only updates their config. + * + * @param name - the name of the model to use. + * @param settings - the settings for the models. + */ + setModels(name: string, settings: ReadonlyPartialJSONObject) { + if (name !== this._name) { + this._name = name; + this._completionProvider.name = name; + this._llmChatModel = getChatModel(name); + this._modelChange.emit(); + } - const data = { - prompt, - suffix, - model: 'codestral-latest', - // temperature: 0, - // top_p: 1, - // max_tokens: 1024, - // min_tokens: 0, - stream: false, - // random_seed: 1337, - stop: [] - }; + // Update the inline completion provider settings. + if (this._completionProvider.llmCompleter) { + AIProvider.updateConfig(this._completionProvider.llmCompleter, settings); + } - try { - return this._throttler.invoke(data); - } catch (error) { - console.error('Error fetching completions', error); - return { items: [] }; + // Update the chat LLM settings. + if (this._llmChatModel) { + AIProvider.updateConfig(this._llmChatModel, settings); } } - private _throttler: Throttler; - private _mistralClient: MistralAI; + get modelChange(): ISignal { + return this._modelChange; + } + + private _completionProvider: CompletionProvider; + private _llmChatModel: BaseChatModel | null = null; + private _name: string = 'None'; + private _modelChange = new Signal(this); } -export namespace CodestralProvider { +export namespace AIProvider { + /** + * The options for the LLM provider. + */ export interface IOptions { - mistralClient: MistralAI; + /** + * The completion provider manager in which register the LLM completer. + */ + completionProviderManager: ICompletionProviderManager; + } + + /** + * This function indicates whether a key is writable in an object. + * https://stackoverflow.com/questions/54724875/can-we-check-whether-property-is-readonly-in-typescript + * + * @param obj - An object extending the BaseLanguageModel interface. + * @param key - A string as a key of the object. + * @returns a boolean whether the key is writable or not. + */ + export function isWritable( + obj: T, + key: keyof T + ) { + const desc = + Object.getOwnPropertyDescriptor(obj, key) || + Object.getOwnPropertyDescriptor(Object.getPrototypeOf(obj), key) || + {}; + return Boolean(desc.writable); + } + + /** + * Update the config of a language model. + * It only updates the writable attributes of the model. + * + * @param model - the model to update. + * @param settings - the configuration s a JSON object. + */ + export function updateConfig( + model: T, + settings: ReadonlyPartialJSONObject + ) { + Object.entries(settings).forEach(([key, value], index) => { + if (key in model) { + const modelKey = key as keyof typeof model; + if (isWritable(model, modelKey)) { + // eslint-disable-next-line @typescript-eslint/ban-ts-comment + // @ts-ignore + model[modelKey] = value; + } + } + }); } } diff --git a/src/token.ts b/src/token.ts new file mode 100644 index 0000000..626be4a --- /dev/null +++ b/src/token.ts @@ -0,0 +1,17 @@ +import { BaseChatModel } from '@langchain/core/language_models/chat_models'; +import { Token } from '@lumino/coreutils'; +import { ISignal } from '@lumino/signaling'; + +import { IBaseCompleter } from './llm-models'; + +export interface IAIProvider { + name: string | null; + completer: IBaseCompleter | null; + chatModel: BaseChatModel | null; + modelChange: ISignal; +} + +export const IAIProvider = new Token( + 'jupyterlab-codestral:AIProvider', + 'Provider for chat and completion LLM provider' +); diff --git a/tsconfig.json b/tsconfig.json index 9897917..bcaac9d 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -19,5 +19,5 @@ "strictNullChecks": true, "target": "ES2018" }, - "include": ["src/*"] + "include": ["src/*", "src/**/*"] } diff --git a/yarn.lock b/yarn.lock index 7bf63b1..47e6599 100644 --- a/yarn.lock +++ b/yarn.lock @@ -4887,6 +4887,7 @@ __metadata: "@langchain/mistralai": ^0.1.1 "@lumino/coreutils": ^2.1.2 "@lumino/polling": ^2.1.2 + "@lumino/signaling": ^2.1.2 "@types/json-schema": ^7.0.11 "@types/react": ^18.0.26 "@types/react-addons-linked-state-mixin": ^0.14.22