Skip to content

Commit

Permalink
update code and README for scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhainebula committed Nov 23, 2024
1 parent a719aaa commit ada9af0
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 61 deletions.
30 changes: 20 additions & 10 deletions scripts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ Hard negatives is a widely used method to improve the quality of sentence embedd

```shell
python hn_mine.py \
--model_name_or_path BAAI/bge-base-en-v1.5 \
--input_file toy_finetune_data.jsonl \
--output_file toy_finetune_data_minedHN.jsonl \
--range_for_sampling 2-200 \
--negative_number 15 \
--use_gpu_for_searching
--use_gpu_for_searching \
--embedder_name_or_path BAAI/bge-base-en-v1.5
```

- **`input_file`**: json data for finetuning. This script will retrieve top-k documents for each query, and random sample negatives from the top-k documents (not including the positive documents).
Expand All @@ -31,6 +31,19 @@ python hn_mine.py \
- **`range_for_sampling`**: where to sample negative. For example, `2-100` means sampling `negative_number` negatives from top2-top200 documents. **You can set larger value to reduce the difficulty of negatives (e.g., set it `60-300` to sample negatives from top60-300 passages)**
- **`candidate_pool`**: The pool to retrieval. The default value is None, and this script will retrieve from the combination of all `neg` in `input_file`. The format of this file is the same as [pretrain data](https://github.com/FlagOpen/FlagEmbedding/tree/master/examples/pretrain#2-data-format). If input a candidate_pool, this script will retrieve negatives from this file.
- **`use_gpu_for_searching`**: whether to use faiss-gpu to retrieve negatives.
- **`search_batch_size`**: batch size for searching. Default is 64.
- **`embedder_name_or_path`**: The name or path to the embedder.
- **`embedder_model_class`**: Class of the model used for embedding (current options include 'encoder-only-base', 'encoder-only-m3', 'decoder-only-base', 'decoder-only-icl'.). Default is None. For the custom model, you should set this argument.
- **`normalize_embeddings`**: Set to `True` to normalize embeddings.
- **`pooling_method`**: The pooling method for the embedder.
- **`use_fp16`**: Use FP16 precision for inference.
- **`devices`**: List of devices used for inference.
- **`query_instruction_for_retrieval`**, **`query_instruction_format_for_retrieval`**: Instructions and format for query during retrieval.
- **`examples_for_task`**, **`examples_instruction_format`**: Example tasks and their instructions format. This is only used when `embedder_model_class` is set to `decoder-only-icl`.
- **`trust_remote_code`**: Set to `True` to trust remote code execution.
- **`cache_dir`**: Cache directory for models.
- **`embedder_batch_size`**: Batch sizes for embedding and reranking.
- **`embedder_query_max_length`**, **`embedder_passage_max_length`**: Maximum length for embedding queries and passages.

### Teacher Scores

Expand All @@ -40,9 +53,7 @@ Teacher scores can be used for model distillation. You can obtain the scores usi
python add_reranker_score.py \
--input_file toy_finetune_data_minedHN.jsonl \
--output_file toy_finetune_data_score.jsonl \
--range_for_sampling 2-200 \
--negative_number 15 \
--use_gpu_for_searching
--reranker_name_or_path BAAI/bge-reranker-v2-m3
```

- **`input_file`**: path to save JSON data with mined hard negatives for finetuning
Expand Down Expand Up @@ -80,15 +91,14 @@ python split_data_by_length.py \
--log_name .split_log \
--length_list 0 500 1000 2000 3000 4000 5000 6000 7000 \
--model_name_or_path BAAI/bge-m3 \
--num_proc 16 \
--overwrite False
--num_proc 16
```

- **`input_path`**: The path of input data. (Required)
- **`output_dir`**: The directory of output data. (Required)
- **`input_path`**: The path of input data. It can be a file or a directory containing multiple files.
- **`output_dir`**: The directory of output data. The split data files will be saved to this directory.
- **`cache_dir`**: The cache directory. Default: None
- **`log_name`**: The name of the log file. Default: `.split_log`, which will be saved to `output_dir`
- **`length_list`**: The length list to split. Default: [0, 500, 1000, 2000, 3000, 4000, 5000, 6000, 7000]
- **`model_name_or_path`**: The model name or path of the tokenizer. Default: `BAAI/bge-m3`
- **`num_proc`**: The number of processes. Default: 16
- **`overwrite`**: Whether to overwrite the output file. Default: False
- **`overwrite`**: Whether to overwrite the output file. Default: False
10 changes: 6 additions & 4 deletions scripts/add_reranker_score.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import json
from typing import Optional, List

from FlagEmbedding import FlagAutoReranker
from dataclasses import dataclass, field
from transformers import HfArgumentParser
from FlagEmbedding import FlagAutoReranker


@dataclass
class ScoreArgs:
Expand All @@ -14,6 +15,7 @@ class ScoreArgs:
default=None, metadata={"help": "The output jsonl file, it includes query, pos, neg, pos_scores and neg_scores."}
)


