diff --git a/pyproject.toml b/pyproject.toml index e203e449..17cefe52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ readme = "README.md" license = { text = "MIT" } dependencies = [ "fastapi~=0.110.0", - "uvicorn[standard]~=0.28.0", + "uvicorn[standard]~=0.29.0", "promplate[all]~=0.3.3.4", "promplate-trace[langfuse,langsmith]==0.3.0dev2", "python-box~=7.1.1", @@ -17,7 +17,7 @@ dependencies = [ "beautifulsoup4~=4.12.3", "rich~=13.7.1", "zhipuai~=2.0.1", - "anthropic~=0.20.0", + "anthropic~=0.21.3", "dashscope~=1.15.0", ] diff --git a/src/routes/run.py b/src/routes/run.py index 811d3d9b..36040f87 100644 --- a/src/routes/run.py +++ b/src/routes/run.py @@ -33,6 +33,8 @@ class Msg(BaseModel): "gpt-4-0125-preview", "chatglm_turbo", "claude-3-haiku-20240307", + "gemma-7b-it", + "mixtral-8x7b-32768", "nous-hermes-2-mixtral-8x7b-dpo", "qwen-turbo", "abab5.5s-chat", diff --git a/src/utils/config.py b/src/utils/config.py index 95181399..96a29bf5 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -11,6 +11,7 @@ class Config(BaseSettings): openai_base_url: str = "" octoai_api_key: str = "" zhipu_api_key: str = "" + groq_api_key: str = "" # other services serper_api_key: str = "" diff --git a/src/utils/llm/__init__.py b/src/utils/llm/__init__.py index 810ca28b..5b7dcd37 100644 --- a/src/utils/llm/__init__.py +++ b/src/utils/llm/__init__.py @@ -1,6 +1,7 @@ from .anthropic import anthropic from .chatglm import glm from .dispatch import find_llm +from .groq import groq from .minimax import minimax from .octoai import octoai from .openai import openai diff --git a/src/utils/llm/groq.py b/src/utils/llm/groq.py new file mode 100644 index 00000000..b3a3a891 --- /dev/null +++ b/src/utils/llm/groq.py @@ -0,0 +1,44 @@ +from promplate import Message +from promplate.llm.base import AsyncComplete, AsyncGenerate +from promplate.llm.openai import AsyncChatComplete, AsyncChatGenerate, AsyncChatOpenAI +from promplate_trace.auto import patch + +from ..config import env +from .common import client +from .dispatch import link_llm + +GROQ_BASE_URL = "https://api.groq.com/openai/v1" + +complete: AsyncComplete = AsyncChatComplete(http_client=client, base_url=GROQ_BASE_URL, api_key=env.groq_api_key) +generate: AsyncGenerate = AsyncChatGenerate(http_client=client, base_url=GROQ_BASE_URL, api_key=env.groq_api_key) + + +@link_llm("gemma") +@link_llm("mixtral") +class Groq(AsyncChatOpenAI): + async def complete(self, prompt: str | list[Message], /, **config): + config = self._run_config | config + return (await complete(prompt, **config)).removeprefix(" ") + + async def generate(self, prompt: str | list[Message], /, **config): + config = self._run_config | config + + first_token = True + + async for token in generate(prompt, **config): + if token and first_token: + first_token = False + yield token.removeprefix(" ") + else: + yield token + + def bind(self, **run_config): # type: ignore + self._run_config.update(run_config) # inplace + return self + + +groq = Groq().bind(model="mixtral-8x7b-32768") + + +groq.complete = patch.chat.acomplete(groq.complete) # type: ignore +groq.generate = patch.chat.agenerate(groq.generate) # type: ignore diff --git a/src/utils/llm/octoai.py b/src/utils/llm/octoai.py index 916acd9b..425f27f9 100644 --- a/src/utils/llm/octoai.py +++ b/src/utils/llm/octoai.py @@ -13,7 +13,6 @@ generate: AsyncGenerate = AsyncChatGenerate(http_client=client, base_url=OCTOAI_BASE_URL, api_key=env.octoai_api_key) -@link_llm("mixtral") @link_llm("nous-hermes") class OctoAI(AsyncChatOpenAI): async def complete(self, prompt: str | list[Message], /, **config):