From de0513531e465cff9824202bc231408cd24101ac Mon Sep 17 00:00:00 2001 From: Muspi Merol Date: Wed, 26 Jun 2024 22:54:16 +0800 Subject: [PATCH] chore: simplify code in `llm/groq.py` and `llm/octoai.py` --- src/utils/llm/groq.py | 13 +++---------- src/utils/llm/octoai.py | 13 +++---------- 2 files changed, 6 insertions(+), 20 deletions(-) diff --git a/src/utils/llm/groq.py b/src/utils/llm/groq.py index 54a40308..3a20798e 100644 --- a/src/utils/llm/groq.py +++ b/src/utils/llm/groq.py @@ -1,5 +1,4 @@ from promplate import Message -from promplate.llm.base import AsyncComplete, AsyncGenerate from promplate.llm.openai import AsyncChatComplete, AsyncChatGenerate, AsyncChatOpenAI from promplate_trace.auto import patch @@ -7,8 +6,8 @@ from .common import client from .dispatch import link_llm -complete: AsyncComplete = AsyncChatComplete(http_client=client, base_url=env.groq_base_url, api_key=env.groq_api_key) -generate: AsyncGenerate = AsyncChatGenerate(http_client=client, base_url=env.groq_base_url, api_key=env.groq_api_key) +complete = AsyncChatComplete(http_client=client, base_url=env.groq_base_url, api_key=env.groq_api_key) +generate = AsyncChatGenerate(http_client=client, base_url=env.groq_base_url, api_key=env.groq_api_key) @link_llm("gemma") @@ -22,14 +21,8 @@ async def complete(self, prompt: str | list[Message], /, **config): async def generate(self, prompt: str | list[Message], /, **config): config = self._run_config | config - first_token = True - async for token in generate(prompt, **config): - if token and first_token: - first_token = False - yield token.removeprefix(" ") - else: - yield token + yield token def bind(self, **run_config): # type: ignore self._run_config.update(run_config) # inplace diff --git a/src/utils/llm/octoai.py b/src/utils/llm/octoai.py index 425f27f9..3b38fae1 100644 --- a/src/utils/llm/octoai.py +++ b/src/utils/llm/octoai.py @@ -1,5 +1,4 @@ from promplate import Message -from promplate.llm.base import AsyncComplete, AsyncGenerate from promplate.llm.openai import AsyncChatComplete, AsyncChatGenerate, AsyncChatOpenAI from promplate_trace.auto import patch @@ -9,8 +8,8 @@ OCTOAI_BASE_URL = "https://text.octoai.run/v1" -complete: AsyncComplete = AsyncChatComplete(http_client=client, base_url=OCTOAI_BASE_URL, api_key=env.octoai_api_key) -generate: AsyncGenerate = AsyncChatGenerate(http_client=client, base_url=OCTOAI_BASE_URL, api_key=env.octoai_api_key) +complete = AsyncChatComplete(http_client=client, base_url=OCTOAI_BASE_URL, api_key=env.octoai_api_key) +generate = AsyncChatGenerate(http_client=client, base_url=OCTOAI_BASE_URL, api_key=env.octoai_api_key) @link_llm("nous-hermes") @@ -22,14 +21,8 @@ async def complete(self, prompt: str | list[Message], /, **config): async def generate(self, prompt: str | list[Message], /, **config): config = self._run_config | config - first_token = True - async for token in generate(prompt, **config): - if token and first_token: - first_token = False - yield token.removeprefix(" ") - else: - yield token + yield token def bind(self, **run_config): # type: ignore self._run_config.update(run_config) # inplace