diff --git a/ragatouille/RAGTrainer.py b/ragatouille/RAGTrainer.py index 8da5ba6..35d5c4e 100644 --- a/ragatouille/RAGTrainer.py +++ b/ragatouille/RAGTrainer.py @@ -12,13 +12,6 @@ class RAGTrainer: """Main trainer to fine-tune/train ColBERT models with a few lines.""" - model: Union[LateInteractionModel, None] = None - negative_miner: Union[HardNegativeMiner, None] = None - collection: list[str] = [] - queries: Union[list[str], None] = None - raw_data: Union[list[tuple], list[list], None] = None - training_triplets: list[list[int]] = list() - def __init__( self, model_name: str, @@ -38,15 +31,19 @@ def __init__( Returns: self (RAGTrainer): The current instance of RAGTrainer, with the base model initialised. """ - self.model_name = model_name self.pretrained_model_name = pretrained_model_name self.language_code = language_code - self.model = ColBERT( + self.model: Union[LateInteractionModel, None] = ColBERT( pretrained_model_name_or_path=pretrained_model_name, n_gpu=n_usable_gpus, training_mode=True, ) + self.negative_miner: Union[HardNegativeMiner, None] = None + self.collection: list[str] = [] + self.queries: Union[list[str], None] = None + self.raw_data: Union[list[tuple], list[list], None] = None + self.training_triplets: list[list[int]] = list() def add_documents(self, documents: list[str]): self.collection += documents @@ -60,6 +57,14 @@ def export_training_data(self, path: Union[str, Path]): path: Union[str, Path] - Path to the directory where the data will be exported.""" self.data_processor.export_training_data(path) + def _add_to_collection(self, content: Union[str, list, dict]): + if isinstance(content, str): + self.collection.append(content) + elif isinstance(content, list): + self.collection += [txt for txt in content if isinstance(txt, str)] + elif isinstance(content, dict): + self.collection += [content["content"]] + def prepare_training_data( self, raw_data: Union[list[tuple], list[list]], @@ -98,12 +103,16 @@ def prepare_training_data( self.collection += [doc for doc in all_documents if isinstance(doc, str)] self.data_dir = Path(data_out_path) - if len(raw_data[0]) == 2: + sample = raw_data[0] + if len(sample) == 2: data_type = "pairs" + elif len(sample) == 3: if pairs_with_labels: data_type = "labeled_pairs" - elif len(raw_data[0]) == 3: - data_type = "triplets" + if sample[-1] not in [positive_label, negative_label]: + raise ValueError(f"Invalid value for label: {sample}") + else: + data_type = "triplets" else: raise ValueError("Raw data must be a list of pairs or triplets of strings.") @@ -113,16 +122,12 @@ def prepare_training_data( self.queries.add(x[0]) else: raise ValueError("Queries must be a strings.") - if isinstance(x[1], str): - self.collection.append(x[1]) - elif isinstance(x[1], list): - self.collection += [txt for txt in x[1] if isinstance(txt, str)] - - if len(x) == 3: # For triplets - if isinstance(x[2], str): - self.collection.append(x[2]) - elif isinstance(x[2], list): - self.collection += [txt for txt in x[2] if isinstance(txt, str)] + + self._add_to_collection(x[1]) + + if data_type == "triplets": + self._add_to_collection(x[2]) + self.collection = list(set(self.collection)) seeded_shuffle(self.collection)