-
Notifications
You must be signed in to change notification settings - Fork 30
/
chat.py
86 lines (62 loc) · 2.55 KB
/
chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# from https://docs.streamlit.io/develop/tutorials/llms/build-conversational-apps
import streamlit as st
from langchain_upstage import ChatUpstage as Chat
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import AIMessage, HumanMessage
from solar_util import initialize_solar_llm
from solar_util import prompt_engineering
llm = initialize_solar_llm()
st.set_page_config(page_title="Chat")
st.title("SolarLLM")
chat_with_history_prompt = ChatPromptTemplate.from_messages(
[
("human", """You are Solar, a smart chatbot by Upstage, loved by many people.
Be smart, cheerful, and fun. Give engaging answers and avoid inappropriate language.
reply in the same language of the user query.
Solar is now being connected with a human.
Please put <END> in the end of your answer."""),
MessagesPlaceholder("chat_history"),
("human", "{user_query}"),
]
)
def get_response(user_query, chat_history):
chain = chat_with_history_prompt | llm | StrOutputParser()
response = ""
end_token = ""
for chunk in chain.stream(
{
"chat_history": chat_history,
"user_query": user_query,
}
):
print(chunk)
response += chunk
end_token += chunk
if "<END>" in end_token:
response = response.split("<END>")[0]
break
# Keep only the last 5 characters to check for <END>
end_token = end_token[-5:]
yield chunk
yield response
if "messages" not in st.session_state:
st.session_state.messages = []
for message in st.session_state.messages:
role = "AI" if isinstance(message, AIMessage) else "Human"
with st.chat_message(role):
st.markdown(message.content)
enhance_prompt = st.toggle("Enhance prompt", True)
if prompt := st.chat_input("What is up?"):
if enhance_prompt:
with st.status("Prompt engineering..."):
new_prompt = prompt_engineering(prompt, st.session_state.messages)
st.write(new_prompt)
if 'enhanced_prompt' in new_prompt:
prompt = new_prompt['enhanced_prompt']
st.session_state.messages.append(HumanMessage(content=prompt))
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
response = st.write_stream(get_response(prompt, st.session_state.messages))
st.session_state.messages.append(AIMessage(content=response))