Skip to content

Commit

Permalink
Add openAI provider
Browse files Browse the repository at this point in the history
  • Loading branch information
brichet committed Nov 8, 2024
1 parent 8043bf9 commit c5ae0a1
Show file tree
Hide file tree
Showing 8 changed files with 232 additions and 8 deletions.
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
"@jupyterlab/settingregistry": "^4.2.0",
"@langchain/core": "^0.3.13",
"@langchain/mistralai": "^0.1.1",
"@langchain/openai": "^0.3.12",
"@lumino/coreutils": "^2.1.2",
"@lumino/polling": "^2.1.2",
"@lumino/signaling": "^2.1.2"
Expand Down
2 changes: 1 addition & 1 deletion schema/ai-provider.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"title": "The AI provider",
"description": "The AI provider to use for chat and completion",
"default": "None",
"enum": ["None", "MistralAI"]
"enum": ["None", "MistralAI", "OpenAI"]
},
"apiKey": {
"type": "string",
Expand Down
4 changes: 2 additions & 2 deletions src/completion-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import {
IInlineCompletionContext,
IInlineCompletionProvider
} from '@jupyterlab/completer';
import { LLM } from '@langchain/core/language_models/llms';
import { BaseLLM } from '@langchain/core/language_models/llms';

import { getCompleter, IBaseCompleter, BaseCompleter } from './llm-models';
import { ReadonlyPartialJSONObject } from '@lumino/coreutils';
Expand Down Expand Up @@ -53,7 +53,7 @@ export class CompletionProvider implements IInlineCompletionProvider {
/**
* Get the LLM completer.
*/
get llmCompleter(): LLM | null {
get llmCompleter(): BaseLLM | null {
return this._completer?.provider || null;
}

Expand Down
4 changes: 2 additions & 2 deletions src/llm-models/base-completer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ import {
CompletionHandler,
IInlineCompletionContext
} from '@jupyterlab/completer';
import { LLM } from '@langchain/core/language_models/llms';
import { BaseLLM } from '@langchain/core/language_models/llms';
import { ReadonlyPartialJSONObject } from '@lumino/coreutils';

export interface IBaseCompleter {
/**
* The LLM completer.
*/
provider: LLM;
provider: BaseLLM;

/**
* The fetch request for the LLM completer.
Expand Down
4 changes: 2 additions & 2 deletions src/llm-models/codestral-completer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import {
CompletionHandler,
IInlineCompletionContext
} from '@jupyterlab/completer';
import { LLM } from '@langchain/core/language_models/llms';
import { BaseLLM } from '@langchain/core/language_models/llms';
import { MistralAI } from '@langchain/mistralai';
import { Throttler } from '@lumino/polling';
import { CompletionRequest } from '@mistralai/mistralai';
Expand Down Expand Up @@ -33,7 +33,7 @@ export class CodestralCompleter implements IBaseCompleter {
}, INTERVAL);
}

get provider(): LLM {
get provider(): BaseLLM {
return this._mistralProvider;
}

Expand Down
52 changes: 52 additions & 0 deletions src/llm-models/openai-completer.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import {
CompletionHandler,
IInlineCompletionContext
} from '@jupyterlab/completer';
import { BaseLLM } from '@langchain/core/language_models/llms';
import { OpenAI } from '@langchain/openai';

import { BaseCompleter, IBaseCompleter } from './base-completer';

export class OpenAICompleter implements IBaseCompleter {
constructor(options: BaseCompleter.IOptions) {
this._gptProvider = new OpenAI({ ...options.settings });
}

get provider(): BaseLLM {
return this._gptProvider;
}

async fetch(
request: CompletionHandler.IRequest,
context: IInlineCompletionContext
) {
const { text, offset: cursorOffset } = request;
const prompt = text.slice(0, cursorOffset);
const suffix = text.slice(cursorOffset);

const data = {
prompt,
suffix,
model: this._gptProvider.model,
// temperature: 0,
// top_p: 1,
// max_tokens: 1024,
// min_tokens: 0,
// random_seed: 1337,
stop: []
};

try {
const response = await this._gptProvider.completionWithRetry(data, {});
const items = response.choices.map((choice: any) => {
return { insertText: choice.message.content as string };
});
return items;
} catch (error) {
console.error('Error fetching completions', error);
return { items: [] };
}
}

private _gptProvider: OpenAI;
}
9 changes: 9 additions & 0 deletions src/llm-models/utils.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { ChatMistralAI } from '@langchain/mistralai';
import { ChatOpenAI } from '@langchain/openai';

import { IBaseCompleter } from './base-completer';
import { CodestralCompleter } from './codestral-completer';
import { ReadonlyPartialJSONObject } from '@lumino/coreutils';
import { OpenAICompleter } from './openai-completer';

/**
* Get an LLM completer from the name.
Expand All @@ -13,6 +16,8 @@ export function getCompleter(
): IBaseCompleter | null {
if (name === 'MistralAI') {
return new CodestralCompleter({ settings });
} else if (name === 'OpenAI') {
return new OpenAICompleter({ settings });
}
return null;
}
Expand All @@ -26,6 +31,8 @@ export function getChatModel(
): BaseChatModel | null {
if (name === 'MistralAI') {
return new ChatMistralAI({ ...settings });
} else if (name === 'OpenAI') {
return new ChatOpenAI({ ...settings });
}
return null;
}
Expand All @@ -36,6 +43,8 @@ export function getChatModel(
export function getErrorMessage(name: string, error: any): string {
if (name === 'MistralAI') {
return error.message;
} else if (name === 'OpenAI') {
return error.message;
}
return 'Unknown provider';
}
Loading

0 comments on commit c5ae0a1

Please sign in to comment.