diff --git a/src/utils/config.py b/src/utils/config.py index b166973..6cede3b 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -4,6 +4,8 @@ class Config(BaseSettings): # llm providers + siliconflow_api_key: str = "" + siliconflow_base_url: str = "https://api.siliconflow.cn/v1/" anthropic_api_key: str = "" dashscope_api_key: str = "" minimax_api_key: str = "" diff --git a/src/utils/llm/__init__.py b/src/utils/llm/__init__.py index e397038..73b35c3 100644 --- a/src/utils/llm/__init__.py +++ b/src/utils/llm/__init__.py @@ -8,6 +8,7 @@ from .octoai import octoai from .openai import openai from .qwen import qwen +from .siliconflow import siliconflow Model = Literal[ "gpt-3.5-turbo-0301", @@ -36,4 +37,20 @@ "abab5.5s-chat", "abab5.5-chat", "abab6-chat", + "Qwen/Qwen2-7B-Instruct", + "Qwen/Qwen2-1.5B-Instruct", + "Qwen/Qwen1.5-7B-Chat", + "Qwen/Qwen2-72B-Instruct", + "Qwen/Qwen2-57B-A14B-Instruct", + "Qwen/Qwen1.5-110B-Chat", + "Qwen/Qwen1.5-32B-Chat", + "Qwen/Qwen1.5-14B-Chat", + "THUDM/glm-4-9b-chat", + "THUDM/chatglm3-6b", + "01-ai/Yi-1.5-9B-Chat-16K", + "01-ai/Yi-1.5-6B-Chat", + "01-ai/Yi-1.5-34B-Chat-16K", + "deepseek-ai/DeepSeek-Coder-V2-Instruct", + "deepseek-ai/DeepSeek-V2-Chat", + "deepseek-ai/deepseek-llm-67b-chat", ] diff --git a/src/utils/llm/siliconflow.py b/src/utils/llm/siliconflow.py new file mode 100644 index 0000000..dd5be9d --- /dev/null +++ b/src/utils/llm/siliconflow.py @@ -0,0 +1,37 @@ +from promplate import Message +from promplate.llm.openai import AsyncChatComplete, AsyncChatGenerate, AsyncChatOpenAI +from promplate_trace.auto import patch + +from ..config import env +from .common import client +from .dispatch import link_llm + +complete = AsyncChatComplete(http_client=client, base_url=env.siliconflow_base_url, api_key=env.siliconflow_api_key) +generate = AsyncChatGenerate(http_client=client, base_url=env.siliconflow_base_url, api_key=env.siliconflow_api_key) + + +@link_llm("Qwen/") +@link_llm("01-ai/") +@link_llm("THUDM/") +@link_llm("deepseek-ai/") +class Siliconflow(AsyncChatOpenAI): + async def complete(self, prompt: str | list[Message], /, **config): + config = self._run_config | config + return (await complete(prompt, **config)).removeprefix(" ") + + async def generate(self, prompt: str | list[Message], /, **config): + config = self._run_config | config + + async for token in generate(prompt, **config): + yield token + + def bind(self, **run_config): # type: ignore + self._run_config.update(run_config) # inplace + return self + + +siliconflow = Siliconflow() + + +siliconflow.complete = patch.chat.acomplete(siliconflow.complete) # type: ignore +siliconflow.generate = patch.chat.agenerate(siliconflow.generate) # type: ignore