diff --git a/pyproject.toml b/pyproject.toml index 067a48e..abff95d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "RAGatouille" -version = "0.0.5a" +version = "0.0.5a1" description = "Library to facilitate the use of state-of-the-art retrieval models in common RAG contexts." authors = ["Benjamin Clavie "] readme = "README.md" diff --git a/ragatouille/__init__.py b/ragatouille/__init__.py index 64f621d..6a50cee 100644 --- a/ragatouille/__init__.py +++ b/ragatouille/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.0.5a" +__version__ = "0.0.5a1" from .RAGPretrainedModel import RAGPretrainedModel from .RAGTrainer import RAGTrainer diff --git a/ragatouille/models/colbert.py b/ragatouille/models/colbert.py index 6d099d3..840d34a 100644 --- a/ragatouille/models/colbert.py +++ b/ragatouille/models/colbert.py @@ -53,22 +53,30 @@ def __init__( self.collection = self._get_collection_from_file( str(pretrained_model_name_or_path / "collection.json") ) - self.pid_docid_map = self._get_collection_from_file( - str(pretrained_model_name_or_path / "pid_docid_map.json") - ) - # convert all keys to int when loading from file because saving converts to str - self.pid_docid_map = { - int(key): value for key, value in self.pid_docid_map.items() - } - self.docid_pid_map = defaultdict(list) - for pid, docid in self.pid_docid_map.items(): - self.docid_pid_map[docid].append(pid) - if os.path.exists( - str(pretrained_model_name_or_path / "docid_metadata_map.json") - ): - self.docid_metadata_map = self._get_collection_from_file( + try: + self.pid_docid_map = self._get_collection_from_file( + str(pretrained_model_name_or_path / "pid_docid_map.json") + ) + # convert all keys to int when loading from file because saving converts to str + self.pid_docid_map = { + int(key): value for key, value in self.pid_docid_map.items() + } + self.docid_pid_map = defaultdict(list) + for pid, docid in self.pid_docid_map.items(): + self.docid_pid_map[docid].append(pid) + if os.path.exists( str(pretrained_model_name_or_path / "docid_metadata_map.json") + ): + self.docid_metadata_map = self._get_collection_from_file( + str(pretrained_model_name_or_path / "docid_metadata_map.json") + ) + except Exception: + print( + "WARNING: Could not load pid_docid_map or docid_metadata_map from index!", + "This is likely because you are loading an old index.", ) + self.pid_docid_map = None + self.docid_metadata_map = None # TODO: Modify root assignment when loading from HF else: