-
Notifications
You must be signed in to change notification settings - Fork 1
/
streamlit_app.py
124 lines (107 loc) · 4.06 KB
/
streamlit_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from typing import Dict, Any
import asyncio
# Create a new event loop
loop = asyncio.new_event_loop()
# Set the event loop as the current event loop
asyncio.set_event_loop(loop)
from llama_index import (
VectorStoreIndex,
ServiceContext,
download_loader,
)
from llama_index.llama_pack.base import BaseLlamaPack
from llama_index.llms import OpenAI
import streamlit as st
from streamlit_pills import pills
st.set_page_config(
page_title=f"Chat with Snowflake's Wikipedia page, powered by LlamaIndex",
page_icon="🦙",
layout="centered",
initial_sidebar_state="auto",
menu_items=None,
)
if "messages" not in st.session_state: # Initialize the chat messages history
st.session_state["messages"] = [
{"role": "assistant", "content": "Ask me a question about Snowflake!"}
]
st.title(
f"Chat with Snowflake's Wikipedia page, powered by LlamaIndex 💬🦙"
)
st.info(
"This example is powered by the **[Llama Hub Wikipedia Loader](https://llamahub.ai/l/wikipedia)**. Use any of [Llama Hub's many loaders](https://llamahub.ai/) to retrieve and chat with your data via a Streamlit app.",
icon="ℹ️",
)
def add_to_message_history(role, content):
message = {"role": role, "content": str(content)}
st.session_state["messages"].append(
message
) # Add response to message history
@st.cache_resource
def load_index_data():
WikipediaReader = download_loader(
"WikipediaReader", custom_path="local_dir"
)
loader = WikipediaReader()
docs = loader.load_data(pages=["Snowflake Inc."])
service_context = ServiceContext.from_defaults(
llm=OpenAI(model="gpt-3.5-turbo", temperature=0.5)
)
index = VectorStoreIndex.from_documents(
docs, service_context=service_context
)
return index
index = load_index_data()
selected = pills(
"Choose a question to get started or write your own below.",
[
"What is Snowflake?",
"What company did Snowflake announce they would acquire in October 2023?",
"What company did Snowflake acquire in March 2022?",
"When did Snowflake IPO?",
],
clearable=True,
index=None,
)
if "chat_engine" not in st.session_state: # Initialize the query engine
st.session_state["chat_engine"] = index.as_chat_engine(
chat_mode="context", verbose=True
)
for message in st.session_state["messages"]: # Display the prior chat messages
with st.chat_message(message["role"]):
st.write(message["content"])
# To avoid duplicated display of answered pill questions each rerun
if selected and selected not in st.session_state.get(
"displayed_pill_questions", set()
):
st.session_state.setdefault("displayed_pill_questions", set()).add(selected)
with st.chat_message("user"):
st.write(selected)
with st.chat_message("assistant"):
response = st.session_state["chat_engine"].stream_chat(selected)
response_str = ""
response_container = st.empty()
for token in response.response_gen:
response_str += token
response_container.write(response_str)
add_to_message_history("user", selected)
add_to_message_history("assistant", response)
if prompt := st.chat_input(
"Your question"
): # Prompt for user input and save to chat history
add_to_message_history("user", prompt)
# Display the new question immediately after it is entered
with st.chat_message("user"):
st.write(prompt)
# If last message is not from assistant, generate a new response
# if st.session_state["messages"][-1]["role"] != "assistant":
with st.chat_message("assistant"):
response = st.session_state["chat_engine"].stream_chat(prompt)
response_str = ""
response_container = st.empty()
for token in response.response_gen:
response_str += token
response_container.write(response_str)
# st.write(response.response)
add_to_message_history("assistant", response.response)
# Save the state of the generator
st.session_state["response_gen"] = response.response_gen