@dataclass
class ModelArgs:
use_fp16: bool = field(
Expand Down Expand Up @@ -78,7 +80,8 @@ class ModelArgs:
default=None, metadata={"help": "The compress layers of lightweight reranker.", "nargs": "+"}
)

def main(score_args, model_args):

def main(score_args: ScoreArgs, model_args: ModelArgs):
reranker = FlagAutoReranker.from_finetuned(
model_name_or_path=model_args.reranker_name_or_path,
model_class=model_args.reranker_model_class,
Expand Down Expand Up @@ -130,7 +133,7 @@ def main(score_args, model_args):
f.write(json.dumps(d) + '\n')


if __name__ == '__main__':
if __name__ == "__main__":
parser = HfArgumentParser((
ScoreArgs,
ModelArgs
Expand All @@ -139,4 +142,3 @@ def main(score_args, model_args):
score_args: ScoreArgs
model_args: ModelArgs
main(score_args, model_args)

201 changes: 156 additions & 45 deletions scripts/hn_mine.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,98 @@
import argparse
import json
import random
import numpy as np
import faiss
from tqdm import tqdm
from typing import Optional
from dataclasses import dataclass, field

from FlagEmbedding import FlagModel


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model_name_or_path', default="BAAI/bge-base-en", type=str)
parser.add_argument('--input_file', default=None, type=str)
parser.add_argument('--candidate_pool', default=None, type=str)
parser.add_argument('--output_file', default=None, type=str)
parser.add_argument('--range_for_sampling', default="10-210", type=str, help="range to sample negatives")
parser.add_argument('--use_gpu_for_searching', action='store_true', help='use faiss-gpu')
parser.add_argument('--negative_number', default=15, type=int, help='the number of negatives')
parser.add_argument('--query_instruction_for_retrieval', default="")

return parser.parse_args()


def create_index(embeddings, use_gpu):
import faiss
from transformers import HfArgumentParser
from FlagEmbedding import FlagAutoModel
from FlagEmbedding.abc.inference import AbsEmbedder


@dataclass
class DataArgs:
"""
Data arguments for hard negative mining.
"""
input_file: str = field(
metadata={"help": "The input file for hard negative mining."}
)
output_file: str = field(
metadata={"help": "The output file for hard negative mining."}
)
candidate_pool: Optional[str] = field(
default=None, metadata={"help": "The candidate pool for hard negative mining. If provided, it should be a jsonl file, each line is a dict with a key 'text'."}
)
range_for_sampling: str = field(
default="10-210", metadata={"help": "The range to sample negatives."}
)
negative_number: int = field(
default=15, metadata={"help": "The number of negatives."}
)
use_gpu_for_searching: bool = field(
default=False, metadata={"help": "Whether to use faiss-gpu for searching."}
)
search_batch_size: int = field(
default=64, metadata={"help": "The batch size for searching."}
)


@dataclass
class ModelArgs:
"""
Model arguments for embedder.
"""
embedder_name_or_path: str = field(
metadata={"help": "The embedder name or path.", "required": True}
)
embedder_model_class: Optional[str] = field(
default=None, metadata={"help": "The embedder model class. Available classes: ['encoder-only-base', 'encoder-only-m3', 'decoder-only-base', 'decoder-only-icl']. Default: None. For the custom model, you need to specifiy the model class.", "choices": ["encoder-only-base", "encoder-only-m3", "decoder-only-base", "decoder-only-icl"]}
)
normalize_embeddings: bool = field(
default=True, metadata={"help": "whether to normalize the embeddings"}
)
pooling_method: str = field(
default="cls", metadata={"help": "The pooling method fot the embedder."}
)
use_fp16: bool = field(
default=True, metadata={"help": "whether to use fp16 for inference"}
)
devices: Optional[str] = field(
default=None, metadata={"help": "Devices to use for inference.", "nargs": "+"}
)
query_instruction_for_retrieval: Optional[str] = field(
default=None, metadata={"help": "Instruction for query"}
)
query_instruction_format_for_retrieval: str = field(
default="{}{}", metadata={"help": "Format for query instruction"}
)
examples_for_task: Optional[str] = field(
default=None, metadata={"help": "Examples for task"}
)
examples_instruction_format: str = field(
default="{}{}", metadata={"help": "Format for examples instruction"}
)
trust_remote_code: bool = field(
default=False, metadata={"help": "Trust remote code"}
)
cache_dir: str = field(
default=None, metadata={"help": "Cache directory for models."}
)
# ================ for inference ===============
batch_size: int = field(
default=3000, metadata={"help": "Batch size for inference."}
)
embedder_query_max_length: int = field(
default=512, metadata={"help": "Max length for query."}
)
embedder_passage_max_length: int = field(
default=512, metadata={"help": "Max length for passage."}
)


def create_index(embeddings: np.ndarray, use_gpu: bool = False):
index = faiss.IndexFlatIP(len(embeddings[0]))
embeddings = np.asarray(embeddings, dtype=np.float32)
if use_gpu:
Expand All @@ -34,10 +104,12 @@ def create_index(embeddings, use_gpu):
return index


def batch_search(index,
query,
topk: int = 200,
batch_size: int = 64):
def batch_search(
index: faiss.Index,
query: np.ndarray,
topk: int = 200,
batch_size: int = 64
):
all_scores, all_inxs = [], []
for start_index in tqdm(range(0, len(query), batch_size), desc="Batches", disable=len(query) < 256):
batch_query = query[start_index:start_index + batch_size]
Expand All @@ -47,15 +119,24 @@ def batch_search(index,
return all_scores, all_inxs


def get_corpus(candidate_pool):
def get_corpus(candidate_pool: str):
corpus = []
for line in open(candidate_pool):
line = json.loads(line.strip())
corpus.append(line['text'])
with open(candidate_pool, "r", encoding="utf-8") as f:
for line in f.readlines():
line = json.loads(line.strip())
corpus.append(line['text'])
return corpus


def find_knn_neg(model, input_file, candidate_pool, output_file, sample_range, negative_number, use_gpu):
def find_knn_neg(
model: AbsEmbedder,
input_file: str,
output_file: str,
candidate_pool: Optional[str] = None,
sample_range: str = "10-210",
negative_number: int = 15,
use_gpu: bool = False
):
corpus = []
queries = []
train_data = []
Expand All @@ -75,9 +156,9 @@ def find_knn_neg(model, input_file, candidate_pool, output_file, sample_range, n
corpus = list(set(corpus))

print(f'inferencing embedding for corpus (number={len(corpus)})--------------')
p_vecs = model.encode(corpus, batch_size=256)
p_vecs = model.encode(corpus)
print(f'inferencing embedding for queries (number={len(queries)})--------------')
q_vecs = model.encode_queries(queries, batch_size=256)
q_vecs = model.encode_queries(queries)

print('create index and search------------------')
index = create_index(p_vecs, use_gpu=use_gpu)
Expand Down Expand Up @@ -106,17 +187,47 @@ def find_knn_neg(model, input_file, candidate_pool, output_file, sample_range, n
f.write(json.dumps(data, ensure_ascii=False) + '\n')


if __name__ == '__main__':
args = get_args()
sample_range = args.range_for_sampling.split('-')
sample_range = [int(x) for x in sample_range]

model = FlagModel(args.model_name_or_path, query_instruction_for_retrieval=args.query_instruction_for_retrieval)

find_knn_neg(model,
input_file=args.input_file,
candidate_pool=args.candidate_pool,
output_file=args.output_file,
sample_range=sample_range,
negative_number=args.negative_number,
use_gpu=args.use_gpu_for_searching)
def load_model(model_args: ModelArgs):
model = FlagAutoModel.from_finetuned(
model_name_or_path=model_args.embedder_name_or_path,
model_class=model_args.embedder_model_class,
normalize_embeddings=model_args.normalize_embeddings,
pooling_method=model_args.pooling_method,
use_fp16=model_args.use_fp16,
query_instruction_for_retrieval=model_args.query_instruction_for_retrieval,
query_instruction_format=model_args.query_instruction_format_for_retrieval,
devices=model_args.devices,
examples_for_task=model_args.examples_for_task,
examples_instruction_format=model_args.examples_instruction_format,
trust_remote_code=model_args.trust_remote_code,
cache_dir=model_args.cache_dir,
batch_size=model_args.batch_size,
query_max_length=model_args.embedder_query_max_length,
passage_max_length=model_args.embedder_passage_max_length,
)
return model


def main(data_args: DataArgs, model_args: ModelArgs):
model = load_model(model_args)

find_knn_neg(
model=model,
input_file=data_args.input_file,
output_file=data_args.output_file,
candidate_pool=data_args.candidate_pool,
sample_range=[int(x) for x in data_args.range_for_sampling.split('-')],
negative_number=data_args.negative_number,
use_gpu=data_args.use_gpu_for_searching
)


if __name__ == "__main__":
parser = HfArgumentParser((
DataArgs,
ModelArgs
))
data_args, model_args = parser.parse_args_into_dataclasses()
data_args: DataArgs
model_args: ModelArgs
main(data_args, model_args)
8 changes: 6 additions & 2 deletions scripts/split_data_by_length.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,7 @@ def run(self, input_path: str, output_dir: str, log_name: str=None):
f.write('\n')


if __name__ == '__main__':
args = get_args()
def main(args):
input_path = args.input_path
output_dir = args.output_dir
log_name = args.log_name
Expand All @@ -207,3 +206,8 @@ def run(self, input_path: str, output_dir: str, log_name: str=None):
log_name=log_name
)
print('\nDONE!')


if __name__ == "__main__":
args = get_args()
main(args)

0 comments on commit ada9af0

Please sign in to comment.