From 6e11c72dc07b77bcea1398f720cb85ab2c9e5ae7 Mon Sep 17 00:00:00 2001 From: Nicolas Brichet Date: Wed, 30 Oct 2024 16:01:33 +0100 Subject: [PATCH] Provides only one completion provider --- src/chat-handler.ts | 2 +- src/completion-provider.ts | 61 ++++++++ src/completion-providers/base-provider.ts | 6 - src/completion-providers/index.ts | 1 - src/index.ts | 6 +- src/llm-models/base-completer.ts | 20 +++ .../codestral-completer.ts} | 26 ++-- src/llm-models/index.ts | 3 + src/llm-models/utils.ts | 24 +++ src/provider.ts | 147 ++++++++++-------- src/token.ts | 12 +- src/tools.ts | 17 -- 12 files changed, 208 insertions(+), 117 deletions(-) create mode 100644 src/completion-provider.ts delete mode 100644 src/completion-providers/base-provider.ts delete mode 100644 src/completion-providers/index.ts create mode 100644 src/llm-models/base-completer.ts rename src/{completion-providers/codestral-provider.ts => llm-models/codestral-completer.ts} (76%) create mode 100644 src/llm-models/index.ts create mode 100644 src/llm-models/utils.ts delete mode 100644 src/tools.ts diff --git a/src/chat-handler.ts b/src/chat-handler.ts index 47c8867..0191302 100644 --- a/src/chat-handler.ts +++ b/src/chat-handler.ts @@ -10,12 +10,12 @@ import { INewMessage } from '@jupyter/chat'; import type { BaseChatModel } from '@langchain/core/language_models/chat_models'; -import { UUID } from '@lumino/coreutils'; import { AIMessage, HumanMessage, mergeMessageRuns } from '@langchain/core/messages'; +import { UUID } from '@lumino/coreutils'; export type ConnectionMessage = { type: 'connection'; diff --git a/src/completion-provider.ts b/src/completion-provider.ts new file mode 100644 index 0000000..53b2051 --- /dev/null +++ b/src/completion-provider.ts @@ -0,0 +1,61 @@ +import { + CompletionHandler, + IInlineCompletionContext, + IInlineCompletionProvider +} from '@jupyterlab/completer'; +import { LLM } from '@langchain/core/language_models/llms'; + +import { getCompleter, IBaseCompleter } from './llm-models'; + +/** + * The generic completion provider to register to the completion provider manager. + */ +export class CompletionProvider implements IInlineCompletionProvider { + readonly identifier = '@jupyterlite/ai'; + + constructor(options: CompletionProvider.IOptions) { + this.name = options.name; + } + + /** + * Getter and setter of the name. + * The setter will create the appropriate completer, accordingly to the name. + */ + get name(): string { + return this._name; + } + set name(name: string) { + this._name = name; + this._completer = getCompleter(name); + } + + /** + * get the current completer. + */ + get completer(): IBaseCompleter | null { + return this._completer; + } + + /** + * Get the LLM completer. + */ + get llmCompleter(): LLM | null { + return this._completer?.client || null; + } + + async fetch( + request: CompletionHandler.IRequest, + context: IInlineCompletionContext + ) { + return this._completer?.fetch(request, context); + } + + private _name: string = 'None'; + private _completer: IBaseCompleter | null = null; +} + +export namespace CompletionProvider { + export interface IOptions { + name: string; + } +} diff --git a/src/completion-providers/base-provider.ts b/src/completion-providers/base-provider.ts deleted file mode 100644 index f312f17..0000000 --- a/src/completion-providers/base-provider.ts +++ /dev/null @@ -1,6 +0,0 @@ -import { IInlineCompletionProvider } from '@jupyterlab/completer'; -import { LLM } from '@langchain/core/language_models/llms'; - -export interface IBaseProvider extends IInlineCompletionProvider { - client: LLM; -} diff --git a/src/completion-providers/index.ts b/src/completion-providers/index.ts deleted file mode 100644 index fdb3eeb..0000000 --- a/src/completion-providers/index.ts +++ /dev/null @@ -1 +0,0 @@ -export * from './codestral-provider'; diff --git a/src/index.ts b/src/index.ts index b37d29b..fa939a3 100644 --- a/src/index.ts +++ b/src/index.ts @@ -15,8 +15,8 @@ import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; import { ISettingRegistry } from '@jupyterlab/settingregistry'; import { ChatHandler } from './chat-handler'; -import { ILlmProvider } from './token'; import { LlmProvider } from './provider'; +import { ILlmProvider } from './token'; const chatPlugin: JupyterFrontEndPlugin = { id: 'jupyterlab-codestral:chat', @@ -45,7 +45,7 @@ const chatPlugin: JupyterFrontEndPlugin = { activeCellManager: activeCellManager }); - llmProvider.providerChange.connect(() => { + llmProvider.modelChange.connect(() => { chatHandler.llmClient = llmProvider.chatModel; }); @@ -111,7 +111,7 @@ const llmProviderPlugin: JupyterFrontEndPlugin = { .then(settings => { const updateProvider = () => { const provider = settings.get('provider').composite as string; - llmProvider.setProvider(provider, settings.composite); + llmProvider.setModels(provider, settings.composite); }; settings.changed.connect(() => updateProvider()); diff --git a/src/llm-models/base-completer.ts b/src/llm-models/base-completer.ts new file mode 100644 index 0000000..8374db4 --- /dev/null +++ b/src/llm-models/base-completer.ts @@ -0,0 +1,20 @@ +import { + CompletionHandler, + IInlineCompletionContext +} from '@jupyterlab/completer'; +import { LLM } from '@langchain/core/language_models/llms'; + +export interface IBaseCompleter { + /** + * The LLM completer. + */ + client: LLM; + + /** + * The fetch request for the LLM completer. + */ + fetch( + request: CompletionHandler.IRequest, + context: IInlineCompletionContext + ): Promise; +} diff --git a/src/completion-providers/codestral-provider.ts b/src/llm-models/codestral-completer.ts similarity index 76% rename from src/completion-providers/codestral-provider.ts rename to src/llm-models/codestral-completer.ts index 6143bd6..8f3e6ee 100644 --- a/src/completion-providers/codestral-provider.ts +++ b/src/llm-models/codestral-completer.ts @@ -2,24 +2,24 @@ import { CompletionHandler, IInlineCompletionContext } from '@jupyterlab/completer'; +import { LLM } from '@langchain/core/language_models/llms'; +import { MistralAI } from '@langchain/mistralai'; import { Throttler } from '@lumino/polling'; import { CompletionRequest } from '@mistralai/mistralai'; -import type { MistralAI } from '@langchain/mistralai'; -import { IBaseProvider } from './base-provider'; -import { LLM } from '@langchain/core/language_models/llms'; +import { IBaseCompleter } from './base-completer'; /* * The Mistral API has a rate limit of 1 request per second */ const INTERVAL = 1000; -export class CodestralProvider implements IBaseProvider { - readonly identifier = 'Codestral'; - readonly name = 'Codestral'; - - constructor(options: CodestralProvider.IOptions) { - this._mistralClient = options.mistralClient; +export class CodestralCompleter implements IBaseCompleter { + constructor() { + this._mistralClient = new MistralAI({ + apiKey: 'TMP', + model: 'codestral-latest' + }); this._throttler = new Throttler(async (data: CompletionRequest) => { const response = await this._mistralClient.completionWithRetry( data, @@ -51,7 +51,7 @@ export class CodestralProvider implements IBaseProvider { const data = { prompt, suffix, - model: 'codestral-latest', + model: this._mistralClient.model, // temperature: 0, // top_p: 1, // max_tokens: 1024, @@ -72,9 +72,3 @@ export class CodestralProvider implements IBaseProvider { private _throttler: Throttler; private _mistralClient: MistralAI; } - -export namespace CodestralProvider { - export interface IOptions { - mistralClient: MistralAI; - } -} diff --git a/src/llm-models/index.ts b/src/llm-models/index.ts new file mode 100644 index 0000000..ae6b725 --- /dev/null +++ b/src/llm-models/index.ts @@ -0,0 +1,3 @@ +export * from './base-completer'; +export * from './codestral-completer'; +export * from './utils'; diff --git a/src/llm-models/utils.ts b/src/llm-models/utils.ts new file mode 100644 index 0000000..6d9b9f4 --- /dev/null +++ b/src/llm-models/utils.ts @@ -0,0 +1,24 @@ +import { BaseChatModel } from '@langchain/core/language_models/chat_models'; +import { ChatMistralAI } from '@langchain/mistralai'; +import { IBaseCompleter } from './base-completer'; +import { CodestralCompleter } from './codestral-completer'; + +/** + * Get an LLM completer from the name. + */ +export function getCompleter(name: string): IBaseCompleter | null { + if (name === 'MistralAI') { + return new CodestralCompleter(); + } + return null; +} + +/** + * Get an LLM chat model from the name. + */ +export function getChatModel(name: string): BaseChatModel | null { + if (name === 'MistralAI') { + return new ChatMistralAI({ apiKey: 'TMP' }); + } + return null; +} diff --git a/src/provider.ts b/src/provider.ts index 7030eed..1eed586 100644 --- a/src/provider.ts +++ b/src/provider.ts @@ -1,88 +1,119 @@ import { ICompletionProviderManager } from '@jupyterlab/completer'; +import { BaseLanguageModel } from '@langchain/core/language_models/base'; import { BaseChatModel } from '@langchain/core/language_models/chat_models'; -import { ChatMistralAI, MistralAI } from '@langchain/mistralai'; import { ISignal, Signal } from '@lumino/signaling'; import { ReadonlyPartialJSONObject } from '@lumino/coreutils'; -import * as completionProviders from './completion-providers'; -import { ILlmProvider, IProviders } from './token'; -import { IBaseProvider } from './completion-providers/base-provider'; -import { isWritable } from './tools'; -import { BaseLanguageModel } from '@langchain/core/language_models/base'; + +import { CompletionProvider } from './completion-provider'; +import { getChatModel, IBaseCompleter } from './llm-models'; +import { ILlmProvider } from './token'; export class LlmProvider implements ILlmProvider { constructor(options: LlmProvider.IOptions) { - this._completionProviderManager = options.completionProviderManager; + this._completionProvider = new CompletionProvider({ name: 'None' }); + options.completionProviderManager.registerInlineProvider( + this._completionProvider + ); } - get name(): string | null { + get name(): string { return this._name; } - get completionProvider(): IBaseProvider | null { + /** + * get the current completer of the completion provider. + */ + get completer(): IBaseCompleter | null { if (this._name === null) { return null; } - return ( - this._completionProviders.get(this._name)?.completionProvider || null - ); + return this._completionProvider.completer; } + /** + * get the current llm chat model. + */ get chatModel(): BaseChatModel | null { if (this._name === null) { return null; } - return this._completionProviders.get(this._name)?.chatModel || null; + return this._llmChatModel; } - setProvider(name: string | null, settings: ReadonlyPartialJSONObject) { - if (name === null) { - // TODO: the inline completion is not disabled. - // It should be removed/disabled from the manager. - this._providerChange.emit(); - return; + /** + * Set the models (chat model and completer). + * Creates the models if the name has changed, otherwise only updates their config. + * + * @param name - the name of the model to use. + * @param settings - the settings for the models. + */ + setModels(name: string, settings: ReadonlyPartialJSONObject) { + if (name !== this._name) { + this._name = name; + this._completionProvider.name = name; + this._llmChatModel = getChatModel(name); + this._modelChange.emit(); } - const providers = this._completionProviders.get(name); - if (providers !== undefined) { - // Update the inline completion provider settings. - this._updateConfig(providers.completionProvider.client, settings); - - // Update the chat LLM settings. - this._updateConfig(providers.chatModel, settings); + // Update the inline completion provider settings. + if (this._completionProvider.llmCompleter) { + LlmProvider.updateConfig(this._completionProvider.llmCompleter, settings); + } - if (name !== this._name) { - this._name = name; - this._providerChange.emit(); - } - return; + // Update the chat LLM settings. + if (this._llmChatModel) { + LlmProvider.updateConfig(this._llmChatModel, settings); } - if (name === 'MistralAI') { - this._name = 'MistralAI'; - const mistralClient = new MistralAI({ apiKey: 'TMP' }); - this._updateConfig(mistralClient, settings); + } - const completionProvider = new completionProviders.CodestralProvider({ - mistralClient - }); - this._completionProviderManager.registerInlineProvider( - completionProvider - ); + get modelChange(): ISignal { + return this._modelChange; + } - const chatModel = new ChatMistralAI({ apiKey: 'TMP' }); - this._updateConfig(chatModel as any, settings); + private _completionProvider: CompletionProvider; + private _llmChatModel: BaseChatModel | null = null; + private _name: string = 'None'; + private _modelChange = new Signal(this); +} - this._completionProviders.set(name, { completionProvider, chatModel }); - } else { - this._name = null; - } - this._providerChange.emit(); +export namespace LlmProvider { + /** + * The options for the LLM provider. + */ + export interface IOptions { + /** + * The completion provider manager in which register the LLM completer. + */ + completionProviderManager: ICompletionProviderManager; } - get providerChange(): ISignal { - return this._providerChange; + /** + * This function indicates whether a key is writable in an object. + * https://stackoverflow.com/questions/54724875/can-we-check-whether-property-is-readonly-in-typescript + * + * @param obj - An object extending the BaseLanguageModel interface. + * @param key - A string as a key of the object. + * @returns a boolean whether the key is writable or not. + */ + export function isWritable( + obj: T, + key: keyof T + ) { + const desc = + Object.getOwnPropertyDescriptor(obj, key) || + Object.getOwnPropertyDescriptor(Object.getPrototypeOf(obj), key) || + {}; + return Boolean(desc.writable); } - private _updateConfig( + /** + * Update the config of a language model. + * It only updates the writable attributes of the model. + * + * @param model - the model to update. + * @param settings - the configuration s a JSON object. + */ + export function updateConfig( model: T, settings: ReadonlyPartialJSONObject ) { @@ -97,18 +128,4 @@ export class LlmProvider implements ILlmProvider { } }); } - - private _completionProviderManager: ICompletionProviderManager; - // The ICompletionProviderManager does not allow manipulating the providers, - // like getting, removing or recreating them. This map store the created providers to - // be able to modify them. - private _completionProviders = new Map(); - private _name: string | null = null; - private _providerChange = new Signal(this); -} - -export namespace LlmProvider { - export interface IOptions { - completionProviderManager: ICompletionProviderManager; - } } diff --git a/src/token.ts b/src/token.ts index 5e1ae1d..3148938 100644 --- a/src/token.ts +++ b/src/token.ts @@ -1,18 +1,14 @@ import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import { Token } from '@lumino/coreutils'; import { ISignal } from '@lumino/signaling'; -import { IBaseProvider } from './completion-providers/base-provider'; + +import { IBaseCompleter } from './llm-models'; export interface ILlmProvider { name: string | null; - completionProvider: IBaseProvider | null; + completer: IBaseCompleter | null; chatModel: BaseChatModel | null; - providerChange: ISignal; -} - -export interface IProviders { - completionProvider: IBaseProvider; - chatModel: BaseChatModel; + modelChange: ISignal; } export const ILlmProvider = new Token( diff --git a/src/tools.ts b/src/tools.ts deleted file mode 100644 index 4369aec..0000000 --- a/src/tools.ts +++ /dev/null @@ -1,17 +0,0 @@ -import { BaseLanguageModel } from '@langchain/core/language_models/base'; - -/** - * This function indicates whether a key is writable in an object. - * https://stackoverflow.com/questions/54724875/can-we-check-whether-property-is-readonly-in-typescript - * - * @param obj - An object extending the BaseLanguageModel interface. - * @param key - A string as a key of the object. - * @returns a boolean whether the key is writable or not. - */ -export function isWritable(obj: T, key: keyof T) { - const desc = - Object.getOwnPropertyDescriptor(obj, key) || - Object.getOwnPropertyDescriptor(Object.getPrototypeOf(obj), key) || - {}; - return Boolean(desc.writable); -}