From d028cab90957e7cf0f83627e28418e03702410c7 Mon Sep 17 00:00:00 2001 From: Ancss Date: Wed, 13 Sep 2023 00:09:27 +0800 Subject: [PATCH 1/7] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E6=96=87?= =?UTF-8?q?=E5=BF=83=E4=B8=80=E8=A8=80bot?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bot/baidu_wenxin.py | 110 +++++++++++++++++++++++++++++++++++++++++++ bot/bot.py | 4 ++ common/session.py | 14 ++++-- config.template.json | 6 ++- 4 files changed, 129 insertions(+), 5 deletions(-) create mode 100644 bot/baidu_wenxin.py diff --git a/bot/baidu_wenxin.py b/bot/baidu_wenxin.py new file mode 100644 index 0000000..3341898 --- /dev/null +++ b/bot/baidu_wenxin.py @@ -0,0 +1,110 @@ +# encoding:utf-8 + +import requests, json +from config import conf +from utils.log import logger +from common.session import Session +from common.reply import Reply, ReplyType +from common.context import ContextType, Context + +from config import conf + + +class BaiduWenxinBot: + def __init__(self): + self.model = conf().get("baidu_wenxin_model") or "eb-instant" + self.baidu_wenxin_api_key = conf().get("baidu_wenxin_api_key") + self.baidu_wenxin_secret_key = conf().get("baidu_wenxin_secret_key") + self.name = self.__class__.__name__ + + def reply(self, context=None): + # acquire reply content + query = context.query + logger.info(f"[{self.name}] Query={query}") + if context.type == ContextType.CREATE_IMAGE: + return self.reply_img(query) + else: + session_id = context.session_id + session = Session.build_session_query(context) + response = self.reply_text(session) + total_tokens, completion_tokens, reply_content = ( + response["total_tokens"], + response["completion_tokens"], + response["content"], + ) + logger.debug( + "[{}] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format( + self.name, + session.messages, + session_id, + reply_content, + completion_tokens, + ) + ) + + if total_tokens > 0: + Session.save_session( + response["content"], session_id, response["total_tokens"] + ) + return Reply(ReplyType.TEXT, response["content"]) + + def reply_img(self, query) -> Reply: + ok, image_url = self.create_img(query, 0) + if ok: + return Reply(ReplyType.IMAGE, image_url) + else: + logger.error(f"[{self.name}] Create image failed: {e}") + return Reply(ReplyType.TEXT, "Image created failed") + + def reply_text(self, session, retry_count=0): + try: + logger.info("[{}] model={}".format(self.name,self.model)) + access_token = self.get_access_token() + if access_token == "None": + logger.warn( + "[{self.name}] access token 获取失败" + ) + return { + "total_tokens": 0, + "completion_tokens": 0, + "content": 0, + } + url = ( + "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/" + + self.model + + "?access_token=" + + access_token + ) + headers = {"Content-Type": "application/json"} + payload = {"messages": session} + response = requests.request( + "POST", url, headers=headers, data=json.dumps(payload) + ) + response_text = json.loads(response.text) + logger.info(f"[{self.name}] response text={response_text}") + res_content = response_text["result"] + total_tokens = response_text["usage"]["total_tokens"] + completion_tokens = response_text["usage"]["completion_tokens"] + logger.info("[{}] reply={}".format(self.name,res_content)) + return { + "total_tokens": total_tokens, + "completion_tokens": completion_tokens, + "content": res_content, + } + except Exception as e: + logger.warn("[{}] Exception: {}".format(self.name,e)) + result = {"completion_tokens": 0, "content": "出错了: {}".format(e)} + return result + + def get_access_token(self): + """ + 使用 AK,SK 生成鉴权签名(Access Token) + :return: access_token,或是None(如果错误) + """ + url = "https://aip.baidubce.com/oauth/2.0/token" + params = { + "grant_type": "client_credentials", + "client_id": self.baidu_wenxin_api_key, + "client_secret": self.baidu_wenxin_secret_key, + } + return str(requests.post(url, params=params).json().get("access_token")) diff --git a/bot/bot.py b/bot/bot.py index 2abd5ae..88bad84 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -9,12 +9,16 @@ class Bot: def __init__(self): use_azure_chatgpt = conf().get("use_azure_chatgpt", False) + use_wxyx = conf().get("use_wxyx", False) model = conf().get("model", "gpt-3.5-turbo") if use_azure_chatgpt: from bot.azure_chatgpt import AzureChatGPTBot self.bot = AzureChatGPTBot() + elif use_wxyx: + from bot.baidu_wenxin import BaiduWenxinBot + self.bot = BaiduWenxinBot() elif model in litellm.open_ai_chat_completion_models: from bot.chatgpt import ChatGPTBot diff --git a/common/session.py b/common/session.py index d08c0d2..b2f0ab4 100644 --- a/common/session.py +++ b/common/session.py @@ -7,7 +7,7 @@ class Session(object): all_sessions = ExpiredDict(conf().get("session_expired_duration") or 3600) @staticmethod - def build_session_query(context: Context): + def build_session_query(context: Context,cls=None): """ build query with conversation history e.g. [ @@ -21,10 +21,16 @@ def build_session_query(context: Context): :return: query content with conversaction """ session = Session.all_sessions.get(context.session_id, []) + use_wxyx = conf().get("use_wxyx", False) + if len(session) == 0: - system_item = {"role": "system", "content": context.system_prompt} - session.append(system_item) - Session.all_sessions[context.session_id] = session + + if use_wxyx: + Session.all_sessions[context.session_id] = session + else: + system_item = {"role": "system", "content": context.system_prompt} + session.append(system_item) + Session.all_sessions[context.session_id] = session user_item = {"role": "user", "content": context.query} session.append(user_item) return session diff --git a/config.template.json b/config.template.json index 8e40348..1c112dc 100644 --- a/config.template.json +++ b/config.template.json @@ -18,5 +18,9 @@ "query_key_command": "#query key", "recent_days": 5, "plugins": [{ "name": "tiktok", "command": "#tiktok" }], - "openai_sensitive_id": "" + "openai_sensitive_id": "", + "use_wxyx": false, + "baidu_wenxin_model": "eb-instant", + "baidu_wenxin_api_key": "", + "baidu_wenxin_secret_key": "" } From e4ba950af87778f2cffc4c3ddd4da7af4aa24346 Mon Sep 17 00:00:00 2001 From: Ancss <61501274+Ancss@users.noreply.github.com> Date: Wed, 13 Sep 2023 10:48:21 +0800 Subject: [PATCH 2/7] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8Dmessage=E6=8A=A5?= =?UTF-8?q?=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bot/baidu_wenxin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bot/baidu_wenxin.py b/bot/baidu_wenxin.py index 3341898..5b2c578 100644 --- a/bot/baidu_wenxin.py +++ b/bot/baidu_wenxin.py @@ -35,7 +35,7 @@ def reply(self, context=None): logger.debug( "[{}] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format( self.name, - session.messages, + session, session_id, reply_content, completion_tokens, From 7cd367892d70dfd996f2385ace9fa329bffd6472 Mon Sep 17 00:00:00 2001 From: Ancss <61501274+Ancss@users.noreply.github.com> Date: Wed, 13 Sep 2023 10:49:31 +0800 Subject: [PATCH 3/7] =?UTF-8?q?fix:=20=E5=88=A0=E9=99=A4=E6=96=87=E5=BF=83?= =?UTF-8?q?=E4=B8=80=E8=A8=80=E8=B0=83=E5=9B=BE=E7=89=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bot/baidu_wenxin.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/bot/baidu_wenxin.py b/bot/baidu_wenxin.py index 5b2c578..fb2dbd7 100644 --- a/bot/baidu_wenxin.py +++ b/bot/baidu_wenxin.py @@ -49,12 +49,13 @@ def reply(self, context=None): return Reply(ReplyType.TEXT, response["content"]) def reply_img(self, query) -> Reply: - ok, image_url = self.create_img(query, 0) - if ok: - return Reply(ReplyType.IMAGE, image_url) - else: - logger.error(f"[{self.name}] Create image failed: {e}") - return Reply(ReplyType.TEXT, "Image created failed") + None + # ok, image_url = self.create_img(query, 0) + # if ok: + # return Reply(ReplyType.IMAGE, image_url) + # else: + # logger.error(f"[{self.name}] Create image failed: {e}") + # return Reply(ReplyType.TEXT, "Image created failed") def reply_text(self, session, retry_count=0): try: From 9d59c2c6ad66eb636cec04bbf6269b5b2dfa8739 Mon Sep 17 00:00:00 2001 From: Ancss <61501274+Ancss@users.noreply.github.com> Date: Wed, 13 Sep 2023 10:53:59 +0800 Subject: [PATCH 4/7] =?UTF-8?q?feat:=20=E5=AE=8C=E5=96=84readme?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README_ZH.md | 4 ++++ bot/baidu_wenxin.py | 2 +- config.template.json | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/README_ZH.md b/README_ZH.md index d749962..f55319d 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -96,6 +96,10 @@ pip install -r requirements.txt "recent_days": 5 # 查询最近的天 "plugins": [{ "name": , other configs }]# 添加你喜爱的插件 "openai_sensitive_id": "" # 查询api key时使用 + "use_wxyx": false, # 是否使用文心一言 + "baidu_wenxin_model": "", # 前往千帆大模型->服务管理->详情(要先创建服务)->顶部有个服务地址:https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/xxxxx 最后的xxxxx就是 baidu_wenxin_model + "baidu_wenxin_api_key": "", # 服务管理 详情可以看到 + "baidu_wenxin_secret_key": "" # 服务管理 详情可以看到 } ``` diff --git a/bot/baidu_wenxin.py b/bot/baidu_wenxin.py index fb2dbd7..7d69f8c 100644 --- a/bot/baidu_wenxin.py +++ b/bot/baidu_wenxin.py @@ -12,7 +12,7 @@ class BaiduWenxinBot: def __init__(self): - self.model = conf().get("baidu_wenxin_model") or "eb-instant" + self.model = conf().get("baidu_wenxin_model") self.baidu_wenxin_api_key = conf().get("baidu_wenxin_api_key") self.baidu_wenxin_secret_key = conf().get("baidu_wenxin_secret_key") self.name = self.__class__.__name__ diff --git a/config.template.json b/config.template.json index 1c112dc..e1e9960 100644 --- a/config.template.json +++ b/config.template.json @@ -20,7 +20,7 @@ "plugins": [{ "name": "tiktok", "command": "#tiktok" }], "openai_sensitive_id": "", "use_wxyx": false, - "baidu_wenxin_model": "eb-instant", + "baidu_wenxin_model": "", "baidu_wenxin_api_key": "", "baidu_wenxin_secret_key": "" } From 4078671d58075adb4e2deed8d72313e74718779e Mon Sep 17 00:00:00 2001 From: Ancss <61501274+Ancss@users.noreply.github.com> Date: Wed, 13 Sep 2023 11:41:42 +0800 Subject: [PATCH 5/7] =?UTF-8?q?fix:=20=E5=8A=A0=E5=85=A5=E9=BB=98=E8=AE=A4?= =?UTF-8?q?=E9=A1=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bot/baidu_wenxin.py | 2 +- config.template.json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/bot/baidu_wenxin.py b/bot/baidu_wenxin.py index 7d69f8c..fb2dbd7 100644 --- a/bot/baidu_wenxin.py +++ b/bot/baidu_wenxin.py @@ -12,7 +12,7 @@ class BaiduWenxinBot: def __init__(self): - self.model = conf().get("baidu_wenxin_model") + self.model = conf().get("baidu_wenxin_model") or "eb-instant" self.baidu_wenxin_api_key = conf().get("baidu_wenxin_api_key") self.baidu_wenxin_secret_key = conf().get("baidu_wenxin_secret_key") self.name = self.__class__.__name__ diff --git a/config.template.json b/config.template.json index e1e9960..1c112dc 100644 --- a/config.template.json +++ b/config.template.json @@ -20,7 +20,7 @@ "plugins": [{ "name": "tiktok", "command": "#tiktok" }], "openai_sensitive_id": "", "use_wxyx": false, - "baidu_wenxin_model": "", + "baidu_wenxin_model": "eb-instant", "baidu_wenxin_api_key": "", "baidu_wenxin_secret_key": "" } From 3cd1a481ef2c3f5f2db25e031cb4a2bea9bba3a3 Mon Sep 17 00:00:00 2001 From: Ancss <61501274+Ancss@users.noreply.github.com> Date: Wed, 13 Sep 2023 16:27:50 +0800 Subject: [PATCH 6/7] =?UTF-8?q?feat:=20=E6=96=87=E5=BF=83=E4=B8=80?= =?UTF-8?q?=E8=A8=80=E5=A2=9E=E5=8A=A0=E6=8F=92=E4=BB=B6=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README_ZH.md | 5 +++-- bot/baidu_wenxin.py | 17 ++++++++++++++--- config.template.json | 3 ++- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/README_ZH.md b/README_ZH.md index f55319d..0f8ad94 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -97,9 +97,10 @@ pip install -r requirements.txt "plugins": [{ "name": , other configs }]# 添加你喜爱的插件 "openai_sensitive_id": "" # 查询api key时使用 "use_wxyx": false, # 是否使用文心一言 - "baidu_wenxin_model": "", # 前往千帆大模型->服务管理->详情(要先创建服务)->顶部有个服务地址:https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/xxxxx 最后的xxxxx就是 baidu_wenxin_model + "baidu_wenxin_model": "eb-instant", # 使用默认eb-instant或 前往千帆大模型->服务管理->详情(要先创建服务)->顶部有个服务地址:https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/xxxxx 最后的xxxxx就是 baidu_wenxin_model "baidu_wenxin_api_key": "", # 服务管理 详情可以看到 - "baidu_wenxin_secret_key": "" # 服务管理 详情可以看到 + "baidu_wenxin_secret_key": "", # 服务管理 详情可以看到 + "baidu_wenxin_plugin_suffix":"" # 获取插件配置接口后的后缀 } ``` diff --git a/bot/baidu_wenxin.py b/bot/baidu_wenxin.py index fb2dbd7..da5eb44 100644 --- a/bot/baidu_wenxin.py +++ b/bot/baidu_wenxin.py @@ -16,6 +16,7 @@ def __init__(self): self.baidu_wenxin_api_key = conf().get("baidu_wenxin_api_key") self.baidu_wenxin_secret_key = conf().get("baidu_wenxin_secret_key") self.name = self.__class__.__name__ + self.baidu_wenxin_plugin_suffix = conf().get("baidu_wenxin_plugin_suffix") def reply(self, context=None): # acquire reply content @@ -26,7 +27,7 @@ def reply(self, context=None): else: session_id = context.session_id session = Session.build_session_query(context) - response = self.reply_text(session) + response = self.reply_text(session,query=query) total_tokens, completion_tokens, reply_content = ( response["total_tokens"], response["completion_tokens"], @@ -57,7 +58,7 @@ def reply_img(self, query) -> Reply: # logger.error(f"[{self.name}] Create image failed: {e}") # return Reply(ReplyType.TEXT, "Image created failed") - def reply_text(self, session, retry_count=0): + def reply_text(self, session,query): try: logger.info("[{}] model={}".format(self.name,self.model)) access_token = self.get_access_token() @@ -76,8 +77,18 @@ def reply_text(self, session, retry_count=0): + "?access_token=" + access_token ) - headers = {"Content-Type": "application/json"} payload = {"messages": session} + + if self.baidu_wenxin_plugin_suffix is not None: + self.url = ( + "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/plugin/" + + self.baidu_wenxin_plugin_suffix+ '/' + + "?access_token=" + + access_token) + payload = {"query":query} + + headers = {"Content-Type": "application/json"} + response = requests.request( "POST", url, headers=headers, data=json.dumps(payload) ) diff --git a/config.template.json b/config.template.json index 1c112dc..81ec287 100644 --- a/config.template.json +++ b/config.template.json @@ -22,5 +22,6 @@ "use_wxyx": false, "baidu_wenxin_model": "eb-instant", "baidu_wenxin_api_key": "", - "baidu_wenxin_secret_key": "" + "baidu_wenxin_secret_key": "", + "baidu_wenxin_plugin_suffix":"" } From 09393bfaeb4ff9e3a4ab11bf3d12da180a4ec081 Mon Sep 17 00:00:00 2001 From: Ancss <61501274+Ancss@users.noreply.github.com> Date: Sat, 7 Oct 2023 15:47:31 +0800 Subject: [PATCH 7/7] fix: fix the lint --- bot/baidu_wenxin.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bot/baidu_wenxin.py b/bot/baidu_wenxin.py index da5eb44..4c9fad8 100644 --- a/bot/baidu_wenxin.py +++ b/bot/baidu_wenxin.py @@ -1,11 +1,11 @@ # encoding:utf-8 -import requests, json -from config import conf +import requests +import json from utils.log import logger from common.session import Session from common.reply import Reply, ReplyType -from common.context import ContextType, Context +from common.context import ContextType from config import conf