-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrag1.py
176 lines (128 loc) · 5.3 KB
/
rag1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import os
import shutil
import time
import streamlit as st
from PyPDF2 import PdfReader
from pydantic import BaseModel
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.embeddings import OllamaEmbeddings
from langchain_community.chat_models import ChatOllama
from langchain.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
from langchain_community.llms import Replicate
from dotenv import load_dotenv
load_dotenv()
os.environ["GOOGLE_API_KEY"] = "AIzaSyCr5lA710ZoQxL4WhXYPc740Jg8wfumcmQ"
os.environ["COHERE_API_KEY"] = "3ml84htGhfTvj2TffmAB22xNpj6xN6ZLpCkiXNje"
os.environ["HUGGINGFACE_API_KEY"] = "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"
# CONSTANTS
DB_PATH = "db/"
MAX_RETIRES = 3
RETRY_DELAY = 1
server_thread = None
def get_embeddings():
google_api_key = os.environ.get("GOOGLE_API_KEY")
return GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=google_api_key)
def get_ollama_embeddings():
return OllamaEmbeddings(model="nomic-embed-text")
def get_llm():
return ChatGoogleGenerativeAI(model="gemini-pro")
def get_ollama_llm():
return ChatOllama(model="llama3")
def get_replicate_llm():
return Replicate(model="meta/meta-llama-3-70b-instruct")
def get_pdf_text(pdf_docs):
text = ""
for pdf in pdf_docs:
pdf_reader = PdfReader(pdf)
for page in pdf_reader.pages:
text += page.extract_text()
return text
def get_text_chunks(text):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=64)
chunks = text_splitter.split_text(text)
return chunks
def get_vector_store(text_chunks, path: str):
embeddings = get_embeddings()
vector_store = Chroma.from_texts(text_chunks, embedding=embeddings, persist_directory=path)
return vector_store
def get_conversational_chain(retriever):
llm = get_llm()
prompt_template = """You are a helpful assistant providing detailed, accurate, and informative responses. Your
task is to assist users by providing relevant information retrieved from a set of documents and generating
coherent and contextually appropriate responses based on the retrieved information.
CONTEXT:
{context}
QUESTION: {question}
YOUR ANSWER:"""
prompt_template = ChatPromptTemplate.from_template(prompt_template)
# Create chain
chain = (
{"context": retriever, "question": RunnablePassthrough()}
| prompt_template
| llm
| StrOutputParser()
)
return chain
def user_input(user_question):
embeddings = get_embeddings()
new_db = Chroma(persist_directory=DB_PATH, embedding_function=embeddings)
retriever = new_db.as_retriever(search_kwargs={"k": 2})
chain = get_conversational_chain(retriever)
response_container = st.empty() # Create an empty container for real-time updates
response = ""
for chunk in chain.stream(user_question):
response += chunk
response_container.markdown(response) # Update the markdown in real-time
def delete_and_recreate_db_directory():
try:
shutil.rmtree(DB_PATH)
print(f"Contents of '{DB_PATH}' successfully cleared.")
except OSError as e:
print(f"Error: {DB_PATH} : {e.strerror}")
return False
try:
os.makedirs(DB_PATH)
print(f"Empty directory '{DB_PATH}' successfully created.")
return True
except OSError as e:
print(f"Error: {DB_PATH} : {e.strerror}")
return False
def retries():
# ! Logic to delete the db and create a new one with fresh embeddings
for attempt in range(MAX_RETIRES):
if delete_and_recreate_db_directory():
break
print(f"Retry {attempt + 1} in {RETRY_DELAY} seconds...")
time.sleep(RETRY_DELAY)
else:
print("Max retries exceeded. Could not delete or recreate directory.")
class QueryRequest(BaseModel):
query: str
def main():
st.set_page_config("Chat Docx")
st.header("Just a normal RAG chatbot")
user_question = st.text_input("Ask a Question from the PDF Files")
if user_question:
user_input(user_question)
with st.sidebar:
st.title("Menu:")
pdf_docs = st.file_uploader("Upload your PDF Files and Click on the Submit & Process Button",
accept_multiple_files=True)
submit_button = st.button("Submit & Process")
if submit_button:
if len(pdf_docs) == 0:
st.warning("Upload a document", icon='⚠️')
else:
with st.spinner("Processing..."):
retries()
raw_text = get_pdf_text(pdf_docs)
text_chunks = get_text_chunks(raw_text)
get_vector_store(text_chunks, path=DB_PATH)
st.success("Vector store created and data processed.")
if __name__ == "__main__":
main()