From ada9af098ca1a88e2f01743e5d62c283c656d023 Mon Sep 17 00:00:00 2001 From: hanhainebula <2512674094@qq.com> Date: Sat, 23 Nov 2024 16:35:12 +0800 Subject: [PATCH] update code and README for scripts --- scripts/README.md | 30 +++-- scripts/add_reranker_score.py | 10 +- scripts/hn_mine.py | 201 +++++++++++++++++++++++++------- scripts/split_data_by_length.py | 8 +- 4 files changed, 188 insertions(+), 61 deletions(-) diff --git a/scripts/README.md b/scripts/README.md index d6086a17..2150a32b 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -17,12 +17,12 @@ Hard negatives is a widely used method to improve the quality of sentence embedd ```shell python hn_mine.py \ ---model_name_or_path BAAI/bge-base-en-v1.5 \ --input_file toy_finetune_data.jsonl \ --output_file toy_finetune_data_minedHN.jsonl \ --range_for_sampling 2-200 \ --negative_number 15 \ ---use_gpu_for_searching +--use_gpu_for_searching \ +--embedder_name_or_path BAAI/bge-base-en-v1.5 ``` - **`input_file`**: json data for finetuning. This script will retrieve top-k documents for each query, and random sample negatives from the top-k documents (not including the positive documents). @@ -31,6 +31,19 @@ python hn_mine.py \ - **`range_for_sampling`**: where to sample negative. For example, `2-100` means sampling `negative_number` negatives from top2-top200 documents. **You can set larger value to reduce the difficulty of negatives (e.g., set it `60-300` to sample negatives from top60-300 passages)** - **`candidate_pool`**: The pool to retrieval. The default value is None, and this script will retrieve from the combination of all `neg` in `input_file`. The format of this file is the same as [pretrain data](https://github.com/FlagOpen/FlagEmbedding/tree/master/examples/pretrain#2-data-format). If input a candidate_pool, this script will retrieve negatives from this file. - **`use_gpu_for_searching`**: whether to use faiss-gpu to retrieve negatives. +- **`search_batch_size`**: batch size for searching. Default is 64. +- **`embedder_name_or_path`**: The name or path to the embedder. +- **`embedder_model_class`**: Class of the model used for embedding (current options include 'encoder-only-base', 'encoder-only-m3', 'decoder-only-base', 'decoder-only-icl'.). Default is None. For the custom model, you should set this argument. +- **`normalize_embeddings`**: Set to `True` to normalize embeddings. +- **`pooling_method`**: The pooling method for the embedder. +- **`use_fp16`**: Use FP16 precision for inference. +- **`devices`**: List of devices used for inference. +- **`query_instruction_for_retrieval`**, **`query_instruction_format_for_retrieval`**: Instructions and format for query during retrieval. +- **`examples_for_task`**, **`examples_instruction_format`**: Example tasks and their instructions format. This is only used when `embedder_model_class` is set to `decoder-only-icl`. +- **`trust_remote_code`**: Set to `True` to trust remote code execution. +- **`cache_dir`**: Cache directory for models. +- **`embedder_batch_size`**: Batch sizes for embedding and reranking. +- **`embedder_query_max_length`**, **`embedder_passage_max_length`**: Maximum length for embedding queries and passages. ### Teacher Scores @@ -40,9 +53,7 @@ Teacher scores can be used for model distillation. You can obtain the scores usi python add_reranker_score.py \ --input_file toy_finetune_data_minedHN.jsonl \ --output_file toy_finetune_data_score.jsonl \ ---range_for_sampling 2-200 \ ---negative_number 15 \ ---use_gpu_for_searching +--reranker_name_or_path BAAI/bge-reranker-v2-m3 ``` - **`input_file`**: path to save JSON data with mined hard negatives for finetuning @@ -80,15 +91,14 @@ python split_data_by_length.py \ --log_name .split_log \ --length_list 0 500 1000 2000 3000 4000 5000 6000 7000 \ --model_name_or_path BAAI/bge-m3 \ ---num_proc 16 \ ---overwrite False +--num_proc 16 ``` -- **`input_path`**: The path of input data. (Required) -- **`output_dir`**: The directory of output data. (Required) +- **`input_path`**: The path of input data. It can be a file or a directory containing multiple files. +- **`output_dir`**: The directory of output data. The split data files will be saved to this directory. - **`cache_dir`**: The cache directory. Default: None - **`log_name`**: The name of the log file. Default: `.split_log`, which will be saved to `output_dir` - **`length_list`**: The length list to split. Default: [0, 500, 1000, 2000, 3000, 4000, 5000, 6000, 7000] - **`model_name_or_path`**: The model name or path of the tokenizer. Default: `BAAI/bge-m3` - **`num_proc`**: The number of processes. Default: 16 -- **`overwrite`**: Whether to overwrite the output file. Default: False \ No newline at end of file +- **`overwrite`**: Whether to overwrite the output file. Default: False diff --git a/scripts/add_reranker_score.py b/scripts/add_reranker_score.py index ee14cfb4..46ebd6ad 100644 --- a/scripts/add_reranker_score.py +++ b/scripts/add_reranker_score.py @@ -1,9 +1,10 @@ import json from typing import Optional, List -from FlagEmbedding import FlagAutoReranker from dataclasses import dataclass, field from transformers import HfArgumentParser +from FlagEmbedding import FlagAutoReranker + @dataclass class ScoreArgs: @@ -14,6 +15,7 @@ class ScoreArgs: default=None, metadata={"help": "The output jsonl file, it includes query, pos, neg, pos_scores and neg_scores."} ) + @dataclass class ModelArgs: use_fp16: bool = field( @@ -78,7 +80,8 @@ class ModelArgs: default=None, metadata={"help": "The compress layers of lightweight reranker.", "nargs": "+"} ) -def main(score_args, model_args): + +def main(score_args: ScoreArgs, model_args: ModelArgs): reranker = FlagAutoReranker.from_finetuned( model_name_or_path=model_args.reranker_name_or_path, model_class=model_args.reranker_model_class, @@ -130,7 +133,7 @@ def main(score_args, model_args): f.write(json.dumps(d) + '\n') -if __name__ == '__main__': +if __name__ == "__main__": parser = HfArgumentParser(( ScoreArgs, ModelArgs @@ -139,4 +142,3 @@ def main(score_args, model_args): score_args: ScoreArgs model_args: ModelArgs main(score_args, model_args) - diff --git a/scripts/hn_mine.py b/scripts/hn_mine.py index dfa82a75..b1d47181 100644 --- a/scripts/hn_mine.py +++ b/scripts/hn_mine.py @@ -1,28 +1,98 @@ -import argparse import json import random import numpy as np -import faiss from tqdm import tqdm +from typing import Optional +from dataclasses import dataclass, field -from FlagEmbedding import FlagModel - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--model_name_or_path', default="BAAI/bge-base-en", type=str) - parser.add_argument('--input_file', default=None, type=str) - parser.add_argument('--candidate_pool', default=None, type=str) - parser.add_argument('--output_file', default=None, type=str) - parser.add_argument('--range_for_sampling', default="10-210", type=str, help="range to sample negatives") - parser.add_argument('--use_gpu_for_searching', action='store_true', help='use faiss-gpu') - parser.add_argument('--negative_number', default=15, type=int, help='the number of negatives') - parser.add_argument('--query_instruction_for_retrieval', default="") - - return parser.parse_args() - - -def create_index(embeddings, use_gpu): +import faiss +from transformers import HfArgumentParser +from FlagEmbedding import FlagAutoModel +from FlagEmbedding.abc.inference import AbsEmbedder + + +@dataclass +class DataArgs: + """ + Data arguments for hard negative mining. + """ + input_file: str = field( + metadata={"help": "The input file for hard negative mining."} + ) + output_file: str = field( + metadata={"help": "The output file for hard negative mining."} + ) + candidate_pool: Optional[str] = field( + default=None, metadata={"help": "The candidate pool for hard negative mining. If provided, it should be a jsonl file, each line is a dict with a key 'text'."} + ) + range_for_sampling: str = field( + default="10-210", metadata={"help": "The range to sample negatives."} + ) + negative_number: int = field( + default=15, metadata={"help": "The number of negatives."} + ) + use_gpu_for_searching: bool = field( + default=False, metadata={"help": "Whether to use faiss-gpu for searching."} + ) + search_batch_size: int = field( + default=64, metadata={"help": "The batch size for searching."} + ) + + +@dataclass +class ModelArgs: + """ + Model arguments for embedder. + """ + embedder_name_or_path: str = field( + metadata={"help": "The embedder name or path.", "required": True} + ) + embedder_model_class: Optional[str] = field( + default=None, metadata={"help": "The embedder model class. Available classes: ['encoder-only-base', 'encoder-only-m3', 'decoder-only-base', 'decoder-only-icl']. Default: None. For the custom model, you need to specifiy the model class.", "choices": ["encoder-only-base", "encoder-only-m3", "decoder-only-base", "decoder-only-icl"]} + ) + normalize_embeddings: bool = field( + default=True, metadata={"help": "whether to normalize the embeddings"} + ) + pooling_method: str = field( + default="cls", metadata={"help": "The pooling method fot the embedder."} + ) + use_fp16: bool = field( + default=True, metadata={"help": "whether to use fp16 for inference"} + ) + devices: Optional[str] = field( + default=None, metadata={"help": "Devices to use for inference.", "nargs": "+"} + ) + query_instruction_for_retrieval: Optional[str] = field( + default=None, metadata={"help": "Instruction for query"} + ) + query_instruction_format_for_retrieval: str = field( + default="{}{}", metadata={"help": "Format for query instruction"} + ) + examples_for_task: Optional[str] = field( + default=None, metadata={"help": "Examples for task"} + ) + examples_instruction_format: str = field( + default="{}{}", metadata={"help": "Format for examples instruction"} + ) + trust_remote_code: bool = field( + default=False, metadata={"help": "Trust remote code"} + ) + cache_dir: str = field( + default=None, metadata={"help": "Cache directory for models."} + ) + # ================ for inference =============== + batch_size: int = field( + default=3000, metadata={"help": "Batch size for inference."} + ) + embedder_query_max_length: int = field( + default=512, metadata={"help": "Max length for query."} + ) + embedder_passage_max_length: int = field( + default=512, metadata={"help": "Max length for passage."} + ) + + +def create_index(embeddings: np.ndarray, use_gpu: bool = False): index = faiss.IndexFlatIP(len(embeddings[0])) embeddings = np.asarray(embeddings, dtype=np.float32) if use_gpu: @@ -34,10 +104,12 @@ def create_index(embeddings, use_gpu): return index -def batch_search(index, - query, - topk: int = 200, - batch_size: int = 64): +def batch_search( + index: faiss.Index, + query: np.ndarray, + topk: int = 200, + batch_size: int = 64 +): all_scores, all_inxs = [], [] for start_index in tqdm(range(0, len(query), batch_size), desc="Batches", disable=len(query) < 256): batch_query = query[start_index:start_index + batch_size] @@ -47,15 +119,24 @@ def batch_search(index, return all_scores, all_inxs -def get_corpus(candidate_pool): +def get_corpus(candidate_pool: str): corpus = [] - for line in open(candidate_pool): - line = json.loads(line.strip()) - corpus.append(line['text']) + with open(candidate_pool, "r", encoding="utf-8") as f: + for line in f.readlines(): + line = json.loads(line.strip()) + corpus.append(line['text']) return corpus -def find_knn_neg(model, input_file, candidate_pool, output_file, sample_range, negative_number, use_gpu): +def find_knn_neg( + model: AbsEmbedder, + input_file: str, + output_file: str, + candidate_pool: Optional[str] = None, + sample_range: str = "10-210", + negative_number: int = 15, + use_gpu: bool = False +): corpus = [] queries = [] train_data = [] @@ -75,9 +156,9 @@ def find_knn_neg(model, input_file, candidate_pool, output_file, sample_range, n corpus = list(set(corpus)) print(f'inferencing embedding for corpus (number={len(corpus)})--------------') - p_vecs = model.encode(corpus, batch_size=256) + p_vecs = model.encode(corpus) print(f'inferencing embedding for queries (number={len(queries)})--------------') - q_vecs = model.encode_queries(queries, batch_size=256) + q_vecs = model.encode_queries(queries) print('create index and search------------------') index = create_index(p_vecs, use_gpu=use_gpu) @@ -106,17 +187,47 @@ def find_knn_neg(model, input_file, candidate_pool, output_file, sample_range, n f.write(json.dumps(data, ensure_ascii=False) + '\n') -if __name__ == '__main__': - args = get_args() - sample_range = args.range_for_sampling.split('-') - sample_range = [int(x) for x in sample_range] - - model = FlagModel(args.model_name_or_path, query_instruction_for_retrieval=args.query_instruction_for_retrieval) - - find_knn_neg(model, - input_file=args.input_file, - candidate_pool=args.candidate_pool, - output_file=args.output_file, - sample_range=sample_range, - negative_number=args.negative_number, - use_gpu=args.use_gpu_for_searching) +def load_model(model_args: ModelArgs): + model = FlagAutoModel.from_finetuned( + model_name_or_path=model_args.embedder_name_or_path, + model_class=model_args.embedder_model_class, + normalize_embeddings=model_args.normalize_embeddings, + pooling_method=model_args.pooling_method, + use_fp16=model_args.use_fp16, + query_instruction_for_retrieval=model_args.query_instruction_for_retrieval, + query_instruction_format=model_args.query_instruction_format_for_retrieval, + devices=model_args.devices, + examples_for_task=model_args.examples_for_task, + examples_instruction_format=model_args.examples_instruction_format, + trust_remote_code=model_args.trust_remote_code, + cache_dir=model_args.cache_dir, + batch_size=model_args.batch_size, + query_max_length=model_args.embedder_query_max_length, + passage_max_length=model_args.embedder_passage_max_length, + ) + return model + + +def main(data_args: DataArgs, model_args: ModelArgs): + model = load_model(model_args) + + find_knn_neg( + model=model, + input_file=data_args.input_file, + output_file=data_args.output_file, + candidate_pool=data_args.candidate_pool, + sample_range=[int(x) for x in data_args.range_for_sampling.split('-')], + negative_number=data_args.negative_number, + use_gpu=data_args.use_gpu_for_searching + ) + + +if __name__ == "__main__": + parser = HfArgumentParser(( + DataArgs, + ModelArgs + )) + data_args, model_args = parser.parse_args_into_dataclasses() + data_args: DataArgs + model_args: ModelArgs + main(data_args, model_args) diff --git a/scripts/split_data_by_length.py b/scripts/split_data_by_length.py index fc4e5ffa..eb2fcfdf 100644 --- a/scripts/split_data_by_length.py +++ b/scripts/split_data_by_length.py @@ -187,8 +187,7 @@ def run(self, input_path: str, output_dir: str, log_name: str=None): f.write('\n') -if __name__ == '__main__': - args = get_args() +def main(args): input_path = args.input_path output_dir = args.output_dir log_name = args.log_name @@ -207,3 +206,8 @@ def run(self, input_path: str, output_dir: str, log_name: str=None): log_name=log_name ) print('\nDONE!') + + +if __name__ == "__main__": + args = get_args() + main(args)