Skip to content

Commit

Permalink
Process split instructions in code rather than data filed
Browse files Browse the repository at this point in the history
  • Loading branch information
vaibhavad committed May 8, 2024
1 parent c5cb6d7 commit 696867b
Showing 1 changed file with 15 additions and 29 deletions.
44 changes: 15 additions & 29 deletions llm2vec/dataset/E5Data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,11 @@

logger = get_logger(__name__, log_level="INFO")

datasets_list = [
"allnli_split1",
"allnli_split2",
"dureader",
"eli5_question_answer",
"fever",
"hotpot_qa",
"miracl",
"mrtydi",
"msmarco_passage",
"msmarco_document",
"nq",
"quora_duplicates_split1",
"quora_duplicates_split2",
"squad",
"t2ranking",
"trivia_qa",
]

E5_EMBEDDING_PROMPTS = {
"allnli_split1": "Given a premise, retrieve a hypothesis that is entailed by the premise",
"allnli_split2": "Retrieve semantically similar text",
"allnli": [
"Given a premise, retrieve a hypothesis that is entailed by the premise",
"Retrieve semantically similar text",
],
"dureader": "Given a Chinese search query, retrieve web passages that answer the question",
"eli5_question_answer": "Provided a user question, retrieve the highest voted answers on Reddit ELI5 forum",
"fever": "Given a claim, retrieve documents that support or refute the claim",
Expand All @@ -38,8 +21,10 @@
"msmarco_passage": "Given a web search query, retrieve relevant passages that answer the query",
"msmarco_document": "Given a web search query, retrieve relevant documents that answer the query",
"nq": "Given a question, retrieve Wikipedia passages that answer the question",
"quora_duplicates_split1": "Given a question, retrieve questions that are semantically equivalent to the given question",
"quora_duplicates_split2": "Find questions that have the same meaning as the input question",
"quora_duplicates": [
"Given a question, retrieve questions that are semantically equivalent to the given question",
"Find questions that have the same meaning as the input question",
],
"squad": "Retrieve Wikipedia passages that answer the question",
"t2ranking": "Given a Chinese search query, retrieve web passages that answer the question",
"trivia_qa": "Retrieve Wikipedia passages that answer the question",
Expand Down Expand Up @@ -75,7 +60,7 @@ def load_data(self, file_path: str = None):
data_map = {}
all_samples = []
id_ = 0
for dataset in datasets_list:
for dataset in E5_EMBEDDING_PROMPTS:
logger.info(f"Loading dataset {dataset}...")
if dataset not in data_map:
data_map[dataset] = []
Expand All @@ -84,12 +69,13 @@ def load_data(self, file_path: str = None):

dataset_samples = [json.loads(d) for d in dataset_samples]

for sample in dataset_samples:
query = (
f"{E5_EMBEDDING_PROMPTS[dataset]}; "
+ self.separator
+ sample["query"]
for i, sample in enumerate(dataset_samples):
instruction = (
E5_EMBEDDING_PROMPTS[dataset]
if isinstance(E5_EMBEDDING_PROMPTS[dataset], str)
else E5_EMBEDDING_PROMPTS[dataset][i % 2]
)
query = f"{instruction}; " + self.separator + sample["query"]
if dataset in [
"allnli_split2",
"quora_duplicates_split1",
Expand Down

0 comments on commit 696867b

Please sign in to comment.