Skip to content

Commit

Permalink
Improves the relevance of codestral completion (#18)
Browse files Browse the repository at this point in the history
* Improves the relevances of codestral completion

* Add a timeout to avoid endless requests

* Remove unused dependency

* Fetch again if the prompt has changed between the request and the response

* lint
  • Loading branch information
brichet authored Dec 3, 2024
1 parent 8043bf9 commit a5a3bd6
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 17 deletions.
6 changes: 6 additions & 0 deletions src/completion-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ export class CompletionProvider implements IInlineCompletionProvider {

constructor(options: CompletionProvider.IOptions) {
const { name, settings } = options;
this._requestCompletion = options.requestCompletion;
this.setCompleter(name, settings);
}

Expand All @@ -28,6 +29,9 @@ export class CompletionProvider implements IInlineCompletionProvider {
setCompleter(name: string, settings: ReadonlyPartialJSONObject) {
try {
this._completer = getCompleter(name, settings);
if (this._completer) {
this._completer.requestCompletion = this._requestCompletion;
}
this._name = this._completer === null ? 'None' : name;
} catch (e: any) {
this._completer = null;
Expand Down Expand Up @@ -65,11 +69,13 @@ export class CompletionProvider implements IInlineCompletionProvider {
}

private _name: string = 'None';
private _requestCompletion: () => void;
private _completer: IBaseCompleter | null = null;
}

export namespace CompletionProvider {
export interface IOptions extends BaseCompleter.IOptions {
name: string;
requestCompletion: () => void;
}
}
5 changes: 4 additions & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ const aiProviderPlugin: JupyterFrontEndPlugin<IAIProvider> = {
manager: ICompletionProviderManager,
settingRegistry: ISettingRegistry
): IAIProvider => {
const aiProvider = new AIProvider({ completionProviderManager: manager });
const aiProvider = new AIProvider({
completionProviderManager: manager,
requestCompletion: () => app.commands.execute('inline-completer:invoke')
});

settingRegistry
.load(aiProviderPlugin.id)
Expand Down
5 changes: 5 additions & 0 deletions src/llm-models/base-completer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ export interface IBaseCompleter {
*/
provider: LLM;

/**
* The function to fetch a new completion.
*/
requestCompletion?: () => void;

/**
* The fetch request for the LLM completer.
*/
Expand Down
72 changes: 57 additions & 15 deletions src/llm-models/codestral-completer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,67 @@ import { CompletionRequest } from '@mistralai/mistralai';

import { BaseCompleter, IBaseCompleter } from './base-completer';

/*
/**
* The Mistral API has a rate limit of 1 request per second
*/
const INTERVAL = 1000;

/**
* Timeout to avoid endless requests
*/
const REQUEST_TIMEOUT = 3000;

export class CodestralCompleter implements IBaseCompleter {
constructor(options: BaseCompleter.IOptions) {
// this._requestCompletion = options.requestCompletion;
this._mistralProvider = new MistralAI({ ...options.settings });
this._throttler = new Throttler(async (data: CompletionRequest) => {
const response = await this._mistralProvider.completionWithRetry(
data,
{},
false
);
const items = response.choices.map((choice: any) => {
return { insertText: choice.message.content as string };
});
this._throttler = new Throttler(
async (data: CompletionRequest) => {
const invokedData = data;

// Request completion.
const request = this._mistralProvider.completionWithRetry(
data,
{},
false
);
const timeoutPromise = new Promise<null>(resolve => {
return setTimeout(() => resolve(null), REQUEST_TIMEOUT);
});

// Fetch again if the request is too long or if the prompt has changed.
const response = await Promise.race([request, timeoutPromise]);
if (
response === null ||
invokedData.prompt !== this._currentData?.prompt
) {
return {
items: [],
fetchAgain: true
};
}

return {
items
};
}, INTERVAL);
// Extract results of completion request.
const items = response.choices.map((choice: any) => {
return { insertText: choice.message.content as string };
});

return {
items
};
},
{ limit: INTERVAL }
);
}

get provider(): LLM {
return this._mistralProvider;
}

set requestCompletion(value: () => void) {
this._requestCompletion = value;
}

async fetch(
request: CompletionHandler.IRequest,
context: IInlineCompletionContext
Expand All @@ -59,13 +92,22 @@ export class CodestralCompleter implements IBaseCompleter {
};

try {
return this._throttler.invoke(data);
this._currentData = data;
const completionResult = await this._throttler.invoke(data);
if (completionResult.fetchAgain) {
if (this._requestCompletion) {
this._requestCompletion();
}
}
return { items: completionResult.items };
} catch (error) {
console.error('Error fetching completions', error);
return { items: [] };
}
}

private _requestCompletion?: () => void;
private _throttler: Throttler;
private _mistralProvider: MistralAI;
private _currentData: CompletionRequest | null = null;
}
7 changes: 6 additions & 1 deletion src/provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ export class AIProvider implements IAIProvider {
constructor(options: AIProvider.IOptions) {
this._completionProvider = new CompletionProvider({
name: 'None',
settings: {}
settings: {},
requestCompletion: options.requestCompletion
});
options.completionProviderManager.registerInlineProvider(
this._completionProvider
Expand Down Expand Up @@ -103,6 +104,10 @@ export namespace AIProvider {
* The completion provider manager in which register the LLM completer.
*/
completionProviderManager: ICompletionProviderManager;
/**
* The application commands registry.
*/
requestCompletion: () => void;
}

/**
Expand Down

0 comments on commit a5a3bd6

Please sign in to comment.