diff --git a/package.json b/package.json index 539736a..0e11ed0 100644 --- a/package.json +++ b/package.json @@ -62,6 +62,7 @@ "@jupyterlab/settingregistry": "^4.2.0", "@langchain/core": "^0.3.13", "@langchain/mistralai": "^0.1.1", + "@langchain/openai": "^0.3.12", "@lumino/coreutils": "^2.1.2", "@lumino/polling": "^2.1.2", "@lumino/signaling": "^2.1.2" diff --git a/schema/ai-provider.json b/schema/ai-provider.json index d4b9a04..77aac60 100644 --- a/schema/ai-provider.json +++ b/schema/ai-provider.json @@ -8,7 +8,7 @@ "title": "The AI provider", "description": "The AI provider to use for chat and completion", "default": "None", - "enum": ["None", "MistralAI"] + "enum": ["None", "MistralAI", "OpenAI"] }, "apiKey": { "type": "string", diff --git a/src/completion-provider.ts b/src/completion-provider.ts index e7000c5..8b9e8f8 100644 --- a/src/completion-provider.ts +++ b/src/completion-provider.ts @@ -3,7 +3,7 @@ import { IInlineCompletionContext, IInlineCompletionProvider } from '@jupyterlab/completer'; -import { LLM } from '@langchain/core/language_models/llms'; +import { BaseLLM } from '@langchain/core/language_models/llms'; import { getCompleter, IBaseCompleter, BaseCompleter } from './llm-models'; import { ReadonlyPartialJSONObject } from '@lumino/coreutils'; @@ -53,7 +53,7 @@ export class CompletionProvider implements IInlineCompletionProvider { /** * Get the LLM completer. */ - get llmCompleter(): LLM | null { + get llmCompleter(): BaseLLM | null { return this._completer?.provider || null; } diff --git a/src/llm-models/base-completer.ts b/src/llm-models/base-completer.ts index fb84f4f..91e396a 100644 --- a/src/llm-models/base-completer.ts +++ b/src/llm-models/base-completer.ts @@ -2,14 +2,14 @@ import { CompletionHandler, IInlineCompletionContext } from '@jupyterlab/completer'; -import { LLM } from '@langchain/core/language_models/llms'; +import { BaseLLM } from '@langchain/core/language_models/llms'; import { ReadonlyPartialJSONObject } from '@lumino/coreutils'; export interface IBaseCompleter { /** * The LLM completer. */ - provider: LLM; + provider: BaseLLM; /** * The fetch request for the LLM completer. diff --git a/src/llm-models/codestral-completer.ts b/src/llm-models/codestral-completer.ts index efa7934..087879f 100644 --- a/src/llm-models/codestral-completer.ts +++ b/src/llm-models/codestral-completer.ts @@ -2,7 +2,7 @@ import { CompletionHandler, IInlineCompletionContext } from '@jupyterlab/completer'; -import { LLM } from '@langchain/core/language_models/llms'; +import { BaseLLM } from '@langchain/core/language_models/llms'; import { MistralAI } from '@langchain/mistralai'; import { Throttler } from '@lumino/polling'; import { CompletionRequest } from '@mistralai/mistralai'; @@ -33,7 +33,7 @@ export class CodestralCompleter implements IBaseCompleter { }, INTERVAL); } - get provider(): LLM { + get provider(): BaseLLM { return this._mistralProvider; } diff --git a/src/llm-models/openai-completer.ts b/src/llm-models/openai-completer.ts new file mode 100644 index 0000000..731c0a5 --- /dev/null +++ b/src/llm-models/openai-completer.ts @@ -0,0 +1,52 @@ +import { + CompletionHandler, + IInlineCompletionContext +} from '@jupyterlab/completer'; +import { BaseLLM } from '@langchain/core/language_models/llms'; +import { OpenAI } from '@langchain/openai'; + +import { BaseCompleter, IBaseCompleter } from './base-completer'; + +export class OpenAICompleter implements IBaseCompleter { + constructor(options: BaseCompleter.IOptions) { + this._gptProvider = new OpenAI({ ...options.settings }); + } + + get provider(): BaseLLM { + return this._gptProvider; + } + + async fetch( + request: CompletionHandler.IRequest, + context: IInlineCompletionContext + ) { + const { text, offset: cursorOffset } = request; + const prompt = text.slice(0, cursorOffset); + const suffix = text.slice(cursorOffset); + + const data = { + prompt, + suffix, + model: this._gptProvider.model, + // temperature: 0, + // top_p: 1, + // max_tokens: 1024, + // min_tokens: 0, + // random_seed: 1337, + stop: [] + }; + + try { + const response = await this._gptProvider.completionWithRetry(data, {}); + const items = response.choices.map((choice: any) => { + return { insertText: choice.message.content as string }; + }); + return items; + } catch (error) { + console.error('Error fetching completions', error); + return { items: [] }; + } + } + + private _gptProvider: OpenAI; +} diff --git a/src/llm-models/utils.ts b/src/llm-models/utils.ts index 544d684..2080fe3 100644 --- a/src/llm-models/utils.ts +++ b/src/llm-models/utils.ts @@ -1,8 +1,11 @@ import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import { ChatMistralAI } from '@langchain/mistralai'; +import { ChatOpenAI } from '@langchain/openai'; + import { IBaseCompleter } from './base-completer'; import { CodestralCompleter } from './codestral-completer'; import { ReadonlyPartialJSONObject } from '@lumino/coreutils'; +import { OpenAICompleter } from './openai-completer'; /** * Get an LLM completer from the name. @@ -13,6 +16,8 @@ export function getCompleter( ): IBaseCompleter | null { if (name === 'MistralAI') { return new CodestralCompleter({ settings }); + } else if (name === 'OpenAI') { + return new OpenAICompleter({ settings }); } return null; } @@ -26,6 +31,8 @@ export function getChatModel( ): BaseChatModel | null { if (name === 'MistralAI') { return new ChatMistralAI({ ...settings }); + } else if (name === 'OpenAI') { + return new ChatOpenAI({ ...settings }); } return null; } @@ -36,6 +43,8 @@ export function getChatModel( export function getErrorMessage(name: string, error: any): string { if (name === 'MistralAI') { return error.message; + } else if (name === 'OpenAI') { + return error.message; } return 'Unknown provider'; } diff --git a/yarn.lock b/yarn.lock index 47e6599..4f6e8f1 100644 --- a/yarn.lock +++ b/yarn.lock @@ -1534,6 +1534,20 @@ __metadata: languageName: node linkType: hard +"@langchain/openai@npm:^0.3.12": + version: 0.3.12 + resolution: "@langchain/openai@npm:0.3.12" + dependencies: + js-tiktoken: ^1.0.12 + openai: ^4.71.0 + zod: ^3.22.4 + zod-to-json-schema: ^3.22.3 + peerDependencies: + "@langchain/core": ">=0.2.26 <0.4.0" + checksum: b5050b82233dd0429f4c8e133962c3adaa8ac57f5298aa270e8df91184188c2b7fc233f017546cdf28c82a47d9309d776bb1cacebc929db4a5832d2df1e94554 + languageName: node + linkType: hard + "@lezer/common@npm:^1.0.0, @lezer/common@npm:^1.0.2, @lezer/common@npm:^1.1.0, @lezer/common@npm:^1.2.0, @lezer/common@npm:^1.2.1": version: 1.2.1 resolution: "@lezer/common@npm:1.2.1" @@ -2272,6 +2286,16 @@ __metadata: languageName: node linkType: hard +"@types/node-fetch@npm:^2.6.4": + version: 2.6.11 + resolution: "@types/node-fetch@npm:2.6.11" + dependencies: + "@types/node": "*" + form-data: ^4.0.0 + checksum: 180e4d44c432839bdf8a25251ef8c47d51e37355ddd78c64695225de8bc5dc2b50b7bb855956d471c026bb84bd7295688a0960085e7158cbbba803053492568b + languageName: node + linkType: hard + "@types/node@npm:*": version: 20.14.2 resolution: "@types/node@npm:20.14.2" @@ -2281,6 +2305,15 @@ __metadata: languageName: node linkType: hard +"@types/node@npm:^18.11.18": + version: 18.19.64 + resolution: "@types/node@npm:18.19.64" + dependencies: + undici-types: ~5.26.4 + checksum: e7680215b03c9bee8a33947f03d06048e8e460f23b1b7b29c45350cf437faa5f8fcb7d8c3eb8dfec8427923e7a184df42bc710c1b6252b4852e3ed7064c6228f + languageName: node + linkType: hard + "@types/normalize-package-data@npm:^2.4.0": version: 2.4.4 resolution: "@types/normalize-package-data@npm:2.4.4" @@ -2705,6 +2738,15 @@ __metadata: languageName: node linkType: hard +"abort-controller@npm:^3.0.0": + version: 3.0.0 + resolution: "abort-controller@npm:3.0.0" + dependencies: + event-target-shim: ^5.0.0 + checksum: 170bdba9b47b7e65906a28c8ce4f38a7a369d78e2271706f020849c1bfe0ee2067d4261df8bbb66eb84f79208fd5b710df759d64191db58cfba7ce8ef9c54b75 + languageName: node + linkType: hard + "acorn-import-assertions@npm:^1.9.0": version: 1.9.0 resolution: "acorn-import-assertions@npm:1.9.0" @@ -2732,6 +2774,15 @@ __metadata: languageName: node linkType: hard +"agentkeepalive@npm:^4.2.1": + version: 4.5.0 + resolution: "agentkeepalive@npm:4.5.0" + dependencies: + humanize-ms: ^1.2.1 + checksum: 13278cd5b125e51eddd5079f04d6fe0914ac1b8b91c1f3db2c1822f99ac1a7457869068997784342fe455d59daaff22e14fb7b8c3da4e741896e7e31faf92481 + languageName: node + linkType: hard + "ajv-formats@npm:^2.1.1": version: 2.1.1 resolution: "ajv-formats@npm:2.1.1" @@ -2890,6 +2941,13 @@ __metadata: languageName: node linkType: hard +"asynckit@npm:^0.4.0": + version: 0.4.0 + resolution: "asynckit@npm:0.4.0" + checksum: 7b78c451df768adba04e2d02e63e2d0bf3b07adcd6e42b4cf665cb7ce899bedd344c69a1dcbce355b5f972d597b25aaa1c1742b52cffd9caccb22f348114f6be + languageName: node + linkType: hard + "available-typed-arrays@npm:^1.0.7": version: 1.0.7 resolution: "available-typed-arrays@npm:1.0.7" @@ -3125,6 +3183,15 @@ __metadata: languageName: node linkType: hard +"combined-stream@npm:^1.0.8": + version: 1.0.8 + resolution: "combined-stream@npm:1.0.8" + dependencies: + delayed-stream: ~1.0.0 + checksum: 49fa4aeb4916567e33ea81d088f6584749fc90c7abec76fd516bf1c5aa5c79f3584b5ba3de6b86d26ddd64bae5329c4c7479343250cfe71c75bb366eae53bb7c + languageName: node + linkType: hard + "commander@npm:^10.0.1": version: 10.0.1 resolution: "commander@npm:10.0.1" @@ -3424,6 +3491,13 @@ __metadata: languageName: node linkType: hard +"delayed-stream@npm:~1.0.0": + version: 1.0.0 + resolution: "delayed-stream@npm:1.0.0" + checksum: 46fe6e83e2cb1d85ba50bd52803c68be9bd953282fa7096f51fc29edd5d67ff84ff753c51966061e5ba7cb5e47ef6d36a91924eddb7f3f3483b1c560f77a0020 + languageName: node + linkType: hard + "dir-glob@npm:^3.0.1": version: 3.0.1 resolution: "dir-glob@npm:3.0.1" @@ -3857,6 +3931,13 @@ __metadata: languageName: node linkType: hard +"event-target-shim@npm:^5.0.0": + version: 5.0.1 + resolution: "event-target-shim@npm:5.0.1" + checksum: 1ffe3bb22a6d51bdeb6bf6f7cf97d2ff4a74b017ad12284cc9e6a279e727dc30a5de6bb613e5596ff4dc3e517841339ad09a7eec44266eccb1aa201a30448166 + languageName: node + linkType: hard + "eventemitter3@npm:^4.0.4": version: 4.0.7 resolution: "eventemitter3@npm:4.0.7" @@ -4035,6 +4116,34 @@ __metadata: languageName: node linkType: hard +"form-data-encoder@npm:1.7.2": + version: 1.7.2 + resolution: "form-data-encoder@npm:1.7.2" + checksum: aeebd87a1cb009e13cbb5e4e4008e6202ed5f6551eb6d9582ba8a062005178907b90f4887899d3c993de879159b6c0c940af8196725b428b4248cec5af3acf5f + languageName: node + linkType: hard + +"form-data@npm:^4.0.0": + version: 4.0.1 + resolution: "form-data@npm:4.0.1" + dependencies: + asynckit: ^0.4.0 + combined-stream: ^1.0.8 + mime-types: ^2.1.12 + checksum: ccee458cd5baf234d6b57f349fe9cc5f9a2ea8fd1af5ecda501a18fd1572a6dd3bf08a49f00568afd995b6a65af34cb8dec083cf9d582c4e621836499498dd84 + languageName: node + linkType: hard + +"formdata-node@npm:^4.3.2": + version: 4.4.1 + resolution: "formdata-node@npm:4.4.1" + dependencies: + node-domexception: 1.0.0 + web-streams-polyfill: 4.0.0-beta.3 + checksum: d91d4f667cfed74827fc281594102c0dabddd03c9f8b426fc97123eedbf73f5060ee43205d89284d6854e2fc5827e030cd352ef68b93beda8decc2d72128c576 + languageName: node + linkType: hard + "free-style@npm:3.1.0": version: 3.1.0 resolution: "free-style@npm:3.1.0" @@ -4381,6 +4490,15 @@ __metadata: languageName: node linkType: hard +"humanize-ms@npm:^1.2.1": + version: 1.2.1 + resolution: "humanize-ms@npm:1.2.1" + dependencies: + ms: ^2.0.0 + checksum: 9c7a74a2827f9294c009266c82031030eae811ca87b0da3dceb8d6071b9bde22c9f3daef0469c3c533cc67a97d8a167cd9fc0389350e5f415f61a79b171ded16 + languageName: node + linkType: hard + "iconv-lite@npm:^0.6.2": version: 0.6.3 resolution: "iconv-lite@npm:0.6.3" @@ -4885,6 +5003,7 @@ __metadata: "@jupyterlab/settingregistry": ^4.2.0 "@langchain/core": ^0.3.13 "@langchain/mistralai": ^0.1.1 + "@langchain/openai": ^0.3.12 "@lumino/coreutils": ^2.1.2 "@lumino/polling": ^2.1.2 "@lumino/signaling": ^2.1.2 @@ -5209,7 +5328,7 @@ __metadata: languageName: node linkType: hard -"mime-types@npm:^2.1.27": +"mime-types@npm:^2.1.12, mime-types@npm:^2.1.27": version: 2.1.35 resolution: "mime-types@npm:2.1.35" dependencies: @@ -5305,6 +5424,13 @@ __metadata: languageName: node linkType: hard +"ms@npm:^2.0.0": + version: 2.1.3 + resolution: "ms@npm:2.1.3" + checksum: aa92de608021b242401676e35cfa5aa42dd70cbdc082b916da7fb925c542173e36bce97ea3e804923fe92c0ad991434e4a38327e15a1b5b5f945d66df615ae6d + languageName: node + linkType: hard + "mustache@npm:^4.2.0": version: 4.2.0 resolution: "mustache@npm:4.2.0" @@ -5344,6 +5470,13 @@ __metadata: languageName: node linkType: hard +"node-domexception@npm:1.0.0": + version: 1.0.0 + resolution: "node-domexception@npm:1.0.0" + checksum: ee1d37dd2a4eb26a8a92cd6b64dfc29caec72bff5e1ed9aba80c294f57a31ba4895a60fd48347cf17dd6e766da0ae87d75657dfd1f384ebfa60462c2283f5c7f + languageName: node + linkType: hard + "node-fetch@npm:^2.6.7": version: 2.7.0 resolution: "node-fetch@npm:2.7.0" @@ -5459,6 +5592,28 @@ __metadata: languageName: node linkType: hard +"openai@npm:^4.71.0": + version: 4.71.1 + resolution: "openai@npm:4.71.1" + dependencies: + "@types/node": ^18.11.18 + "@types/node-fetch": ^2.6.4 + abort-controller: ^3.0.0 + agentkeepalive: ^4.2.1 + form-data-encoder: 1.7.2 + formdata-node: ^4.3.2 + node-fetch: ^2.6.7 + peerDependencies: + zod: ^3.23.8 + peerDependenciesMeta: + zod: + optional: true + bin: + openai: bin/cli + checksum: 0021ae9122550e0479be3f28d2c7891a4b1a103db68d75b1d6a8117cfe73e0e0ce181b85065f8877c86ee4f9223b4104ca04c6cac2f830ef4c56a7a899031bc8 + languageName: node + linkType: hard + "optionator@npm:^0.9.3": version: 0.9.4 resolution: "optionator@npm:0.9.4" @@ -7183,6 +7338,13 @@ __metadata: languageName: node linkType: hard +"web-streams-polyfill@npm:4.0.0-beta.3": + version: 4.0.0-beta.3 + resolution: "web-streams-polyfill@npm:4.0.0-beta.3" + checksum: dfec1fbf52b9140e4183a941e380487b6c3d5d3838dd1259be81506c1c9f2abfcf5aeb670aeeecfd9dff4271a6d8fef931b193c7bedfb42542a3b05ff36c0d16 + languageName: node + linkType: hard + "webidl-conversions@npm:^3.0.0": version: 3.0.1 resolution: "webidl-conversions@npm:3.0.1"