Skip to content

Latest commit

 

History

History
117 lines (85 loc) · 3.14 KB

README.md

File metadata and controls

117 lines (85 loc) · 3.14 KB

CAiRE in DialDoc21

This repository contains the code of CAiRE submissions for DialDoc21 shared task: CAiRE in DialDoc21: Data Augmentation for Information-Seeking Dialogue System. Yan Xu, Etsuko Ishii, Genta Indra Winata, Zhaojiang Lin, Andrea Madotto, Zihan Liu, Peng Xu, Pascale Fung DialDoc Shared Task@ACL2021 [PDF]

The implementation is mainly based on Huggingface package and Shared DDP is leveraged in the trainig process. If you use any source codes included in this toolkit in your work, please cite the following paper. The bibtex is listed below:

TBD

Install environment

pip install -r requirements.txt
python -m spacy download en_core_web_sm

load dialdoc dataset

datasets = load_dataset("utils/dialdoc.py", "doc2dial_rc")

Task 1

cd task1

Model training

  • Train the model on MRQA dataset
sh run_mrqa.sh
  • Train the model on CQA dataset and DialDoc dataset
sh run_qa_extra.sh

For RoBERTa_{all} model, please add ../utils/mrqa.py and mrqa_rc_small under extra_dataset_name and extra_dataset_config_name arguments, respectively.

  • Finetune the model on DialDoc dataset
sh run_qa.sh
  • Evaluate the model
sh eval_qa.sh

Data postprocessing

  • Post-processing on the predicted spans.

If calculating the metrics is needed, add --do_eval. This argument only could be applied on validation set.

python postprocess_prediction.py --task split --prediction_file [PATH TO THE POSITION FILE(appear as positions.json)] --output_file [PATH OF OUTPUT FILE] --split [validation/devtest/test] --threshold 0.1 --save_span

Ensemble

  • Build an ensemble of the existing models.

Before building the ensemble, please put all the post-processed positions.json file into the same specific folder, e.g. test_sp.

python ensemble test_sp
sh do_ensemble.sh

Task 2

cd task2

Model Pre-training

  • Pre-train BART model on WoW dataset.
TBC

Model Fine-tuning

  • Further finetune BART model on dialdoc dataset
sh run_seq2seq_ddp.sh
  • Evaluate the model
sh eval_seq2seq_ddp.sh
  • Get the model generations
sh run_predict.sh

Post-processing

  • Post-processing is only applied to final test set.
python merge.py --gen_preds [PATH TO BART GENERATIONS] --raw_preds [PATH TO THE PREDICTIONS FROM TASK1] --domain_file cache/test_domain.json --output_file [PATH TO OUTPUT FILE]