This repository contains the code of CAiRE submissions for DialDoc21 shared task: CAiRE in DialDoc21: Data Augmentation for Information-Seeking Dialogue System. Yan Xu, Etsuko Ishii, Genta Indra Winata, Zhaojiang Lin, Andrea Madotto, Zihan Liu, Peng Xu, Pascale Fung DialDoc Shared Task@ACL2021 [PDF]
The implementation is mainly based on Huggingface package and Shared DDP is leveraged in the trainig process. If you use any source codes included in this toolkit in your work, please cite the following paper. The bibtex is listed below:
TBD
pip install -r requirements.txt
python -m spacy download en_core_web_sm
datasets = load_dataset("utils/dialdoc.py", "doc2dial_rc")
cd task1
- Train the model on MRQA dataset
sh run_mrqa.sh
- Train the model on CQA dataset and DialDoc dataset
sh run_qa_extra.sh
For RoBERTa_{all} model, please add ../utils/mrqa.py
and mrqa_rc_small
under extra_dataset_name
and extra_dataset_config_name
arguments, respectively.
- Finetune the model on DialDoc dataset
sh run_qa.sh
- Evaluate the model
sh eval_qa.sh
- Post-processing on the predicted spans.
If calculating the metrics is needed, add --do_eval
. This argument only could be applied on validation set.
python postprocess_prediction.py --task split --prediction_file [PATH TO THE POSITION FILE(appear as positions.json)] --output_file [PATH OF OUTPUT FILE] --split [validation/devtest/test] --threshold 0.1 --save_span
- Build an ensemble of the existing models.
Before building the ensemble, please put all the post-processed positions.json
file into the same specific folder, e.g. test_sp
.
python ensemble test_sp
sh do_ensemble.sh
cd task2
- Pre-train BART model on WoW dataset.
TBC
- Further finetune BART model on dialdoc dataset
sh run_seq2seq_ddp.sh
- Evaluate the model
sh eval_seq2seq_ddp.sh
- Get the model generations
sh run_predict.sh
- Post-processing is only applied to final test set.
python merge.py --gen_preds [PATH TO BART GENERATIONS] --raw_preds [PATH TO THE PREDICTIONS FROM TASK1] --domain_file cache/test_domain.json --output_file [PATH TO OUTPUT FILE]