Skip to content

Commit

Permalink
Merge branch 'main' into deploy
Browse files Browse the repository at this point in the history
  • Loading branch information
CNSeniorious000 committed Jun 26, 2024
2 parents fecb54d + 746a5a0 commit 2423752
Show file tree
Hide file tree
Showing 10 changed files with 69 additions and 30 deletions.
2 changes: 2 additions & 0 deletions src/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

class Config(BaseSettings):
# llm providers
siliconflow_api_key: str = ""
siliconflow_base_url: str = "https://api.siliconflow.cn/v1/"
anthropic_api_key: str = ""
dashscope_api_key: str = ""
minimax_api_key: str = ""
Expand Down
17 changes: 17 additions & 0 deletions src/utils/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .octoai import octoai
from .openai import openai
from .qwen import qwen
from .siliconflow import siliconflow

Model = Literal[
"gpt-3.5-turbo-0301",
Expand All @@ -33,4 +34,20 @@
"abab5.5s-chat",
"abab5.5-chat",
"abab6-chat",
"Qwen/Qwen2-7B-Instruct",
"Qwen/Qwen2-1.5B-Instruct",
"Qwen/Qwen1.5-7B-Chat",
"Qwen/Qwen2-72B-Instruct",
"Qwen/Qwen2-57B-A14B-Instruct",
"Qwen/Qwen1.5-110B-Chat",
"Qwen/Qwen1.5-32B-Chat",
"Qwen/Qwen1.5-14B-Chat",
"THUDM/glm-4-9b-chat",
"THUDM/chatglm3-6b",
"01-ai/Yi-1.5-9B-Chat-16K",
"01-ai/Yi-1.5-6B-Chat",
"01-ai/Yi-1.5-34B-Chat-16K",
"deepseek-ai/DeepSeek-Coder-V2-Instruct",
"deepseek-ai/DeepSeek-V2-Chat",
"deepseek-ai/deepseek-llm-67b-chat",
]
3 changes: 1 addition & 2 deletions src/utils/llm/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ def ensure_even(prompt: str | list[Message]) -> list[SafeMessage]:
class ChatGLM(LLM):
@staticmethod
@validate_call
def validate(temperature: float = Field(0.95, gt=0, le=1), top_p: float = Field(0.7, gt=0, lt=1), **_):
pass
def validate(temperature: float = Field(0.95, gt=0, le=1), top_p: float = Field(0.7, gt=0, lt=1), **_): ...

@staticmethod
@patch.chat.acomplete
Expand Down
2 changes: 2 additions & 0 deletions src/utils/llm/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .http import client
from .messages import SafeMessage, ensure_safe
3 changes: 3 additions & 0 deletions src/utils/llm/common/http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from httpx import AsyncClient

client = AsyncClient(http2=True)
3 changes: 0 additions & 3 deletions src/utils/llm/common.py → src/utils/llm/common/messages.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from typing import Literal, cast

from httpx import AsyncClient
from promplate.prompt.chat import Message, ensure
from typing_extensions import TypedDict

client = AsyncClient(http2=True)


class SafeMessage(TypedDict):
role: Literal["user", "assistant"]
Expand Down
13 changes: 3 additions & 10 deletions src/utils/llm/groq.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from promplate import Message
from promplate.llm.base import AsyncComplete, AsyncGenerate
from promplate.llm.openai import AsyncChatComplete, AsyncChatGenerate, AsyncChatOpenAI
from promplate_trace.auto import patch

from ..config import env
from .common import client
from .dispatch import link_llm

complete: AsyncComplete = AsyncChatComplete(http_client=client, base_url=env.groq_base_url, api_key=env.groq_api_key)
generate: AsyncGenerate = AsyncChatGenerate(http_client=client, base_url=env.groq_base_url, api_key=env.groq_api_key)
complete = AsyncChatComplete(http_client=client, base_url=env.groq_base_url, api_key=env.groq_api_key)
generate = AsyncChatGenerate(http_client=client, base_url=env.groq_base_url, api_key=env.groq_api_key)


