Skip to content

Commit

Permalink
Refactoring AIProvider and handling errors (#15)
Browse files Browse the repository at this point in the history
* Better handling of error with providers

* Add a method to get the error message when catching an error
  • Loading branch information
brichet authored Nov 6, 2024
1 parent f0637eb commit 8043bf9
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 67 deletions.
70 changes: 46 additions & 24 deletions src/chat-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import {
mergeMessageRuns
} from '@langchain/core/messages';
import { UUID } from '@lumino/coreutils';
import { getErrorMessage } from './llm-models';
import { IAIProvider } from './token';

export type ConnectionMessage = {
type: 'connection';
Expand All @@ -25,14 +27,14 @@ export type ConnectionMessage = {
export class ChatHandler extends ChatModel {
constructor(options: ChatHandler.IOptions) {
super(options);
this._provider = options.provider;
this._aiProvider = options.aiProvider;
this._aiProvider.modelChange.connect(() => {
this._errorMessage = this._aiProvider.chatError;
});
}

get provider(): BaseChatModel | null {
return this._provider;
}
set provider(provider: BaseChatModel | null) {
this._provider = provider;
return this._aiProvider.chatModel;
}

async sendMessage(message: INewMessage): Promise<boolean> {
Expand All @@ -46,15 +48,15 @@ export class ChatHandler extends ChatModel {
};
this.messageAdded(msg);

if (this._provider === null) {
const botMsg: IChatMessage = {
if (this._aiProvider.chatModel === null) {
const errorMsg: IChatMessage = {
id: UUID.uuid4(),
body: '**AI provider not configured for the chat**',
body: `**${this._errorMessage ? this._errorMessage : this._defaultErrorMessage}**`,
sender: { username: 'ERROR' },
time: Date.now(),
type: 'msg'
};
this.messageAdded(botMsg);
this.messageAdded(errorMsg);
return false;
}

Expand All @@ -69,19 +71,37 @@ export class ChatHandler extends ChatModel {
})
);

const response = await this._provider.invoke(messages);
// TODO: fix deprecated response.text
const content = response.text;
const botMsg: IChatMessage = {
id: UUID.uuid4(),
body: content,
sender: { username: 'Bot' },
time: Date.now(),
type: 'msg'
};
this.messageAdded(botMsg);
this._history.messages.push(botMsg);
return true;
this.updateWriters([{ username: 'AI' }]);
return this._aiProvider.chatModel
.invoke(messages)
.then(response => {
const content = response.content;
const botMsg: IChatMessage = {
id: UUID.uuid4(),
body: content.toString(),
sender: { username: 'AI' },
time: Date.now(),
type: 'msg'
};
this.messageAdded(botMsg);
this._history.messages.push(botMsg);
return true;
})
.catch(reason => {
const error = getErrorMessage(this._aiProvider.name, reason);
const errorMsg: IChatMessage = {
id: UUID.uuid4(),
body: `**${error}**`,
sender: { username: 'ERROR' },
time: Date.now(),
type: 'msg'
};
this.messageAdded(errorMsg);
return false;
})
.finally(() => {
this.updateWriters([]);
});
}

async getHistory(): Promise<IChatHistory> {
Expand All @@ -96,12 +116,14 @@ export class ChatHandler extends ChatModel {
super.messageAdded(message);
}

private _provider: BaseChatModel | null;
private _aiProvider: IAIProvider;
private _errorMessage: string = '';
private _history: IChatHistory = { messages: [] };
private _defaultErrorMessage = 'AI provider not configured';
}

export namespace ChatHandler {
export interface IOptions extends ChatModel.IOptions {
provider: BaseChatModel | null;
aiProvider: IAIProvider;
}
}
34 changes: 24 additions & 10 deletions src/completion-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ import {
} from '@jupyterlab/completer';
import { LLM } from '@langchain/core/language_models/llms';

import { getCompleter, IBaseCompleter } from './llm-models';
import { getCompleter, IBaseCompleter, BaseCompleter } from './llm-models';
import { ReadonlyPartialJSONObject } from '@lumino/coreutils';

/**
* The generic completion provider to register to the completion provider manager.
Expand All @@ -14,23 +15,36 @@ export class CompletionProvider implements IInlineCompletionProvider {
readonly identifier = '@jupyterlite/ai';

constructor(options: CompletionProvider.IOptions) {
this.name = options.name;
const { name, settings } = options;
this.setCompleter(name, settings);
}

/**
* Getter and setter of the name.
* The setter will create the appropriate completer, accordingly to the name.
* Set the completer.
*
* @param name - the name of the completer.
* @param settings - The settings associated to the completer.
*/
setCompleter(name: string, settings: ReadonlyPartialJSONObject) {
try {
this._completer = getCompleter(name, settings);
this._name = this._completer === null ? 'None' : name;
} catch (e: any) {
this._completer = null;
this._name = 'None';
throw e;
}
}

/**
* Get the current completer name.
*/
get name(): string {
return this._name;
}
set name(name: string) {
this._name = name;
this._completer = getCompleter(name);
}

/**
* get the current completer.
* Get the current completer.
*/
get completer(): IBaseCompleter | null {
return this._completer;
Expand All @@ -55,7 +69,7 @@ export class CompletionProvider implements IInlineCompletionProvider {
}

export namespace CompletionProvider {
export interface IOptions {
export interface IOptions extends BaseCompleter.IOptions {
name: string;
}
}
6 changes: 1 addition & 5 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,10 @@ const chatPlugin: JupyterFrontEndPlugin<void> = {
}

const chatHandler = new ChatHandler({
provider: aiProvider.chatModel,
aiProvider: aiProvider,
activeCellManager: activeCellManager
});

aiProvider.modelChange.connect(() => {
chatHandler.provider = aiProvider.chatModel;
});

let sendWithShiftEnter = false;
let enableCodeToolbar = true;

Expand Down
13 changes: 13 additions & 0 deletions src/llm-models/base-completer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import {
IInlineCompletionContext
} from '@jupyterlab/completer';
import { LLM } from '@langchain/core/language_models/llms';
import { ReadonlyPartialJSONObject } from '@lumino/coreutils';

export interface IBaseCompleter {
/**
Expand All @@ -18,3 +19,15 @@ export interface IBaseCompleter {
context: IInlineCompletionContext
): Promise<any>;
}

/**
* The namespace for the base completer.
*/
export namespace BaseCompleter {
/**
* The options for the constructor of a completer.
*/
export interface IOptions {
settings: ReadonlyPartialJSONObject;
}
}
9 changes: 3 additions & 6 deletions src/llm-models/codestral-completer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,16 @@ import { MistralAI } from '@langchain/mistralai';
import { Throttler } from '@lumino/polling';
import { CompletionRequest } from '@mistralai/mistralai';

import { IBaseCompleter } from './base-completer';
import { BaseCompleter, IBaseCompleter } from './base-completer';

/*
* The Mistral API has a rate limit of 1 request per second
*/
const INTERVAL = 1000;

export class CodestralCompleter implements IBaseCompleter {
constructor() {
this._mistralProvider = new MistralAI({
apiKey: 'TMP',
model: 'codestral-latest'
});
constructor(options: BaseCompleter.IOptions) {
this._mistralProvider = new MistralAI({ ...options.settings });
this._throttler = new Throttler(async (data: CompletionRequest) => {
const response = await this._mistralProvider.completionWithRetry(
data,
Expand Down
25 changes: 21 additions & 4 deletions src/llm-models/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,40 @@ import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { ChatMistralAI } from '@langchain/mistralai';
import { IBaseCompleter } from './base-completer';
import { CodestralCompleter } from './codestral-completer';
import { ReadonlyPartialJSONObject } from '@lumino/coreutils';

/**
* Get an LLM completer from the name.
*/
export function getCompleter(name: string): IBaseCompleter | null {
export function getCompleter(
name: string,
settings: ReadonlyPartialJSONObject
): IBaseCompleter | null {
if (name === 'MistralAI') {
return new CodestralCompleter();
return new CodestralCompleter({ settings });
}
return null;
}

/**
* Get an LLM chat model from the name.
*/
export function getChatModel(name: string): BaseChatModel | null {
export function getChatModel(
name: string,
settings: ReadonlyPartialJSONObject
): BaseChatModel | null {
if (name === 'MistralAI') {
return new ChatMistralAI({ apiKey: 'TMP' });
return new ChatMistralAI({ ...settings });
}
return null;
}

/**
* Get the error message from provider.
*/
export function getErrorMessage(name: string, error: any): string {
if (name === 'MistralAI') {
return error.message;
}
return 'Unknown provider';
}
52 changes: 35 additions & 17 deletions src/provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ import { IAIProvider } from './token';

export class AIProvider implements IAIProvider {
constructor(options: AIProvider.IOptions) {
this._completionProvider = new CompletionProvider({ name: 'None' });
this._completionProvider = new CompletionProvider({
name: 'None',
settings: {}
});
options.completionProviderManager.registerInlineProvider(
this._completionProvider
);
Expand All @@ -21,7 +24,7 @@ export class AIProvider implements IAIProvider {
}

/**
* get the current completer of the completion provider.
* Get the current completer of the completion provider.
*/
get completer(): IBaseCompleter | null {
if (this._name === null) {
Expand All @@ -31,7 +34,7 @@ export class AIProvider implements IAIProvider {
}

/**
* get the current llm chat model.
* Get the current llm chat model.
*/
get chatModel(): BaseChatModel | null {
if (this._name === null) {
Expand All @@ -40,6 +43,20 @@ export class AIProvider implements IAIProvider {
return this._llmChatModel;
}

/**
* Get the current chat error;
*/
get chatError(): string {
return this._chatError;
}

/**
* get the current completer error.
*/
get completerError(): string {
return this._completerError;
}

/**
* Set the models (chat model and completer).
* Creates the models if the name has changed, otherwise only updates their config.
Expand All @@ -48,22 +65,21 @@ export class AIProvider implements IAIProvider {
* @param settings - the settings for the models.
*/
setModels(name: string, settings: ReadonlyPartialJSONObject) {
if (name !== this._name) {
this._name = name;
this._completionProvider.name = name;
this._llmChatModel = getChatModel(name);
this._modelChange.emit();
try {
this._completionProvider.setCompleter(name, settings);
this._completerError = '';
} catch (e: any) {
this._completerError = e.message;
}

// Update the inline completion provider settings.
if (this._completionProvider.llmCompleter) {
AIProvider.updateConfig(this._completionProvider.llmCompleter, settings);
}

// Update the chat LLM settings.
if (this._llmChatModel) {
AIProvider.updateConfig(this._llmChatModel, settings);
try {
this._llmChatModel = getChatModel(name, settings);
this._chatError = '';
} catch (e: any) {
this._chatError = e.message;
this._llmChatModel = null;
}
this._name = name;
this._modelChange.emit();
}

get modelChange(): ISignal<IAIProvider, void> {
Expand All @@ -74,6 +90,8 @@ export class AIProvider implements IAIProvider {
private _llmChatModel: BaseChatModel | null = null;
private _name: string = 'None';
private _modelChange = new Signal<IAIProvider, void>(this);
private _chatError: string = '';
private _completerError: string = '';
}

export namespace AIProvider {
Expand Down
4 changes: 3 additions & 1 deletion src/token.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ import { ISignal } from '@lumino/signaling';
import { IBaseCompleter } from './llm-models';

export interface IAIProvider {
name: string | null;
name: string;
completer: IBaseCompleter | null;
chatModel: BaseChatModel | null;
modelChange: ISignal<IAIProvider, void>;
chatError: string;
completerError: string;
}

export const IAIProvider = new Token<IAIProvider>(
Expand Down

0 comments on commit 8043bf9

Please sign in to comment.