Skip to content

Commit

Permalink
feat: integrate with Groq provider
Browse files Browse the repository at this point in the history
chore: update deps
  • Loading branch information
CNSeniorious000 committed Mar 28, 2024
1 parent 26e9fab commit c1ddac0
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 4 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ readme = "README.md"
license = { text = "MIT" }
dependencies = [
"fastapi~=0.110.0",
"uvicorn[standard]~=0.28.0",
"uvicorn[standard]~=0.29.0",
"promplate[all]~=0.3.3.4",
"promplate-trace[langfuse,langsmith]==0.3.0dev2",
"python-box~=7.1.1",
Expand All @@ -17,7 +17,7 @@ dependencies = [
"beautifulsoup4~=4.12.3",
"rich~=13.7.1",
"zhipuai~=2.0.1",
"anthropic~=0.20.0",
"anthropic~=0.21.3",
"dashscope~=1.15.0",
]

Expand Down
3 changes: 2 additions & 1 deletion src/routes/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ class Msg(BaseModel):
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
"mixtral-8x7b-instruct-fp16",
"gemma-7b-it",
"mixtral-8x7b-32768",
"nous-hermes-2-mixtral-8x7b-dpo",
"qwen-turbo",
"qwen-plus",
Expand Down
1 change: 1 addition & 0 deletions src/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class Config(BaseSettings):
openai_base_url: str = ""
octoai_api_key: str = ""
zhipu_api_key: str = ""
groq_api_key: str = ""

# other services
serper_api_key: str = ""
Expand Down
1 change: 1 addition & 0 deletions src/utils/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .anthropic import anthropic
from .chatglm import glm
from .dispatch import find_llm
from .groq import groq
from .minimax import minimax
from .octoai import octoai
from .openai import openai
Expand Down
44 changes: 44 additions & 0 deletions src/utils/llm/groq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from promplate import Message
from promplate.llm.base import AsyncComplete, AsyncGenerate
from promplate.llm.openai import AsyncChatComplete, AsyncChatGenerate, AsyncChatOpenAI
from promplate_trace.auto import patch

from ..config import env
from .common import client
from .dispatch import link_llm

GROQ_BASE_URL = "https://api.groq.com/openai/v1"

complete: AsyncComplete = AsyncChatComplete(http_client=client, base_url=GROQ_BASE_URL, api_key=env.groq_api_key)
generate: AsyncGenerate = AsyncChatGenerate(http_client=client, base_url=GROQ_BASE_URL, api_key=env.groq_api_key)


@link_llm("gemma")
@link_llm("mixtral")
class Groq(AsyncChatOpenAI):
async def complete(self, prompt: str | list[Message], /, **config):
config = self._run_config | config
return (await complete(prompt, **config)).removeprefix(" ")

async def generate(self, prompt: str | list[Message], /, **config):
config = self._run_config | config

first_token = True

async for token in generate(prompt, **config):
if token and first_token:
first_token = False
yield token.removeprefix(" ")
else:
yield token

def bind(self, **run_config): # type: ignore
self._run_config.update(run_config) # inplace
return self


groq = Groq().bind(model="mixtral-8x7b-32768")


groq.complete = patch.chat.acomplete(groq.complete) # type: ignore
groq.generate = patch.chat.agenerate(groq.generate) # type: ignore
1 change: 0 additions & 1 deletion src/utils/llm/octoai.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
generate: AsyncGenerate = AsyncChatGenerate(http_client=client, base_url=OCTOAI_BASE_URL, api_key=env.octoai_api_key)


@link_llm("mixtral")
@link_llm("nous-hermes")
class OctoAI(AsyncChatOpenAI):
async def complete(self, prompt: str | list[Message], /, **config):
Expand Down

0 comments on commit c1ddac0

Please sign in to comment.