Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

接入文心一言和文心一言插件 #89

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README_ZH.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ pip install -r requirements.txt
"recent_days": 5 # 查询最近的<recent_days>天
"plugins": [{ "name": <plugin name>, other configs }]# 添加你喜爱的插件
"openai_sensitive_id": "" # 查询api key时使用
"use_wxyx": false, # 是否使用文心一言
"baidu_wenxin_model": "eb-instant", # 使用默认eb-instant或 前往千帆大模型->服务管理->详情(要先创建服务)->顶部有个服务地址:https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/xxxxx 最后的xxxxx就是 baidu_wenxin_model
"baidu_wenxin_api_key": "", # 服务管理 详情可以看到
"baidu_wenxin_secret_key": "", # 服务管理 详情可以看到
"baidu_wenxin_plugin_suffix":"" # 获取插件配置接口后的后缀
}
```

Expand Down
122 changes: 122 additions & 0 deletions bot/baidu_wenxin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# encoding:utf-8

import requests
import json
from utils.log import logger
from common.session import Session
from common.reply import Reply, ReplyType
from common.context import ContextType

from config import conf


class BaiduWenxinBot:
def __init__(self):
self.model = conf().get("baidu_wenxin_model") or "eb-instant"
self.baidu_wenxin_api_key = conf().get("baidu_wenxin_api_key")
self.baidu_wenxin_secret_key = conf().get("baidu_wenxin_secret_key")
self.name = self.__class__.__name__
self.baidu_wenxin_plugin_suffix = conf().get("baidu_wenxin_plugin_suffix")

def reply(self, context=None):
# acquire reply content
query = context.query
logger.info(f"[{self.name}] Query={query}")
if context.type == ContextType.CREATE_IMAGE:
return self.reply_img(query)
else:
session_id = context.session_id
session = Session.build_session_query(context)
response = self.reply_text(session,query=query)
total_tokens, completion_tokens, reply_content = (
response["total_tokens"],
response["completion_tokens"],
response["content"],
)
logger.debug(
"[{}] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
self.name,
session,
session_id,
reply_content,
completion_tokens,
)
)

if total_tokens > 0:
Session.save_session(
response["content"], session_id, response["total_tokens"]
)
return Reply(ReplyType.TEXT, response["content"])

def reply_img(self, query) -> Reply:
None
# ok, image_url = self.create_img(query, 0)
# if ok:
# return Reply(ReplyType.IMAGE, image_url)
# else:
# logger.error(f"[{self.name}] Create image failed: {e}")
# return Reply(ReplyType.TEXT, "Image created failed")

def reply_text(self, session,query):
try:
logger.info("[{}] model={}".format(self.name,self.model))
access_token = self.get_access_token()
if access_token == "None":
logger.warn(
"[{self.name}] access token 获取失败"
)
return {
"total_tokens": 0,
"completion_tokens": 0,
"content": 0,
}
url = (
"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/"
+ self.model
+ "?access_token="
+ access_token
)
payload = {"messages": session}

if self.baidu_wenxin_plugin_suffix is not None:
self.url = (
"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/plugin/"
+ self.baidu_wenxin_plugin_suffix+ '/'
+ "?access_token="
+ access_token)
payload = {"query":query}

headers = {"Content-Type": "application/json"}

response = requests.request(
"POST", url, headers=headers, data=json.dumps(payload)
)
response_text = json.loads(response.text)
logger.info(f"[{self.name}] response text={response_text}")
res_content = response_text["result"]
total_tokens = response_text["usage"]["total_tokens"]
completion_tokens = response_text["usage"]["completion_tokens"]
logger.info("[{}] reply={}".format(self.name,res_content))
return {
"total_tokens": total_tokens,
"completion_tokens": completion_tokens,
"content": res_content,
}
except Exception as e:
logger.warn("[{}] Exception: {}".format(self.name,e))
result = {"completion_tokens": 0, "content": "出错了: {}".format(e)}
return result

def get_access_token(self):
"""
使用 AK,SK 生成鉴权签名(Access Token)
:return: access_token,或是None(如果错误)
"""
url = "https://aip.baidubce.com/oauth/2.0/token"
params = {
"grant_type": "client_credentials",
"client_id": self.baidu_wenxin_api_key,
"client_secret": self.baidu_wenxin_secret_key,
}
return str(requests.post(url, params=params).json().get("access_token"))
4 changes: 4 additions & 0 deletions bot/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@
class Bot:
def __init__(self):
use_azure_chatgpt = conf().get("use_azure_chatgpt", False)
use_wxyx = conf().get("use_wxyx", False)
model = conf().get("model", "gpt-3.5-turbo")
if use_azure_chatgpt:
from bot.azure_chatgpt import AzureChatGPTBot

self.bot = AzureChatGPTBot()
elif use_wxyx:
from bot.baidu_wenxin import BaiduWenxinBot

self.bot = BaiduWenxinBot()
elif model in litellm.open_ai_chat_completion_models:
from bot.chatgpt import ChatGPTBot

Expand Down
14 changes: 10 additions & 4 deletions common/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class Session(object):
all_sessions = ExpiredDict(conf().get("session_expired_duration") or 3600)

@staticmethod
def build_session_query(context: Context):
def build_session_query(context: Context,cls=None):
"""
build query with conversation history
e.g. [
Expand All @@ -21,10 +21,16 @@ def build_session_query(context: Context):
:return: query content with conversaction
"""
session = Session.all_sessions.get(context.session_id, [])
use_wxyx = conf().get("use_wxyx", False)

if len(session) == 0:
system_item = {"role": "system", "content": context.system_prompt}
session.append(system_item)
Session.all_sessions[context.session_id] = session

if use_wxyx:
Session.all_sessions[context.session_id] = session
else:
system_item = {"role": "system", "content": context.system_prompt}
session.append(system_item)
Session.all_sessions[context.session_id] = session
user_item = {"role": "user", "content": context.query}
session.append(user_item)
return session
Expand Down
7 changes: 6 additions & 1 deletion config.template.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,10 @@
"query_key_command": "#query key",
"recent_days": 5,
"plugins": [{ "name": "tiktok", "command": "#tiktok" }],
"openai_sensitive_id": ""
"openai_sensitive_id": "",
"use_wxyx": false,
"baidu_wenxin_model": "eb-instant",
"baidu_wenxin_api_key": "",
"baidu_wenxin_secret_key": "",
"baidu_wenxin_plugin_suffix":""
}