Skip to content

Commit

Permalink
Better handling of error with providers
Browse files Browse the repository at this point in the history
  • Loading branch information
brichet committed Nov 5, 2024
1 parent f0637eb commit 66d416f
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 53 deletions.
50 changes: 34 additions & 16 deletions src/chat-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ export class ChatHandler extends ChatModel {
this.messageAdded(msg);

if (this._provider === null) {
const botMsg: IChatMessage = {
const errorMsg: IChatMessage = {
id: UUID.uuid4(),
body: '**AI provider not configured for the chat**',
body: `**${this.message ? this.message : this._defaultMessage}**`,
sender: { username: 'ERROR' },
time: Date.now(),
type: 'msg'
};
this.messageAdded(botMsg);
this.messageAdded(errorMsg);
return false;
}

Expand All @@ -69,19 +69,35 @@ export class ChatHandler extends ChatModel {
})
);

const response = await this._provider.invoke(messages);
// TODO: fix deprecated response.text
const content = response.text;
const botMsg: IChatMessage = {
id: UUID.uuid4(),
body: content,
sender: { username: 'Bot' },
time: Date.now(),
type: 'msg'
};
this.messageAdded(botMsg);
this._history.messages.push(botMsg);
return true;
return this._provider
.invoke(messages)
.then(response => {
const content = response.content;
const botMsg: IChatMessage = {
id: UUID.uuid4(),
body: content.toString(),
sender: { username: 'Bot' },
time: Date.now(),
type: 'msg'
};
this.messageAdded(botMsg);
this._history.messages.push(botMsg);
return true;
})
.catch(reason => {
const error = reason.error.error.message ?? 'Error with the chat API';
console.log('REASON', error);
console.log('REASON', typeof error);
const errorMsg: IChatMessage = {
id: UUID.uuid4(),
body: `**${error}**`,
sender: { username: 'ERROR' },
time: Date.now(),
type: 'msg'
};
this.messageAdded(errorMsg);
return false;
});
}

async getHistory(): Promise<IChatHistory> {
Expand All @@ -96,8 +112,10 @@ export class ChatHandler extends ChatModel {
super.messageAdded(message);
}

message: string = '';
private _provider: BaseChatModel | null;
private _history: IChatHistory = { messages: [] };
private _defaultMessage = 'AI provider not configured';
}

