diff --git a/pyproject.toml b/pyproject.toml index 6823253..2234dcd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "RAGatouille" -version = "0.0.7post4" +version = "0.0.7post5" description = "Library to facilitate the use of state-of-the-art retrieval models in common RAG contexts." authors = ["Benjamin Clavie "] license = "Apache-2.0" diff --git a/ragatouille/__init__.py b/ragatouille/__init__.py index 1ad2fe1..df4258b 100644 --- a/ragatouille/__init__.py +++ b/ragatouille/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.0.7post4" +__version__ = "0.0.7post5" from .RAGPretrainedModel import RAGPretrainedModel from .RAGTrainer import RAGTrainer diff --git a/ragatouille/models/colbert.py b/ragatouille/models/colbert.py index 534a952..422baa3 100644 --- a/ragatouille/models/colbert.py +++ b/ragatouille/models/colbert.py @@ -718,7 +718,9 @@ def _encode_index_free_documents( embedded_docs = self.inference_ckpt.docFromText( documents, bsize=bsize, showprogress=verbose )[0] - doc_mask = torch.full(embedded_docs.shape[:2], -float("inf")) + doc_mask = torch.full(embedded_docs.shape[:2], -float("inf")).to( + embedded_docs.device + ) return embedded_docs, doc_mask def rank( @@ -771,7 +773,7 @@ def encode( - doc_masks.shape[1], ), -float("inf"), - ).to(device=encodings.device), + ).to(device=doc_masks.device), ], dim=1, )