diff --git a/README_ZH.md b/README_ZH.md index d749962..0f8ad94 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -96,6 +96,11 @@ pip install -r requirements.txt "recent_days": 5 # 查询最近的天 "plugins": [{ "name": , other configs }]# 添加你喜爱的插件 "openai_sensitive_id": "" # 查询api key时使用 + "use_wxyx": false, # 是否使用文心一言 + "baidu_wenxin_model": "eb-instant", # 使用默认eb-instant或 前往千帆大模型->服务管理->详情(要先创建服务)->顶部有个服务地址:https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/xxxxx 最后的xxxxx就是 baidu_wenxin_model + "baidu_wenxin_api_key": "", # 服务管理 详情可以看到 + "baidu_wenxin_secret_key": "", # 服务管理 详情可以看到 + "baidu_wenxin_plugin_suffix":"" # 获取插件配置接口后的后缀 } ``` diff --git a/bot/baidu_wenxin.py b/bot/baidu_wenxin.py new file mode 100644 index 0000000..4c9fad8 --- /dev/null +++ b/bot/baidu_wenxin.py @@ -0,0 +1,122 @@ +# encoding:utf-8 + +import requests +import json +from utils.log import logger +from common.session import Session +from common.reply import Reply, ReplyType +from common.context import ContextType + +from config import conf + + +class BaiduWenxinBot: + def __init__(self): + self.model = conf().get("baidu_wenxin_model") or "eb-instant" + self.baidu_wenxin_api_key = conf().get("baidu_wenxin_api_key") + self.baidu_wenxin_secret_key = conf().get("baidu_wenxin_secret_key") + self.name = self.__class__.__name__ + self.baidu_wenxin_plugin_suffix = conf().get("baidu_wenxin_plugin_suffix") + + def reply(self, context=None): + # acquire reply content + query = context.query + logger.info(f"[{self.name}] Query={query}") + if context.type == ContextType.CREATE_IMAGE: + return self.reply_img(query) + else: + session_id = context.session_id + session = Session.build_session_query(context) + response = self.reply_text(session,query=query) + total_tokens, completion_tokens, reply_content = ( + response["total_tokens"], + response["completion_tokens"], + response["content"], + ) + logger.debug( + "[{}] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format( + self.name, + session, + session_id, + reply_content, + completion_tokens, + ) + ) + + if total_tokens > 0: + Session.save_session( + response["content"], session_id, response["total_tokens"] + ) + return Reply(ReplyType.TEXT, response["content"]) + + def reply_img(self, query) -> Reply: + None + # ok, image_url = self.create_img(query, 0) + # if ok: + # return Reply(ReplyType.IMAGE, image_url) + # else: + # logger.error(f"[{self.name}] Create image failed: {e}") + # return Reply(ReplyType.TEXT, "Image created failed") + + def reply_text(self, session,query): + try: + logger.info("[{}] model={}".format(self.name,self.model)) + access_token = self.get_access_token() + if access_token == "None": + logger.warn( + "[{self.name}] access token 获取失败" + ) + return { + "total_tokens": 0, + "completion_tokens": 0, + "content": 0, + } + url = ( + "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/" + + self.model + + "?access_token=" + + access_token + ) + payload = {"messages": session} + + if self.baidu_wenxin_plugin_suffix is not None: + self.url = ( + "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/plugin/" + + self.baidu_wenxin_plugin_suffix+ '/' + + "?access_token=" + + access_token) + payload = {"query":query} + + headers = {"Content-Type": "application/json"} + + response = requests.request( + "POST", url, headers=headers, data=json.dumps(payload) + ) + response_text = json.loads(response.text) + logger.info(f"[{self.name}] response text={response_text}") + res_content = response_text["result"] + total_tokens = response_text["usage"]["total_tokens"] + completion_tokens = response_text["usage"]["completion_tokens"] + logger.info("[{}] reply={}".format(self.name,res_content)) + return { + "total_tokens": total_tokens, + "completion_tokens": completion_tokens, + "content": res_content, + } + except Exception as e: + logger.warn("[{}] Exception: {}".format(self.name,e)) + result = {"completion_tokens": 0, "content": "出错了: {}".format(e)} + return result + + def get_access_token(self): + """ + 使用 AK,SK 生成鉴权签名(Access Token) + :return: access_token,或是None(如果错误) + """ + url = "https://aip.baidubce.com/oauth/2.0/token" + params = { + "grant_type": "client_credentials", + "client_id": self.baidu_wenxin_api_key, + "client_secret": self.baidu_wenxin_secret_key, + } + return str(requests.post(url, params=params).json().get("access_token")) diff --git a/bot/bot.py b/bot/bot.py index 2abd5ae..88bad84 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -9,12 +9,16 @@ class Bot: def __init__(self): use_azure_chatgpt = conf().get("use_azure_chatgpt", False) + use_wxyx = conf().get("use_wxyx", False) model = conf().get("model", "gpt-3.5-turbo") if use_azure_chatgpt: from bot.azure_chatgpt import AzureChatGPTBot self.bot = AzureChatGPTBot() + elif use_wxyx: + from bot.baidu_wenxin import BaiduWenxinBot + self.bot = BaiduWenxinBot() elif model in litellm.open_ai_chat_completion_models: from bot.chatgpt import ChatGPTBot diff --git a/common/session.py b/common/session.py index d08c0d2..b2f0ab4 100644 --- a/common/session.py +++ b/common/session.py @@ -7,7 +7,7 @@ class Session(object): all_sessions = ExpiredDict(conf().get("session_expired_duration") or 3600) @staticmethod - def build_session_query(context: Context): + def build_session_query(context: Context,cls=None): """ build query with conversation history e.g. [ @@ -21,10 +21,16 @@ def build_session_query(context: Context): :return: query content with conversaction """ session = Session.all_sessions.get(context.session_id, []) + use_wxyx = conf().get("use_wxyx", False) + if len(session) == 0: - system_item = {"role": "system", "content": context.system_prompt} - session.append(system_item) - Session.all_sessions[context.session_id] = session + + if use_wxyx: + Session.all_sessions[context.session_id] = session + else: + system_item = {"role": "system", "content": context.system_prompt} + session.append(system_item) + Session.all_sessions[context.session_id] = session user_item = {"role": "user", "content": context.query} session.append(user_item) return session diff --git a/config.template.json b/config.template.json index 8e40348..81ec287 100644 --- a/config.template.json +++ b/config.template.json @@ -18,5 +18,10 @@ "query_key_command": "#query key", "recent_days": 5, "plugins": [{ "name": "tiktok", "command": "#tiktok" }], - "openai_sensitive_id": "" + "openai_sensitive_id": "", + "use_wxyx": false, + "baidu_wenxin_model": "eb-instant", + "baidu_wenxin_api_key": "", + "baidu_wenxin_secret_key": "", + "baidu_wenxin_plugin_suffix":"" }