From 0f6097f76e9d69f195e0e3f3f0c42422422fd17b Mon Sep 17 00:00:00 2001 From: Luis Otavio Martins Date: Mon, 16 Sep 2024 03:49:13 -0300 Subject: [PATCH] Groq tool calling (#304) * Support for Groq's tool calling * removed console.log and unused code --- .env.example | 6 ++ package.json | 1 + pnpm-lock.yaml | 81 ++++++++++++++++++++++++ src/clients/open-router.ts | 27 ++++++-- src/clients/tools/tool-calling-models.ts | 5 ++ src/constants.ts | 1 + src/handlers/command/change-llm.ts | 4 ++ 7 files changed, 120 insertions(+), 5 deletions(-) diff --git a/.env.example b/.env.example index 5820dd3..34d3cc6 100644 --- a/.env.example +++ b/.env.example @@ -29,6 +29,12 @@ GOOGLE_API_KEY="" # AIz... # You can get this at https://anthropic.com/ ANTHROPIC_API_KEY="" # sk-... +# ------------------------------ +# Obligatory if you're using Groq's models and want to use tool calling: +# ------------------------------ +# You can get this at https://console.groq.com/keys +GROQ_API_KEY="" # gsk-... + # ------------------------------ # Obligatory if you're using one of OpenRouter models: # ------------------------------ diff --git a/package.json b/package.json index 404f0b6..516c37a 100644 --- a/package.json +++ b/package.json @@ -35,6 +35,7 @@ "@langchain/core": "0.2.32", "@langchain/google-genai": "^0.1.0", "@langchain/google-vertexai-web": "0.0.19", + "@langchain/groq": "^0.1.1", "@langchain/openai": "0.2.1", "@prisma/client": "5.3.1", "@types/common-tags": "^1.8.4", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 240d83b..ac155ab 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -26,6 +26,9 @@ importers: '@langchain/google-vertexai-web': specifier: 0.0.19 version: 0.0.19(langchain@0.2.9(axios@1.6.8)(cheerio@1.0.0-rc.12)(encoding@0.1.13)(fast-xml-parser@4.3.6)(openai@4.39.0(encoding@0.1.13))(web-auth-library@1.0.3)(ws@8.17.0))(openai@4.39.0(encoding@0.1.13))(zod@3.23.5) + '@langchain/groq': + specifier: ^0.1.1 + version: 0.1.1(@langchain/core@0.2.32(langchain@0.2.9(axios@1.6.8)(cheerio@1.0.0-rc.12)(encoding@0.1.13)(fast-xml-parser@4.3.6)(openai@4.39.0(encoding@0.1.13))(web-auth-library@1.0.3)(ws@8.17.0))(openai@4.39.0(encoding@0.1.13)))(encoding@0.1.13) '@langchain/openai': specifier: 0.2.1 version: 0.2.1(encoding@0.1.13)(langchain@0.2.9(axios@1.6.8)(cheerio@1.0.0-rc.12)(encoding@0.1.13)(fast-xml-parser@4.3.6)(openai@4.39.0(encoding@0.1.13))(web-auth-library@1.0.3)(ws@8.17.0)) @@ -567,6 +570,12 @@ packages: resolution: {integrity: sha512-3IqCvDhLDY739NiN0OH59xkNHBfd7aVVsmQH3atTcw6rCAHv9aicOabIQ9z7NpKJJh1KDHl5fcONUBRiq8ufpA==} engines: {node: '>=18'} + '@langchain/groq@0.1.1': + resolution: {integrity: sha512-HVrIudHj1rtnXLqbsRf20EGwL1B5fThv4c9yOgV7ngFd9YvHAwmcH7NbB7N+3j/8PEHSEo69KkiZZd/2YahsRw==} + engines: {node: '>=18'} + peerDependencies: + '@langchain/core': '>=0.2.21 <0.4.0' + '@langchain/openai@0.0.28': resolution: {integrity: sha512-2s1RA3/eAnz4ahdzsMPBna9hfAqpFNlWdHiPxVGZ5yrhXsbLWWoPcF+22LCk9t0HJKtazi2GCIWc0HVXH9Abig==} engines: {node: '>=18'} @@ -579,6 +588,12 @@ packages: resolution: {integrity: sha512-Ti3C6ZIUPaueIPAfMljMnLu3GSGNq5KmrlHeWkIbrLShOBlzj4xj7mRfR73oWgAC0qivfxdkfbB0e+WCY+oRJw==} engines: {node: '>=18'} + '@langchain/openai@0.3.0': + resolution: {integrity: sha512-yXrz5Qn3t9nq3NQAH2l4zZOI4ev2CFdLC5kvmi5SdW4bggRuM40SXTUAY3VRld4I5eocYfk82VbrlA+6dvN5EA==} + engines: {node: '>=18'} + peerDependencies: + '@langchain/core': '>=0.2.26 <0.4.0' + '@langchain/textsplitters@0.0.0': resolution: {integrity: sha512-3hPesWomnmVeYMppEGYbyv0v/sRUugUdlFBNn9m1ueJYHAIKbvCErkWxNUH3guyKKYgJVrkvZoQxcd9faucSaw==} engines: {node: '>=18'} @@ -665,6 +680,9 @@ packages: '@types/qrcode@1.5.5': resolution: {integrity: sha512-CdfBi/e3Qk+3Z/fXYShipBT13OJ2fDO2Q2w5CIP5anLTLIndQG9z6P1cnm+8zCWSpm5dnxMFd/uREtb0EXuQzg==} + '@types/qs@6.9.16': + resolution: {integrity: sha512-7i+zxXdPD0T4cKDuxCUXJ4wHcsJLwENa6Z3dCu8cfCK743OGy5Nu1RmAGqDPsoTDINVEcdXKRvR/zre+P2Ku1A==} + '@types/retry@0.12.0': resolution: {integrity: sha512-wWKOClTTiizcZhXnPY4wikVAwmdYHp8q6DmC+EJUzAMsycb7HB32Kh9RN4+0gExjmPmZSAQjgURXIGATPegAvA==} @@ -1398,6 +1416,9 @@ packages: graceful-fs@4.2.11: resolution: {integrity: sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ==} + groq-sdk@0.5.0: + resolution: {integrity: sha512-RVmhW7qZ+XZoy5fIuSdx/LGQJONpL8MHgZEW7dFwTdgkzStub2XQx6OKv28CHogijdwH41J+Npj/z2jBPu3vmw==} + gtoken@7.1.0: resolution: {integrity: sha512-pCcEwRi+TKpMlxAQObHDQ56KawURgyAf6jtIY046fJ5tIv3zDe/LEIubckAO8fj6JnAxLdmWkUfNyulQ2iKdEw==} engines: {node: '>=14.0.0'} @@ -2237,6 +2258,15 @@ packages: resolution: {integrity: sha512-dgxA6UZHary6NXUHEDj5TWt8ogv0+ibH+b4pT5RrWMjiRZVylNwLcw/2ubDrX5n0oUmHX/ZgudMJeemxzOvz7A==} hasBin: true + openai@4.61.0: + resolution: {integrity: sha512-xkygRBRLIUumxzKGb1ug05pWmJROQsHkGuj/N6Jiw2dj0dI19JvbFpErSZKmJ/DA+0IvpcugZqCAyk8iLpyM6Q==} + hasBin: true + peerDependencies: + zod: ^3.23.8 + peerDependenciesMeta: + zod: + optional: true + openapi-types@12.1.3: resolution: {integrity: sha512-N4YtSYJqghVu4iek2ZUvcN/0aqH1kRDuNqzcycDxhOUpg7GdvLa2F3DgS6yBNhInhv2r/6I0Flkn7CqL8+nIcw==} @@ -3159,6 +3189,16 @@ snapshots: - openai - zod + '@langchain/groq@0.1.1(@langchain/core@0.2.32(langchain@0.2.9(axios@1.6.8)(cheerio@1.0.0-rc.12)(encoding@0.1.13)(fast-xml-parser@4.3.6)(openai@4.39.0(encoding@0.1.13))(web-auth-library@1.0.3)(ws@8.17.0))(openai@4.39.0(encoding@0.1.13)))(encoding@0.1.13)': + dependencies: + '@langchain/core': 0.2.32(langchain@0.2.9(axios@1.6.8)(cheerio@1.0.0-rc.12)(encoding@0.1.13)(fast-xml-parser@4.3.6)(openai@4.39.0(encoding@0.1.13))(web-auth-library@1.0.3)(ws@8.17.0))(openai@4.39.0(encoding@0.1.13)) + '@langchain/openai': 0.3.0(@langchain/core@0.2.32(langchain@0.2.9(axios@1.6.8)(cheerio@1.0.0-rc.12)(encoding@0.1.13)(fast-xml-parser@4.3.6)(openai@4.39.0(encoding@0.1.13))(web-auth-library@1.0.3)(ws@8.17.0))(openai@4.39.0(encoding@0.1.13)))(encoding@0.1.13) + groq-sdk: 0.5.0(encoding@0.1.13) + zod: 3.23.5 + zod-to-json-schema: 3.23.0(zod@3.23.5) + transitivePeerDependencies: + - encoding + '@langchain/openai@0.0.28(encoding@0.1.13)(langchain@0.2.3(axios@1.6.8)(cheerio@1.0.0-rc.12)(encoding@0.1.13)(fast-xml-parser@4.3.6)(openai@4.39.0(encoding@0.1.13))(web-auth-library@1.0.3)(ws@8.17.0))': dependencies: '@langchain/core': 0.1.61(langchain@0.2.3(axios@1.6.8)(cheerio@1.0.0-rc.12)(encoding@0.1.13)(fast-xml-parser@4.3.6)(openai@4.39.0(encoding@0.1.13))(web-auth-library@1.0.3)(ws@8.17.0))(openai@4.52.7(encoding@0.1.13)) @@ -3192,6 +3232,16 @@ snapshots: - encoding - langchain + '@langchain/openai@0.3.0(@langchain/core@0.2.32(langchain@0.2.9(axios@1.6.8)(cheerio@1.0.0-rc.12)(encoding@0.1.13)(fast-xml-parser@4.3.6)(openai@4.39.0(encoding@0.1.13))(web-auth-library@1.0.3)(ws@8.17.0))(openai@4.39.0(encoding@0.1.13)))(encoding@0.1.13)': + dependencies: + '@langchain/core': 0.2.32(langchain@0.2.9(axios@1.6.8)(cheerio@1.0.0-rc.12)(encoding@0.1.13)(fast-xml-parser@4.3.6)(openai@4.39.0(encoding@0.1.13))(web-auth-library@1.0.3)(ws@8.17.0))(openai@4.39.0(encoding@0.1.13)) + js-tiktoken: 1.0.12 + openai: 4.61.0(encoding@0.1.13)(zod@3.23.5) + zod: 3.23.5 + zod-to-json-schema: 3.23.0(zod@3.23.5) + transitivePeerDependencies: + - encoding + '@langchain/textsplitters@0.0.0(langchain@0.2.3(axios@1.6.8)(cheerio@1.0.0-rc.12)(encoding@0.1.13)(fast-xml-parser@4.3.6)(openai@4.39.0(encoding@0.1.13))(web-auth-library@1.0.3)(ws@8.17.0))(openai@4.39.0(encoding@0.1.13))': dependencies: '@langchain/core': 0.1.61(langchain@0.2.3(axios@1.6.8)(cheerio@1.0.0-rc.12)(encoding@0.1.13)(fast-xml-parser@4.3.6)(openai@4.39.0(encoding@0.1.13))(web-auth-library@1.0.3)(ws@8.17.0))(openai@4.39.0(encoding@0.1.13)) @@ -3298,6 +3348,8 @@ snapshots: dependencies: '@types/node': 20.12.7 + '@types/qs@6.9.16': {} + '@types/retry@0.12.0': {} '@types/uuid@10.0.0': {} @@ -4167,6 +4219,19 @@ snapshots: graceful-fs@4.2.11: {} + groq-sdk@0.5.0(encoding@0.1.13): + dependencies: + '@types/node': 18.19.31 + '@types/node-fetch': 2.6.11 + abort-controller: 3.0.0 + agentkeepalive: 4.5.0 + form-data-encoder: 1.7.2 + formdata-node: 4.4.1 + node-fetch: 2.7.0(encoding@0.1.13) + web-streams-polyfill: 3.3.3 + transitivePeerDependencies: + - encoding + gtoken@7.1.0(encoding@0.1.13): dependencies: gaxios: 6.5.0(encoding@0.1.13) @@ -4900,6 +4965,22 @@ snapshots: transitivePeerDependencies: - encoding + openai@4.61.0(encoding@0.1.13)(zod@3.23.5): + dependencies: + '@types/node': 18.19.31 + '@types/node-fetch': 2.6.11 + '@types/qs': 6.9.16 + abort-controller: 3.0.0 + agentkeepalive: 4.5.0 + form-data-encoder: 1.7.2 + formdata-node: 4.4.1 + node-fetch: 2.7.0(encoding@0.1.13) + qs: 6.12.1 + optionalDependencies: + zod: 3.23.5 + transitivePeerDependencies: + - encoding + openapi-types@12.1.3: {} ora@5.4.1: diff --git a/src/clients/open-router.ts b/src/clients/open-router.ts index 5239ab0..bc0ec30 100644 --- a/src/clients/open-router.ts +++ b/src/clients/open-router.ts @@ -17,6 +17,7 @@ import { ANTHROPIC_API_KEY, DEFAULT_MODEL, GOOGLE_API_KEY, + GROQ_API_KEY, MODEL_TEMPERATURE, OPENAI_API_KEY, OPENROUTER_API_KEY, @@ -32,10 +33,12 @@ import { import { anthropicToolCallingModels, googleToolCallingModels, + groqToolCallingModels, openAIToolCallingModels, } from "./tools/tool-calling-models"; import { tools } from "./tools/tools-openrouter"; import { ChatAnthropic } from "@langchain/anthropic"; +import { ChatGroq } from "@langchain/groq"; function parseMessageHistory( rawHistory: { [key: string]: string }[] @@ -122,7 +125,6 @@ export async function createExecutorForOpenRouter( // OpenAI LLM with Tool Calling Agent if (openAIToolCallingModels.includes(llmModel) && OPENAI_API_KEY !== "") { - console.log("Using OpenAI LLM"); prompt = await pull( "luisotee/wa-assistant-tool-calling" ); @@ -145,7 +147,6 @@ export async function createExecutorForOpenRouter( googleToolCallingModels.includes(llmModel) && GOOGLE_API_KEY !== "" ) { - console.log("Using Google Generative AI"); prompt = await pull( "luisotee/wa-assistant-tool-calling" ); @@ -168,7 +169,6 @@ export async function createExecutorForOpenRouter( anthropicToolCallingModels.includes(llmModel) && ANTHROPIC_API_KEY !== "" ) { - console.log("Using Anthropics LLM"); prompt = await pull( "luisotee/wa-assistant-tool-calling" ); @@ -186,9 +186,27 @@ export async function createExecutorForOpenRouter( prompt, }); } + // Groq LLM with Tool Calling Agent + else if (groqToolCallingModels.includes(llmModel) && GROQ_API_KEY !== "") { + prompt = await pull( + "luisotee/wa-assistant-tool-calling" + ); + + llm = new ChatGroq({ + modelName: llmModel, + streaming: true, + temperature: MODEL_TEMPERATURE, + apiKey: GROQ_API_KEY, + }); + + agent = await createToolCallingAgent({ + llm, + tools, + prompt, + }); + } // OpenRouter LLMs without Tool Calling Agent, with Structured Agent else { - console.log("Using OpenRouter LLM"); prompt = await pull("luisotee/wa-assistant"); llm = new ChatOpenAI( @@ -211,7 +229,6 @@ export async function createExecutorForOpenRouter( } const executor = new AgentExecutor({ - // @ts-ignore agent, tools, memory, diff --git a/src/clients/tools/tool-calling-models.ts b/src/clients/tools/tool-calling-models.ts index 3043072..d85eb61 100644 --- a/src/clients/tools/tool-calling-models.ts +++ b/src/clients/tools/tool-calling-models.ts @@ -13,3 +13,8 @@ export const anthropicToolCallingModels = [ "claude-3-sonnet-20240229", "claude-3-haiku-20240307", ]; + +export const groqToolCallingModels = [ + "llama3-groq-70b-8192-tool-use-preview", + "llava-v1.5-7b-4096-preview", +]; diff --git a/src/constants.ts b/src/constants.ts index 56655bd..8cf0b27 100644 --- a/src/constants.ts +++ b/src/constants.ts @@ -65,3 +65,4 @@ export const MODEL_TEMPERATURE = parseFloat( export const ENABLE_GOOGLE_ROUTES = process.env.ENABLE_GOOGLE_ROUTES as string; export const GOOGLE_API_KEY = process.env.GOOGLE_API_KEY as string; export const ANTHROPIC_API_KEY = process.env.ANTHROPIC_API_KEY as string; +export const GROQ_API_KEY = process.env.GROQ_API_KEY as string; diff --git a/src/handlers/command/change-llm.ts b/src/handlers/command/change-llm.ts index 418e31e..8cd8ac9 100644 --- a/src/handlers/command/change-llm.ts +++ b/src/handlers/command/change-llm.ts @@ -15,6 +15,8 @@ const LLM_OPTIONS = { "8": "claude-3-5-sonnet-20240620", "9": "claude-3-opus-20240229", "10": "claude-3-haiku-20240307", + "11": "llama3-groq-70b-8192-tool-use-preview", + "12": "llava-v1.5-7b-4096-preview", }; export async function handleChangeLLM(message: Message, args: string) { @@ -45,6 +47,8 @@ export async function handleChangeLLM(message: Message, args: string) { *${CMD_PREFIX}change 8* for _claude-3-5-sonnet_ (Anthropic API) *${CMD_PREFIX}change 9* for _claude-3-opus_ (Anthropic API) *${CMD_PREFIX}change 10* for _claude-3-haiku_ (Anthropic API) + *${CMD_PREFIX}change 11* for _llama3-groq-70b_ (Groq API) + *${CMD_PREFIX}change 12* for _llava-v1.5-7b_ (Groq API) You can also type the name of your desired model supported by OpenRouter, like *${CMD_PREFIX}change mistralai/mixtral-8x7b-instruct*