This is a third party implementation of STN in pytorch which written in paper
- opencv(for load data faster)
- python3
- Pytorch >= 0.4.1
- torchvision
- numpy
- Network building
- Multi STN
- Train
Network | Test Acc |
---|---|
Paper Best | 84.1 |
Implement | 85.71 |
For training, use following command.
$ python train.py
Currently only cuda available device support.
See README.md
single stn (learn [[a, 0, c], [0, e, f]])
multi stn (learn [[parallel, 0, c], [0, parallel, f]], parallel=[0.9, 0.7, 0.5, 0.3, 0.1]),
Seen from the results, parallel=[0.9, 0.7, 0.5] is enough.