-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
101 lines (85 loc) · 3.86 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
__import__('pysqlite3')
import sys
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
from langchain.vectorstores import Chroma
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
import os
from LLM import InternLM_LLM
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
import gradio as gr
os.system("python download_model.py")
os.system("python create_db.py")
# 定义新闻链
class NewsQaChain():
def __init__(self):
# 构造函数,加载检索问答链
self.chain = self.load_chain()
def qa_chain_self_answer(self, question: str, chat_history: list = []):
"""
调用问答链进行回答
"""
if question == None or len(question) < 1:
return "", chat_history
try:
chat_history.append(
(question, self.chain({"query": question})["result"]))
# 将问答结果直接附加到问答历史中,Gradio 会将其展示出来
return "", chat_history
except Exception as e:
return e, chat_history
def load_chain(self):
# 加载向量数据库
vectordb = Chroma(
persist_directory = '/home/xlab-app-center/database/vector_db/news_qa',
embedding_function = HuggingFaceEmbeddings(model_name="/home/xlab-app-center/model/sentence-transformer")
)
# 加载自定义LLM
llm = InternLM_LLM(model_path = "/home/xlab-app-center/model/Shanghai_AI_Laboratory/internlm-chat-7b")
# 定义提示模板
template = """使用以下上下文来回答最后的问题。如果你不知道答案,就说你不知道,不要试图编造答案。
尽量使答案简明扼要。总是在回答的最后说“谢谢你的提问!”。
{context}
问题: {question}
有用的回答:"""
QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context","question"],template=template)
# 创建问答链
qa_chain = RetrievalQA.from_chain_type(llm,
retriever=vectordb.as_retriever(),
return_source_documents=True,
chain_type_kwargs={"prompt":QA_CHAIN_PROMPT})
return qa_chain
# 创建新闻链
news_qa_chain = NewsQaChain()
# 创建一个 Web 界面
block = gr.Blocks()
with block as demo:
with gr.Row(equal_height=True):
with gr.Column(scale=15):
# 展示的页面标题
gr.Markdown("""<h1><center>News_QA</center></h1>
<center>新闻问答小助手</center>
""")
with gr.Row():
with gr.Column(scale=4):
# 创建一个聊天机器人对象
chatbot = gr.Chatbot(height=450, show_copy_button=True)
# 创建一个文本框组件,用于输入 prompt。
msg = gr.Textbox(label="Prompt/问题")
with gr.Row():
# 创建提交按钮。
db_wo_his_btn = gr.Button("提问")
with gr.Row():
# 创建一个清除按钮,用于清除聊天机器人组件的内容。
clear = gr.ClearButton(
components=[chatbot], value="清空")
# 设置按钮的点击事件。当点击时,调用上面定义的 qa_chain_self_answer 函数,并传入用户的消息和聊天历史记录,然后更新文本框和聊天机器人组件。
db_wo_his_btn.click(news_qa_chain.qa_chain_self_answer, inputs=[
msg, chatbot], outputs=[msg, chatbot])
gr.Markdown("""提醒:<br>
1. 初始化数据库时间可能较长,请耐心等待。
2. 使用中如果出现异常,将会在文本输入框进行展示,请不要惊慌。 <br>
""")
gr.close_all()
# 直接启动
demo.launch()