Skip to content

Commit

Permalink
chore: simplify code in llm/groq.py and llm/octoai.py
Browse files Browse the repository at this point in the history
  • Loading branch information
CNSeniorious000 committed Jun 26, 2024
1 parent fc1953b commit de05135
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 20 deletions.
13 changes: 3 additions & 10 deletions src/utils/llm/groq.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from promplate import Message
from promplate.llm.base import AsyncComplete, AsyncGenerate
from promplate.llm.openai import AsyncChatComplete, AsyncChatGenerate, AsyncChatOpenAI
from promplate_trace.auto import patch

from ..config import env
from .common import client
from .dispatch import link_llm

complete: AsyncComplete = AsyncChatComplete(http_client=client, base_url=env.groq_base_url, api_key=env.groq_api_key)
generate: AsyncGenerate = AsyncChatGenerate(http_client=client, base_url=env.groq_base_url, api_key=env.groq_api_key)
complete = AsyncChatComplete(http_client=client, base_url=env.groq_base_url, api_key=env.groq_api_key)
generate = AsyncChatGenerate(http_client=client, base_url=env.groq_base_url, api_key=env.groq_api_key)


@link_llm("gemma")
Expand All @@ -22,14 +21,8 @@ async def complete(self, prompt: str | list[Message], /, **config):
async def generate(self, prompt: str | list[Message], /, **config):
config = self._run_config | config

first_token = True

async for token in generate(prompt, **config):
if token and first_token:
first_token = False
yield token.removeprefix(" ")
else:
yield token
yield token

def bind(self, **run_config): # type: ignore
self._run_config.update(run_config) # inplace
Expand Down
13 changes: 3 additions & 10 deletions src/utils/llm/octoai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from promplate import Message
from promplate.llm.base import AsyncComplete, AsyncGenerate
from promplate.llm.openai import AsyncChatComplete, AsyncChatGenerate, AsyncChatOpenAI
from promplate_trace.auto import patch

Expand All @@ -9,8 +8,8 @@

OCTOAI_BASE_URL = "https://text.octoai.run/v1"

complete: AsyncComplete = AsyncChatComplete(http_client=client, base_url=OCTOAI_BASE_URL, api_key=env.octoai_api_key)
generate: AsyncGenerate = AsyncChatGenerate(http_client=client, base_url=OCTOAI_BASE_URL, api_key=env.octoai_api_key)
complete = AsyncChatComplete(http_client=client, base_url=OCTOAI_BASE_URL, api_key=env.octoai_api_key)
generate = AsyncChatGenerate(http_client=client, base_url=OCTOAI_BASE_URL, api_key=env.octoai_api_key)


@link_llm("nous-hermes")
Expand All @@ -22,14 +21,8 @@ async def complete(self, prompt: str | list[Message], /, **config):
async def generate(self, prompt: str | list[Message], /, **config):
config = self._run_config | config

first_token = True

async for token in generate(prompt, **config):
if token and first_token:
first_token = False
yield token.removeprefix(" ")
else:
yield token
yield token

def bind(self, **run_config): # type: ignore
self._run_config.update(run_config) # inplace
Expand Down

0 comments on commit de05135

Please sign in to comment.