Skip to content

Commit

Permalink
Update changes in settings to the chat and completion LLM
Browse files Browse the repository at this point in the history
  • Loading branch information
brichet committed Oct 29, 2024
1 parent 1ff8eef commit f63cce0
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 75 deletions.
12 changes: 1 addition & 11 deletions src/completion-providers/base-provider.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
import { IInlineCompletionProvider } from '@jupyterlab/completer';
import { LLM } from '@langchain/core/language_models/llms';
import { JSONValue } from '@lumino/coreutils';

export interface IBaseProvider extends IInlineCompletionProvider {
configure(settings: { [property: string]: JSONValue }): void;
}

// https://stackoverflow.com/questions/54724875/can-we-check-whether-property-is-readonly-in-typescript
export function isWritable<T extends LLM>(obj: T, key: keyof T) {
const desc =
Object.getOwnPropertyDescriptor(obj, key) ||
Object.getOwnPropertyDescriptor(Object.getPrototypeOf(obj), key) ||
{};
return Boolean(desc.writable);
client: LLM;
}
24 changes: 7 additions & 17 deletions src/completion-providers/codestral-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@ import {
CompletionHandler,
IInlineCompletionContext
} from '@jupyterlab/completer';

import { Throttler } from '@lumino/polling';

import { CompletionRequest } from '@mistralai/mistralai';

import type { MistralAI } from '@langchain/mistralai';
import { JSONValue } from '@lumino/coreutils';
import { IBaseProvider, isWritable } from './base-provider';

import { IBaseProvider } from './base-provider';
import { LLM } from '@langchain/core/language_models/llms';

