-
Notifications
You must be signed in to change notification settings - Fork 1
/
pdf_helper.py
194 lines (153 loc) · 6.61 KB
/
pdf_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import os
import time
import uuid
from dotenv import load_dotenv
from pathlib import Path
from langchain_core.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOllama
from langchain.document_loaders import PyPDFLoader
from langchain_community.document_loaders.llmsherpa import LLMSherpaFileLoader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from langchain_groq import ChatGroq
from config import Config
from langchain.vectorstores import FAISS
load_dotenv()
groq_api_key = os.environ['GROQ_API_KEY']
# This loads the PDF file
def load_pdf_data(file_path):
#Create a PyMuPDFLoader object with file_path
loader = PyPDFLoader(file_path=file_path)
docs = loader.load()
#return the loaded document
return docs
# loader = LLMSherpaFileLoader(
# file_path=file_path,
# new_indent_parser=True,
# apply_ocr=True,
# strategy="sections",
# llmsherpa_api_url="https://readers.llmsherpa.com/api/document/developer/parseDocument?renderFormat=all",
# )
# docs = loader.load()
# return docs
# Responsible for splitting the documents into several chunks
def split_docs(documents, chunk_size=1024, chunk_overlap=40):
# Initialize the RecursiveCharacterTextSplitter with
# chunk_size and chunk_overlap
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
# Split the documents into chunks
chunks = text_splitter.split_documents(documents=documents)
# return the document chunks
return chunks
# function for loading the embedding model
def load_embedding_model(model_name, normalize_embedding=True):
print("Loading embedding model...")
start_time = time.time()
hugging_face_embeddings = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs={'device': Config.HUGGING_FACE_EMBEDDINGS_DEVICE_TYPE}, # here we will run the model with CPU only
encode_kwargs={
'normalize_embeddings': normalize_embedding # keep True to compute cosine similarity
}
)
end_time = time.time()
time_taken = round(end_time - start_time, 2)
print(f"Embedding model load time: {time_taken} seconds.\n")
return hugging_face_embeddings
# Function for creating embeddings using FAISS
def create_embeddings(chunks, embedding_model, storing_path="vectorstore"):
print("Creating embeddings...")
e_start_time = time.time()
# Create the embeddings using FAISS
vectorstore = FAISS.from_documents(chunks, embedding_model)
e_end_time = time.time()
e_time_taken = round(e_end_time - e_start_time, 2)
print(f"Embeddings creation time: {e_time_taken} seconds.\n")
print("Writing vectorstore..")
v_start_time = time.time()
vectorstore.save_local(storing_path)
v_end_time = time.time()
v_time_taken = round(v_end_time - v_start_time, 2)
print(f"Vectorstore write time: {v_time_taken} seconds.\n")
# return the vectorstore
return vectorstore
# Create the chain for Question Answering
def load_qa_chain(retriever, llm, prompt):
print("Loading QA chain...")
start_time = time.time()
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=retriever, # here we are using the vectorstore as a retriever
chain_type="stuff",
return_source_documents=True, # including source documents in output
chain_type_kwargs={'prompt': prompt} # customizing the prompt
)
end_time = time.time()
time_taken = round(end_time - start_time, 2)
print(f"QA chain load time: {time_taken} seconds.\n")
return qa_chain
def get_response(query, chain) -> str:
response = chain({'query': query})
res = response['result']
return res
class PDFHelper:
def __init__(self, ollama_api_base_url: str, model_name: str = Config.MODEL,
embedding_model_name: str = Config.EMBEDDING_MODEL_NAME):
self._ollama_api_base_url = ollama_api_base_url
self._model_name = model_name
self._embedding_model_name = embedding_model_name
def ask(self, pdf_file_path: str, question: str) -> str:
vector_store_directory = os.path.join(str(Path.home()), 'langchain-store', 'vectorstore',
'pdf-doc-helper-store', str(uuid.uuid4()))
os.makedirs(vector_store_directory, exist_ok=True)
print(f"Using vector store: {vector_store_directory}")
llm = ChatGroq(groq_api_key=groq_api_key, model="llama2-70b-4096")
# temperature=0,
# base_url=self._ollama_api_base_url,
# model=self._model_name,
# streaming=True,
# # seed=2,
# top_k=5,
# # A higher value (100) will give more diverse answers, while a lower value (10) will be more conservative.
# top_p=0.3,
# # Higher value (0.95) will lead to more diverse text, while a lower value (0.5) will generate more
# # focused text.
# num_ctx=3072, # Sets the size of the context window used to generate the next token.
# verbose=False
# Load the Embedding Model
embed = load_embedding_model(model_name=self._embedding_model_name)
# load and split the documents
docs = load_pdf_data(file_path=pdf_file_path)
documents = split_docs(documents=docs)
# create vectorstore
vectorstore = create_embeddings(chunks=documents, embedding_model=embed, storing_path=vector_store_directory)
# convert vectorstore to a retriever
retriever = vectorstore.as_retriever()
template = """
### System:
You are an honest assistant.
You will accept PDF files and you will answer the question asked by the user accurately.
Don't infer anything by yourself. Just analyze the data.
If you don't know the answer, just say you don't know. Don't try to make up an answer.
### Context:
{context}
### User:
{question}
### Response:
"""
prompt = PromptTemplate.from_template(template)
# Create the chain
chain = load_qa_chain(retriever, llm, prompt)
start_time = time.time()
response = get_response(question, chain)
end_time = time.time()
time_taken = round(end_time - start_time, 2)
print(f"Response time: {time_taken} seconds.\n")
with open("output.txt", "w") as f:
f.write(response.strip())
return response.strip()