Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: restore compatibility with previous loading method but remove from README #59

Merged
merged 3 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,23 +149,13 @@ results = RAG.search(query)
```

This is the preferred way of doing things, since every index saves the full configuration of the model used to create it, and you can easily load it back up.
However, if you'd rather do it yourself or want to use a slightly different configuration, you can spin-up an instance of `RAGPretrainedModel` and specify the index you want to use:

```python
from ragatouille import RAGPretrainedModel

query = "What manga did Hayao Miyazaki write?"
RAG = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
results = RAG.search(query, index_name="my_index")
```

`RAG.search` is a flexible method! You can set the `k` value to however many results you want (it defaults to `10`), and you can also use it to search for multiple queries at once:

```python
RAG.search(["What manga did Hayao Miyazaki write?",
"Who are the founders of Ghibli?"
"Who is the director of Spirited Away?"],
index_name="my_index")
"Who is the director of Spirited Away?"],)
```

`RAG.search` returns results in the form of a list of dictionaries, or a list of list of dictionaries if you used multiple queries:
Expand Down
6 changes: 3 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "RAGatouille"
version = "0.0.4b2"
version = "0.0.4b3"
description = "Library to facilitate the use of state-of-the-art retrieval models in common RAG contexts."
authors = ["Benjamin Clavie <[email protected]>"]
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion ragatouille/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.0.4b2"
__version__ = "0.0.4b3"
from .RAGPretrainedModel import RAGPretrainedModel
from .RAGTrainer import RAGTrainer

Expand Down
21 changes: 20 additions & 1 deletion ragatouille/models/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def __init__(
if n_gpu == -1:
n_gpu = 1 if torch.cuda.device_count() == 0 else torch.cuda.device_count()

self.loaded_from_index = load_from_index

if load_from_index:
ckpt_config = ColBERTConfig.load_from_index(
str(pretrained_model_name_or_path)
Expand Down Expand Up @@ -104,12 +106,29 @@ def add_to_index(
"add_to_index support will be more thorough in future versions",
)

if self.loaded_from_index:
index_root = self.config.root
else:
index_root = str(
Path(self.config.root) / Path(self.config.experiment) / "indexes"
)
if not self.collection:
self.collection = self._get_collection_from_file(
str(
Path(self.config.root)
/ Path(self.config.experiment)
/ "indexes"
/ self.index_name
/ "collection.json"
)
)

searcher = Searcher(
checkpoint=self.checkpoint,
config=None,
collection=self.collection,
index=self.index_name,
index_root=self.config.root,
index_root=index_root,
verbose=self.verbose,
)
new_documents = list(set(new_documents))
Expand Down
Loading