-
Notifications
You must be signed in to change notification settings - Fork 84
/
Copy pathcustom_knowledge_chatbot.py
62 lines (49 loc) · 2.01 KB
/
custom_knowledge_chatbot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from pathlib import Path
from typing import List, Tuple
from langchain import PromptTemplate, LLMChain
from langchain.document_loaders import TextLoader
from langchain.embeddings import LlamaCppEmbeddings
from langchain.llms import GPT4All
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores.faiss import FAISS
from pydantic import BaseModel, Field
from langchain.chains import ConversationalRetrievalChain
# Constants
local_path = "./models/gpt4all-converted.bin"
model_path = "./models/ggml-model-q4_0.bin"
text_path = "./docs/state_of_the_union.txt"
index_path = "./full_sotu_index"
# Functions
def initialize_embeddings() -> LlamaCppEmbeddings:
return LlamaCppEmbeddings(model_path=model_path)
def load_documents() -> List:
loader = TextLoader(text_path)
return loader.load()
def split_chunks(sources: List) -> List:
chunks = []
splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=32)
for chunk in splitter.split_documents(sources):
chunks.append(chunk)
return chunks
def generate_index(chunks: List, embeddings: LlamaCppEmbeddings) -> FAISS:
texts = [doc.page_content for doc in chunks]
metadatas = [doc.metadata for doc in chunks]
return FAISS.from_texts(texts, embeddings, metadatas=metadatas)
# Main execution
llm = GPT4All(model=local_path, n_ctx=2048, verbose=True)
embeddings = initialize_embeddings()
# sources = load_documents()
# chunks = split_chunks(sources)
# vectorstore = generate_index(chunks, embeddings)
# vectorstore.save_local("full_sotu_index")
index = FAISS.load_local(index_path, embeddings)
qa = ConversationalRetrievalChain.from_llm(llm, index.as_retriever(), max_tokens_limit=400)
# Chatbot loop
chat_history = []
print("Welcome to the State of the Union chatbot! Type 'exit' to stop.")
while True:
query = input("Please enter your question: ")
if query.lower() == 'exit':
break
result = qa({"question": query, "chat_history": chat_history})
print("Answer:", result['answer'])