Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring AIProvider and handling errors #15

Merged
merged 2 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 46 additions & 24 deletions src/chat-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import {
mergeMessageRuns
} from '@langchain/core/messages';
import { UUID } from '@lumino/coreutils';
import { getErrorMessage } from './llm-models';
import { IAIProvider } from './token';

export type ConnectionMessage = {
type: 'connection';
Expand All @@ -25,14 +27,14 @@ export type ConnectionMessage = {
export class ChatHandler extends ChatModel {
constructor(options: ChatHandler.IOptions) {
super(options);
this._provider = options.provider;
this._aiProvider = options.aiProvider;
this._aiProvider.modelChange.connect(() => {
this._errorMessage = this._aiProvider.chatError;
});
}

get provider(): BaseChatModel | null {
return this._provider;
}
set provider(provider: BaseChatModel | null) {
this._provider = provider;
return this._aiProvider.chatModel;
}

async sendMessage(message: INewMessage): Promise<boolean> {
Expand All @@ -46,15 +48,15 @@ export class ChatHandler extends ChatModel {
};
this.messageAdded(msg);

if (this._provider === null) {
const botMsg: IChatMessage = {
if (this._aiProvider.chatModel === null) {
const errorMsg: IChatMessage = {
id: UUID.uuid4(),
body: '**AI provider not configured for the chat**',
body: `**${this._errorMessage ? this._errorMessage : this._defaultErrorMessage}**`,
sender: { username: 'ERROR' },
time: Date.now(),
type: 'msg'
};
this.messageAdded(botMsg);
this.messageAdded(errorMsg);
return false;
}

Expand All @@ -69,19 +71,37 @@ export class ChatHandler extends ChatModel {
})
);

const response = await this._provider.invoke(messages);
// TODO: fix deprecated response.text
const content = response.text;
const botMsg: IChatMessage = {
id: UUID.uuid4(),
body: content,
sender: { username: 'Bot' },
time: Date.now(),
type: 'msg'
};
this.messageAdded(botMsg);
this._history.messages.push(botMsg);
return true;
this.updateWriters([{ username: 'AI' }]);
return this._aiProvider.chatModel
.invoke(messages)
.then(response => {
const content = response.content;
const botMsg: IChatMessage = {
id: UUID.uuid4(),
body: content.toString(),
sender: { username: 'AI' },
time: Date.now(),
type: 'msg'
};
this.messageAdded(botMsg);
this._history.messages.push(botMsg);
return true;
})
.catch(reason => {
const error = getErrorMessage(this._aiProvider.name, reason);
const errorMsg: IChatMessage = {
id: UUID.uuid4(),
body: `**${error}**`,
sender: { username: 'ERROR' },
time: Date.now(),
type: 'msg'
};
this.messageAdded(errorMsg);
return false;
})
.finally(() => {
this.updateWriters([]);
});
}

async getHistory(): Promise<IChatHistory> {
Expand All @@ -96,12 +116,14 @@ export class ChatHandler extends ChatModel {
super.messageAdded(message);
}

private _provider: BaseChatModel | null;
private _aiProvider: IAIProvider;
private _errorMessage: string = '';
private _history: IChatHistory = { messages: [] };
private _defaultErrorMessage = 'AI provider not configured';
}

export namespace ChatHandler {
export interface IOptions extends ChatModel.IOptions {
provider: BaseChatModel | null;
aiProvider: IAIProvider;
}
}
34 changes: 24 additions & 10 deletions src/completion-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ import {
} from '@jupyterlab/completer';
import { LLM } from '@langchain/core/language_models/llms';

import { getCompleter, IBaseCompleter } from './llm-models';
import { getCompleter, IBaseCompleter, BaseCompleter } from './llm-models';
import { ReadonlyPartialJSONObject } from '@lumino/coreutils';

/**
* The generic completion provider to register to the completion provider manager.
Expand All @@ -14,23 +15,36 @@ export class CompletionProvider implements IInlineCompletionProvider {
readonly identifier = '@jupyterlite/ai';

constructor(options: CompletionProvider.IOptions) {
this.name = options.name;
const { name, settings } = options;
this.setCompleter(name, settings);
}

/**
* Getter and setter of the name.
* The setter will create the appropriate completer, accordingly to the name.
* Set the completer.
*
* @param name - the name of the completer.
* @param settings - The settings associated to the completer.
*/
setCompleter(name: string, settings: ReadonlyPartialJSONObject) {
try {
this._completer = getCompleter(name, settings);
this._name = this._completer === null ? 'None' : name;
} catch (e: any) {
this._completer = null;
this._name = 'None';
throw e;
}
}

/**
* Get the current completer name.
*/
get name(): string {
return this._name;
}
set name(name: string) {
this._name = name;
this._completer = getCompleter(name);
}

/**
* get the current completer.
* Get the current completer.
*/
get completer(): IBaseCompleter | null {
return this._completer;
Expand All @@ -55,7 +69,7 @@ export class CompletionProvider implements IInlineCompletionProvider {
}

export namespace CompletionProvider {
export interface IOptions {
export interface IOptions extends BaseCompleter.IOptions {
name: string;
}
}
6 changes: 1 addition & 5 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,10 @@ const chatPlugin: JupyterFrontEndPlugin<void> = {
}

const chatHandler = new ChatHandler({
provider: aiProvider.chatModel,
aiProvider: aiProvider,
activeCellManager: activeCellManager
});

aiProvider.modelChange.connect(() => {
chatHandler.provider = aiProvider.chatModel;
});

let sendWithShiftEnter = false;
let enableCodeToolbar = true;

Expand Down
13 changes: 13 additions & 0 deletions src/llm-models/base-completer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import {
IInlineCompletionContext
} from '@jupyterlab/completer';
import { LLM } from '@langchain/core/language_models/llms';
import { ReadonlyPartialJSONObject } from '@lumino/coreutils';

export interface IBaseCompleter {
/**
Expand All @@ -18,3 +19,15 @@ export interface IBaseCompleter {
context: IInlineCompletionContext
): Promise<any>;
}

/**
* The namespace for the base completer.
*/
export namespace BaseCompleter {
/**
* The options for the constructor of a completer.
*/
export interface IOptions {
settings: ReadonlyPartialJSONObject;
}
}
9 changes: 3 additions & 6 deletions src/llm-models/codestral-completer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,16 @@ import { MistralAI } from '@langchain/mistralai';
import { Throttler } from '@lumino/polling';
import { CompletionRequest } from '@mistralai/mistralai';

import { IBaseCompleter } from './base-completer';
import { BaseCompleter, IBaseCompleter } from './base-completer';

/*
* The Mistral API has a rate limit of 1 request per second
*/
const INTERVAL = 1000;

export class CodestralCompleter implements IBaseCompleter {
constructor() {
this._mistralProvider = new MistralAI({
apiKey: 'TMP',
model: 'codestral-latest'
});
constructor(options: BaseCompleter.IOptions) {
this._mistralProvider = new MistralAI({ ...options.settings });
this._throttler = new Throttler(async (data: CompletionRequest) => {
const response = await this._mistralProvider.completionWithRetry(
data,
Expand Down
25 changes: 21 additions & 4 deletions src/llm-models/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,40 @@ import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { ChatMistralAI } from '@langchain/mistralai';
import { IBaseCompleter } from './base-completer';
import { CodestralCompleter } from './codestral-completer';
import { ReadonlyPartialJSONObject } from '@lumino/coreutils';

/**
* Get an LLM completer from the name.
*/
export function getCompleter(name: string): IBaseCompleter | null {
export function getCompleter(
name: string,
settings: ReadonlyPartialJSONObject
): IBaseCompleter | null {
if (name === 'MistralAI') {
return new CodestralCompleter();
return new CodestralCompleter({ settings });
}
return null;
}

/**
* Get an LLM chat model from the name.
*/
export function getChatModel(name: string): BaseChatModel | null {
export function getChatModel(
name: string,
settings: ReadonlyPartialJSONObject
): BaseChatModel | null {
if (name === 'MistralAI') {
return new ChatMistralAI({ apiKey: 'TMP' });
return new ChatMistralAI({ ...settings });
}
return null;
}

/**
* Get the error message from provider.
*/
export function getErrorMessage(name: string, error: any): string {
if (name === 'MistralAI') {
return error.message;
}
return 'Unknown provider';
}
52 changes: 35 additions & 17 deletions src/provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ import { IAIProvider } from './token';

export class AIProvider implements IAIProvider {
constructor(options: AIProvider.IOptions) {
this._completionProvider = new CompletionProvider({ name: 'None' });
this._completionProvider = new CompletionProvider({
name: 'None',
settings: {}
});
options.completionProviderManager.registerInlineProvider(
this._completionProvider
);
Expand All @@ -21,7 +24,7 @@ export class AIProvider implements IAIProvider {
}

/**
* get the current completer of the completion provider.
* Get the current completer of the completion provider.
*/
get completer(): IBaseCompleter | null {
if (this._name === null) {
Expand All @@ -31,7 +34,7 @@ export class AIProvider implements IAIProvider {
}

/**
* get the current llm chat model.
* Get the current llm chat model.
*/
get chatModel(): BaseChatModel | null {
if (this._name === null) {
Expand All @@ -40,6 +43,20 @@ export class AIProvider implements IAIProvider {
return this._llmChatModel;
}

/**
* Get the current chat error;
*/
get chatError(): string {
return this._chatError;
}

/**
* get the current completer error.
*/
get completerError(): string {
return this._completerError;
}

/**
* Set the models (chat model and completer).
* Creates the models if the name has changed, otherwise only updates their config.
Expand All @@ -48,22 +65,21 @@ export class AIProvider implements IAIProvider {
* @param settings - the settings for the models.
*/
setModels(name: string, settings: ReadonlyPartialJSONObject) {
if (name !== this._name) {
this._name = name;
this._completionProvider.name = name;
this._llmChatModel = getChatModel(name);
this._modelChange.emit();
try {
this._completionProvider.setCompleter(name, settings);
this._completerError = '';
} catch (e: any) {
this._completerError = e.message;
}

// Update the inline completion provider settings.
if (this._completionProvider.llmCompleter) {
AIProvider.updateConfig(this._completionProvider.llmCompleter, settings);
}

// Update the chat LLM settings.
if (this._llmChatModel) {
AIProvider.updateConfig(this._llmChatModel, settings);
try {
this._llmChatModel = getChatModel(name, settings);
this._chatError = '';
} catch (e: any) {
this._chatError = e.message;
this._llmChatModel = null;
}
this._name = name;
this._modelChange.emit();
}

get modelChange(): ISignal<IAIProvider, void> {
Expand All @@ -74,6 +90,8 @@ export class AIProvider implements IAIProvider {
private _llmChatModel: BaseChatModel | null = null;
private _name: string = 'None';
private _modelChange = new Signal<IAIProvider, void>(this);
private _chatError: string = '';
private _completerError: string = '';
}

export namespace AIProvider {
Expand Down
4 changes: 3 additions & 1 deletion src/token.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ import { ISignal } from '@lumino/signaling';
import { IBaseCompleter } from './llm-models';

export interface IAIProvider {
name: string | null;
name: string;
completer: IBaseCompleter | null;
chatModel: BaseChatModel | null;
modelChange: ISignal<IAIProvider, void>;
chatError: string;
completerError: string;
}

export const IAIProvider = new Token<IAIProvider>(
Expand Down
Loading