-
Notifications
You must be signed in to change notification settings - Fork 1
/
index_wiki.py
61 lines (46 loc) · 1.99 KB
/
index_wiki.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
import sys
sys.path.insert(0, "./ColBERT/")
import argparse
from colbert.infra import Run, RunConfig, ColBERTConfig
from colbert.data import Queries, Collection
from colbert import Indexer, Searcher
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--nbits", type=int)
parser.add_argument("--doc_maxlen", type=int)
parser.add_argument("--checkpoint", type=str)
parser.add_argument("--split", type=str)
parser.add_argument("--experiment_name", type=str)
parser.add_argument("--index_name", type=str)
parser.add_argument("--collection", type=str)
parser.add_argument("--nranks", type=int, help="Number of GPUs to use for indexing")
args = parser.parse_args()
print(args)
collection = Collection(path=args.collection)
f"Loaded {len(collection):,} passages"
print("example passage from collection: ", collection[0])
print()
with Run().context(
RunConfig(nranks=args.nranks, experiment=args.experiment_name)
): # nranks specifies the number of GPUs to use.
config = ColBERTConfig(doc_maxlen=args.doc_maxlen, nbits=args.nbits)
indexer = Indexer(checkpoint=args.checkpoint, config=config)
indexer.index(name=args.index_name, collection=collection, overwrite=True)
print("index location: ", indexer.get_index())
with Run().context(RunConfig(experiment=args.experiment_name, index_root="")):
searcher = Searcher(
index=os.path.join(
f"experiments/{args.experiment_name}/indexes", args.index_name
),
collection=collection,
)
query = "who is Haruki Murakami"
print(f"#> {query}")
# Find the top-3 passages for this query
results = searcher.search(query, k=3)
# Print out the top-k retrieved passages
for passage_id, passage_rank, passage_score in zip(*results):
print(
f"\t [{passage_rank}] \t\t {passage_score:.1f} \t\t {searcher.collection[passage_id]}"
)