export namespace ChatHandler {
Expand Down
34 changes: 24 additions & 10 deletions src/completion-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ import {
} from '@jupyterlab/completer';
import { LLM } from '@langchain/core/language_models/llms';

import { getCompleter, IBaseCompleter } from './llm-models';
import { getCompleter, IBaseCompleter, BaseCompleter } from './llm-models';
import { ReadonlyPartialJSONObject } from '@lumino/coreutils';

/**
* The generic completion provider to register to the completion provider manager.
Expand All @@ -14,23 +15,36 @@ export class CompletionProvider implements IInlineCompletionProvider {
readonly identifier = '@jupyterlite/ai';

constructor(options: CompletionProvider.IOptions) {
this.name = options.name;
const { name, settings } = options;
this.setCompleter(name, settings);
}

/**
* Getter and setter of the name.
* The setter will create the appropriate completer, accordingly to the name.
* Set the completer.
*
* @param name - the name of the completer.
* @param settings - The settings associated to the completer.
*/
setCompleter(name: string, settings: ReadonlyPartialJSONObject) {
try {
this._completer = getCompleter(name, settings);
this._name = this._completer === null ? 'None' : name;
} catch (e: any) {
this._completer = null;
this._name = 'None';
throw e;
}
}

/**
* Get the current completer name.
*/
get name(): string {
return this._name;
}
set name(name: string) {
this._name = name;
this._completer = getCompleter(name);
}

/**
* get the current completer.
* Get the current completer.
*/
get completer(): IBaseCompleter | null {
return this._completer;
Expand All @@ -55,7 +69,7 @@ export class CompletionProvider implements IInlineCompletionProvider {
}

export namespace CompletionProvider {
export interface IOptions {
export interface IOptions extends BaseCompleter.IOptions {
name: string;
}
}
1 change: 1 addition & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ const chatPlugin: JupyterFrontEndPlugin<void> = {

aiProvider.modelChange.connect(() => {
chatHandler.provider = aiProvider.chatModel;
chatHandler.message = aiProvider.chatError;
});

let sendWithShiftEnter = false;
Expand Down
13 changes: 13 additions & 0 deletions src/llm-models/base-completer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import {
IInlineCompletionContext
} from '@jupyterlab/completer';
import { LLM } from '@langchain/core/language_models/llms';
import { ReadonlyPartialJSONObject } from '@lumino/coreutils';

export interface IBaseCompleter {
/**
Expand All @@ -18,3 +19,15 @@ export interface IBaseCompleter {
context: IInlineCompletionContext
): Promise<any>;
}

/**
* The namespace for the base completer.
*/
export namespace BaseCompleter {
/**
* The options for the constructor of a completer.
*/
export interface IOptions {
settings: ReadonlyPartialJSONObject;
}
}
9 changes: 3 additions & 6 deletions src/llm-models/codestral-completer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,16 @@ import { MistralAI } from '@langchain/mistralai';
import { Throttler } from '@lumino/polling';
import { CompletionRequest } from '@mistralai/mistralai';

import { IBaseCompleter } from './base-completer';
import { BaseCompleter, IBaseCompleter } from './base-completer';

/*
* The Mistral API has a rate limit of 1 request per second
*/
const INTERVAL = 1000;

export class CodestralCompleter implements IBaseCompleter {
constructor() {
this._mistralProvider = new MistralAI({
apiKey: 'TMP',
model: 'codestral-latest'
});
constructor(options: BaseCompleter.IOptions) {
this._mistralProvider = new MistralAI({ ...options.settings });
this._throttler = new Throttler(async (data: CompletionRequest) => {
const response = await this._mistralProvider.completionWithRetry(
data,
Expand Down
15 changes: 11 additions & 4 deletions src/llm-models/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,30 @@ import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { ChatMistralAI } from '@langchain/mistralai';
import { IBaseCompleter } from './base-completer';
import { CodestralCompleter } from './codestral-completer';
import { ReadonlyPartialJSONObject } from '@lumino/coreutils';

/**
* Get an LLM completer from the name.
*/
export function getCompleter(name: string): IBaseCompleter | null {
export function getCompleter(
name: string,
settings: ReadonlyPartialJSONObject
): IBaseCompleter | null {
if (name === 'MistralAI') {
return new CodestralCompleter();
return new CodestralCompleter({ settings });
}
return null;
}

/**
* Get an LLM chat model from the name.
*/
export function getChatModel(name: string): BaseChatModel | null {
export function getChatModel(
name: string,
settings: ReadonlyPartialJSONObject
): BaseChatModel | null {
if (name === 'MistralAI') {
return new ChatMistralAI({ apiKey: 'TMP' });
return new ChatMistralAI({ ...settings });
}
return null;
}
52 changes: 35 additions & 17 deletions src/provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ import { IAIProvider } from './token';

export class AIProvider implements IAIProvider {
constructor(options: AIProvider.IOptions) {
this._completionProvider = new CompletionProvider({ name: 'None' });
this._completionProvider = new CompletionProvider({
name: 'None',
settings: {}
});
options.completionProviderManager.registerInlineProvider(
this._completionProvider
);
Expand All @@ -21,7 +24,7 @@ export class AIProvider implements IAIProvider {
}

/**
* get the current completer of the completion provider.
* Get the current completer of the completion provider.
*/
get completer(): IBaseCompleter | null {
if (this._name === null) {
Expand All @@ -31,7 +34,7 @@ export class AIProvider implements IAIProvider {
}

/**
* get the current llm chat model.
* Get the current llm chat model.
*/
get chatModel(): BaseChatModel | null {
if (this._name === null) {
Expand All @@ -40,6 +43,20 @@ export class AIProvider implements IAIProvider {
return this._llmChatModel;
}

/**
* Get the current chat error;
*/
get chatError(): string {
return this._chatError;
}

/**
* get the current completer error.
*/
get completerError(): string {
return this._completerError;
}

/**
* Set the models (chat model and completer).
* Creates the models if the name has changed, otherwise only updates their config.
Expand All @@ -48,22 +65,21 @@ export class AIProvider implements IAIProvider {
* @param settings - the settings for the models.
*/
setModels(name: string, settings: ReadonlyPartialJSONObject) {
if (name !== this._name) {
this._name = name;
this._completionProvider.name = name;
this._llmChatModel = getChatModel(name);
this._modelChange.emit();
try {
this._completionProvider.setCompleter(name, settings);
this._completerError = '';
} catch (e: any) {
this._completerError = e.message;
}

// Update the inline completion provider settings.
if (this._completionProvider.llmCompleter) {
AIProvider.updateConfig(this._completionProvider.llmCompleter, settings);
}

// Update the chat LLM settings.
if (this._llmChatModel) {
AIProvider.updateConfig(this._llmChatModel, settings);
try {
this._llmChatModel = getChatModel(name, settings);
this._chatError = '';
} catch (e: any) {
this._chatError = e.message;
this._llmChatModel = null;
}
this._name = name;
this._modelChange.emit();
}

get modelChange(): ISignal<IAIProvider, void> {
Expand All @@ -74,6 +90,8 @@ export class AIProvider implements IAIProvider {
private _llmChatModel: BaseChatModel | null = null;
private _name: string = 'None';
private _modelChange = new Signal<IAIProvider, void>(this);
private _chatError: string = '';
private _completerError: string = '';
}

export namespace AIProvider {
Expand Down
2 changes: 2 additions & 0 deletions src/token.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ export interface IAIProvider {
completer: IBaseCompleter | null;
chatModel: BaseChatModel | null;
modelChange: ISignal<IAIProvider, void>;
chatError: string;
completerError: string;
}

export const IAIProvider = new Token<IAIProvider>(
Expand Down

0 comments on commit 66d416f

Please sign in to comment.