Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OpenAI provider #19

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
"@jupyterlab/settingregistry": "^4.2.0",
"@langchain/core": "^0.3.13",
"@langchain/mistralai": "^0.1.1",
"@langchain/openai": "^0.3.12",
"@lumino/coreutils": "^2.1.2",
"@lumino/polling": "^2.1.2",
"@lumino/signaling": "^2.1.2"
Expand Down
2 changes: 1 addition & 1 deletion schema/ai-provider.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"title": "The AI provider",
"description": "The AI provider to use for chat and completion",
"default": "None",
"enum": ["None", "MistralAI"]
"enum": ["None", "MistralAI", "OpenAI"]
},
"apiKey": {
"type": "string",
Expand Down
4 changes: 2 additions & 2 deletions src/completion-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import {
IInlineCompletionContext,
IInlineCompletionProvider
} from '@jupyterlab/completer';
import { LLM } from '@langchain/core/language_models/llms';
import { BaseLLM } from '@langchain/core/language_models/llms';

import { getCompleter, IBaseCompleter, BaseCompleter } from './llm-models';
import { ReadonlyPartialJSONObject } from '@lumino/coreutils';
Expand Down Expand Up @@ -53,7 +53,7 @@ export class CompletionProvider implements IInlineCompletionProvider {
/**
* Get the LLM completer.
*/
get llmCompleter(): LLM | null {
get llmCompleter(): BaseLLM | null {
return this._completer?.provider || null;
}

Expand Down
4 changes: 2 additions & 2 deletions src/llm-models/base-completer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ import {
CompletionHandler,
IInlineCompletionContext
} from '@jupyterlab/completer';
import { LLM } from '@langchain/core/language_models/llms';
import { BaseLLM } from '@langchain/core/language_models/llms';
import { ReadonlyPartialJSONObject } from '@lumino/coreutils';

export interface IBaseCompleter {
/**
* The LLM completer.
*/
provider: LLM;
provider: BaseLLM;

/**
* The fetch request for the LLM completer.
Expand Down
4 changes: 2 additions & 2 deletions src/llm-models/codestral-completer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import {
CompletionHandler,
IInlineCompletionContext
} from '@jupyterlab/completer';
import { LLM } from '@langchain/core/language_models/llms';
import { BaseLLM } from '@langchain/core/language_models/llms';
import { MistralAI } from '@langchain/mistralai';
import { Throttler } from '@lumino/polling';
import { CompletionRequest } from '@mistralai/mistralai';
Expand Down Expand Up @@ -33,7 +33,7 @@ export class CodestralCompleter implements IBaseCompleter {
}, INTERVAL);
}

get provider(): LLM {
get provider(): BaseLLM {
return this._mistralProvider;
}

Expand Down
52 changes: 52 additions & 0 deletions src/llm-models/openai-completer.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import {
CompletionHandler,
IInlineCompletionContext
} from '@jupyterlab/completer';
import { BaseLLM } from '@langchain/core/language_models/llms';
import { OpenAI } from '@langchain/openai';

import { BaseCompleter, IBaseCompleter } from './base-completer';

export class OpenAICompleter implements IBaseCompleter {
constructor(options: BaseCompleter.IOptions) {
this._gptProvider = new OpenAI({ ...options.settings });
}

get provider(): BaseLLM {
return this._gptProvider;
}

async fetch(
request: CompletionHandler.IRequest,
context: IInlineCompletionContext
) {
const { text, offset: cursorOffset } = request;
const prompt = text.slice(0, cursorOffset);
const suffix = text.slice(cursorOffset);

const data = {
prompt,
suffix,
model: this._gptProvider.model,
// temperature: 0,
// top_p: 1,
// max_tokens: 1024,
// min_tokens: 0,
// random_seed: 1337,
stop: []
};

try {
const response = await this._gptProvider.completionWithRetry(data, {});
const items = response.choices.map((choice: any) => {
return { insertText: choice.message.content as string };
});
return items;
} catch (error) {
console.error('Error fetching completions', error);
return { items: [] };
}
}

private _gptProvider: OpenAI;
}
9 changes: 9 additions & 0 deletions src/llm-models/utils.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { ChatMistralAI } from '@langchain/mistralai';
import { ChatOpenAI } from '@langchain/openai';

import { IBaseCompleter } from './base-completer';
import { CodestralCompleter } from './codestral-completer';
import { ReadonlyPartialJSONObject } from '@lumino/coreutils';
import { OpenAICompleter } from './openai-completer';

/**
* Get an LLM completer from the name.
Expand All @@ -13,6 +16,8 @@ export function getCompleter(
): IBaseCompleter | null {
if (name === 'MistralAI') {
return new CodestralCompleter({ settings });
} else if (name === 'OpenAI') {
return new OpenAICompleter({ settings });
}
return null;
}
Expand All @@ -26,6 +31,8 @@ export function getChatModel(
): BaseChatModel | null {
if (name === 'MistralAI') {
return new ChatMistralAI({ ...settings });
} else if (name === 'OpenAI') {
return new ChatOpenAI({ ...settings });
}
return null;
}
Expand All @@ -36,6 +43,8 @@ export function getChatModel(
export function getErrorMessage(name: string, error: any): string {
if (name === 'MistralAI') {
return error.message;
} else if (name === 'OpenAI') {
return error.message;
}
return 'Unknown provider';
}
Loading
Loading