diff --git a/src/completion-provider.ts b/src/completion-provider.ts index e7000c5..f036c0e 100644 --- a/src/completion-provider.ts +++ b/src/completion-provider.ts @@ -16,6 +16,7 @@ export class CompletionProvider implements IInlineCompletionProvider { constructor(options: CompletionProvider.IOptions) { const { name, settings } = options; + this._requestCompletion = options.requestCompletion; this.setCompleter(name, settings); } @@ -28,6 +29,9 @@ export class CompletionProvider implements IInlineCompletionProvider { setCompleter(name: string, settings: ReadonlyPartialJSONObject) { try { this._completer = getCompleter(name, settings); + if (this._completer) { + this._completer.requestCompletion = this._requestCompletion; + } this._name = this._completer === null ? 'None' : name; } catch (e: any) { this._completer = null; @@ -65,11 +69,13 @@ export class CompletionProvider implements IInlineCompletionProvider { } private _name: string = 'None'; + private _requestCompletion: () => void; private _completer: IBaseCompleter | null = null; } export namespace CompletionProvider { export interface IOptions extends BaseCompleter.IOptions { name: string; + requestCompletion: () => void; } } diff --git a/src/index.ts b/src/index.ts index 76d2ab2..b056487 100644 --- a/src/index.ts +++ b/src/index.ts @@ -100,7 +100,10 @@ const aiProviderPlugin: JupyterFrontEndPlugin = { manager: ICompletionProviderManager, settingRegistry: ISettingRegistry ): IAIProvider => { - const aiProvider = new AIProvider({ completionProviderManager: manager }); + const aiProvider = new AIProvider({ + completionProviderManager: manager, + requestCompletion: () => app.commands.execute('inline-completer:invoke') + }); settingRegistry .load(aiProviderPlugin.id) diff --git a/src/llm-models/base-completer.ts b/src/llm-models/base-completer.ts index fb84f4f..0828a9c 100644 --- a/src/llm-models/base-completer.ts +++ b/src/llm-models/base-completer.ts @@ -11,6 +11,11 @@ export interface IBaseCompleter { */ provider: LLM; + /** + * The function to fetch a new completion. + */ + requestCompletion?: () => void; + /** * The fetch request for the LLM completer. */ diff --git a/src/llm-models/codestral-completer.ts b/src/llm-models/codestral-completer.ts index efa7934..4db7313 100644 --- a/src/llm-models/codestral-completer.ts +++ b/src/llm-models/codestral-completer.ts @@ -9,34 +9,67 @@ import { CompletionRequest } from '@mistralai/mistralai'; import { BaseCompleter, IBaseCompleter } from './base-completer'; -/* +/** * The Mistral API has a rate limit of 1 request per second */ const INTERVAL = 1000; +/** + * Timeout to avoid endless requests + */ +const REQUEST_TIMEOUT = 3000; + export class CodestralCompleter implements IBaseCompleter { constructor(options: BaseCompleter.IOptions) { + // this._requestCompletion = options.requestCompletion; this._mistralProvider = new MistralAI({ ...options.settings }); - this._throttler = new Throttler(async (data: CompletionRequest) => { - const response = await this._mistralProvider.completionWithRetry( - data, - {}, - false - ); - const items = response.choices.map((choice: any) => { - return { insertText: choice.message.content as string }; - }); + this._throttler = new Throttler( + async (data: CompletionRequest) => { + const invokedData = data; + + // Request completion. + const request = this._mistralProvider.completionWithRetry( + data, + {}, + false + ); + const timeoutPromise = new Promise(resolve => { + return setTimeout(() => resolve(null), REQUEST_TIMEOUT); + }); + + // Fetch again if the request is too long or if the prompt has changed. + const response = await Promise.race([request, timeoutPromise]); + if ( + response === null || + invokedData.prompt !== this._currentData?.prompt + ) { + return { + items: [], + fetchAgain: true + }; + } - return { - items - }; - }, INTERVAL); + // Extract results of completion request. + const items = response.choices.map((choice: any) => { + return { insertText: choice.message.content as string }; + }); + + return { + items + }; + }, + { limit: INTERVAL } + ); } get provider(): LLM { return this._mistralProvider; } + set requestCompletion(value: () => void) { + this._requestCompletion = value; + } + async fetch( request: CompletionHandler.IRequest, context: IInlineCompletionContext @@ -59,13 +92,22 @@ export class CodestralCompleter implements IBaseCompleter { }; try { - return this._throttler.invoke(data); + this._currentData = data; + const completionResult = await this._throttler.invoke(data); + if (completionResult.fetchAgain) { + if (this._requestCompletion) { + this._requestCompletion(); + } + } + return { items: completionResult.items }; } catch (error) { console.error('Error fetching completions', error); return { items: [] }; } } + private _requestCompletion?: () => void; private _throttler: Throttler; private _mistralProvider: MistralAI; + private _currentData: CompletionRequest | null = null; } diff --git a/src/provider.ts b/src/provider.ts index 6019785..1347b5b 100644 --- a/src/provider.ts +++ b/src/provider.ts @@ -12,7 +12,8 @@ export class AIProvider implements IAIProvider { constructor(options: AIProvider.IOptions) { this._completionProvider = new CompletionProvider({ name: 'None', - settings: {} + settings: {}, + requestCompletion: options.requestCompletion }); options.completionProviderManager.registerInlineProvider( this._completionProvider @@ -103,6 +104,10 @@ export namespace AIProvider { * The completion provider manager in which register the LLM completer. */ completionProviderManager: ICompletionProviderManager; + /** + * The application commands registry. + */ + requestCompletion: () => void; } /**