From c1ddac0c2ff0e55fb6f8d76d72396da8f6bde34e Mon Sep 17 00:00:00 2001 From: Muspi Merol Date: Wed, 27 Mar 2024 16:28:14 +0800 Subject: [PATCH] feat: integrate with `Groq` provider chore: update deps --- pyproject.toml | 4 ++-- src/routes/run.py | 3 ++- src/utils/config.py | 1 + src/utils/llm/__init__.py | 1 + src/utils/llm/groq.py | 44 +++++++++++++++++++++++++++++++++++++++ src/utils/llm/octoai.py | 1 - 6 files changed, 50 insertions(+), 4 deletions(-) create mode 100644 src/utils/llm/groq.py diff --git a/pyproject.toml b/pyproject.toml index e203e449..17cefe52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ readme = "README.md" license = { text = "MIT" } dependencies = [ "fastapi~=0.110.0", - "uvicorn[standard]~=0.28.0", + "uvicorn[standard]~=0.29.0", "promplate[all]~=0.3.3.4", "promplate-trace[langfuse,langsmith]==0.3.0dev2", "python-box~=7.1.1", @@ -17,7 +17,7 @@ dependencies = [ "beautifulsoup4~=4.12.3", "rich~=13.7.1", "zhipuai~=2.0.1", - "anthropic~=0.20.0", + "anthropic~=0.21.3", "dashscope~=1.15.0", ] diff --git a/src/routes/run.py b/src/routes/run.py index 84d9bac1..886da9dd 100644 --- a/src/routes/run.py +++ b/src/routes/run.py @@ -36,7 +36,8 @@ class Msg(BaseModel): "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", - "mixtral-8x7b-instruct-fp16", + "gemma-7b-it", + "mixtral-8x7b-32768", "nous-hermes-2-mixtral-8x7b-dpo", "qwen-turbo", "qwen-plus", diff --git a/src/utils/config.py b/src/utils/config.py index 95181399..96a29bf5 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -11,6 +11,7 @@ class Config(BaseSettings): openai_base_url: str = "" octoai_api_key: str = "" zhipu_api_key: str = "" + groq_api_key: str = "" # other services serper_api_key: str = "" diff --git a/src/utils/llm/__init__.py b/src/utils/llm/__init__.py index 810ca28b..5b7dcd37 100644 --- a/src/utils/llm/__init__.py +++ b/src/utils/llm/__init__.py @@ -1,6 +1,7 @@ from .anthropic import anthropic from .chatglm import glm from .dispatch import find_llm +from .groq import groq from .minimax import minimax from .octoai import octoai from .openai import openai diff --git a/src/utils/llm/groq.py b/src/utils/llm/groq.py new file mode 100644 index 00000000..b3a3a891 --- /dev/null +++ b/src/utils/llm/groq.py @@ -0,0 +1,44 @@ +from promplate import Message +from promplate.llm.base import AsyncComplete, AsyncGenerate +from promplate.llm.openai import AsyncChatComplete, AsyncChatGenerate, AsyncChatOpenAI +from promplate_trace.auto import patch + +from ..config import env +from .common import client +from .dispatch import link_llm + +GROQ_BASE_URL = "https://api.groq.com/openai/v1" + +complete: AsyncComplete = AsyncChatComplete(http_client=client, base_url=GROQ_BASE_URL, api_key=env.groq_api_key) +generate: AsyncGenerate = AsyncChatGenerate(http_client=client, base_url=GROQ_BASE_URL, api_key=env.groq_api_key) + + +@link_llm("gemma") +@link_llm("mixtral") +class Groq(AsyncChatOpenAI): + async def complete(self, prompt: str | list[Message], /, **config): + config = self._run_config | config + return (await complete(prompt, **config)).removeprefix(" ") + + async def generate(self, prompt: str | list[Message], /, **config): + config = self._run_config | config + + first_token = True + + async for token in generate(prompt, **config): + if token and first_token: + first_token = False + yield token.removeprefix(" ") + else: + yield token + + def bind(self, **run_config): # type: ignore + self._run_config.update(run_config) # inplace + return self + + +groq = Groq().bind(model="mixtral-8x7b-32768") + + +groq.complete = patch.chat.acomplete(groq.complete) # type: ignore +groq.generate = patch.chat.agenerate(groq.generate) # type: ignore diff --git a/src/utils/llm/octoai.py b/src/utils/llm/octoai.py index 916acd9b..425f27f9 100644 --- a/src/utils/llm/octoai.py +++ b/src/utils/llm/octoai.py @@ -13,7 +13,6 @@ generate: AsyncGenerate = AsyncChatGenerate(http_client=client, base_url=OCTOAI_BASE_URL, api_key=env.octoai_api_key) -@link_llm("mixtral") @link_llm("nous-hermes") class OctoAI(AsyncChatOpenAI): async def complete(self, prompt: str | list[Message], /, **config):