-
Notifications
You must be signed in to change notification settings - Fork 3
/
rag_fusion.py
263 lines (228 loc) · 9.37 KB
/
rag_fusion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
import threading
from cohere.client import Chat, Client, Reranking
from config import defaults
from langchain.llms import Cohere
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.schema.document import Document
from langchain.schema.output_parser import StrOutputParser
from langchain.vectorstores import Weaviate
from operator import itemgetter
from streamlit.runtime.scriptrunner import add_script_run_ctx, ScriptRunContext
from streamlit.runtime.scriptrunner.script_run_context import get_script_run_ctx
# RAG Fusion logics
# Step 1: Generate query variations
def generate_variations(query: str, variation_count: int, llm: Cohere, example_questions: bool) -> list[str]:
# Step 1: Generate query variations:
variation_prompt_array = [
SystemMessagePromptTemplate.from_template("""Your task is to generate {variation_count} different search queries that aim to answer the user question from multiple perspectives.
The user questions are focused on ThruThink budgeting analysis and projection web application usage, or a wide range of budgeting and accounting topics, including EBITDA, cash flow balance, inventory management, and more.
Each query MUST tackle the question from a different viewpoint, we want to get a variety of RELEVANT search results.
Each query MUST be in one line and one line only. You SHOULD NOT include any preamble or explanations, and you SHOULD NOT answer the questions or add anything else, just geenrate the queries."""),
HumanMessagePromptTemplate.from_template("Original question: {query}"),
]
variation_user_example_prompt_template = "Example output:\n"
if example_questions:
for i in range(variation_count):
variation_user_example_prompt_template += f"{i + 1}. Query variation\n"
variation_prompt_array.append(HumanMessagePromptTemplate.from_template(variation_user_example_prompt_template))
variation_prompt_array.append(HumanMessagePromptTemplate.from_template("OUTPUT ({variation_count} numbered queries):"))
variation_prompt = ChatPromptTemplate.from_messages(variation_prompt_array)
variation_chain = (
dict(
query=itemgetter("query"),
variation_count=itemgetter("variation_count")
)
| variation_prompt
| llm
| StrOutputParser()
)
query_variations = []
for t in range(defaults["max_retries"]):
query_variations = variation_chain.invoke(dict(query=query, variation_count=variation_count))
# print(f"{t}.: {query_variations}")
if query_variations.count(".") >= variation_count and query_variations.count("\n") >= variation_count - 1:
break
return query_variations
def extract_query_variations(query: str, query_variations: list[str], variation_count: int) -> list[str]:
queries = [query]
if query_variations.count(".") >= variation_count:
for query_variation in query_variations.split("\n")[:variation_count]:
dot_index = query_variation.index(".") if "." in query_variation else -1
q = query_variation[dot_index + 1:].strip()
if q not in queries:
queries.append(q)
return queries
def retrieve_documents_for_query_variation_func(ctx: ScriptRunContext, query: str, document_sets: list, vectorstore: Weaviate, document_k: int):
add_script_run_ctx(ctx) # register context on thread func
docs = vectorstore.similarity_search_by_text(query, k=document_k)
document_sets.append(docs)
# Step 2: Retrieve documents for each query variation
def retrieve_documents_for_query_variations(queries: list[str], vectorstore: Weaviate, document_k: int) -> list[list[Document]]:
ctx = get_script_run_ctx() # create a context
thread_list = []
document_sets = []
for q in queries:
# pass context to thread
t = threading.Thread(target=retrieve_documents_for_query_variation_func, args=(ctx, q, document_sets, vectorstore, document_k))
t.start()
thread_list.append(t)
for t in thread_list:
t.join()
print(len(document_sets))
return document_sets
# Step 3: Rerank the document sets with reciprocal rank fusion
def rerank_and_fuse_documents(document_sets: list[list[Document]], rerank_k: int) -> list[tuple[Document, float]]:
fused_scores = dict()
doc_map = dict()
for doc_set in document_sets:
for rank, doc in enumerate(doc_set):
title = doc.metadata["title"]
if title not in doc_map:
doc_map[title] = doc
if title not in fused_scores:
fused_scores[title] = 0
fused_scores[title] += 1 / (rank + rerank_k)
# reranked documents
return [
(doc_map[title], score)
for title, score in sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)
]
# Step 4: Cohere Rerank
def cohere_reranking(
query: str,
reranked_results: list[tuple[Document, float]],
top_k_augment_doc: int,
co: Client,
) -> Reranking:
documents_to_cohere_rank = []
for rrr in reranked_results:
documents_to_cohere_rank.append(rrr[0].page_content)
return co.rerank(
query=query,
documents=documents_to_cohere_rank,
max_chunks_per_doc=100,
top_n=top_k_augment_doc,
model="rerank-english-v2.0"
)
def document_based_query_func(
ctx: ScriptRunContext,
cohere_fusion_model: str,
temperature: float,
conversation_id: str,
chat_system_prompt: str,
documents: list[dict],
query: str,
results: list[Chat],
co: Client,
):
add_script_run_ctx(ctx) # register context on thread func
results[0] = co.chat(
model=cohere_fusion_model,
prompt_truncation="auto",
temperature=temperature,
citation_quality="accurate",
conversation_id=conversation_id,
documents=documents,
preamble_override=chat_system_prompt,
message=query,
)
def web_connector_query_func(
ctx: ScriptRunContext,
cohere_fusion_model: str,
temperature: float,
conversation_id: str,
chat_system_prompt: str,
rag_query: str,
results: list[Chat],
co: Client,
):
add_script_run_ctx(ctx) # register context on thread func
results[1] = co.chat(
model=cohere_fusion_model,
prompt_truncation="auto",
temperature=temperature,
connectors=[dict(id="web-search")],
citation_quality="accurate",
conversation_id=conversation_id,
preamble_override=chat_system_prompt,
message=rag_query,
)
# Step 5: Prepare and executing final RAG calls
# (a document based and a web connector based - also augmented)
def final_rag_operations(
query: str,
reranked_results: list[tuple[Document, float]],
reranking: Reranking,
cohere_fusion_model: str,
temperature: float,
conversation_id: str,
co: Client,
) -> tuple[Chat, Chat]:
# Step 6: Prepare prompt augmentation for RAG
context = ""
documents = []
for index, cohere_rank in enumerate(reranking):
if context:
context += "\n"
rrr = reranked_results[cohere_rank.index]
context_content = rrr[0].page_content
context += f"{index + 1}. context: `{context_content}`"
documents.append(dict(
id=rrr[0].metadata["slug"],
title=rrr[0].metadata["title"],
category=rrr[0].metadata["category"],
snippet=rrr[0].page_content,
))
# Step 7: Final augmented RAG calls
chat_system_prompt = """You are an assistant specialized in ThruThink budgeting analysis and projection web application usage.
You are also knowledgeable in a wide range of budgeting and accounting topics, including EBITDA, cash flow balance, inventory management, and more.
While you strive to provide accurate information and assistance, please keep in mind that you are not a licensed investment advisor, financial advisor, or tax advisor.
Therefore, you cannot provide personalized investment advice, financial planning, or tax guidance.
You are here to assist with ThruThink-related inquiries, or offer general information, answer questions to the best of your knowledge.
When provided, factor in any pieces of retrieved context to answer the question. Also factor in any
If you don't know the answer, just say that "I don't know", don't try to make up an answer."""
rag_query = f"""Use the following pieces of retrieved context to answer the question.
---
Contexts: {context}
---
Question: {query}
Answer:
"""
ctx = get_script_run_ctx() # create a context
results = [None, None]
document_based_query_thread = threading.Thread(
target=document_based_query_func,
args=(
ctx,
cohere_fusion_model,
temperature,
conversation_id,
chat_system_prompt,
documents,
query,
results,
co
)
)
web_connector_query_thread = threading.Thread(
target=web_connector_query_func,
args=(
ctx,
cohere_fusion_model,
temperature,
conversation_id,
chat_system_prompt,
rag_query,
results,
co
)
)
document_based_query_thread.start()
web_connector_query_thread.start()
document_based_query_thread.join()
web_connector_query_thread.join()
return results