Pre-Train Your Loss! High-Performance Transfer Learning with Bayesian Neural Networks and Pre-Trained Priors
This repository contains an easy-to-use PyTorch implementation of methods described in Pre-Train Your Loss! High-Performance Transfer Learning with Bayesian Neural Networks and Pre-Trained Priors by Ravid Shwartz-Ziv, Micah Goldblum, Hossein Souri, Sanyam Kapoor, Chen Zhu, Yann LeCun, and Andrew Gordon Wilson.
Idea: We can transfer much more than an initialization. Knowledge of the source task should affect the locations and shape of optima on the downstream task.
Approach: Infer a posterior on the source task to re-scale as an informative prior on the downstream task.
Results: Significantly improved performance over standard transfer learning and fine tuning, with minimal overhead.
Our Bayesian transfer learning framework transfers knowledge from pre-training to downstream tasks. To up-weight parameter settings consistent with a pre-training loss function, we fit a probability distribution over the parameters of feature extractors to a pre-training loss function and rescale it as a prior. By adopting a learned prior, we alter the downstream loss surface and its optimal locations. By contrast, typical transfer learning methods only use a pre-trained initialization.
Our Bayesian transfer learning pipeline uses only easy-to-implement existing tools. In our experiments, Bayesian transfer learning outperforms both SGD-based transfer learning and non-learned Bayesian inference. A schematic of our framework is found below.
This repo contains the code for extracting your prior parameters and applying them to a downstream task using Bayesian inference. The downstream tasks include both image classification and image segmentation.
- torch >= 1.8.1
- torchvision >= 0.9.1
- pytorch-lightning >= 1.4.7
For the complete list of requirements see requirements.txt
.
For your convenience, we have provided the python scripts for downloading and organizing the Oxford Flowers 102
and Oxford-IIIT Pet
datasets. The python scripts can be found here.
Use prior_run_jobs.py
both to learn priors from pre-trained checkpoints and also to perform inference on downstream tasks.
python prior_run_jobs.py --job=<JOB> \
--prior_type=<PRIOR_TYPE> \
--data_dir=<DATA_DIR> \
--train_dataset=<TRAIN_DATASET> \
--val_dataset=<VAL_DATASET> \
--pytorch_pretrain=<PYTORCH_PRETRAIN> \
--prior_scale=<PRIOR_SCALE> \
--num_of_train_examples=<NUM_OF_TRAIN_EXAMPLES> \
--weights_path=<WEIGHTS_PATH> \
--number_of_samples_prior=<NUMBER_OF_SAMPLES_PRIOR> \
--encoder=<ENCODER> \
Parameters:
-
JOB
- setprior
to learn a prior orsupervised_bayesian_learning
to perform inference on downstream tasks. -
PRIOR_TYPE
--type of prior used for inference on a downstream task:- `normal` - zero-mean isotropic Gaussian prior - `shifted_gaussian` - learned prior
-
PRIOR_PATH
- path for the file to load the learned prior. The file should contain model weight, mean, variance, and cov_factor fields. It must fit to the following format: prior_path_model.pt, prior_path_mean.pt, prior_path_variance.pt, prior_path_covmat.pt. You can download the pretrained priors here. -
DATA_DIR
- path which contains the data -
TRAIN_DATASET
- dataset for training -
VAL_DATASET
- dataset for validation -
PYTORCH_PRETRAIN
- if we would like to load the weights from a torchvision pretrained model -
PRIOR_SCALE
- parameter for re-scaling the prior -
NUM_OF_TRAIN_EXAMPLES
- number of training samples on which to train our model -
WEIGHTS_PATH
- path for loading pre-train weights -
NUMBER_OF_SAMPLES_PRIOR
- number of samples for fitting the covariance of the prior -
ENCODER
- base network architecture. The options include most models supported by torchvision.
For the full list of arguments, see priorBox/options.py
. All optional arguments for Bayesian learning are listed here and optional arguments for learning a prior are listed here.
Our learned priors can be found here. The priors include torchvision ResNet-50 and ResNet-101 as well as SimCLR ResNet-50, all trained on ImageNet. To use these for downstream tasks, pass the argument --prior_path
along with the path for the prior when running prior_run_jobs.py
. Please note that the path should contain model weight, mean, variance, and cov_factor fields. Also, it must fit to the following format: "prior_path"_model.pt, "prior_path"_mean.pt, "prior_path"_variance.pt, "prior_path"_covmat.pt.
@article{shwartz2022pre,
title={Pre-Train Your Loss: Easy Bayesian Transfer Learning with Informative Priors},
author={Shwartz-Ziv, Ravid and Goldblum, Micah and Souri, Hossein and Kapoor, Sanyam and Zhu, Chen and LeCun, Yann and Wilson, Andrew Gordon},
journal={arXiv preprint arXiv:2205.10279},
year={2022}
}