Skip to content

Commit

Permalink
hotfix: do not break old indices
Browse files Browse the repository at this point in the history
  • Loading branch information
bclavie committed Jan 24, 2024
1 parent 02cbdf2 commit 610a20b
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 16 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "RAGatouille"
version = "0.0.5a"
version = "0.0.5a1"
description = "Library to facilitate the use of state-of-the-art retrieval models in common RAG contexts."
authors = ["Benjamin Clavie <[email protected]>"]
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion ragatouille/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.0.5a"
__version__ = "0.0.5a1"
from .RAGPretrainedModel import RAGPretrainedModel
from .RAGTrainer import RAGTrainer

Expand Down
36 changes: 22 additions & 14 deletions ragatouille/models/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,22 +53,30 @@ def __init__(
self.collection = self._get_collection_from_file(
str(pretrained_model_name_or_path / "collection.json")
)
self.pid_docid_map = self._get_collection_from_file(
str(pretrained_model_name_or_path / "pid_docid_map.json")
)
# convert all keys to int when loading from file because saving converts to str
self.pid_docid_map = {
int(key): value for key, value in self.pid_docid_map.items()
}
self.docid_pid_map = defaultdict(list)
for pid, docid in self.pid_docid_map.items():
self.docid_pid_map[docid].append(pid)
if os.path.exists(
str(pretrained_model_name_or_path / "docid_metadata_map.json")
):
self.docid_metadata_map = self._get_collection_from_file(
try:
self.pid_docid_map = self._get_collection_from_file(
str(pretrained_model_name_or_path / "pid_docid_map.json")
)
# convert all keys to int when loading from file because saving converts to str
self.pid_docid_map = {
int(key): value for key, value in self.pid_docid_map.items()
}
self.docid_pid_map = defaultdict(list)
for pid, docid in self.pid_docid_map.items():
self.docid_pid_map[docid].append(pid)
if os.path.exists(
str(pretrained_model_name_or_path / "docid_metadata_map.json")
):
self.docid_metadata_map = self._get_collection_from_file(
str(pretrained_model_name_or_path / "docid_metadata_map.json")
)
except Exception:
print(
"WARNING: Could not load pid_docid_map or docid_metadata_map from index!",
"This is likely because you are loading an old index.",
)
self.pid_docid_map = None
self.docid_metadata_map = None
# TODO: Modify root assignment when loading from HF

else:
Expand Down

0 comments on commit 610a20b

Please sign in to comment.