Skip to content

Latest commit

 

History

History
114 lines (97 loc) · 3.53 KB

README.md

File metadata and controls

114 lines (97 loc) · 3.53 KB

RetroMAE Bi-encoder on MSMARCO

The finetune code is revised from tevatron.

Prepare Data

Run:

bash get_data.sh

It will download the cleaned corpus hosted by RocketQA team. Then tokenize the data and save it to ./data/BertTokenizer_data by:

python preprocess.py  --tokenizer_name bert-base-uncased --max_seq_length 150 --output_dir ./data/BertTokenizer_data --use_title

Train

You can download our checkpoint in huggingface hub: shitao/RetroMAE_MSMARCO_finetune. You also can finetune your model as following:

torchrun --nproc_per_node 8 \
-m bi_encoder.run \
--output_dir {path to save model} \
--model_name_or_path Shitao/RetroMAE_MSMARCO \
--do_train  \
--corpus_file ./data/BertTokenizer_data/corpus \
--train_query_file ./data/BertTokenizer_data/train_query \
--train_qrels ./data/BertTokenizer_data/train_qrels.txt \
--neg_file {negative file} \
--query_max_len 32 \
--passage_max_len 140 \
--fp16  \
--per_device_train_batch_size 16 \
--train_group_size 16 \
--sample_neg_from_topk 200 \
--learning_rate 2e-5 \
--num_train_epochs 4 \
--negatives_x_device  \
--dataloader_num_workers 6 

For neg_file, you can use the official negatives generated by bm25: ./data/train_negs.tsv or generate the hard negatives following Hard negative mining. We also provide the our hard negatives hard_negs.txt retrieved by bi-encoder model in google drive. Besides, you also can set teacher_score_files to use distill the knowledge from cross-encoder (we find that larger train_group_size usually makes a better performance in distillation).

Inference

Generate the embeddings of passages and save to results/passage_reps:

torchrun --nproc_per_node 8 \
-m bi_encoder.run \
--output_dir retromae_msmarco_passage_fintune \
--model_name_or_path Shitao/RetroMAE_MSMARCO_finetune  \
--corpus_file ./data/BertTokenizer_data/corpus \
--passage_max_len 140 \
--fp16  \
--do_predict \
--prediction_save_path results/ \
--per_device_eval_batch_size 256 \
--dataloader_num_workers 6 \
--eval_accumulation_steps 100 

Generate the embeddings of passages and save to results/query_reps:

torchrun --nproc_per_node 8 \
-m bi_encoder.run \
--output_dir retromae_msmarco_passage_fintune \
--model_name_or_path Shitao/RetroMAE_MSMARCO_finetune \
--test_query_file ./data/BertTokenizer_data/dev_query \
--query_max_len 32 \
--fp16  \
--do_predict \
--prediction_save_path results/ \
--per_device_eval_batch_size 256 \
--dataloader_num_workers 6 \
--eval_accumulation_steps 100 

Test

python test.py \
--query_reps_path results/query_reps \
--passage_reps_path results/passage_reps \
--qrels_file ./data/qrels.dev.tsv \
--ranking_file  results/dev_ranking.txt \
--use_gpu 

Hard negative mining

Hard negatives generated by retriever model can obtain a better performance than random negatives or BM25 negatives. To get hard negatives, use the retriever model to inference the embedding for train queries:./data/BertTokenizer_data/train_query, then rank the passages for each query in train set:

# rank the passages
python test.py \
--query_reps_path {embedding of train queries} \
--passage_reps_path results/passage_reps \
--ranking_file  results/train_ranking.txt \
--depth 200 \
--use_gpu 

# delete the positive passages in top-k results
python generate_hard_negtives.py \
--ranking_file  results/train_ranking.txt \
--qrels_file ./data/train_qrels.txt \
--output_neg_file {neg_file}