From 0b340ab5edf723cf10309eddf5439f4be47c1289 Mon Sep 17 00:00:00 2001 From: bclavie Date: Fri, 23 Feb 2024 12:17:37 +0100 Subject: [PATCH 1/3] fix: doc_masks on encodings device --- ragatouille/models/colbert.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ragatouille/models/colbert.py b/ragatouille/models/colbert.py index 534a952..bbfbd5a 100644 --- a/ragatouille/models/colbert.py +++ b/ragatouille/models/colbert.py @@ -11,7 +11,6 @@ from colbert import Indexer, IndexUpdater, Searcher, Trainer from colbert.infra import ColBERTConfig, Run, RunConfig from colbert.modeling.checkpoint import Checkpoint - from ragatouille.models.base import LateInteractionModel # TODO: Move all bsize related calcs to `_set_bsize()` @@ -718,7 +717,9 @@ def _encode_index_free_documents( embedded_docs = self.inference_ckpt.docFromText( documents, bsize=bsize, showprogress=verbose )[0] - doc_mask = torch.full(embedded_docs.shape[:2], -float("inf")) + doc_mask = torch.full(embedded_docs.shape[:2], -float("inf")).to( + embedded_docs.device + ) return embedded_docs, doc_mask def rank( @@ -771,7 +772,7 @@ def encode( - doc_masks.shape[1], ), -float("inf"), - ).to(device=encodings.device), + ).to(device=doc_masks.device), ], dim=1, ) From 9bda84272e8d6387c18fff2ec3338cfe7abd3c7d Mon Sep 17 00:00:00 2001 From: bclavie Date: Fri, 23 Feb 2024 12:18:01 +0100 Subject: [PATCH 2/3] version bump --- pyproject.toml | 2 +- ragatouille/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6823253..2234dcd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "RAGatouille" -version = "0.0.7post4" +version = "0.0.7post5" description = "Library to facilitate the use of state-of-the-art retrieval models in common RAG contexts." authors = ["Benjamin Clavie "] license = "Apache-2.0" diff --git a/ragatouille/__init__.py b/ragatouille/__init__.py index 1ad2fe1..df4258b 100644 --- a/ragatouille/__init__.py +++ b/ragatouille/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.0.7post4" +__version__ = "0.0.7post5" from .RAGPretrainedModel import RAGPretrainedModel from .RAGTrainer import RAGTrainer From c5607352efa3f93bdd2c4165079e0a02b7f83b1e Mon Sep 17 00:00:00 2001 From: bclavie Date: Fri, 23 Feb 2024 12:19:19 +0100 Subject: [PATCH 3/3] chore: isort --- ragatouille/models/colbert.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ragatouille/models/colbert.py b/ragatouille/models/colbert.py index bbfbd5a..422baa3 100644 --- a/ragatouille/models/colbert.py +++ b/ragatouille/models/colbert.py @@ -11,6 +11,7 @@ from colbert import Indexer, IndexUpdater, Searcher, Trainer from colbert.infra import ColBERTConfig, Run, RunConfig from colbert.modeling.checkpoint import Checkpoint + from ragatouille.models.base import LateInteractionModel # TODO: Move all bsize related calcs to `_set_bsize()`