From f6e88a91fd4367ca88163d4b2cb24467cd998ef7 Mon Sep 17 00:00:00 2001 From: Nicolas Brichet Date: Tue, 5 Nov 2024 17:08:26 +0100 Subject: [PATCH] Add a method to get the error message when catching an error --- src/chat-handler.ts | 36 ++++++++++++++++++++---------------- src/index.ts | 7 +------ src/llm-models/utils.ts | 10 ++++++++++ src/token.ts | 2 +- 4 files changed, 32 insertions(+), 23 deletions(-) diff --git a/src/chat-handler.ts b/src/chat-handler.ts index c5b9074..a9b0ef8 100644 --- a/src/chat-handler.ts +++ b/src/chat-handler.ts @@ -16,6 +16,8 @@ import { mergeMessageRuns } from '@langchain/core/messages'; import { UUID } from '@lumino/coreutils'; +import { getErrorMessage } from './llm-models'; +import { IAIProvider } from './token'; export type ConnectionMessage = { type: 'connection'; @@ -25,14 +27,14 @@ export type ConnectionMessage = { export class ChatHandler extends ChatModel { constructor(options: ChatHandler.IOptions) { super(options); - this._provider = options.provider; + this._aiProvider = options.aiProvider; + this._aiProvider.modelChange.connect(() => { + this._errorMessage = this._aiProvider.chatError; + }); } get provider(): BaseChatModel | null { - return this._provider; - } - set provider(provider: BaseChatModel | null) { - this._provider = provider; + return this._aiProvider.chatModel; } async sendMessage(message: INewMessage): Promise { @@ -46,10 +48,10 @@ export class ChatHandler extends ChatModel { }; this.messageAdded(msg); - if (this._provider === null) { + if (this._aiProvider.chatModel === null) { const errorMsg: IChatMessage = { id: UUID.uuid4(), - body: `**${this.message ? this.message : this._defaultMessage}**`, + body: `**${this._errorMessage ? this._errorMessage : this._defaultErrorMessage}**`, sender: { username: 'ERROR' }, time: Date.now(), type: 'msg' @@ -69,14 +71,15 @@ export class ChatHandler extends ChatModel { }) ); - return this._provider + this.updateWriters([{ username: 'AI' }]); + return this._aiProvider.chatModel .invoke(messages) .then(response => { const content = response.content; const botMsg: IChatMessage = { id: UUID.uuid4(), body: content.toString(), - sender: { username: 'Bot' }, + sender: { username: 'AI' }, time: Date.now(), type: 'msg' }; @@ -85,9 +88,7 @@ export class ChatHandler extends ChatModel { return true; }) .catch(reason => { - const error = reason.error.error.message ?? 'Error with the chat API'; - console.log('REASON', error); - console.log('REASON', typeof error); + const error = getErrorMessage(this._aiProvider.name, reason); const errorMsg: IChatMessage = { id: UUID.uuid4(), body: `**${error}**`, @@ -97,6 +98,9 @@ export class ChatHandler extends ChatModel { }; this.messageAdded(errorMsg); return false; + }) + .finally(() => { + this.updateWriters([]); }); } @@ -112,14 +116,14 @@ export class ChatHandler extends ChatModel { super.messageAdded(message); } - message: string = ''; - private _provider: BaseChatModel | null; + private _aiProvider: IAIProvider; + private _errorMessage: string = ''; private _history: IChatHistory = { messages: [] }; - private _defaultMessage = 'AI provider not configured'; + private _defaultErrorMessage = 'AI provider not configured'; } export namespace ChatHandler { export interface IOptions extends ChatModel.IOptions { - provider: BaseChatModel | null; + aiProvider: IAIProvider; } } diff --git a/src/index.ts b/src/index.ts index 435101a..76d2ab2 100644 --- a/src/index.ts +++ b/src/index.ts @@ -41,15 +41,10 @@ const chatPlugin: JupyterFrontEndPlugin = { } const chatHandler = new ChatHandler({ - provider: aiProvider.chatModel, + aiProvider: aiProvider, activeCellManager: activeCellManager }); - aiProvider.modelChange.connect(() => { - chatHandler.provider = aiProvider.chatModel; - chatHandler.message = aiProvider.chatError; - }); - let sendWithShiftEnter = false; let enableCodeToolbar = true; diff --git a/src/llm-models/utils.ts b/src/llm-models/utils.ts index 308f5d3..544d684 100644 --- a/src/llm-models/utils.ts +++ b/src/llm-models/utils.ts @@ -29,3 +29,13 @@ export function getChatModel( } return null; } + +/** + * Get the error message from provider. + */ +export function getErrorMessage(name: string, error: any): string { + if (name === 'MistralAI') { + return error.message; + } + return 'Unknown provider'; +} diff --git a/src/token.ts b/src/token.ts index b94510e..09f5a6e 100644 --- a/src/token.ts +++ b/src/token.ts @@ -5,7 +5,7 @@ import { ISignal } from '@lumino/signaling'; import { IBaseCompleter } from './llm-models'; export interface IAIProvider { - name: string | null; + name: string; completer: IBaseCompleter | null; chatModel: BaseChatModel | null; modelChange: ISignal;