Skip to content

Commit

Permalink
Provides only one completion provider
Browse files Browse the repository at this point in the history
  • Loading branch information
brichet committed Oct 30, 2024
1 parent 9f2f249 commit 6e11c72
Show file tree
Hide file tree
Showing 12 changed files with 208 additions and 117 deletions.
2 changes: 1 addition & 1 deletion src/chat-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ import {
INewMessage
} from '@jupyter/chat';
import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { UUID } from '@lumino/coreutils';
import {
AIMessage,
HumanMessage,
mergeMessageRuns
} from '@langchain/core/messages';
import { UUID } from '@lumino/coreutils';

export type ConnectionMessage = {
type: 'connection';
Expand Down
61 changes: 61 additions & 0 deletions src/completion-provider.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import {
CompletionHandler,
IInlineCompletionContext,
IInlineCompletionProvider
} from '@jupyterlab/completer';
import { LLM } from '@langchain/core/language_models/llms';

import { getCompleter, IBaseCompleter } from './llm-models';

/**
* The generic completion provider to register to the completion provider manager.
*/
export class CompletionProvider implements IInlineCompletionProvider {
readonly identifier = '@jupyterlite/ai';

constructor(options: CompletionProvider.IOptions) {
this.name = options.name;
}

/**
* Getter and setter of the name.
* The setter will create the appropriate completer, accordingly to the name.
*/
get name(): string {
return this._name;
}
set name(name: string) {
this._name = name;
this._completer = getCompleter(name);
}

/**
* get the current completer.
*/
get completer(): IBaseCompleter | null {
return this._completer;
}

/**
* Get the LLM completer.
*/
get llmCompleter(): LLM | null {
return this._completer?.client || null;
}

async fetch(
request: CompletionHandler.IRequest,
context: IInlineCompletionContext
) {
return this._completer?.fetch(request, context);
}

private _name: string = 'None';
private _completer: IBaseCompleter | null = null;
}

export namespace CompletionProvider {
export interface IOptions {
name: string;
}
}
6 changes: 0 additions & 6 deletions src/completion-providers/base-provider.ts

This file was deleted.

1 change: 0 additions & 1 deletion src/completion-providers/index.ts

This file was deleted.

6 changes: 3 additions & 3 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ import { IRenderMimeRegistry } from '@jupyterlab/rendermime';
import { ISettingRegistry } from '@jupyterlab/settingregistry';

import { ChatHandler } from './chat-handler';
import { ILlmProvider } from './token';
import { LlmProvider } from './provider';
import { ILlmProvider } from './token';

const chatPlugin: JupyterFrontEndPlugin<void> = {
id: 'jupyterlab-codestral:chat',
Expand Down Expand Up @@ -45,7 +45,7 @@ const chatPlugin: JupyterFrontEndPlugin<void> = {
activeCellManager: activeCellManager
});

llmProvider.providerChange.connect(() => {
llmProvider.modelChange.connect(() => {
chatHandler.llmClient = llmProvider.chatModel;
});

Expand Down Expand Up @@ -111,7 +111,7 @@ const llmProviderPlugin: JupyterFrontEndPlugin<ILlmProvider> = {
.then(settings => {
const updateProvider = () => {
const provider = settings.get('provider').composite as string;
llmProvider.setProvider(provider, settings.composite);
llmProvider.setModels(provider, settings.composite);
};

settings.changed.connect(() => updateProvider());
Expand Down
20 changes: 20 additions & 0 deletions src/llm-models/base-completer.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import {
CompletionHandler,
IInlineCompletionContext
} from '@jupyterlab/completer';
import { LLM } from '@langchain/core/language_models/llms';

export interface IBaseCompleter {
/**
* The LLM completer.
*/
client: LLM;

/**
* The fetch request for the LLM completer.
*/
fetch(
request: CompletionHandler.IRequest,
context: IInlineCompletionContext
): Promise<any>;
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,24 @@ import {
CompletionHandler,
IInlineCompletionContext
} from '@jupyterlab/completer';
import { LLM } from '@langchain/core/language_models/llms';
import { MistralAI } from '@langchain/mistralai';
import { Throttler } from '@lumino/polling';
import { CompletionRequest } from '@mistralai/mistralai';
import type { MistralAI } from '@langchain/mistralai';

import { IBaseProvider } from './base-provider';
import { LLM } from '@langchain/core/language_models/llms';
import { IBaseCompleter } from './base-completer';

/*
* The Mistral API has a rate limit of 1 request per second
*/
const INTERVAL = 1000;

export class CodestralProvider implements IBaseProvider {
readonly identifier = 'Codestral';
readonly name = 'Codestral';

constructor(options: CodestralProvider.IOptions) {
this._mistralClient = options.mistralClient;
export class CodestralCompleter implements IBaseCompleter {
constructor() {
this._mistralClient = new MistralAI({
apiKey: 'TMP',
model: 'codestral-latest'
});
this._throttler = new Throttler(async (data: CompletionRequest) => {
const response = await this._mistralClient.completionWithRetry(
data,
Expand Down Expand Up @@ -51,7 +51,7 @@ export class CodestralProvider implements IBaseProvider {
const data = {
prompt,
suffix,
model: 'codestral-latest',
model: this._mistralClient.model,
// temperature: 0,
// top_p: 1,
// max_tokens: 1024,
Expand All @@ -72,9 +72,3 @@ export class CodestralProvider implements IBaseProvider {
private _throttler: Throttler;
private _mistralClient: MistralAI;
}

export namespace CodestralProvider {
export interface IOptions {
mistralClient: MistralAI;
}
}
3 changes: 3 additions & 0 deletions src/llm-models/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
export * from './base-completer';
export * from './codestral-completer';
export * from './utils';
24 changes: 24 additions & 0 deletions src/llm-models/utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { ChatMistralAI } from '@langchain/mistralai';
import { IBaseCompleter } from './base-completer';
import { CodestralCompleter } from './codestral-completer';

/**
* Get an LLM completer from the name.
*/
export function getCompleter(name: string): IBaseCompleter | null {
if (name === 'MistralAI') {
return new CodestralCompleter();
}
return null;
}

/**
* Get an LLM chat model from the name.
*/
export function getChatModel(name: string): BaseChatModel | null {
if (name === 'MistralAI') {
return new ChatMistralAI({ apiKey: 'TMP' });
}
return null;
}
Loading

0 comments on commit 6e11c72

Please sign in to comment.