From 696867b53e1061601c1262b582d40b3ec13a9f52 Mon Sep 17 00:00:00 2001 From: vaibhavad Date: Wed, 8 May 2024 16:16:10 +0000 Subject: [PATCH] Process split instructions in code rather than data filed --- llm2vec/dataset/E5Data.py | 44 +++++++++++++-------------------------- 1 file changed, 15 insertions(+), 29 deletions(-) diff --git a/llm2vec/dataset/E5Data.py b/llm2vec/dataset/E5Data.py index 69e3ca3..cfe33cd 100644 --- a/llm2vec/dataset/E5Data.py +++ b/llm2vec/dataset/E5Data.py @@ -7,28 +7,11 @@ logger = get_logger(__name__, log_level="INFO") -datasets_list = [ - "allnli_split1", - "allnli_split2", - "dureader", - "eli5_question_answer", - "fever", - "hotpot_qa", - "miracl", - "mrtydi", - "msmarco_passage", - "msmarco_document", - "nq", - "quora_duplicates_split1", - "quora_duplicates_split2", - "squad", - "t2ranking", - "trivia_qa", -] - E5_EMBEDDING_PROMPTS = { - "allnli_split1": "Given a premise, retrieve a hypothesis that is entailed by the premise", - "allnli_split2": "Retrieve semantically similar text", + "allnli": [ + "Given a premise, retrieve a hypothesis that is entailed by the premise", + "Retrieve semantically similar text", + ], "dureader": "Given a Chinese search query, retrieve web passages that answer the question", "eli5_question_answer": "Provided a user question, retrieve the highest voted answers on Reddit ELI5 forum", "fever": "Given a claim, retrieve documents that support or refute the claim", @@ -38,8 +21,10 @@ "msmarco_passage": "Given a web search query, retrieve relevant passages that answer the query", "msmarco_document": "Given a web search query, retrieve relevant documents that answer the query", "nq": "Given a question, retrieve Wikipedia passages that answer the question", - "quora_duplicates_split1": "Given a question, retrieve questions that are semantically equivalent to the given question", - "quora_duplicates_split2": "Find questions that have the same meaning as the input question", + "quora_duplicates": [ + "Given a question, retrieve questions that are semantically equivalent to the given question", + "Find questions that have the same meaning as the input question", + ], "squad": "Retrieve Wikipedia passages that answer the question", "t2ranking": "Given a Chinese search query, retrieve web passages that answer the question", "trivia_qa": "Retrieve Wikipedia passages that answer the question", @@ -75,7 +60,7 @@ def load_data(self, file_path: str = None): data_map = {} all_samples = [] id_ = 0 - for dataset in datasets_list: + for dataset in E5_EMBEDDING_PROMPTS: logger.info(f"Loading dataset {dataset}...") if dataset not in data_map: data_map[dataset] = [] @@ -84,12 +69,13 @@ def load_data(self, file_path: str = None): dataset_samples = [json.loads(d) for d in dataset_samples] - for sample in dataset_samples: - query = ( - f"{E5_EMBEDDING_PROMPTS[dataset]}; " - + self.separator - + sample["query"] + for i, sample in enumerate(dataset_samples): + instruction = ( + E5_EMBEDDING_PROMPTS[dataset] + if isinstance(E5_EMBEDDING_PROMPTS[dataset], str) + else E5_EMBEDDING_PROMPTS[dataset][i % 2] ) + query = f"{instruction}; " + self.separator + sample["query"] if dataset in [ "allnli_split2", "quora_duplicates_split1",