State at the time of submission: here
Implementation of the CIFAR10 example from Bayesian Uncertainty Estimation for Batch Normalized Deep Networks with Pytorch.
The implementation is close to the paper's implementation but has some differences. Also I can't guarantee that the calculated metrics match the exact calculations from the paper because the provided code is probably outdated. To convert NLL to PLL, multiply the NLL value by "-1".
UNCERTAINTY ESTIMATION VIA STOCHASTIC BATCH NORMALIZATION describes a similar concept but also addresses computational cost, a huge problem of the paper implemented here.
I implemented a modified version of the algorithm, which results in similar results (and even nearly identical results if looking only at individual samples) but is way faster (for batch size 128, it should be around 128 times faster). On my notebook, I can calculate mcbn for all 10k test samples with 64 iterations in under 2 minutes.
The authors also suggest this approach, but to reduce memory requirements: "One can store BN statistics, instead of batches, to reduce memory issues."
set bn to train()
for all training samples:
run one batch through net
collect bn.running_mean and bn.running_var of all bn layers
bn.reset_running_stats()
set bn to eval()
for each processed batch of test data
get one collected entry
set bn.running_mean and bn.running_var of the corresponding bn layers
process batch
- much faster (for bs 128, its 128 times faster)
- same results if looking at individual samples
- If looking at multiple samples, they are calculated based on the same bn layer values if processed in the same batch.
While it doesn't impact general quality, its ...
- more important that the chosen bn params are good (less problematic if mcbn iters are high enough)
- not the approach of the original paper
Number of stochastic forward passes | 1 | 2 | 4 | 8 | 16 | 32 | 64 | 128 | baseline |
---|---|---|---|---|---|---|---|---|---|
PLL paper | -.36 | -.32 | -.30 | -.29 | -.29 | -.28 | -.28 | -.28 | -.32 |
PLL this implementation | -.37 | -.33 | -.30 | -.29 | -.28 | -.27 | -.27 | -.27 | -.33 |
Full log is provided in run.log