@link_llm("gemma")
Expand All @@ -22,14 +21,8 @@ async def complete(self, prompt: str | list[Message], /, **config):
async def generate(self, prompt: str | list[Message], /, **config):
config = self._run_config | config

first_token = True

async for token in generate(prompt, **config):
if token and first_token:
first_token = False
yield token.removeprefix(" ")
else:
yield token
yield token

def bind(self, **run_config): # type: ignore
self._run_config.update(run_config) # inplace
Expand Down
13 changes: 3 additions & 10 deletions src/utils/llm/octoai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from promplate import Message
from promplate.llm.base import AsyncComplete, AsyncGenerate
from promplate.llm.openai import AsyncChatComplete, AsyncChatGenerate, AsyncChatOpenAI
from promplate_trace.auto import patch

Expand All @@ -9,8 +8,8 @@

OCTOAI_BASE_URL = "https://text.octoai.run/v1"

complete: AsyncComplete = AsyncChatComplete(http_client=client, base_url=OCTOAI_BASE_URL, api_key=env.octoai_api_key)
generate: AsyncGenerate = AsyncChatGenerate(http_client=client, base_url=OCTOAI_BASE_URL, api_key=env.octoai_api_key)
complete = AsyncChatComplete(http_client=client, base_url=OCTOAI_BASE_URL, api_key=env.octoai_api_key)
generate = AsyncChatGenerate(http_client=client, base_url=OCTOAI_BASE_URL, api_key=env.octoai_api_key)


@link_llm("nous-hermes")
Expand All @@ -22,14 +21,8 @@ async def complete(self, prompt: str | list[Message], /, **config):
async def generate(self, prompt: str | list[Message], /, **config):
config = self._run_config | config

first_token = True

async for token in generate(prompt, **config):
if token and first_token:
first_token = False
yield token.removeprefix(" ")
else:
yield token
yield token

def bind(self, **run_config): # type: ignore
self._run_config.update(run_config) # inplace
Expand Down
6 changes: 1 addition & 5 deletions src/utils/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,4 @@ def bind(self, **run_config): # type: ignore
return self


openai = OpenAI().bind(
model="gpt-3.5-turbo-0125",
temperature=0.7,
# response_format={"type": "json_object"},
)
openai = OpenAI().bind(model="gpt-3.5-turbo-0125", temperature=0.7)
37 changes: 37 additions & 0 deletions src/utils/llm/siliconflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from promplate import Message
from promplate.llm.openai import AsyncChatComplete, AsyncChatGenerate, AsyncChatOpenAI
from promplate_trace.auto import patch

from ..config import env
from .common import client
from .dispatch import link_llm

complete = AsyncChatComplete(http_client=client, base_url=env.siliconflow_base_url, api_key=env.siliconflow_api_key)
generate = AsyncChatGenerate(http_client=client, base_url=env.siliconflow_base_url, api_key=env.siliconflow_api_key)


@link_llm("Qwen/")
@link_llm("01-ai/")
@link_llm("THUDM/")
@link_llm("deepseek-ai/")
class Siliconflow(AsyncChatOpenAI):
async def complete(self, prompt: str | list[Message], /, **config):
config = self._run_config | config
return (await complete(prompt, **config)).removeprefix(" ")

async def generate(self, prompt: str | list[Message], /, **config):
config = self._run_config | config

async for token in generate(prompt, **config):
yield token

def bind(self, **run_config): # type: ignore
self._run_config.update(run_config) # inplace
return self


siliconflow = Siliconflow()


siliconflow.complete = patch.chat.acomplete(siliconflow.complete) # type: ignore
siliconflow.generate = patch.chat.agenerate(siliconflow.generate) # type: ignore

0 comments on commit 2423752

Please sign in to comment.