From e3fb9f44d3f009ed6413cbccdf70ac94fc868747 Mon Sep 17 00:00:00 2001 From: ErisWS Date: Sun, 5 May 2024 09:06:57 -0400 Subject: [PATCH] Add function support to LLaMA and mixtral (Stage 1) --- bots/llama_3.ts | 82 +++++++++++++++++++++++++++++++++++++++++++++++-- bots/mixtral.ts | 82 +++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 160 insertions(+), 4 deletions(-) diff --git a/bots/llama_3.ts b/bots/llama_3.ts index 53b45df..fd4102c 100644 --- a/bots/llama_3.ts +++ b/bots/llama_3.ts @@ -1,9 +1,30 @@ import * as types from "../main.d.ts"; +const tools: types.Tool[] = [{ + type: "function", + function: { + name: "sanitycheck", + description: + "Returns true, as a sanity check to make sure function support is OK. If this fails, something's fucked.", + parameters: { + type: "object", + properties: { + useless: { + type: "string", + description: + "You can put whatever here, it's not going to do anything.", + }, + }, + required: ["useless"], + }, + }, +}]; + export const information: types.information = { llmFileVersion: "1.0", env: ["GROQ_API_KEY"], - functions: false, + functions: true, + functionsData: tools, multiModal: false, callbackSupport: true, streamingSupport: false, @@ -15,6 +36,52 @@ export const information: types.information = { // const db = await Deno.openKv("./db.sqlite") +async function doTools( + res: types.Response, + callback?: + | ((information: types.callbackData, complete: boolean) => void) + | null, + requirements?: types.Requirements, +): Promise { + if (res.choices[0].finish_reason !== "tool_calls") { + throw "What The Shit?"; + } + + const toolCalls = res.choices[0].message.tool_calls!; + + // What if they happen to use it later? + // deno-lint-ignore require-await + const promises = toolCalls.map(async (tool) => { + if (tool.function.name === "sanitycheck") { + return { + role: "tool", + content: "true", + tool_call_id: tool.id, + }; + } else { + return { + role: "tool", + content: "Unknown tool or not implemented", + tool_call_id: tool.id, + //}; + }; + } + }); + + // Use Promise.all to wait for all promises to resolve + const results = await Promise.all(promises); + + results.forEach((result) => { + res.messages.push(result); + }); + + const newres = await send(null, res.messages, callback, requirements); + + console.log(newres); + + return newres; +} + export async function send( prompt: string | null, messages: types.Message[], @@ -57,7 +124,7 @@ export async function send( }), }); - const resp: types.Response = await res.json(); + let resp: types.Response = await res.json(); if (resp.error) { // Fuck. @@ -68,6 +135,17 @@ export async function send( resp.messages = messages; + if (resp.choices[0].finish_reason === "tool_calls") { + if (callback) { + callback({ + toolCalls: resp.choices[0].message.tool_calls, + data: resp.choices[0].message.content, + }, false); + } + resp = await doTools(resp, null, requirements); + resp.choices[0].message.content = resp.choices[0].message.content as string; + } + if (callback) callback({ data: resp.choices[0].message.content }, true); return resp; diff --git a/bots/mixtral.ts b/bots/mixtral.ts index 54a7850..b873d0d 100644 --- a/bots/mixtral.ts +++ b/bots/mixtral.ts @@ -1,9 +1,30 @@ import * as types from "../main.d.ts"; +const tools: types.Tool[] = [{ + type: "function", + function: { + name: "sanitycheck", + description: + "Returns true, as a sanity check to make sure function support is OK. If this fails, something's fucked.", + parameters: { + type: "object", + properties: { + useless: { + type: "string", + description: + "You can put whatever here, it's not going to do anything.", + }, + }, + required: ["useless"], + }, + }, +}]; + export const information: types.information = { llmFileVersion: "1.0", env: ["GROQ_API_KEY"], - functions: false, + functions: true, + functionsData: tools, multiModal: false, callbackSupport: true, streamingSupport: false, @@ -15,6 +36,52 @@ export const information: types.information = { // const db = await Deno.openKv("./db.sqlite") +async function doTools( + res: types.Response, + callback?: + | ((information: types.callbackData, complete: boolean) => void) + | null, + requirements?: types.Requirements, +): Promise { + if (res.choices[0].finish_reason !== "tool_calls") { + throw "What The Shit?"; + } + + const toolCalls = res.choices[0].message.tool_calls!; + + // What if they happen to use it later? + // deno-lint-ignore require-await + const promises = toolCalls.map(async (tool) => { + if (tool.function.name === "sanitycheck") { + return { + role: "tool", + content: "true", + tool_call_id: tool.id, + }; + } else { + return { + role: "tool", + content: "Unknown tool or not implemented", + tool_call_id: tool.id, + //}; + }; + } + }); + + // Use Promise.all to wait for all promises to resolve + const results = await Promise.all(promises); + + results.forEach((result) => { + res.messages.push(result); + }); + + const newres = await send(null, res.messages, callback, requirements); + + console.log(newres); + + return newres; +} + export async function send( prompt: string | null, messages: types.Message[], @@ -57,7 +124,7 @@ export async function send( }), }); - const resp: types.Response = await res.json(); + let resp: types.Response = await res.json(); if (resp.error) { // Fuck. @@ -68,6 +135,17 @@ export async function send( resp.messages = messages; + if (resp.choices[0].finish_reason === "tool_calls") { + if (callback) { + callback({ + toolCalls: resp.choices[0].message.tool_calls, + data: resp.choices[0].message.content, + }, false); + } + resp = await doTools(resp, null, requirements); + resp.choices[0].message.content = resp.choices[0].message.content as string; + } + if (callback) callback({ data: resp.choices[0].message.content }, true); return resp;