/*
* The Mistral API has a rate limit of 1 request per second
Expand Down Expand Up @@ -38,6 +36,10 @@ export class CodestralProvider implements IBaseProvider {
}, INTERVAL);
}

get client(): LLM {
return this._mistralClient;
}

async fetch(
request: CompletionHandler.IRequest,
context: IInlineCompletionContext
Expand Down Expand Up @@ -67,18 +69,6 @@ export class CodestralProvider implements IBaseProvider {
}
}

configure(settings: { [property: string]: JSONValue }): void {
Object.entries(settings).forEach(([key, value], index) => {
if (key in this._mistralClient) {
if (isWritable(this._mistralClient, key as keyof MistralAI)) {
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore
this._mistralClient[key as keyof MistralAI] = value;
}
}
});
}

private _throttler: Throttler;
private _mistralClient: MistralAI;
}
Expand Down
17 changes: 0 additions & 17 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,6 @@ import { ChatHandler } from './chat-handler';
import { ILlmProvider } from './token';
import { LlmProvider } from './provider';

// const inlineProviderPlugin: JupyterFrontEndPlugin<void> = {
// id: 'jupyterlab-codestral:inline-provider',
// autoStart: true,
// requires: [ICompletionProviderManager, ILlmProvider, ISettingRegistry],
// activate: (
// app: JupyterFrontEnd,
// manager: ICompletionProviderManager,
// llmProvider: ILlmProvider
// ): void => {
// llmProvider.providerChange.connect(() => {
// if (llmProvider.inlineCompleter !== null) {
// manager.registerInlineProvider(llmProvider.inlineCompleter);
// }
// });
// }
// };

const chatPlugin: JupyterFrontEndPlugin<void> = {
id: 'jupyterlab-codestral:chat',
description: 'LLM chat extension',
Expand Down
95 changes: 67 additions & 28 deletions src/provider.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import {
ICompletionProviderManager,
IInlineCompletionProvider
} from '@jupyterlab/completer';
import { ICompletionProviderManager } from '@jupyterlab/completer';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { ChatMistralAI, MistralAI } from '@langchain/mistralai';
import { ISignal, Signal } from '@lumino/signaling';
import { JSONValue, ReadonlyPartialJSONObject } from '@lumino/coreutils';
import { ReadonlyPartialJSONObject } from '@lumino/coreutils';
import * as completionProviders from './completion-providers';
import { ILlmProvider } from './token';
import { ILlmProvider, IProviders } from './token';
import { IBaseProvider } from './completion-providers/base-provider';
import { isWritable } from './tools';
import { BaseLanguageModel } from '@langchain/core/language_models/base';

export class LlmProvider implements ILlmProvider {
constructor(options: LlmProvider.IOptions) {
Expand All @@ -19,42 +18,65 @@ export class LlmProvider implements ILlmProvider {
return this._name;
}

get inlineProvider(): IInlineCompletionProvider | null {
return this._inlineProvider;
get completionProvider(): IBaseProvider | null {
if (this._name === null) {
return null;
}
return (
this._completionProviders.get(this._name)?.completionProvider || null
);
}

get chatModel(): BaseChatModel | null {
return this._chatModel;
if (this._name === null) {
return null;
}
return this._completionProviders.get(this._name)?.chatModel || null;
}

setProvider(value: string | null, settings: ReadonlyPartialJSONObject) {
if (value === null) {
this._inlineProvider = null;
this._chatModel = null;
setProvider(name: string | null, settings: ReadonlyPartialJSONObject) {
console.log('SET PROVIDER', name);
if (name === null) {
// TODO: the inline completion is not disabled, it should be removed/disabled
// from the manager.
this._providerChange.emit();
return;
}

const provider = this._completionProviders.get(value) as IBaseProvider;
if (provider) {
provider.configure(settings as { [property: string]: JSONValue });
const providers = this._completionProviders.get(name);
if (providers !== undefined) {
console.log('Provider defined');
// Update the inline completion provider settings.
this._updateConfig(providers.completionProvider.client, settings);

// Update the chat LLM settings.
this._updateConfig(providers.chatModel, settings);

if (name !== this._name) {
this._name = name;
this._providerChange.emit();
}
return;
}

if (value === 'MistralAI') {
console.log('Provider undefined');
if (name === 'MistralAI') {
this._name = 'MistralAI';
const mistralClient = new MistralAI({ apiKey: 'TMP', ...settings });
this._inlineProvider = new completionProviders.CodestralProvider({
const mistralClient = new MistralAI({ apiKey: 'TMP' });
this._updateConfig(mistralClient, settings);

const completionProvider = new completionProviders.CodestralProvider({
mistralClient
});
this._completionProviderManager.registerInlineProvider(
this._inlineProvider
completionProvider
);
this._completionProviders.set(value, this._inlineProvider);
this._chatModel = new ChatMistralAI({ apiKey: 'TMP', ...settings });

const chatModel = new ChatMistralAI({ apiKey: 'TMP' });
this._updateConfig(chatModel as any, settings);

this._completionProviders.set(name, { completionProvider, chatModel });
} else {
this._inlineProvider = null;
this._chatModel = null;
this._name = null;
}
this._providerChange.emit();
}
Expand All @@ -63,11 +85,28 @@ export class LlmProvider implements ILlmProvider {
return this._providerChange;
}

private _updateConfig<T extends BaseLanguageModel>(
model: T,
settings: ReadonlyPartialJSONObject
) {
Object.entries(settings).forEach(([key, value], index) => {
if (key in model) {
const modelKey = key as keyof typeof model;
if (isWritable(model, modelKey)) {
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore
model[modelKey] = value;
}
}
});
}

private _completionProviderManager: ICompletionProviderManager;
private _completionProviders = new Map<string, IInlineCompletionProvider>();
// The ICompletionProviderManager does not allow manipulating the providers,
// like getting, removing or recreating them. This map store the created providers to
// be able to modify them.
private _completionProviders = new Map<string, IProviders>();
private _name: string | null = null;
private _inlineProvider: IBaseProvider | null = null;
private _chatModel: BaseChatModel | null = null;
private _providerChange = new Signal<ILlmProvider, void>(this);
}

Expand Down
9 changes: 7 additions & 2 deletions src/token.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import { IInlineCompletionProvider } from '@jupyterlab/completer';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { Token } from '@lumino/coreutils';
import { ISignal } from '@lumino/signaling';
import { IBaseProvider } from './completion-providers/base-provider';

export interface ILlmProvider {
name: string | null;
inlineProvider: IInlineCompletionProvider | null;
completionProvider: IBaseProvider | null;
chatModel: BaseChatModel | null;
providerChange: ISignal<ILlmProvider, void>;
}

export interface IProviders {
completionProvider: IBaseProvider;
chatModel: BaseChatModel;
}

export const ILlmProvider = new Token<ILlmProvider>(
'jupyterlab-codestral:LlmProvider',
'Provider for chat and completion LLM client'
Expand Down
17 changes: 17 additions & 0 deletions src/tools.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import { BaseLanguageModel } from '@langchain/core/language_models/base';

/**
* This function indicates whether a key is writable in an object.
* https://stackoverflow.com/questions/54724875/can-we-check-whether-property-is-readonly-in-typescript
*
* @param obj - An object extending the BaseLanguageModel interface.
* @param key - A string as a key of the object.
* @returns a boolean whether the key is writable or not.
*/
export function isWritable<T extends BaseLanguageModel>(obj: T, key: keyof T) {
const desc =
Object.getOwnPropertyDescriptor(obj, key) ||
Object.getOwnPropertyDescriptor(Object.getPrototypeOf(obj), key) ||
{};
return Boolean(desc.writable);
}

0 comments on commit f63cce0

Please sign in to comment.