diff --git a/index.d.ts b/index.d.ts index 51c984e..f41811e 100644 --- a/index.d.ts +++ b/index.d.ts @@ -158,6 +158,14 @@ export interface Message { content: string copilot_references: MessageCopilotReference[] copilot_confirmations?: MessageCopilotConfirmation[] + tool_calls?: { + "function": { + "arguments": string, + "name": string + }, + "id": string, + "type": "function" + }[] name?: string } @@ -251,9 +259,21 @@ export type ModelName = | "gpt-4" | "gpt-3.5-turbo" +export interface PromptFunction { + type: "function" + function: { + name: string; + description?: string; + /** @see https://platform.openai.com/docs/guides/structured-outputs/supported-schemas */ + parameters?: Record; + strict?: boolean | null; + } +} + export type PromptOptions = { model: ModelName token: string + tools?: PromptFunction[] request?: { fetch?: Function } diff --git a/index.test-d.ts b/index.test-d.ts index f8a16bd..267d21c 100644 --- a/index.test-d.ts +++ b/index.test-d.ts @@ -295,4 +295,24 @@ export async function promptTest() { // @ts-expect-error - token argument is required prompt("What is the capital of France?", { model: "" }) +} + +export async function promptWithToolsTest() { + await prompt("What is the capital of France?", { + model: "gpt-4", + token: "secret", + tools: [ + { + type: "function", + function: { + name: "", + description: "", + parameters: { + + }, + strict: true, + } + } + ] + }) } \ No newline at end of file diff --git a/lib/prompt.js b/lib/prompt.js index 5e9bf8d..77bbab9 100644 --- a/lib/prompt.js +++ b/lib/prompt.js @@ -3,6 +3,11 @@ /** @type {import('..').PromptInterface} */ export async function prompt(userPrompt, promptOptions) { const promptFetch = promptOptions.request?.fetch || fetch; + + const systemMessage = promptOptions.tools + ? "You are a helpful assistant. Use the supplied tools to assist the user." + : "You are a helpful assistant."; + const response = await promptFetch( "https://api.githubcopilot.com/chat/completions", { @@ -17,7 +22,7 @@ export async function prompt(userPrompt, promptOptions) { messages: [ { role: "system", - content: "You are a helpful assistant.", + content: systemMessage, }, { role: "user", @@ -25,6 +30,8 @@ export async function prompt(userPrompt, promptOptions) { }, ], model: promptOptions.model, + toolChoice: promptOptions.tools ? "auto" : undefined, + tools: promptOptions.tools, }), } ); diff --git a/test/prompt.test.js b/test/prompt.test.js index 81ba603..fb8d723 100644 --- a/test/prompt.test.js +++ b/test/prompt.test.js @@ -33,7 +33,7 @@ test("minimal usage", async (t) => { content: "What is the capital of France?", }, ], - model: "gpt-4o-mini", + model: "gpt-4", }), }) .reply( @@ -57,7 +57,82 @@ test("minimal usage", async (t) => { const result = await prompt("What is the capital of France?", { token: "secret", - model: "gpt-4o-mini", + model: "gpt-4", + request: { fetch: fetchMock }, + }); + + t.assert.deepEqual(result, { + requestId: "", + message: { + content: "", + }, + }); +}); + +test("function calling", async (t) => { + const mockAgent = new MockAgent(); + function fetchMock(url, opts) { + opts ||= {}; + opts.dispatcher = mockAgent; + return fetch(url, opts); + } + + mockAgent.disableNetConnect(); + const mockPool = mockAgent.get("https://api.githubcopilot.com"); + mockPool + .intercept({ + method: "post", + path: `/chat/completions`, + body: JSON.stringify({ + messages: [ + { + role: "system", + content: + "You are a helpful assistant. Use the supplied tools to assist the user.", + }, + { role: "user", content: "Call the function" }, + ], + model: "gpt-4", + toolChoice: "auto", + tools: [ + { + type: "function", + function: { name: "the_function", description: "The function" }, + }, + ], + }), + }) + .reply( + 200, + { + choices: [ + { + message: { + content: "", + }, + }, + ], + }, + { + headers: { + "content-type": "application/json", + "x-request-id": "", + }, + } + ); + + const result = await prompt("Call the function", { + token: "secret", + model: "gpt-4", + tools: [ + { + type: "function", + function: { + name: "the_function", + description: "The function", + }, + }, + ], request: { fetch: fetchMock }, });