Skip to content

Commit

Permalink
fix: restore compatibility with previous loading method but remove fr…
Browse files Browse the repository at this point in the history
…om README
  • Loading branch information
bclavie committed Jan 16, 2024
1 parent 68e37d3 commit 46842e0
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 12 deletions.
12 changes: 1 addition & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,23 +149,13 @@ results = RAG.search(query)
```

This is the preferred way of doing things, since every index saves the full configuration of the model used to create it, and you can easily load it back up.
However, if you'd rather do it yourself or want to use a slightly different configuration, you can spin-up an instance of `RAGPretrainedModel` and specify the index you want to use:

```python
from ragatouille import RAGPretrainedModel

query = "What manga did Hayao Miyazaki write?"
RAG = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
results = RAG.search(query, index_name="my_index")
```

`RAG.search` is a flexible method! You can set the `k` value to however many results you want (it defaults to `10`), and you can also use it to search for multiple queries at once:

```python
RAG.search(["What manga did Hayao Miyazaki write?",
"Who are the founders of Ghibli?"
"Who is the director of Spirited Away?"],
index_name="my_index")
"Who is the director of Spirited Away?"],)
```

`RAG.search` returns results in the form of a list of dictionaries, or a list of list of dictionaries if you used multiple queries:
Expand Down
24 changes: 23 additions & 1 deletion ragatouille/models/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def __init__(
if n_gpu == -1:
n_gpu = 1 if torch.cuda.device_count() == 0 else torch.cuda.device_count()

self.loaded_from_index = load_from_index

if load_from_index:
ckpt_config = ColBERTConfig.load_from_index(
str(pretrained_model_name_or_path)
Expand Down Expand Up @@ -104,12 +106,32 @@ def add_to_index(
"add_to_index support will be more thorough in future versions",
)

if self.loaded_from_index:
index_root = self.config.root
else:
index_root = str(
Path(self.config.root)
/ Path(self.config.experiment)
/ "indexes"
)
if not self.collection:
self.collection = self._get_collection_from_file(
str(
Path(self.config.root)
/ Path(self.config.experiment)
/ "indexes"
/ self.index_name
/ "collection.json"
)
)


searcher = Searcher(
checkpoint=self.checkpoint,
config=None,
collection=self.collection,
index=self.index_name,
index_root=self.config.root,
index_root=index_root,
verbose=self.verbose,
)
new_documents = list(set(new_documents))
Expand Down

0 comments on commit 46842e0

Please sign in to comment.