Skip to content

Commit

Permalink
remove unnecessary changes
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshuaPurtell committed Feb 26, 2024
1 parent d5a1d5a commit 6c4813d
Showing 1 changed file with 27 additions and 22 deletions.
49 changes: 27 additions & 22 deletions ragatouille/RAGTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,6 @@
class RAGTrainer:
"""Main trainer to fine-tune/train ColBERT models with a few lines."""

model: Union[LateInteractionModel, None] = None
negative_miner: Union[HardNegativeMiner, None] = None
collection: list[str] = []
queries: Union[list[str], None] = None
raw_data: Union[list[tuple], list[list], None] = None
training_triplets: list[list[int]] = list()

def __init__(
self,
model_name: str,
Expand All @@ -38,15 +31,19 @@ def __init__(
Returns:
self (RAGTrainer): The current instance of RAGTrainer, with the base model initialised.
"""

self.model_name = model_name
self.pretrained_model_name = pretrained_model_name
self.language_code = language_code
self.model = ColBERT(
self.model: Union[LateInteractionModel, None] = ColBERT(
pretrained_model_name_or_path=pretrained_model_name,
n_gpu=n_usable_gpus,
training_mode=True,
)
self.negative_miner: Union[HardNegativeMiner, None] = None
self.collection: list[str] = []
self.queries: Union[list[str], None] = None
self.raw_data: Union[list[tuple], list[list], None] = None
self.training_triplets: list[list[int]] = list()

def add_documents(self, documents: list[str]):
self.collection += documents
Expand All @@ -60,6 +57,14 @@ def export_training_data(self, path: Union[str, Path]):
path: Union[str, Path] - Path to the directory where the data will be exported."""
self.data_processor.export_training_data(path)

def _add_to_collection(self, content: Union[str, list, dict]):
if isinstance(content, str):
self.collection.append(content)
elif isinstance(content, list):
self.collection += [txt for txt in content if isinstance(txt, str)]
elif isinstance(content, dict):
self.collection += [content["content"]]

def prepare_training_data(
self,
raw_data: Union[list[tuple], list[list]],
Expand Down Expand Up @@ -98,12 +103,16 @@ def prepare_training_data(
self.collection += [doc for doc in all_documents if isinstance(doc, str)]

self.data_dir = Path(data_out_path)
if len(raw_data[0]) == 2:
sample = raw_data[0]
if len(sample) == 2:
data_type = "pairs"
elif len(sample) == 3:
if pairs_with_labels:
data_type = "labeled_pairs"
elif len(raw_data[0]) == 3:
data_type = "triplets"
if sample[-1] not in [positive_label, negative_label]:
raise ValueError(f"Invalid value for label: {sample}")
else:
data_type = "triplets"
else:
raise ValueError("Raw data must be a list of pairs or triplets of strings.")

Expand All @@ -113,16 +122,12 @@ def prepare_training_data(
self.queries.add(x[0])
else:
raise ValueError("Queries must be a strings.")
if isinstance(x[1], str):
self.collection.append(x[1])
elif isinstance(x[1], list):
self.collection += [txt for txt in x[1] if isinstance(txt, str)]

if len(x) == 3: # For triplets
if isinstance(x[2], str):
self.collection.append(x[2])
elif isinstance(x[2], list):
self.collection += [txt for txt in x[2] if isinstance(txt, str)]

self._add_to_collection(x[1])

if data_type == "triplets":
self._add_to_collection(x[2])

self.collection = list(set(self.collection))
seeded_shuffle(self.collection)

Expand Down

0 comments on commit 6c4813d

Please sign in to comment.