- A Tensorflow implementation of R-NET: MACHINE READING COMPREHENSION WITH SELF-MATCHING NETWORKS. This project is specially designed for the SQuAD dataset.
- Should you have any question, please contact Wenxuan Zhou ([email protected]).
- Python >= 3.4
- unzip, wget
- Tensorflow == 1.4.0
- spaCy >= 2.0.0
- tqdm
- ujson
To download and preprocess the data, run
# download SQuAD and Glove
sh download.sh
# preprocess the data
python config.py --mode prepro
Hyper parameters are stored in config.py. To debug/train/test the model, run
python config.py --mode debug/train/test
To get the official score, run
python evaluate-v1.1.py ~/data/squad/dev-v1.1.json log/answer/answer.json
The default directory for tensorboard log file is log/event
- The original paper uses additive attention, which consumes lots of memory. This project adopts scaled multiplicative attention presented in Attention Is All You Need.
- This project adopts variational dropout presented in A Theoretically Grounded Application of Dropout in Recurrent Neural Networks.
- To solve the degradation problem in stacked RNN, outputs of each layer are concatenated to produce the final output.
- When the loss on dev set increases in a certain period, the learning rate is halved.
- During prediction, the project adopts search method presented in Machine Comprehension Using Match-LSTM and Answer Pointer.
- To address efficiency issue, this implementation uses bucketing method (contributed by xiongyifan) and CudnnGRU. Due to a known bug #13254 in Tensorflow, the weights of CudnnGRU may not be properly restored. Check the test score if you want to use it for prediction. The bucketing method can speedup the training, but will lower the F1 score by 0.3%.
EM | F1 | |
---|---|---|
original paper | 71.1 | 79.5 |
this project | 71.07 | 79.51 |
Native | Native + Bucket | Cudnn | Cudnn + Bucket | |
---|---|---|---|---|
E5-2640 | 6.21 | 3.56 | - | - |
TITAN X | 2.72 | 1.67 | 0.61 | 0.35 |
These settings may increase the score but not used in the model by default. You can turn these settings on in config.py
.
- Pretrained GloVe character embedding. Contributed by yanghanxy.
- Fasttext Embedding. Contributed by xiongyifan. May increase the F1 by 1% (reported by xiongyifan).