-
Notifications
You must be signed in to change notification settings - Fork 0
/
ingest.py
201 lines (157 loc) · 6 KB
/
ingest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
# Description: This script loads the Arch Linux Wiki dump and splits the text into chunks
import os
print(os.getcwd())
import sys
from collections import namedtuple
from typing import Any
import argparse
import yaml
import torch
from tqdm.contrib.concurrent import process_map
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import MWDumpLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
Document = namedtuple("Document", ["page_content", "metadata"])
if not torch.cuda.is_available():
torch.set_num_threads(torch.get_num_threads() * 2)
def parse_args(config: dict, args: list):
"""Parses command line arguments.
Args:
config (dict): items in config.yaml
args (list(str)): user input parameters
Returns:
dict: dictionary of items in config.yaml, modified by user input parameters
"""
parser = argparse.ArgumentParser()
parser.add_argument("--test-embed", dest="test_embed", action="store_true")
args = parser.parse_args(args)
if args.test_embed:
config["mediawikis"] = ["archwiki"]
config["data_dir"] = "./test_data"
config["question"] = (
"What is the the best editor for the terminal in Arch Linux?"
)
return config
def load_config():
"""Loads configuration from config.yaml file.
Returns:
dict: items in config.yaml
"""
try:
with open("./config.yaml", "r", encoding="utf-8") as file:
data = yaml.safe_load(file)
except FileNotFoundError:
print("Error: File config.yaml not found.")
sys.exit(1)
except yaml.YAMLError as err:
print(f"Error reading YAML file: {err}")
sys.exit(1)
return data
def rename_duplicates(documents: [Document]):
"""Rename duplicates in a list of documents.
Args:
documents (list(Document)): input documents via loader.load()
Returns:
list(Document): input documents with modified source metadata
"""
document_counts = {}
for idx, doc in enumerate(documents):
doc_source = doc.metadata["source"]
count = document_counts.get(doc_source, 0) + 1
document_counts[doc_source] = count
documents[idx].metadata["source"] = (
doc_source if count == 1 else f"{doc_source}_{count - 1}"
)
return documents
############################################################################################################
def load_document(wiki: tuple):
"""Loads an xml file of mediawiki pages into document format.
Args:
wiki (str): name of the wiki
Returns:
list(Document): input documents from mediawikis config with modified source metadata
"""
# https://python.langchain.com/docs/integrations/document_loaders/mediawikidump
loader = MWDumpLoader(
encoding="utf-8",
file_path=f"{wiki[0]}/{wiki[1]}_pages_current.xml",
# https://www.mediawiki.org/wiki/Help:Namespaces
namespaces=[0],
skip_redirects=True,
stop_on_error=False,
)
# For each Document provided:
# Modify the source metadata by accounting for duplicates (<name>_n)
# And add the mediawiki title (<name>_n - <wikiname>)
return [
Document(doc.page_content, {"source": doc.metadata["source"] + f" - {wiki[1]}"})
for doc in rename_duplicates(loader.load())
]
############################################################################################################
class CustomTextSplitter(RecursiveCharacterTextSplitter):
"""Creates a custom Character Text Splitter.
Args:
RecursiveCharacterTextSplitter (RecursiveCharacterTextSplitter): Generates chunks based on different separator rules
"""
def __init__(self, **kwargs: Any) -> None:
separators = [r"\w(=){3}\n", r"\w(=){2}\n", r"\n\n", r"\n", r"\s"]
super().__init__(separators=separators, keep_separator=False, **kwargs)
############################################################################################################
def load_documents(config: dict):
"""Load all the documents in the MediaWiki wiki page using multithreading.
Args:
config (dict): items in config.yaml
Returns:
list(Document): input documents from mediawikis config with modified source metadata
"""
documents = sum(
process_map(
load_document,
[(config["source"], wiki) for wiki in config["mediawikis"]],
desc="Loading Documents",
max_workers=torch.get_num_threads(),
),
[],
)
splitter = CustomTextSplitter(
add_start_index=True,
chunk_size=1000,
is_separator_regex=True,
)
documents = sum(
process_map(
splitter.split_documents,
[[doc] for doc in documents],
chunksize=1,
desc="Splitting Documents",
max_workers=torch.get_num_threads(),
),
[],
)
documents = rename_duplicates(documents)
return documents
############################################################################################################
if __name__ == "__main__":
config = load_config()
config = parse_args(config, sys.argv[1:])
documents = load_documents(config)
print(f"Embedding {len(documents)} Documents, this may take a while.")
# https://python.langchain.com/docs/integrations/text_embedding/huggingfacehub
embeddings = HuggingFaceEmbeddings(
cache_folder="./model",
model_name=config["embeddings_model"],
show_progress=True,
)
# https://python.langchain.com/docs/integrations/vectorstores/chroma
vectordb = Chroma.from_documents(
documents=documents,
embedding=embeddings,
persist_directory=config["data_dir"],
)
# save to disk
vectordb.persist()
retriever = vectordb.as_retriever()
# Test code
docs = retriever.get_relevant_documents("how to setup LUKS2?")
print(docs)