-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
173 lines (136 loc) · 6.76 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# from typing import List, Union
# from code.vector_database import VectorDatabase
# import os
# from code.splitters import RecursiveCharacterTextSplitter
# from code.loaders import TextFileLoader
# import requests
# import datetime
# # app = FastAPI()
# # class Item(BaseModel):
# # text: Union[dict, List[dict]] = Field(..., example={"query":True, "text": "What is the capital of India."})
# # class Documents(BaseModel):
# # documents: Union[dict, List[dict]] = Field(..., example={"content": "This is a sample document"})
# # class Query(BaseModel):
# # query_vector: List[float] = Field(..., example=[0.1, 0.2, 0.3])
# # model = EmbeddingModel(model_dir="path_to_your_model")
# # @app.post("/bulk_insert")
# # async def bulk_insert_documents(documents: Documents):
# # response = db.bulk_insert(documents=documents.documents)
# # if response is True:
# # return {"status": "success"}
# # else:
# # return {"status": "failure", "error": response}
# # @app.post("/similarity_search")
# # async def similarity_search(query: Query):
# # results = db.similarity_search(query.query_vector)
# # return {"results": results}
# # if __name__ == "__main__":
# # loader = TextFileLoader(path=r"D:\Files\LocalLLM\demo_documents\demo.txt", is_folder=False)
# # docs = loader.load()
# # print(docs[0]["content"])
# # splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=30, length_function=len)
# # splits = splitter.split_documents(docs)
# # for idx, split in enumerate(splits, start=1):
# # print(f"Split {idx}:\n{split}\n\n-------------------------\n\n")
# def read_files(folder_path: str) -> List[dict]:
# files = TextFileLoader(path=folder_path, is_folder=True)
# files = files.load()
# return files
# def split_documents(documents: List[dict]) -> List[dict]:
# splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=30, length_function=len)
# splits = splitter.split_documents(documents)
# return splits
# def get_embeddings(texts: List[dict], queries: List[bool], embedding_model_api) -> List[dict]:
# inputs = [{"query":is_query, "text": text["content"]} for is_query, text in zip(queries, texts)]
# response = requests.post(embedding_model_api, json={"text": inputs})
# embeddings = response.json()
# outputs = []
# for idx in range(len(texts)):
# if not inputs[idx]["query"]:
# outputs.append({"content": texts[idx]["content"], "embedding": embeddings["embeddings"][idx], "metadata": texts[idx]["metadata"]})
# else:
# outputs.append({"content": texts[idx]["content"], "embedding": embeddings["embeddings"][idx]})
# return outputs
# def add_to_database(documents: List[dict], database: VectorDatabase):
# for doc in documents:
# doc["metadata"]["creation_date"] = datetime.datetime.now().isoformat()
# return database.bulk_insert(documents)
# def main():
# embedding_model_api = "https://7ee5-104-199-149-72.ngrok-free.app/embed"
# vector_database_host = "https://localhost:9200/"
# vector_database_index = "test_index"
# vector_database_user = "elastic"
# vector_database_password = "xiayY08ILG7iiYuf3Xx5"
# database = VectorDatabase(host=vector_database_host, index_name=vector_database_index, user_name=vector_database_user, password=vector_database_password)
# # folder_path = "./demo_documents"
# # files = read_files(folder_path)
# # splits = split_documents(files)
# # embeddings = get_embeddings(splits, [False]*len(splits), embedding_model_api)
# # response = add_to_database(embeddings, database)
# # print(response)
# while True:
# question = input("Enter a query: ")
# query_embedding = get_embeddings(texts=[{"content": question}], queries=[True], embedding_model_api=embedding_model_api)
# results = database.similarity_search(query_embedding[0]["embedding"], top_k=1)
# print(results[0])
# if __name__ == "__main__":
# main()
import code
import warnings
warnings.filterwarnings("ignore")
# CHAT_TEMPLATE = """<s> [INST] \
# You are an AI assistant. You will be given some contexts with a chat history and a question. \
# Your task is to answer the question based on the context and the chat history. You are not supposed to answer \
# the question out of context or chat history. If you are unable to answer the question, you can say \
# "Sorry, I don't know the answer. Please rephrase the question." \
# \n\n Context: {context}\nChat History: {chat_history}\nQuestion: \n {question} [/INST]\
# \n Answer: </s>\
# """
CHAT_TEMPLATE = """<s> [INST] \
You are an AI assistant. You will be given some contexts and a question. \
Your task is to answer the question based on the context only. You are not supposed to answer \
the question out of context. If you are unable to answer the question, you can say \
"Sorry, I don't know the answer. Please rephrase the question." \
\n\n Context: {context}\nQuestion: \n {question} [/INST]\
\n Answer: </s>\
"""
chat_template = code.templates.BaseTemplate(CHAT_TEMPLATE, ["context", "question"])
QUESTION_GENERATOR = """<s> [INST] \
You are given a chat history and a question. Your task is to generate a stand-alone \
question based on the chat history and the question. \
\n\nChat History: {chat_history}\nQuestion: {question} [/INST]\
\nStand-alone Question: </s>\
"""
question_generator = code.templates.BaseTemplate(QUESTION_GENERATOR, ["chat_history", "question"])
query_template = code.templates.BaseTemplate("Instruct: Given a query, retrieve relevant documents that answer the query.\nQuery: {query}", ["query"])
db = code.vector_database.VectorDatabase(
host="https://54.174.178.103:4100",
index_name="mcube_genai_v1",
user_name="elastic",
password="zDP1wbqb3LBcxh1D=KGt"
)
embedding_model = code.embeddings.MistralEmbeddings("http://54.174.178.103:4000/create_embeddings")
retriever = code.Retriever(
vector_database=db,
embedding_model=embedding_model,
query_template=query_template
)
llm = code.llms.Mixtral(
model_url="http://54.174.178.103:4000/generate_text"
)
chat_agent = code.agents.RAGChatAgent(
retriever=retriever,
llm_model=llm,
question_generator=question_generator,
chat_template=chat_template
)
if __name__ == "__main__":
while True:
query = input("Enter a query: ")
response = chat_agent.generate(query)
print(response, end="\n----------------------------------\n")
# loader = code.loaders.TextFileLoader("D:\Files\LocalLLM\demo_documents", is_folder=True)
# docs = loader.load()
# spliter = code.splitters.RecursiveCharacterTextSplitter(1500, 500, len)
# chunks = spliter.split_documents(docs)
# db.add_documents(chunks, embedding_model)