diff --git a/README.md b/README.md index dad9f20..a04185b 100644 --- a/README.md +++ b/README.md @@ -149,23 +149,13 @@ results = RAG.search(query) ``` This is the preferred way of doing things, since every index saves the full configuration of the model used to create it, and you can easily load it back up. -However, if you'd rather do it yourself or want to use a slightly different configuration, you can spin-up an instance of `RAGPretrainedModel` and specify the index you want to use: - -```python -from ragatouille import RAGPretrainedModel - -query = "What manga did Hayao Miyazaki write?" -RAG = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0") -results = RAG.search(query, index_name="my_index") -``` `RAG.search` is a flexible method! You can set the `k` value to however many results you want (it defaults to `10`), and you can also use it to search for multiple queries at once: ```python RAG.search(["What manga did Hayao Miyazaki write?", "Who are the founders of Ghibli?" -"Who is the director of Spirited Away?"], -index_name="my_index") +"Who is the director of Spirited Away?"],) ``` `RAG.search` returns results in the form of a list of dictionaries, or a list of list of dictionaries if you used multiple queries: diff --git a/poetry.lock b/poetry.lock index 647f967..e5af88d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1419,13 +1419,13 @@ requests = ">=2,<3" [[package]] name = "llama-index" -version = "0.9.31" +version = "0.9.32" description = "Interface between LLMs and your data" optional = false python-versions = ">=3.8.1,<4.0" files = [ - {file = "llama_index-0.9.31-py3-none-any.whl", hash = "sha256:b5a2394ac1463a687df7d37233abfc69924c0441b3984423e6f1653bcb1a3a59"}, - {file = "llama_index-0.9.31.tar.gz", hash = "sha256:1a3018ab9aa05f7ef217c9dc6d95117cd9146f8211023e866da0b113c5b75b9f"}, + {file = "llama_index-0.9.32-py3-none-any.whl", hash = "sha256:ca3122d78169fe700a47a77f0a0d6b7ddeb2532f7f44c66899cbf05818bb59b0"}, + {file = "llama_index-0.9.32.tar.gz", hash = "sha256:8ba179259ea6589f9e085a0acaf27236858bb7307825e203c192ffeeb8452574"}, ] [package.dependencies] diff --git a/pyproject.toml b/pyproject.toml index 7b89ce7..a846223 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "RAGatouille" -version = "0.0.4b2" +version = "0.0.4b3" description = "Library to facilitate the use of state-of-the-art retrieval models in common RAG contexts." authors = ["Benjamin Clavie "] readme = "README.md" diff --git a/ragatouille/__init__.py b/ragatouille/__init__.py index d226666..41920e6 100644 --- a/ragatouille/__init__.py +++ b/ragatouille/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.0.4b2" +__version__ = "0.0.4b3" from .RAGPretrainedModel import RAGPretrainedModel from .RAGTrainer import RAGTrainer diff --git a/ragatouille/models/colbert.py b/ragatouille/models/colbert.py index c4cc492..95d0dba 100644 --- a/ragatouille/models/colbert.py +++ b/ragatouille/models/colbert.py @@ -30,6 +30,8 @@ def __init__( if n_gpu == -1: n_gpu = 1 if torch.cuda.device_count() == 0 else torch.cuda.device_count() + self.loaded_from_index = load_from_index + if load_from_index: ckpt_config = ColBERTConfig.load_from_index( str(pretrained_model_name_or_path) @@ -104,12 +106,29 @@ def add_to_index( "add_to_index support will be more thorough in future versions", ) + if self.loaded_from_index: + index_root = self.config.root + else: + index_root = str( + Path(self.config.root) / Path(self.config.experiment) / "indexes" + ) + if not self.collection: + self.collection = self._get_collection_from_file( + str( + Path(self.config.root) + / Path(self.config.experiment) + / "indexes" + / self.index_name + / "collection.json" + ) + ) + searcher = Searcher( checkpoint=self.checkpoint, config=None, collection=self.collection, index=self.index_name, - index_root=self.config.root, + index_root=index_root, verbose=self.verbose, ) new_documents = list(set(new_documents))