Skip to content

Commit

Permalink
Add a method to get the error message when catching an error
Browse files Browse the repository at this point in the history
  • Loading branch information
brichet committed Nov 5, 2024
1 parent 66d416f commit f6e88a9
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 23 deletions.
36 changes: 20 additions & 16 deletions src/chat-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import {
mergeMessageRuns
} from '@langchain/core/messages';
import { UUID } from '@lumino/coreutils';
import { getErrorMessage } from './llm-models';
import { IAIProvider } from './token';

export type ConnectionMessage = {
type: 'connection';
Expand All @@ -25,14 +27,14 @@ export type ConnectionMessage = {
export class ChatHandler extends ChatModel {
constructor(options: ChatHandler.IOptions) {
super(options);
this._provider = options.provider;
this._aiProvider = options.aiProvider;
this._aiProvider.modelChange.connect(() => {
this._errorMessage = this._aiProvider.chatError;
});
}

get provider(): BaseChatModel | null {
return this._provider;
}
set provider(provider: BaseChatModel | null) {
this._provider = provider;
return this._aiProvider.chatModel;
}

async sendMessage(message: INewMessage): Promise<boolean> {
Expand All @@ -46,10 +48,10 @@ export class ChatHandler extends ChatModel {
};
this.messageAdded(msg);

if (this._provider === null) {
if (this._aiProvider.chatModel === null) {
const errorMsg: IChatMessage = {
id: UUID.uuid4(),
body: `**${this.message ? this.message : this._defaultMessage}**`,
body: `**${this._errorMessage ? this._errorMessage : this._defaultErrorMessage}**`,
sender: { username: 'ERROR' },
time: Date.now(),
type: 'msg'
Expand All @@ -69,14 +71,15 @@ export class ChatHandler extends ChatModel {
})
);

return this._provider
this.updateWriters([{ username: 'AI' }]);
return this._aiProvider.chatModel
.invoke(messages)
.then(response => {
const content = response.content;
const botMsg: IChatMessage = {
id: UUID.uuid4(),
body: content.toString(),
sender: { username: 'Bot' },
sender: { username: 'AI' },
time: Date.now(),
type: 'msg'
};
Expand All @@ -85,9 +88,7 @@ export class ChatHandler extends ChatModel {
return true;
})
.catch(reason => {
const error = reason.error.error.message ?? 'Error with the chat API';
console.log('REASON', error);
console.log('REASON', typeof error);
const error = getErrorMessage(this._aiProvider.name, reason);
const errorMsg: IChatMessage = {
id: UUID.uuid4(),
body: `**${error}**`,
Expand All @@ -97,6 +98,9 @@ export class ChatHandler extends ChatModel {
};
this.messageAdded(errorMsg);
return false;
})
.finally(() => {
this.updateWriters([]);
});
}

Expand All @@ -112,14 +116,14 @@ export class ChatHandler extends ChatModel {
super.messageAdded(message);
}

message: string = '';
private _provider: BaseChatModel | null;
private _aiProvider: IAIProvider;
private _errorMessage: string = '';
private _history: IChatHistory = { messages: [] };
private _defaultMessage = 'AI provider not configured';
private _defaultErrorMessage = 'AI provider not configured';
}

export namespace ChatHandler {
export interface IOptions extends ChatModel.IOptions {
provider: BaseChatModel | null;
aiProvider: IAIProvider;
}
}
7 changes: 1 addition & 6 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,10 @@ const chatPlugin: JupyterFrontEndPlugin<void> = {
}

const chatHandler = new ChatHandler({
provider: aiProvider.chatModel,
aiProvider: aiProvider,
activeCellManager: activeCellManager
});

aiProvider.modelChange.connect(() => {
chatHandler.provider = aiProvider.chatModel;
chatHandler.message = aiProvider.chatError;
});

let sendWithShiftEnter = false;
let enableCodeToolbar = true;

Expand Down
10 changes: 10 additions & 0 deletions src/llm-models/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,13 @@ export function getChatModel(
}
return null;
}

/**
* Get the error message from provider.
*/
export function getErrorMessage(name: string, error: any): string {
if (name === 'MistralAI') {
return error.message;
}
return 'Unknown provider';
}
2 changes: 1 addition & 1 deletion src/token.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { ISignal } from '@lumino/signaling';
import { IBaseCompleter } from './llm-models';

export interface IAIProvider {
name: string | null;
name: string;
completer: IBaseCompleter | null;
chatModel: BaseChatModel | null;
modelChange: ISignal<IAIProvider, void>;
Expand Down

0 comments on commit f6e88a9

Please sign in to comment.