Skip to content

Commit

Permalink
Add function support to LLaMA and mixtral (Stage 1)
Browse files Browse the repository at this point in the history
  • Loading branch information
Erisfiregamer1 committed May 5, 2024
1 parent b1bf378 commit e3fb9f4
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 4 deletions.
82 changes: 80 additions & 2 deletions bots/llama_3.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,30 @@
import * as types from "../main.d.ts";

const tools: types.Tool[] = [{
type: "function",
function: {
name: "sanitycheck",
description:
"Returns true, as a sanity check to make sure function support is OK. If this fails, something's fucked.",
parameters: {
type: "object",
properties: {
useless: {
type: "string",
description:
"You can put whatever here, it's not going to do anything.",
},
},
required: ["useless"],
},
},
}];

export const information: types.information = {
llmFileVersion: "1.0",
env: ["GROQ_API_KEY"],
functions: false,
functions: true,
functionsData: tools,
multiModal: false,
callbackSupport: true,
streamingSupport: false,
Expand All @@ -15,6 +36,52 @@ export const information: types.information = {

// const db = await Deno.openKv("./db.sqlite")

async function doTools(
res: types.Response,
callback?:
| ((information: types.callbackData, complete: boolean) => void)
| null,
requirements?: types.Requirements,
): Promise<types.Response> {
if (res.choices[0].finish_reason !== "tool_calls") {
throw "What The Shit?";
}

const toolCalls = res.choices[0].message.tool_calls!;

// What if they happen to use it later?
// deno-lint-ignore require-await
const promises = toolCalls.map(async (tool) => {
if (tool.function.name === "sanitycheck") {
return {
role: "tool",
content: "true",
tool_call_id: tool.id,
};
} else {
return {
role: "tool",
content: "Unknown tool or not implemented",
tool_call_id: tool.id,
//};
};
}
});

// Use Promise.all to wait for all promises to resolve
const results = await Promise.all(promises);

results.forEach((result) => {
res.messages.push(result);
});

const newres = await send(null, res.messages, callback, requirements);

console.log(newres);

return newres;
}

export async function send(
prompt: string | null,
messages: types.Message[],
Expand Down Expand Up @@ -57,7 +124,7 @@ export async function send(
}),
});

const resp: types.Response = await res.json();
let resp: types.Response = await res.json();

if (resp.error) {
// Fuck.
Expand All @@ -68,6 +135,17 @@ export async function send(

resp.messages = messages;

if (resp.choices[0].finish_reason === "tool_calls") {
if (callback) {
callback({
toolCalls: resp.choices[0].message.tool_calls,
data: resp.choices[0].message.content,
}, false);
}
resp = await doTools(resp, null, requirements);
resp.choices[0].message.content = resp.choices[0].message.content as string;
}

if (callback) callback({ data: resp.choices[0].message.content }, true);

return resp;
Expand Down
82 changes: 80 additions & 2 deletions bots/mixtral.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,30 @@
import * as types from "../main.d.ts";

const tools: types.Tool[] = [{
type: "function",
function: {
name: "sanitycheck",
description:
"Returns true, as a sanity check to make sure function support is OK. If this fails, something's fucked.",
parameters: {
type: "object",
properties: {
useless: {
type: "string",
description:
"You can put whatever here, it's not going to do anything.",
},
},
required: ["useless"],
},
},
}];

export const information: types.information = {
llmFileVersion: "1.0",
env: ["GROQ_API_KEY"],
functions: false,
functions: true,
functionsData: tools,
multiModal: false,
callbackSupport: true,
streamingSupport: false,
Expand All @@ -15,6 +36,52 @@ export const information: types.information = {

// const db = await Deno.openKv("./db.sqlite")

async function doTools(
res: types.Response,
callback?:
| ((information: types.callbackData, complete: boolean) => void)
| null,
requirements?: types.Requirements,
): Promise<types.Response> {
if (res.choices[0].finish_reason !== "tool_calls") {
throw "What The Shit?";
}

const toolCalls = res.choices[0].message.tool_calls!;

// What if they happen to use it later?
// deno-lint-ignore require-await
const promises = toolCalls.map(async (tool) => {
if (tool.function.name === "sanitycheck") {
return {
role: "tool",
content: "true",
tool_call_id: tool.id,
};
} else {
return {
role: "tool",
content: "Unknown tool or not implemented",
tool_call_id: tool.id,
//};
};
}
});

// Use Promise.all to wait for all promises to resolve
const results = await Promise.all(promises);

results.forEach((result) => {
res.messages.push(result);
});

const newres = await send(null, res.messages, callback, requirements);

console.log(newres);

return newres;
}

export async function send(
prompt: string | null,
messages: types.Message[],
Expand Down Expand Up @@ -57,7 +124,7 @@ export async function send(
}),
});

const resp: types.Response = await res.json();
let resp: types.Response = await res.json();

if (resp.error) {
// Fuck.
Expand All @@ -68,6 +135,17 @@ export async function send(

resp.messages = messages;

if (resp.choices[0].finish_reason === "tool_calls") {
if (callback) {
callback({
toolCalls: resp.choices[0].message.tool_calls,
data: resp.choices[0].message.content,
}, false);
}
resp = await doTools(resp, null, requirements);
resp.choices[0].message.content = resp.choices[0].message.content as string;
}

if (callback) callback({ data: resp.choices[0].message.content }, true);

return resp;
Expand Down

0 comments on commit e3fb9f4

Please sign in to comment.